Unverified Commit d68028c8 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

[common][pyTorch]Add zero_centered_gamma option to RMSNorm (#631)



* Add zero_centered_gamma option to RMSNorm
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Improving tests
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* More improvements to tests
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Tweaking the tolerances
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix LayerNormMLP test
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Update transformer_engine/common/rmsnorm/rmsnorm_api.cpp
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update transformer_engine/common/rmsnorm/rmsnorm_api.cpp
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* docs suggestions
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Tweak tolerances with bfloat16
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 5b155fb3
...@@ -43,13 +43,17 @@ void compute_ref_stats(const InputType *data, float *rsigma, const size_t N, con ...@@ -43,13 +43,17 @@ void compute_ref_stats(const InputType *data, float *rsigma, const size_t N, con
template <typename InputType, typename OutputType> template <typename InputType, typename OutputType>
void compute_ref_output(const InputType *data, const InputType *gamma, OutputType *output, void compute_ref_output(const InputType *data, const InputType *gamma, OutputType *output,
const float *rsigma, const size_t N, const size_t H, float *amax, const float *rsigma, const size_t N, const size_t H, float *amax,
float scale) { float scale, const bool zero_centered_gamma) {
using compute_t = float; using compute_t = float;
compute_t current_max = -1e100; compute_t current_max = -1e100;
for (size_t i = 0; i < N; ++i) { for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) { for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]); compute_t current = static_cast<compute_t>(data[i * H + j]);
compute_t tmp = current * rsigma[i] * static_cast<compute_t>(gamma[j]); compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1;
}
compute_t tmp = current * rsigma[i] * g;
output[i * H + j] = static_cast<OutputType>(tmp * scale); output[i * H + j] = static_cast<OutputType>(tmp * scale);
current_max = fmaxf(current_max, fabsf(tmp)); current_max = fmaxf(current_max, fabsf(tmp));
} }
...@@ -60,7 +64,7 @@ void compute_ref_output(const InputType *data, const InputType *gamma, OutputTyp ...@@ -60,7 +64,7 @@ void compute_ref_output(const InputType *data, const InputType *gamma, OutputTyp
template <typename InputType, typename OutputType> template <typename InputType, typename OutputType>
void compute_ref_backward(const OutputType *output_grad, const InputType *data, const float *rsigma, void compute_ref_backward(const OutputType *output_grad, const InputType *data, const float *rsigma,
const InputType *gamma, InputType *data_grad, InputType *gamma_grad, const InputType *gamma, InputType *data_grad, InputType *gamma_grad,
const size_t N, const size_t H) { const size_t N, const size_t H, const bool zero_centered_gamma) {
using compute_t = float; using compute_t = float;
std::vector<compute_t> dgamma(H, 0.f); std::vector<compute_t> dgamma(H, 0.f);
...@@ -70,7 +74,10 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data, ...@@ -70,7 +74,10 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data,
for (size_t j = 0; j < H; ++j) { for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]); const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = x * rsigma[i]; const compute_t y = x * rsigma[i];
const compute_t g = static_cast<compute_t>(gamma[j]); compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1;
}
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]); const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz; const compute_t dy = g * dz;
dgamma[j] += y * dz; dgamma[j] += y * dz;
...@@ -82,7 +89,10 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data, ...@@ -82,7 +89,10 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data,
for (size_t j = 0; j < H; ++j) { for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]); const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = x * rsigma[i]; const compute_t y = x * rsigma[i];
const compute_t g = static_cast<compute_t>(gamma[j]); compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1;
}
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]); const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz; const compute_t dy = g * dz;
const compute_t dx = rsigma[i] * (dy - mdyy * y); const compute_t dx = rsigma[i] * (dy - mdyy * y);
...@@ -97,7 +107,7 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data, ...@@ -97,7 +107,7 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data,
} }
template <typename InputType, typename OutputType> template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H) { void performTest(const size_t N, const size_t H, const bool zero_centered_gamma) {
if (sizeof(InputType) < sizeof(OutputType)) { if (sizeof(InputType) < sizeof(OutputType)) {
GTEST_SKIP() << "RMSNorm kernel does not support OutputType > InputType"; GTEST_SKIP() << "RMSNorm kernel does not support OutputType > InputType";
return; return;
...@@ -137,21 +147,23 @@ void performTest(const size_t N, const size_t H) { ...@@ -137,21 +147,23 @@ void performTest(const size_t N, const size_t H) {
// Forward kernel // Forward kernel
float epsilon = 1e-5; float epsilon = 1e-5;
nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0, auto fwd_function = zero_centered_gamma ? nvte_rmsnorm1p_fwd : nvte_rmsnorm_fwd;
fwd_function(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0,
prop.multiProcessorCount, workspace.data(), barrier.data()); prop.multiProcessorCount, workspace.data(), barrier.data());
workspace = Tensor(workspace.shape(), workspace.dtype()); workspace = Tensor(workspace.shape(), workspace.dtype());
barrier = Tensor(barrier.shape(), barrier.dtype()); barrier = Tensor(barrier.shape(), barrier.dtype());
nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0, fwd_function(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0,
prop.multiProcessorCount, workspace.data(), barrier.data()); prop.multiProcessorCount, workspace.data(), barrier.data());
// Backward kernel // Backward kernel
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(), auto bwd_function = zero_centered_gamma ? nvte_rmsnorm1p_bwd : nvte_rmsnorm_bwd;
bwd_function(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(),
dgamma_part.data(), 0, prop.multiProcessorCount, workspace.data(), dgamma_part.data(), 0, prop.multiProcessorCount, workspace.data(),
barrier.data()); barrier.data());
workspace = Tensor(workspace.shape(), workspace.dtype()); workspace = Tensor(workspace.shape(), workspace.dtype());
barrier = Tensor(barrier.shape(), barrier.dtype()); barrier = Tensor(barrier.shape(), barrier.dtype());
dgamma_part = Tensor(dgamma_part.shape(), dgamma_part.dtype()); dgamma_part = Tensor(dgamma_part.shape(), dgamma_part.dtype());
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(), bwd_function(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(),
dgamma_part.data(), 0, prop.multiProcessorCount, workspace.data(), dgamma_part.data(), 0, prop.multiProcessorCount, workspace.data(),
barrier.data()); barrier.data());
...@@ -162,10 +174,11 @@ void performTest(const size_t N, const size_t H) { ...@@ -162,10 +174,11 @@ void performTest(const size_t N, const size_t H) {
compute_ref_stats(input.cpu_dptr<InputType>(), ref_rsigma.get(), N, H, epsilon); compute_ref_stats(input.cpu_dptr<InputType>(), ref_rsigma.get(), N, H, epsilon);
float ref_scale = isFp8Type(otype) ? z.scale() : 1.f; float ref_scale = isFp8Type(otype) ? z.scale() : 1.f;
compute_ref_output(input.cpu_dptr<InputType>(), gamma.cpu_dptr<WeightType>(), ref_output.get(), compute_ref_output(input.cpu_dptr<InputType>(), gamma.cpu_dptr<WeightType>(), ref_output.get(),
rsigma.cpu_dptr<float>(), N, H, &ref_amax, ref_scale); rsigma.cpu_dptr<float>(), N, H, &ref_amax, ref_scale,
zero_centered_gamma);
compute_ref_backward(dz.cpu_dptr<WeightType>(), input.cpu_dptr<InputType>(), compute_ref_backward(dz.cpu_dptr<WeightType>(), input.cpu_dptr<InputType>(),
rsigma.cpu_dptr<float>(), gamma.cpu_dptr<WeightType>(), ref_dx.get(), rsigma.cpu_dptr<float>(), gamma.cpu_dptr<WeightType>(), ref_dx.get(),
ref_dgamma.get(), N, H); ref_dgamma.get(), N, H, zero_centered_gamma);
cudaDeviceSynchronize(); cudaDeviceSynchronize();
auto err = cudaGetLastError(); auto err = cudaGetLastError();
...@@ -197,9 +210,10 @@ std::vector<std::pair<size_t, size_t>> test_cases = { ...@@ -197,9 +210,10 @@ std::vector<std::pair<size_t, size_t>> test_cases = {
} // namespace } // namespace
class RMSNormTestSuite class RMSNormTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
: public ::testing::TestWithParam<std::tuple< transformer_engine::DType,
transformer_engine::DType, transformer_engine::DType, std::pair<size_t, size_t>>> {}; std::pair<size_t, size_t>,
bool>> {};
TEST_P(RMSNormTestSuite, TestRMSNorm) { TEST_P(RMSNormTestSuite, TestRMSNorm) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -208,11 +222,11 @@ TEST_P(RMSNormTestSuite, TestRMSNorm) { ...@@ -208,11 +222,11 @@ TEST_P(RMSNormTestSuite, TestRMSNorm) {
const DType input_type = std::get<0>(GetParam()); const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam()); const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam()); const auto size = std::get<2>(GetParam());
const bool zero_centered_gamma = std::get<3>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma);););
output_type, OutputType, performTest<InputType, OutputType>(size.first, size.second);););
} }
INSTANTIATE_TEST_SUITE_P(OperatorTest, RMSNormTestSuite, INSTANTIATE_TEST_SUITE_P(OperatorTest, RMSNormTestSuite,
...@@ -220,11 +234,14 @@ INSTANTIATE_TEST_SUITE_P(OperatorTest, RMSNormTestSuite, ...@@ -220,11 +234,14 @@ INSTANTIATE_TEST_SUITE_P(OperatorTest, RMSNormTestSuite,
DType::kFloat16), DType::kFloat16),
::testing::Values(DType::kFloat32, DType::kBFloat16, ::testing::Values(DType::kFloat32, DType::kBFloat16,
DType::kFloat16, DType::kFloat8E4M3), DType::kFloat16, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases)), ::testing::ValuesIn(test_cases),
::testing::Values(false, true)),
[](const testing::TestParamInfo<RMSNormTestSuite::ParamType> &info) { [](const testing::TestParamInfo<RMSNormTestSuite::ParamType> &info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" + std::string name =
test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" + test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" + std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second); std::to_string(std::get<2>(info.param).second) + "X" +
std::to_string(std::get<3>(info.param));
return name; return name;
}); });
...@@ -21,7 +21,7 @@ from transformer_engine.pytorch.utils import ( ...@@ -21,7 +21,7 @@ from transformer_engine.pytorch.utils import (
) )
from transformer_engine.pytorch import ( from transformer_engine.pytorch import (
DotProductAttention, LayerNormLinear, LayerNormMLP, Linear, DotProductAttention, LayerNormLinear, LayerNormMLP, Linear,
MultiheadAttention, RMSNorm, TransformerLayer MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm
) )
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker
...@@ -215,15 +215,41 @@ class TorchDotProductAttention(torch.nn.Module): ...@@ -215,15 +215,41 @@ class TorchDotProductAttention(torch.nn.Module):
return context_layer return context_layer
class TorchLayerNorm(nn.Module):
def __init__(self, in_features: int,
eps: float,
zero_centered_gamma: bool):
super().__init__()
self.eps = eps
self.in_features = in_features
self.zero_centered_gamma = zero_centered_gamma
initial_value = torch.ones(in_features) if zero_centered_gamma else torch.zeros(in_features)
self.weight = nn.Parameter(initial_value)
self.bias = nn.Parameter(torch.zeros(in_features))
self.register_parameter("weight", self.weight)
self.register_parameter("bias", self.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
w = self.weight if not self.zero_centered_gamma else 1 + self.weight
w = w.to(torch.float32)
b = self.bias.to(torch.float32)
inp = x.to(torch.float32)
out = torch.nn.functional.layer_norm(inp, (self.in_features,), weight=w,
bias=b, eps=self.eps)
return out.to(x.dtype)
# Adapted from https://github.com/bzhangGo/rmsnorm/blob/c6691f20ec0af4128c8159c903071f7575404295/rmsnorm_torch.py # Adapted from https://github.com/bzhangGo/rmsnorm/blob/c6691f20ec0af4128c8159c903071f7575404295/rmsnorm_torch.py
class TorchRMSNorm(nn.Module): class TorchRMSNorm(nn.Module):
def __init__(self, in_features, eps=1e-5): def __init__(self, in_features, zero_centered_gamma, eps=1e-5):
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.in_features = in_features self.in_features = in_features
self.zero_centered_gamma = zero_centered_gamma
self.weight = nn.Parameter(torch.ones(in_features)) initial_value = torch.ones(in_features) if zero_centered_gamma else torch.zeros(in_features)
self.weight = nn.Parameter(initial_value)
self.register_parameter("weight", self.weight) self.register_parameter("weight", self.weight)
def forward(self, x): def forward(self, x):
...@@ -234,18 +260,24 @@ class TorchRMSNorm(nn.Module): ...@@ -234,18 +260,24 @@ class TorchRMSNorm(nn.Module):
r_rms_x = rms_x2 ** (-1. / 2) r_rms_x = rms_x2 ** (-1. / 2)
x_normed = x * r_rms_x x_normed = x * r_rms_x
return (self.weight.float() * x_normed).to(x.dtype) w = self.weight.float()
if self.zero_centered_gamma:
w = 1 + w
return (w * x_normed).to(x.dtype)
class TorchLayerNormLinear(nn.Module): class TorchLayerNormLinear(nn.Module):
def __init__(self, in_features: int, out_features: int, def __init__(self, in_features: int, out_features: int,
eps: float, bias: bool = True, eps: float, bias: bool = True,
normalization: str = "LayerNorm"): normalization: str = "LayerNorm",
zero_centered_gamma: bool = False):
super().__init__() super().__init__()
if normalization == "LayerNorm": if normalization == "LayerNorm":
self.layernorm = nn.LayerNorm(in_features, eps=eps) self.layernorm = TorchLayerNorm(in_features, eps=eps,
zero_centered_gamma=zero_centered_gamma)
elif normalization == "RMSNorm": elif normalization == "RMSNorm":
self.layernorm = TorchRMSNorm(in_features, eps=eps) self.layernorm = TorchRMSNorm(in_features, eps=eps,
zero_centered_gamma=zero_centered_gamma)
else: else:
raise RuntimeError("Unsupported normalization") raise RuntimeError("Unsupported normalization")
...@@ -299,9 +331,11 @@ class TorchLayerNormMLP(nn.Module): ...@@ -299,9 +331,11 @@ class TorchLayerNormMLP(nn.Module):
normalization: str = "LayerNorm"): normalization: str = "LayerNorm"):
super().__init__() super().__init__()
if normalization == "LayerNorm": if normalization == "LayerNorm":
self.ln = nn.LayerNorm(hidden_size, eps=eps) self.ln = TorchLayerNorm(hidden_size, eps=eps,
zero_centered_gamma=False)
elif normalization == "RMSNorm": elif normalization == "RMSNorm":
self.ln = TorchRMSNorm(hidden_size, eps=eps) self.ln = TorchRMSNorm(hidden_size, eps=eps,
zero_centered_gamma=False)
else: else:
raise RuntimeError("Unsupported normalization") raise RuntimeError("Unsupported normalization")
if 'glu' in activation: if 'glu' in activation:
...@@ -893,13 +927,15 @@ def test_linear_accuracy(dtype, bs, model): ...@@ -893,13 +927,15 @@ def test_linear_accuracy(dtype, bs, model):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7]) @pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
def test_rmsnorm_accuracy(dtype, bs, model, eps): @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
config = model_configs[model] config = model_configs[model]
te_rmsnorm = ( te_rmsnorm = (
RMSNorm( RMSNorm(
config.hidden_size, config.hidden_size,
eps=eps, eps=eps,
zero_centered_gamma=zero_centered_gamma
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -910,6 +946,7 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps): ...@@ -910,6 +946,7 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps):
TorchRMSNorm( TorchRMSNorm(
config.hidden_size, config.hidden_size,
eps=eps, eps=eps,
zero_centered_gamma=zero_centered_gamma
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -924,17 +961,64 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps): ...@@ -924,17 +961,64 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps):
torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config) torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config)
# Check output. # Check output.
if dtype == torch.float32: atol = {torch.float32 : 1e-7,
assert_allclose(te_outputs[0], torch_outputs[0], 1e-7) torch.half : 2e-3,
else: torch.bfloat16: 2e-2,
assert_allclose(te_outputs[0], torch_outputs[0], 2e-2) }
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
config = model_configs[model]
te_layernorm = (
LayerNorm(
config.hidden_size,
eps=eps,
zero_centered_gamma=zero_centered_gamma
)
.to(dtype=dtype)
.cuda()
.eval()
)
torch_layernorm = (
TorchLayerNorm(
config.hidden_size,
eps=eps,
zero_centered_gamma=zero_centered_gamma
)
.to(dtype=dtype)
.cuda()
.eval()
)
# Share params
with torch.no_grad():
torch_layernorm.weight = Parameter(te_layernorm.weight.clone())
torch_layernorm.bias = Parameter(te_layernorm.bias.clone())
te_outputs = _test_granular_accuracy(te_layernorm, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_layernorm, bs, dtype, config)
# Check output.
atol = {torch.float32 : 1e-7,
torch.half : 2e-3,
torch.bfloat16: 2e-2,
}
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
def test_layernorm_linear_accuracy(dtype, bs, model, normalization): @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centered_gamma):
config = model_configs[model] config = model_configs[model]
te_ln_linear = ( te_ln_linear = (
...@@ -944,6 +1028,7 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization): ...@@ -944,6 +1028,7 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization):
config.eps, config.eps,
bias=True, bias=True,
normalization=normalization, normalization=normalization,
zero_centered_gamma=zero_centered_gamma,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -957,6 +1042,7 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization): ...@@ -957,6 +1042,7 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization):
config.eps, config.eps,
bias=True, bias=True,
normalization=normalization, normalization=normalization,
zero_centered_gamma=zero_centered_gamma,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -975,10 +1061,11 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization): ...@@ -975,10 +1061,11 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization):
torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config) torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)
# Check output. # Check output.
if dtype == torch.float32: atol = {torch.float32 : 2e-4,
assert_allclose(te_outputs[0], torch_outputs[0], 5e-3) torch.half : 2e-3,
else: torch.bfloat16: 2e-2,
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) }
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
......
...@@ -625,7 +625,7 @@ def test_export_layernorm( ...@@ -625,7 +625,7 @@ def test_export_layernorm(
eps = 1e-6 # An arbitrary small value eps = 1e-6 # An arbitrary small value
dtype = torch.float if fake_bf16_io else precision dtype = torch.float if fake_bf16_io else precision
self.ln = te.LayerNorm(inp_shape[1], eps, params_dtype=dtype, self.ln = te.LayerNorm(inp_shape[1], eps, params_dtype=dtype,
zero_centered_gamma=False).eval().cuda() zero_centered_gamma=zero_centered_gamma).eval().cuda()
def forward(self, inp): def forward(self, inp):
ret = self.ln(inp) ret = self.ln(inp)
...@@ -679,6 +679,7 @@ def test_export_layernorm( ...@@ -679,6 +679,7 @@ def test_export_layernorm(
fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs) fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs)
@pytest.mark.parametrize("scale_factor", [448, 112]) @pytest.mark.parametrize("scale_factor", [448, 112])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_fp8, precision, atol", [ "use_fp8, precision, atol", [
[False, torch.float32, 1e-7], [False, torch.float32, 1e-7],
...@@ -695,6 +696,7 @@ def test_export_rmsnorm( ...@@ -695,6 +696,7 @@ def test_export_rmsnorm(
use_fp8: bool, use_fp8: bool,
scale_factor: float, scale_factor: float,
precision: torch.dtype, precision: torch.dtype,
zero_centered_gamma: bool,
atol: float atol: float
): ):
fake_bf16_io = precision == "fake-torch.bfloat16" fake_bf16_io = precision == "fake-torch.bfloat16"
...@@ -713,7 +715,8 @@ def test_export_rmsnorm( ...@@ -713,7 +715,8 @@ def test_export_rmsnorm(
super().__init__() super().__init__()
eps = 1e-6 # An arbitrary small value eps = 1e-6 # An arbitrary small value
dtype = torch.float if fake_bf16_io else precision dtype = torch.float if fake_bf16_io else precision
self.ln = te.RMSNorm(inp_shape[1], eps, params_dtype=dtype).eval().cuda() self.ln = te.RMSNorm(inp_shape[1], eps, params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma).eval().cuda()
def forward(self, inp): def forward(self, inp):
ret = self.ln(inp) ret = self.ln(inp)
...@@ -739,7 +742,7 @@ def test_export_rmsnorm( ...@@ -739,7 +742,7 @@ def test_export_rmsnorm(
self.meta, self.meta,
self.fp8_tensor, self.fp8_tensor,
self.fp8_type, self.fp8_type,
False) zero_centered_gamma)
ret = cast_from_fp8( ret = cast_from_fp8(
ret, ret,
...@@ -875,9 +878,6 @@ def test_export_layernorm_linear( ...@@ -875,9 +878,6 @@ def test_export_layernorm_linear(
if use_fp8 and not fp8_available: if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
in_features = 64 in_features = 64
out_features = 256 out_features = 256
...@@ -946,8 +946,6 @@ def test_export_layernorm_mlp( ...@@ -946,8 +946,6 @@ def test_export_layernorm_mlp(
if use_fp8 and not fp8_available: if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
in_features = 64 in_features = 64
......
...@@ -361,9 +361,6 @@ def test_sanity_layernorm_linear(dtype, fp8_recipe, model, skip_wgrad, ...@@ -361,9 +361,6 @@ def test_sanity_layernorm_linear(dtype, fp8_recipe, model, skip_wgrad,
if not config.is_fp8_supported(): if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
sigma = 0.023 sigma = 0.023
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
...@@ -427,9 +424,6 @@ def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad, ...@@ -427,9 +424,6 @@ def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad,
if not config.is_fp8_supported(): if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
sigma = 0.023 sigma = 0.023
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
...@@ -472,9 +466,6 @@ def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad, ...@@ -472,9 +466,6 @@ def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad,
if not config.is_fp8_supported(): if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
sigma = 0.023 sigma = 0.023
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
...@@ -544,9 +535,6 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, ...@@ -544,9 +535,6 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
if not config.is_fp8_supported(): if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
sigma = 0.023 sigma = 0.023
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
...@@ -608,9 +596,6 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, ...@@ -608,9 +596,6 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
if not config.is_fp8_supported(): if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
sigma = 0.023 sigma = 0.023
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
...@@ -824,9 +809,6 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm ...@@ -824,9 +809,6 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm
if not config.is_fp8_supported(): if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
sigma = 0.023 sigma = 0.023
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
......
...@@ -18,6 +18,15 @@ extern "C" { ...@@ -18,6 +18,15 @@ extern "C" {
#endif #endif
/*! \brief Compute RMSNorm on the input. /*! \brief Compute RMSNorm on the input.
*
* The formula used:
* @f[
* y = \frac{x}{RMS_\varepsilon(x)}\gamma
* @f]
* where
* @f[
* RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon}
* @f]
* *
* Calling this function with workspace and barrier set to empty tensor will not * Calling this function with workspace and barrier set to empty tensor will not
* perform the operation, but instead set the shape and type of the workspace * perform the operation, but instead set the shape and type of the workspace
...@@ -44,7 +53,53 @@ void nvte_rmsnorm_fwd(const NVTETensor x, ...@@ -44,7 +53,53 @@ void nvte_rmsnorm_fwd(const NVTETensor x,
NVTETensor workspace, NVTETensor workspace,
NVTETensor barrier); NVTETensor barrier);
/*! \brief Compute RMSNorm with zero-centered gamma on the input.
*
* The formula used:
* @f[
* y = \frac{x}{RMS_\varepsilon(x)}(1 + \gamma)
* @f]
* where
* @f[
* RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon}
* @f]
*
* Calling this function with workspace and barrier set to empty tensor will not
* perform the operation, but instead set the shape and type of the workspace
* and barrier tensors to the required values.
*
* \param[in] x Input tensor of shape [N, H].
* \param[in] gamma Gamma tensor of shape [H].
* \param[in] epsilon Value added to denominator for numerical stability.
* \param[in,out] z Output tensor of shape [N, H].
* \param[out] rsigma Reciprocal of the root mean square of the input
* calculated over the last dimension. Shape: [N].
* \param[in] stream CUDA stream used for the operation.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
*/
void nvte_rmsnorm1p_fwd(const NVTETensor x,
const NVTETensor gamma,
const float epsilon,
NVTETensor z,
NVTETensor rsigma,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier);
/*! \brief Compute backward of RMSNorm. /*! \brief Compute backward of RMSNorm.
*
* This function computes the gradient of function:
* @f[
* y = \frac{x}{RMS_\varepsilon(x)}\gamma
* @f]
* where
* @f[
* RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon}
* @f]
* with respect to \f$x\f$ and \f$gamma\f$.
* *
* Calling this function with workspace, barrier, dgamma_part set * Calling this function with workspace, barrier, dgamma_part set
* to empty tensor will not perform the operation, but instead set the shape and type * to empty tensor will not perform the operation, but instead set the shape and type
...@@ -76,6 +131,48 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, ...@@ -76,6 +131,48 @@ void nvte_rmsnorm_bwd(const NVTETensor dz,
NVTETensor barrier NVTETensor barrier
); );
/*! \brief Compute backward of RMSNorm with zero-centered gamma.
*
* This function computes the gradient of function:
* @f[
* y = \frac{x}{RMS_\varepsilon(x)}(1 + \gamma)
* @f]
* where
* @f[
* RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon}
* @f]
* with respect to \f$x\f$ and \f$gamma\f$.
*
* Calling this function with workspace, barrier, dgamma_part set
* to empty tensor will not perform the operation, but instead set the shape and type
* of these tensors to the required values.
*
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
* \param[in] rsigma Reciprocal of the root mean square of the input
* calculated over the last dimension. Shape: [N].
* \param[in] gamma Gamma tensor of shape [H].
* \param[out] dx Output gradient of shape [N, H].
* \param[out] dgamma Gradient for gamma tensor of shape [H].
* \param[out] dgamma_part Storage for partial gamma gradient.
* \param[in] stream CUDA stream used for the operation.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
*/
void nvte_rmsnorm1p_bwd(const NVTETensor dz,
const NVTETensor x,
const NVTETensor rsigma,
const NVTETensor gamma,
NVTETensor dx,
NVTETensor dgamma,
NVTETensor dgamma_part,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier
);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -106,7 +106,7 @@ inline size_t product(const std::vector<size_t> &shape) { ...@@ -106,7 +106,7 @@ inline size_t product(const std::vector<size_t> &shape) {
void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z, void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z,
Tensor *rsigma, cudaStream_t stream, const int multiprocessorCount, Tensor *rsigma, cudaStream_t stream, const int multiprocessorCount,
Tensor *workspace, Tensor *barrier) { Tensor *workspace, Tensor *barrier, const bool zero_centered_gamma) {
auto itype = x.data.dtype; auto itype = x.data.dtype;
auto wtype = gamma.data.dtype; auto wtype = gamma.data.dtype;
auto otype = z->data.dtype; auto otype = z->data.dtype;
...@@ -149,6 +149,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -149,6 +149,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
params.amax = z->amax.dptr; params.amax = z->amax.dptr;
params.scale = z->scale.dptr; params.scale = z->scale.dptr;
params.fp8_out = fp8_out; params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma;
// Query the kernel-specific launch parameters. // Query the kernel-specific launch parameters.
launcher(launch_params, true); launcher(launch_params, true);
...@@ -199,7 +200,8 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -199,7 +200,8 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const Tensor &gamma, void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const Tensor &gamma,
Tensor *dx, Tensor *dgamma, Tensor *dgamma_part, cudaStream_t stream, Tensor *dx, Tensor *dgamma, Tensor *dgamma_part, cudaStream_t stream,
const int multiprocessorCount, Tensor *workspace, Tensor *barrier) { const int multiprocessorCount, Tensor *workspace, Tensor *barrier,
const bool zero_centered_gamma) {
using namespace transformer_engine; using namespace transformer_engine;
auto itype = x.data.dtype; auto itype = x.data.dtype;
...@@ -245,6 +247,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const ...@@ -245,6 +247,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
params.dgamma = dgamma->data.dptr; params.dgamma = dgamma->data.dptr;
params.dbeta_part = nullptr; params.dbeta_part = nullptr;
params.dgamma_part = dgamma_part->data.dptr; params.dgamma_part = dgamma_part->data.dptr;
params.zero_centered_gamma = zero_centered_gamma;
// Query the kernel-specific launch parameters. // Query the kernel-specific launch parameters.
launcher(launch_params, true); launcher(launch_params, true);
...@@ -295,20 +298,53 @@ void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size ...@@ -295,20 +298,53 @@ void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size
rmsnorm_fwd(*reinterpret_cast<const Tensor *>(x), *reinterpret_cast<const Tensor *>(gamma), rmsnorm_fwd(*reinterpret_cast<const Tensor *>(x), *reinterpret_cast<const Tensor *>(gamma),
epsilon, reinterpret_cast<Tensor *>(z), reinterpret_cast<Tensor *>(rsigma), stream, epsilon, reinterpret_cast<Tensor *>(z), reinterpret_cast<Tensor *>(rsigma), stream,
multiprocessorCount, reinterpret_cast<Tensor *>(workspace), multiprocessorCount, reinterpret_cast<Tensor *>(workspace),
reinterpret_cast<Tensor *>(barrier)); reinterpret_cast<Tensor *>(barrier), false);
} }
void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
const NVTETensor x, // Nxhidden_size const NVTETensor x, // Nxhidden_size
const NVTETensor rsigma, // N, FP32! const NVTETensor rsigma, // N, FP32!
const NVTETensor gamma, // hidden_size const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, cudaStream_t stream, NVTETensor dx, NVTETensor dgamma,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { NVTETensor dgamma_part, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace,
NVTETensor barrier) {
NVTE_API_CALL(nvte_rmsnorm_bwd); NVTE_API_CALL(nvte_rmsnorm_bwd);
using namespace transformer_engine; using namespace transformer_engine;
rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x), rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x),
*reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma), *reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma),
reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma), reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma),
reinterpret_cast<Tensor *>(dgamma_part), stream, multiprocessorCount, reinterpret_cast<Tensor *>(dgamma_part), stream, multiprocessorCount,
reinterpret_cast<Tensor *>(workspace), reinterpret_cast<Tensor *>(barrier)); reinterpret_cast<Tensor *>(workspace), reinterpret_cast<Tensor *>(barrier),
false);
}
void nvte_rmsnorm1p_fwd(const NVTETensor x, // Nxhidden_size
const NVTETensor gamma, // hidden_size
const float epsilon, NVTETensor z, NVTETensor rsigma, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) {
NVTE_API_CALL(nvte_rmsnorm1p_fwd);
using namespace transformer_engine;
rmsnorm_fwd(*reinterpret_cast<const Tensor *>(x), *reinterpret_cast<const Tensor *>(gamma),
epsilon, reinterpret_cast<Tensor *>(z), reinterpret_cast<Tensor *>(rsigma), stream,
multiprocessorCount, reinterpret_cast<Tensor *>(workspace),
reinterpret_cast<Tensor *>(barrier), true);
}
void nvte_rmsnorm1p_bwd(const NVTETensor dz, // Nxhidden_size
const NVTETensor x, // Nxhidden_size
const NVTETensor rsigma, // N, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma,
NVTETensor dgamma_part, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace,
NVTETensor barrier) {
NVTE_API_CALL(nvte_rmsnorm1p_bwd);
using namespace transformer_engine;
rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x),
*reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma),
reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma),
reinterpret_cast<Tensor *>(dgamma_part), stream, multiprocessorCount,
reinterpret_cast<Tensor *>(workspace), reinterpret_cast<Tensor *>(barrier),
true);
} }
...@@ -97,7 +97,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_ke ...@@ -97,7 +97,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_ke
for (int jt = 0; jt < NUM_ELTS; jt++) { for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t x_tmp = x[it].data.elt[jt]; compute_t x_tmp = x[it].data.elt[jt];
compute_t y_tmp = rs_r * (x_tmp); compute_t y_tmp = rs_r * (x_tmp);
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]); const compute_t dy_tmp_shift = (params.zero_centered_gamma) ? 1.0f : 0.f;
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) + dy_tmp_shift;
dy_tmp *= compute_t(dz[it].data.elt[jt]); dy_tmp *= compute_t(dz[it].data.elt[jt]);
compute_t dz_tmp = dz[it].data.elt[jt]; compute_t dz_tmp = dz[it].data.elt[jt];
...@@ -356,7 +357,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_ ...@@ -356,7 +357,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_
for (int jt = 0; jt < NUM_ELTS; jt++) { for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t x_ij = x.data.elt[jt]; compute_t x_ij = x.data.elt[jt];
compute_t y_ij = rs * (x_ij); compute_t y_ij = rs * (x_ij);
compute_t g_ij = gamma[it].data.elt[jt]; const compute_t g_ij_shift = (params.zero_centered_gamma) ? 1.0f : 0.f;
compute_t g_ij = gamma[it].data.elt[jt] + g_ij_shift;
compute_t dz_ij = dz.data.elt[jt]; compute_t dz_ij = dz.data.elt[jt];
compute_t dy_ij = g_ij * dz_ij; compute_t dy_ij = g_ij * dz_ij;
......
...@@ -106,6 +106,9 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke ...@@ -106,6 +106,9 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
for (int jt = 0; jt < NUM_ELTS; jt++) { for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t y_ij = rs * (xf[it * NUM_ELTS + jt]); compute_t y_ij = rs * (xf[it * NUM_ELTS + jt]);
compute_t g_ij = gamma[it].data.elt[jt]; compute_t g_ij = gamma[it].data.elt[jt];
if (params.zero_centered_gamma) {
g_ij += 1;
}
compute_t temp_output = g_ij * y_ij; compute_t temp_output = g_ij * y_ij;
if (params.fp8_out) { if (params.fp8_out) {
...@@ -236,6 +239,9 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ ...@@ -236,6 +239,9 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
for (int jt = 0; jt < NUM_ELTS; jt++) { for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t y_ij = rs * (x[it].data.elt[jt]); compute_t y_ij = rs * (x[it].data.elt[jt]);
compute_t g_ij = gamma[it].data.elt[jt]; compute_t g_ij = gamma[it].data.elt[jt];
if (params.zero_centered_gamma) {
g_ij += 1;
}
z.data.elt[jt] = g_ij * y_ij; z.data.elt[jt] = g_ij * y_ij;
} }
......
...@@ -218,8 +218,6 @@ std::vector<at::Tensor> rmsnorm_bwd(const at::Tensor &dz, ...@@ -218,8 +218,6 @@ std::vector<at::Tensor> rmsnorm_bwd(const at::Tensor &dz,
const int sm_margin, const int sm_margin,
const bool zero_centered_gamma const bool zero_centered_gamma
) { ) {
NVTE_CHECK(zero_centered_gamma == false,
"Zero-centered gamma is not supported yet for RMSNorm.");
auto dx = at::empty_like(x); auto dx = at::empty_like(x);
auto dgamma = at::empty_like(gamma); auto dgamma = at::empty_like(gamma);
transformer_engine::TensorWrapper workspace, barrier, dgamma_part; transformer_engine::TensorWrapper workspace, barrier, dgamma_part;
...@@ -232,7 +230,7 @@ std::vector<at::Tensor> rmsnorm_bwd(const at::Tensor &dz, ...@@ -232,7 +230,7 @@ std::vector<at::Tensor> rmsnorm_bwd(const at::Tensor &dz,
auto dgamma_cu = makeTransformerEngineTensor(dgamma); auto dgamma_cu = makeTransformerEngineTensor(dgamma);
// This call populates tensors with the required config. // This call populates tensors with the required config.
const auto bwd_fun = nvte_rmsnorm_bwd; const auto bwd_fun = zero_centered_gamma ? nvte_rmsnorm1p_bwd : nvte_rmsnorm_bwd;
bwd_fun(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), bwd_fun(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dgamma_part.data(), dx_cu.data(), dgamma_cu.data(), dgamma_part.data(),
at::cuda::getCurrentCUDAStream(), at::cuda::getCurrentCUDAStream(),
...@@ -295,8 +293,6 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, ...@@ -295,8 +293,6 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input,
const bool zero_centered_gamma const bool zero_centered_gamma
) { ) {
using namespace transformer_engine; using namespace transformer_engine;
NVTE_CHECK(zero_centered_gamma == false,
"Zero-centered gamma is not supported yet for RMSNorm.");
size_t N = static_cast<size_t>(input.size(0)); size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1)); size_t H = static_cast<size_t>(input.size(1));
...@@ -313,7 +309,7 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, ...@@ -313,7 +309,7 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input,
transformer_engine::TensorWrapper workspace, barrier; transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config // This call populates workspace and barrier tensors with the required config
const auto func = nvte_rmsnorm_fwd; const auto func = zero_centered_gamma ? nvte_rmsnorm1p_fwd : nvte_rmsnorm_fwd;
func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(),
rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
......
...@@ -114,7 +114,7 @@ class RMSNorm(torch.nn.Module): ...@@ -114,7 +114,7 @@ class RMSNorm(torch.nn.Module):
the RMSNorm formula changes to the RMSNorm formula changes to
.. math:: .. math::
y = \frac{x}{RMS(x) + \varepsilon} * (1 + \gamma) y = \frac{x}{RMS_\varepsilon(x)} * (1 + \gamma)
device : Union[torch.device, str], default = "cuda" device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
...@@ -155,7 +155,7 @@ class RMSNorm(torch.nn.Module): ...@@ -155,7 +155,7 @@ class RMSNorm(torch.nn.Module):
def reset_rms_norm_parameters(self) -> None: def reset_rms_norm_parameters(self) -> None:
"""Init RMSNorm params""" """Init RMSNorm params"""
warnings.warn( warnings.warn(
("This method will be deprecated in an upcoming release. " ("This method is deprecated and will be removed in an upcoming release. "
"Update your code to use RMSNorm.reset_parameters() instead."), "Update your code to use RMSNorm.reset_parameters() instead."),
DeprecationWarning, DeprecationWarning,
stacklevel=2 stacklevel=2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment