"...git@developer.sourcefind.cn:bw-bestperf/resnet-cbam.git" did not exist on "8de662233d4c740afc2e4408a51050d37c6fa8a0"
Unverified Commit d68028c8 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

[common][pyTorch]Add zero_centered_gamma option to RMSNorm (#631)



* Add zero_centered_gamma option to RMSNorm
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Improving tests
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* More improvements to tests
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Tweaking the tolerances
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix LayerNormMLP test
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Update transformer_engine/common/rmsnorm/rmsnorm_api.cpp
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update transformer_engine/common/rmsnorm/rmsnorm_api.cpp
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* docs suggestions
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Tweak tolerances with bfloat16
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 5b155fb3
......@@ -43,13 +43,17 @@ void compute_ref_stats(const InputType *data, float *rsigma, const size_t N, con
template <typename InputType, typename OutputType>
void compute_ref_output(const InputType *data, const InputType *gamma, OutputType *output,
const float *rsigma, const size_t N, const size_t H, float *amax,
float scale) {
float scale, const bool zero_centered_gamma) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
compute_t tmp = current * rsigma[i] * static_cast<compute_t>(gamma[j]);
compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1;
}
compute_t tmp = current * rsigma[i] * g;
output[i * H + j] = static_cast<OutputType>(tmp * scale);
current_max = fmaxf(current_max, fabsf(tmp));
}
......@@ -60,7 +64,7 @@ void compute_ref_output(const InputType *data, const InputType *gamma, OutputTyp
template <typename InputType, typename OutputType>
void compute_ref_backward(const OutputType *output_grad, const InputType *data, const float *rsigma,
const InputType *gamma, InputType *data_grad, InputType *gamma_grad,
const size_t N, const size_t H) {
const size_t N, const size_t H, const bool zero_centered_gamma) {
using compute_t = float;
std::vector<compute_t> dgamma(H, 0.f);
......@@ -70,7 +74,10 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data,
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = x * rsigma[i];
const compute_t g = static_cast<compute_t>(gamma[j]);
compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1;
}
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
dgamma[j] += y * dz;
......@@ -82,7 +89,10 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data,
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = x * rsigma[i];
const compute_t g = static_cast<compute_t>(gamma[j]);
compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1;
}
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
const compute_t dx = rsigma[i] * (dy - mdyy * y);
......@@ -97,7 +107,7 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data,
}
template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H) {
void performTest(const size_t N, const size_t H, const bool zero_centered_gamma) {
if (sizeof(InputType) < sizeof(OutputType)) {
GTEST_SKIP() << "RMSNorm kernel does not support OutputType > InputType";
return;
......@@ -137,21 +147,23 @@ void performTest(const size_t N, const size_t H) {
// Forward kernel
float epsilon = 1e-5;
nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0,
auto fwd_function = zero_centered_gamma ? nvte_rmsnorm1p_fwd : nvte_rmsnorm_fwd;
fwd_function(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0,
prop.multiProcessorCount, workspace.data(), barrier.data());
workspace = Tensor(workspace.shape(), workspace.dtype());
barrier = Tensor(barrier.shape(), barrier.dtype());
nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0,
fwd_function(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0,
prop.multiProcessorCount, workspace.data(), barrier.data());
// Backward kernel
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(),
auto bwd_function = zero_centered_gamma ? nvte_rmsnorm1p_bwd : nvte_rmsnorm_bwd;
bwd_function(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(),
dgamma_part.data(), 0, prop.multiProcessorCount, workspace.data(),
barrier.data());
workspace = Tensor(workspace.shape(), workspace.dtype());
barrier = Tensor(barrier.shape(), barrier.dtype());
dgamma_part = Tensor(dgamma_part.shape(), dgamma_part.dtype());
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(),
bwd_function(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(),
dgamma_part.data(), 0, prop.multiProcessorCount, workspace.data(),
barrier.data());
......@@ -162,10 +174,11 @@ void performTest(const size_t N, const size_t H) {
compute_ref_stats(input.cpu_dptr<InputType>(), ref_rsigma.get(), N, H, epsilon);
float ref_scale = isFp8Type(otype) ? z.scale() : 1.f;
compute_ref_output(input.cpu_dptr<InputType>(), gamma.cpu_dptr<WeightType>(), ref_output.get(),
rsigma.cpu_dptr<float>(), N, H, &ref_amax, ref_scale);
rsigma.cpu_dptr<float>(), N, H, &ref_amax, ref_scale,
zero_centered_gamma);
compute_ref_backward(dz.cpu_dptr<WeightType>(), input.cpu_dptr<InputType>(),
rsigma.cpu_dptr<float>(), gamma.cpu_dptr<WeightType>(), ref_dx.get(),
ref_dgamma.get(), N, H);
ref_dgamma.get(), N, H, zero_centered_gamma);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
......@@ -197,9 +210,10 @@ std::vector<std::pair<size_t, size_t>> test_cases = {
} // namespace
class RMSNormTestSuite
: public ::testing::TestWithParam<std::tuple<
transformer_engine::DType, transformer_engine::DType, std::pair<size_t, size_t>>> {};
class RMSNormTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>,
bool>> {};
TEST_P(RMSNormTestSuite, TestRMSNorm) {
using namespace transformer_engine;
......@@ -208,11 +222,11 @@ TEST_P(RMSNormTestSuite, TestRMSNorm) {
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
const bool zero_centered_gamma = std::get<3>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType, performTest<InputType, OutputType>(size.first, size.second);););
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma);););
}
INSTANTIATE_TEST_SUITE_P(OperatorTest, RMSNormTestSuite,
......@@ -220,11 +234,14 @@ INSTANTIATE_TEST_SUITE_P(OperatorTest, RMSNormTestSuite,
DType::kFloat16),
::testing::Values(DType::kFloat32, DType::kBFloat16,
DType::kFloat16, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases)),
::testing::ValuesIn(test_cases),
::testing::Values(false, true)),
[](const testing::TestParamInfo<RMSNormTestSuite::ParamType> &info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
std::string name =
test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
std::to_string(std::get<2>(info.param).second) + "X" +
std::to_string(std::get<3>(info.param));
return name;
});
......@@ -21,7 +21,7 @@ from transformer_engine.pytorch.utils import (
)
from transformer_engine.pytorch import (
DotProductAttention, LayerNormLinear, LayerNormMLP, Linear,
MultiheadAttention, RMSNorm, TransformerLayer
MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm
)
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker
......@@ -215,15 +215,41 @@ class TorchDotProductAttention(torch.nn.Module):
return context_layer
class TorchLayerNorm(nn.Module):
def __init__(self, in_features: int,
eps: float,
zero_centered_gamma: bool):
super().__init__()
self.eps = eps
self.in_features = in_features
self.zero_centered_gamma = zero_centered_gamma
initial_value = torch.ones(in_features) if zero_centered_gamma else torch.zeros(in_features)
self.weight = nn.Parameter(initial_value)
self.bias = nn.Parameter(torch.zeros(in_features))
self.register_parameter("weight", self.weight)
self.register_parameter("bias", self.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
w = self.weight if not self.zero_centered_gamma else 1 + self.weight
w = w.to(torch.float32)
b = self.bias.to(torch.float32)
inp = x.to(torch.float32)
out = torch.nn.functional.layer_norm(inp, (self.in_features,), weight=w,
bias=b, eps=self.eps)
return out.to(x.dtype)
# Adapted from https://github.com/bzhangGo/rmsnorm/blob/c6691f20ec0af4128c8159c903071f7575404295/rmsnorm_torch.py
class TorchRMSNorm(nn.Module):
def __init__(self, in_features, eps=1e-5):
def __init__(self, in_features, zero_centered_gamma, eps=1e-5):
super().__init__()
self.eps = eps
self.in_features = in_features
self.zero_centered_gamma = zero_centered_gamma
self.weight = nn.Parameter(torch.ones(in_features))
initial_value = torch.ones(in_features) if zero_centered_gamma else torch.zeros(in_features)
self.weight = nn.Parameter(initial_value)
self.register_parameter("weight", self.weight)
def forward(self, x):
......@@ -234,18 +260,24 @@ class TorchRMSNorm(nn.Module):
r_rms_x = rms_x2 ** (-1. / 2)
x_normed = x * r_rms_x
return (self.weight.float() * x_normed).to(x.dtype)
w = self.weight.float()
if self.zero_centered_gamma:
w = 1 + w
return (w * x_normed).to(x.dtype)
class TorchLayerNormLinear(nn.Module):
def __init__(self, in_features: int, out_features: int,
eps: float, bias: bool = True,
normalization: str = "LayerNorm"):
normalization: str = "LayerNorm",
zero_centered_gamma: bool = False):
super().__init__()
if normalization == "LayerNorm":
self.layernorm = nn.LayerNorm(in_features, eps=eps)
self.layernorm = TorchLayerNorm(in_features, eps=eps,
zero_centered_gamma=zero_centered_gamma)
elif normalization == "RMSNorm":
self.layernorm = TorchRMSNorm(in_features, eps=eps)
self.layernorm = TorchRMSNorm(in_features, eps=eps,
zero_centered_gamma=zero_centered_gamma)
else:
raise RuntimeError("Unsupported normalization")
......@@ -299,9 +331,11 @@ class TorchLayerNormMLP(nn.Module):
normalization: str = "LayerNorm"):
super().__init__()
if normalization == "LayerNorm":
self.ln = nn.LayerNorm(hidden_size, eps=eps)
self.ln = TorchLayerNorm(hidden_size, eps=eps,
zero_centered_gamma=False)
elif normalization == "RMSNorm":
self.ln = TorchRMSNorm(hidden_size, eps=eps)
self.ln = TorchRMSNorm(hidden_size, eps=eps,
zero_centered_gamma=False)
else:
raise RuntimeError("Unsupported normalization")
if 'glu' in activation:
......@@ -893,13 +927,15 @@ def test_linear_accuracy(dtype, bs, model):
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
def test_rmsnorm_accuracy(dtype, bs, model, eps):
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
config = model_configs[model]
te_rmsnorm = (
RMSNorm(
config.hidden_size,
eps=eps,
zero_centered_gamma=zero_centered_gamma
)
.to(dtype=dtype)
.cuda()
......@@ -910,6 +946,7 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps):
TorchRMSNorm(
config.hidden_size,
eps=eps,
zero_centered_gamma=zero_centered_gamma
)
.to(dtype=dtype)
.cuda()
......@@ -924,17 +961,64 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps):
torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config)
# Check output.
if dtype == torch.float32:
assert_allclose(te_outputs[0], torch_outputs[0], 1e-7)
else:
assert_allclose(te_outputs[0], torch_outputs[0], 2e-2)
atol = {torch.float32 : 1e-7,
torch.half : 2e-3,
torch.bfloat16: 2e-2,
}
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
config = model_configs[model]
te_layernorm = (
LayerNorm(
config.hidden_size,
eps=eps,
zero_centered_gamma=zero_centered_gamma
)
.to(dtype=dtype)
.cuda()
.eval()
)
torch_layernorm = (
TorchLayerNorm(
config.hidden_size,
eps=eps,
zero_centered_gamma=zero_centered_gamma
)
.to(dtype=dtype)
.cuda()
.eval()
)
# Share params
with torch.no_grad():
torch_layernorm.weight = Parameter(te_layernorm.weight.clone())
torch_layernorm.bias = Parameter(te_layernorm.bias.clone())
te_outputs = _test_granular_accuracy(te_layernorm, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_layernorm, bs, dtype, config)
# Check output.
atol = {torch.float32 : 1e-7,
torch.half : 2e-3,
torch.bfloat16: 2e-2,
}
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("normalization", all_normalizations)
def test_layernorm_linear_accuracy(dtype, bs, model, normalization):
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centered_gamma):
config = model_configs[model]
te_ln_linear = (
......@@ -944,6 +1028,7 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization):
config.eps,
bias=True,
normalization=normalization,
zero_centered_gamma=zero_centered_gamma,
)
.to(dtype=dtype)
.cuda()
......@@ -957,6 +1042,7 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization):
config.eps,
bias=True,
normalization=normalization,
zero_centered_gamma=zero_centered_gamma,
)
.to(dtype=dtype)
.cuda()
......@@ -975,10 +1061,11 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization):
torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)
# Check output.
if dtype == torch.float32:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
atol = {torch.float32 : 2e-4,
torch.half : 2e-3,
torch.bfloat16: 2e-2,
}
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
@pytest.mark.parametrize("dtype", param_types)
......
......@@ -625,7 +625,7 @@ def test_export_layernorm(
eps = 1e-6 # An arbitrary small value
dtype = torch.float if fake_bf16_io else precision
self.ln = te.LayerNorm(inp_shape[1], eps, params_dtype=dtype,
zero_centered_gamma=False).eval().cuda()
zero_centered_gamma=zero_centered_gamma).eval().cuda()
def forward(self, inp):
ret = self.ln(inp)
......@@ -679,6 +679,7 @@ def test_export_layernorm(
fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs)
@pytest.mark.parametrize("scale_factor", [448, 112])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize(
"use_fp8, precision, atol", [
[False, torch.float32, 1e-7],
......@@ -695,6 +696,7 @@ def test_export_rmsnorm(
use_fp8: bool,
scale_factor: float,
precision: torch.dtype,
zero_centered_gamma: bool,
atol: float
):
fake_bf16_io = precision == "fake-torch.bfloat16"
......@@ -713,7 +715,8 @@ def test_export_rmsnorm(
super().__init__()
eps = 1e-6 # An arbitrary small value
dtype = torch.float if fake_bf16_io else precision
self.ln = te.RMSNorm(inp_shape[1], eps, params_dtype=dtype).eval().cuda()
self.ln = te.RMSNorm(inp_shape[1], eps, params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma).eval().cuda()
def forward(self, inp):
ret = self.ln(inp)
......@@ -739,7 +742,7 @@ def test_export_rmsnorm(
self.meta,
self.fp8_tensor,
self.fp8_type,
False)
zero_centered_gamma)
ret = cast_from_fp8(
ret,
......@@ -875,9 +878,6 @@ def test_export_layernorm_linear(
if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
# Set dimensions (these are arbitrary).
in_features = 64
out_features = 256
......@@ -946,8 +946,6 @@ def test_export_layernorm_mlp(
if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
# Set dimensions (these are arbitrary).
in_features = 64
......
......@@ -361,9 +361,6 @@ def test_sanity_layernorm_linear(dtype, fp8_recipe, model, skip_wgrad,
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
sigma = 0.023
init_method = init_method_normal(sigma)
......@@ -427,9 +424,6 @@ def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad,
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
......@@ -472,9 +466,6 @@ def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad,
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
......@@ -544,9 +535,6 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
......@@ -608,9 +596,6 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
......@@ -824,9 +809,6 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if normalization == "RMSNorm" and zero_centered_gamma:
pytest.skip("RMSNorm does not support zero_centered_gamma yet!")
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
......
......@@ -18,6 +18,15 @@ extern "C" {
#endif
/*! \brief Compute RMSNorm on the input.
*
* The formula used:
* @f[
* y = \frac{x}{RMS_\varepsilon(x)}\gamma
* @f]
* where
* @f[
* RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon}
* @f]
*
* Calling this function with workspace and barrier set to empty tensor will not
* perform the operation, but instead set the shape and type of the workspace
......@@ -44,7 +53,53 @@ void nvte_rmsnorm_fwd(const NVTETensor x,
NVTETensor workspace,
NVTETensor barrier);
/*! \brief Compute RMSNorm with zero-centered gamma on the input.
*
* The formula used:
* @f[
* y = \frac{x}{RMS_\varepsilon(x)}(1 + \gamma)
* @f]
* where
* @f[
* RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon}
* @f]
*
* Calling this function with workspace and barrier set to empty tensor will not
* perform the operation, but instead set the shape and type of the workspace
* and barrier tensors to the required values.
*
* \param[in] x Input tensor of shape [N, H].
* \param[in] gamma Gamma tensor of shape [H].
* \param[in] epsilon Value added to denominator for numerical stability.
* \param[in,out] z Output tensor of shape [N, H].
* \param[out] rsigma Reciprocal of the root mean square of the input
* calculated over the last dimension. Shape: [N].
* \param[in] stream CUDA stream used for the operation.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
*/
void nvte_rmsnorm1p_fwd(const NVTETensor x,
const NVTETensor gamma,
const float epsilon,
NVTETensor z,
NVTETensor rsigma,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier);
/*! \brief Compute backward of RMSNorm.
*
* This function computes the gradient of function:
* @f[
* y = \frac{x}{RMS_\varepsilon(x)}\gamma
* @f]
* where
* @f[
* RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon}
* @f]
* with respect to \f$x\f$ and \f$gamma\f$.
*
* Calling this function with workspace, barrier, dgamma_part set
* to empty tensor will not perform the operation, but instead set the shape and type
......@@ -76,6 +131,48 @@ void nvte_rmsnorm_bwd(const NVTETensor dz,
NVTETensor barrier
);
/*! \brief Compute backward of RMSNorm with zero-centered gamma.
*
* This function computes the gradient of function:
* @f[
* y = \frac{x}{RMS_\varepsilon(x)}(1 + \gamma)
* @f]
* where
* @f[
* RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon}
* @f]
* with respect to \f$x\f$ and \f$gamma\f$.
*
* Calling this function with workspace, barrier, dgamma_part set
* to empty tensor will not perform the operation, but instead set the shape and type
* of these tensors to the required values.
*
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
* \param[in] rsigma Reciprocal of the root mean square of the input
* calculated over the last dimension. Shape: [N].
* \param[in] gamma Gamma tensor of shape [H].
* \param[out] dx Output gradient of shape [N, H].
* \param[out] dgamma Gradient for gamma tensor of shape [H].
* \param[out] dgamma_part Storage for partial gamma gradient.
* \param[in] stream CUDA stream used for the operation.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
*/
void nvte_rmsnorm1p_bwd(const NVTETensor dz,
const NVTETensor x,
const NVTETensor rsigma,
const NVTETensor gamma,
NVTETensor dx,
NVTETensor dgamma,
NVTETensor dgamma_part,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier
);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -106,7 +106,7 @@ inline size_t product(const std::vector<size_t> &shape) {
void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z,
Tensor *rsigma, cudaStream_t stream, const int multiprocessorCount,
Tensor *workspace, Tensor *barrier) {
Tensor *workspace, Tensor *barrier, const bool zero_centered_gamma) {
auto itype = x.data.dtype;
auto wtype = gamma.data.dtype;
auto otype = z->data.dtype;
......@@ -149,6 +149,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
params.amax = z->amax.dptr;
params.scale = z->scale.dptr;
params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma;
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
......@@ -199,7 +200,8 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const Tensor &gamma,
Tensor *dx, Tensor *dgamma, Tensor *dgamma_part, cudaStream_t stream,
const int multiprocessorCount, Tensor *workspace, Tensor *barrier) {
const int multiprocessorCount, Tensor *workspace, Tensor *barrier,
const bool zero_centered_gamma) {
using namespace transformer_engine;
auto itype = x.data.dtype;
......@@ -245,6 +247,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
params.dgamma = dgamma->data.dptr;
params.dbeta_part = nullptr;
params.dgamma_part = dgamma_part->data.dptr;
params.zero_centered_gamma = zero_centered_gamma;
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
......@@ -295,20 +298,53 @@ void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size
rmsnorm_fwd(*reinterpret_cast<const Tensor *>(x), *reinterpret_cast<const Tensor *>(gamma),
epsilon, reinterpret_cast<Tensor *>(z), reinterpret_cast<Tensor *>(rsigma), stream,
multiprocessorCount, reinterpret_cast<Tensor *>(workspace),
reinterpret_cast<Tensor *>(barrier));
reinterpret_cast<Tensor *>(barrier), false);
}
void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
const NVTETensor x, // Nxhidden_size
const NVTETensor rsigma, // N, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) {
NVTETensor dx, NVTETensor dgamma,
NVTETensor dgamma_part, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace,
NVTETensor barrier) {
NVTE_API_CALL(nvte_rmsnorm_bwd);
using namespace transformer_engine;
rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x),
*reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma),
reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma),
reinterpret_cast<Tensor *>(dgamma_part), stream, multiprocessorCount,
reinterpret_cast<Tensor *>(workspace), reinterpret_cast<Tensor *>(barrier));
reinterpret_cast<Tensor *>(workspace), reinterpret_cast<Tensor *>(barrier),
false);
}
void nvte_rmsnorm1p_fwd(const NVTETensor x, // Nxhidden_size
const NVTETensor gamma, // hidden_size
const float epsilon, NVTETensor z, NVTETensor rsigma, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) {
NVTE_API_CALL(nvte_rmsnorm1p_fwd);
using namespace transformer_engine;
rmsnorm_fwd(*reinterpret_cast<const Tensor *>(x), *reinterpret_cast<const Tensor *>(gamma),
epsilon, reinterpret_cast<Tensor *>(z), reinterpret_cast<Tensor *>(rsigma), stream,
multiprocessorCount, reinterpret_cast<Tensor *>(workspace),
reinterpret_cast<Tensor *>(barrier), true);
}
void nvte_rmsnorm1p_bwd(const NVTETensor dz, // Nxhidden_size
const NVTETensor x, // Nxhidden_size
const NVTETensor rsigma, // N, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma,
NVTETensor dgamma_part, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace,
NVTETensor barrier) {
NVTE_API_CALL(nvte_rmsnorm1p_bwd);
using namespace transformer_engine;
rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x),
*reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma),
reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma),
reinterpret_cast<Tensor *>(dgamma_part), stream, multiprocessorCount,
reinterpret_cast<Tensor *>(workspace), reinterpret_cast<Tensor *>(barrier),
true);
}
......@@ -97,7 +97,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_ke
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t x_tmp = x[it].data.elt[jt];
compute_t y_tmp = rs_r * (x_tmp);
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]);
const compute_t dy_tmp_shift = (params.zero_centered_gamma) ? 1.0f : 0.f;
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) + dy_tmp_shift;
dy_tmp *= compute_t(dz[it].data.elt[jt]);
compute_t dz_tmp = dz[it].data.elt[jt];
......@@ -356,7 +357,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t x_ij = x.data.elt[jt];
compute_t y_ij = rs * (x_ij);
compute_t g_ij = gamma[it].data.elt[jt];
const compute_t g_ij_shift = (params.zero_centered_gamma) ? 1.0f : 0.f;
compute_t g_ij = gamma[it].data.elt[jt] + g_ij_shift;
compute_t dz_ij = dz.data.elt[jt];
compute_t dy_ij = g_ij * dz_ij;
......
......@@ -106,6 +106,9 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t y_ij = rs * (xf[it * NUM_ELTS + jt]);
compute_t g_ij = gamma[it].data.elt[jt];
if (params.zero_centered_gamma) {
g_ij += 1;
}
compute_t temp_output = g_ij * y_ij;
if (params.fp8_out) {
......@@ -236,6 +239,9 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t y_ij = rs * (x[it].data.elt[jt]);
compute_t g_ij = gamma[it].data.elt[jt];
if (params.zero_centered_gamma) {
g_ij += 1;
}
z.data.elt[jt] = g_ij * y_ij;
}
......
......@@ -218,8 +218,6 @@ std::vector<at::Tensor> rmsnorm_bwd(const at::Tensor &dz,
const int sm_margin,
const bool zero_centered_gamma
) {
NVTE_CHECK(zero_centered_gamma == false,
"Zero-centered gamma is not supported yet for RMSNorm.");
auto dx = at::empty_like(x);
auto dgamma = at::empty_like(gamma);
transformer_engine::TensorWrapper workspace, barrier, dgamma_part;
......@@ -232,7 +230,7 @@ std::vector<at::Tensor> rmsnorm_bwd(const at::Tensor &dz,
auto dgamma_cu = makeTransformerEngineTensor(dgamma);
// This call populates tensors with the required config.
const auto bwd_fun = nvte_rmsnorm_bwd;
const auto bwd_fun = zero_centered_gamma ? nvte_rmsnorm1p_bwd : nvte_rmsnorm_bwd;
bwd_fun(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dgamma_part.data(),
at::cuda::getCurrentCUDAStream(),
......@@ -295,8 +293,6 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input,
const bool zero_centered_gamma
) {
using namespace transformer_engine;
NVTE_CHECK(zero_centered_gamma == false,
"Zero-centered gamma is not supported yet for RMSNorm.");
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
......@@ -313,7 +309,7 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input,
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config
const auto func = nvte_rmsnorm_fwd;
const auto func = zero_centered_gamma ? nvte_rmsnorm1p_fwd : nvte_rmsnorm_fwd;
func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(),
rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
......
......@@ -114,7 +114,7 @@ class RMSNorm(torch.nn.Module):
the RMSNorm formula changes to
.. math::
y = \frac{x}{RMS(x) + \varepsilon} * (1 + \gamma)
y = \frac{x}{RMS_\varepsilon(x)} * (1 + \gamma)
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
......@@ -155,7 +155,7 @@ class RMSNorm(torch.nn.Module):
def reset_rms_norm_parameters(self) -> None:
"""Init RMSNorm params"""
warnings.warn(
("This method will be deprecated in an upcoming release. "
("This method is deprecated and will be removed in an upcoming release. "
"Update your code to use RMSNorm.reset_parameters() instead."),
DeprecationWarning,
stacklevel=2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment