"tests/vscode:/vscode.git/clone" did not exist on "c6a4a4e0809ca5ecd99b147450845ef1ac0cb8b8"
Unverified Commit 7324fe2b authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Zero-centered gamma support in LayerNorm (LayerNorm1p) (#67)



* C++ implementation of LayerNorm1P
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Expose zero centered gamma to pyTorch
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix ONNX export
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix ONNX export and tests
Signed-off-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>

* Fix lint
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix backward handling - C++ part
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix for backward - Python side
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix FP8 path
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Reenable the pylint check
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix the NVTX marker
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Change in the bwd kernel
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
parent 2f643ada
...@@ -49,14 +49,17 @@ template <typename InputType, typename OutputType> ...@@ -49,14 +49,17 @@ template <typename InputType, typename OutputType>
void compute_ref_output(const InputType *data, const InputType *gamma, const InputType *beta, void compute_ref_output(const InputType *data, const InputType *gamma, const InputType *beta,
OutputType *output, const float *mu, const float *rsigma, OutputType *output, const float *mu, const float *rsigma,
const size_t N, const size_t H, const size_t N, const size_t H,
float *amax, float scale) { float *amax, float scale, const bool zero_centered_gamma) {
using compute_t = float; using compute_t = float;
compute_t current_max = -1e100; compute_t current_max = -1e100;
for (size_t i = 0 ; i < N; ++i) { for (size_t i = 0 ; i < N; ++i) {
for (size_t j = 0; j < H; ++j) { for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]); compute_t current = static_cast<compute_t>(data[i * H + j]);
compute_t tmp = (current - mu[i]) * rsigma[i] * static_cast<compute_t>(gamma[j]) + compute_t g = static_cast<compute_t>(gamma[j]);
static_cast<compute_t>(beta[j]); if (zero_centered_gamma) {
g += 1;
}
compute_t tmp = (current - mu[i]) * rsigma[i] * g + static_cast<compute_t>(beta[j]);
output[i * H + j] = static_cast<OutputType>(tmp * scale); output[i * H + j] = static_cast<OutputType>(tmp * scale);
current_max = fmaxf(current_max, fabsf(tmp)); current_max = fmaxf(current_max, fabsf(tmp));
} }
...@@ -70,7 +73,8 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data, ...@@ -70,7 +73,8 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data,
const InputType *gamma, const InputType *gamma,
InputType *data_grad, InputType *data_grad,
InputType *gamma_grad, InputType *beta_grad, InputType *gamma_grad, InputType *beta_grad,
const size_t N, const size_t H) { const size_t N, const size_t H,
const bool zero_centered_gamma) {
using compute_t = float; using compute_t = float;
std::vector<compute_t> dgamma(H, 0.f); std::vector<compute_t> dgamma(H, 0.f);
std::vector<compute_t> dbeta(H, 0.f); std::vector<compute_t> dbeta(H, 0.f);
...@@ -81,7 +85,10 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data, ...@@ -81,7 +85,10 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data,
for (size_t j = 0; j < H; ++j) { for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]); const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = (x - mu[i]) * rsigma[i]; const compute_t y = (x - mu[i]) * rsigma[i];
const compute_t g = static_cast<compute_t>(gamma[j]); compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1;
}
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]); const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz; const compute_t dy = g * dz;
dgamma[j] += y * dz; dgamma[j] += y * dz;
...@@ -96,7 +103,10 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data, ...@@ -96,7 +103,10 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data,
for (size_t j = 0; j < H; ++j) { for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]); const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = (x - mu[i]) * rsigma[i]; const compute_t y = (x - mu[i]) * rsigma[i];
const compute_t g = static_cast<compute_t>(gamma[j]); compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1;
}
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]); const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz; const compute_t dy = g * dz;
const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy); const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy);
...@@ -112,7 +122,7 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data, ...@@ -112,7 +122,7 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data,
} }
template <typename InputType, typename OutputType> template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H) { void performTest(const size_t N, const size_t H, const bool zero_centered_gamma) {
if (sizeof(InputType) < sizeof(OutputType)) { if (sizeof(InputType) < sizeof(OutputType)) {
GTEST_SKIP() << "LN kernel does not support OutputType > InputType"; GTEST_SKIP() << "LN kernel does not support OutputType > InputType";
return; return;
...@@ -158,17 +168,19 @@ void performTest(const size_t N, const size_t H) { ...@@ -158,17 +168,19 @@ void performTest(const size_t N, const size_t H) {
// Forward kernel // Forward kernel
float epsilon = 1e-5; float epsilon = 1e-5;
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, auto fwd_function = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
fwd_function(input.data(), gamma.data(), beta.data(), epsilon,
z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount, z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount,
workspace.data(), barrier.data()); workspace.data(), barrier.data());
workspace = Tensor(workspace.shape(), workspace.dtype()); workspace = Tensor(workspace.shape(), workspace.dtype());
barrier = Tensor(barrier.shape(), barrier.dtype()); barrier = Tensor(barrier.shape(), barrier.dtype());
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, fwd_function(input.data(), gamma.data(), beta.data(), epsilon,
z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount, z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount,
workspace.data(), barrier.data()); workspace.data(), barrier.data());
// Backward kernel // Backward kernel
nvte_layernorm_bwd(dz.data(), input.data(), auto bwd_function = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
bwd_function(dz.data(), input.data(),
mu.data(), rsigma.data(), gamma.data(), mu.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), dbeta.data(), dx.data(), dgamma.data(), dbeta.data(),
dgamma_part.data(), dbeta_part.data(), dgamma_part.data(), dbeta_part.data(),
...@@ -178,7 +190,7 @@ void performTest(const size_t N, const size_t H) { ...@@ -178,7 +190,7 @@ void performTest(const size_t N, const size_t H) {
barrier = Tensor(barrier.shape(), barrier.dtype()); barrier = Tensor(barrier.shape(), barrier.dtype());
dgamma_part = Tensor(dgamma_part.shape(), dgamma_part.dtype()); dgamma_part = Tensor(dgamma_part.shape(), dgamma_part.dtype());
dbeta_part = Tensor(dbeta_part.shape(), dbeta_part.dtype()); dbeta_part = Tensor(dbeta_part.shape(), dbeta_part.dtype());
nvte_layernorm_bwd(dz.data(), input.data(), bwd_function(dz.data(), input.data(),
mu.data(), rsigma.data(), gamma.data(), mu.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), dbeta.data(), dx.data(), dgamma.data(), dbeta.data(),
dgamma_part.data(), dbeta_part.data(), dgamma_part.data(), dbeta_part.data(),
...@@ -201,12 +213,13 @@ void performTest(const size_t N, const size_t H) { ...@@ -201,12 +213,13 @@ void performTest(const size_t N, const size_t H) {
rsigma.cpu_dptr<float>(), rsigma.cpu_dptr<float>(),
N, H, N, H,
&ref_amax, &ref_amax,
ref_scale); ref_scale,
zero_centered_gamma);
compute_ref_backward(dz.cpu_dptr<WeightType>(), input.cpu_dptr<InputType>(), compute_ref_backward(dz.cpu_dptr<WeightType>(), input.cpu_dptr<InputType>(),
mu.cpu_dptr<float>(), rsigma.cpu_dptr<float>(), mu.cpu_dptr<float>(), rsigma.cpu_dptr<float>(),
gamma.cpu_dptr<WeightType>(), gamma.cpu_dptr<WeightType>(),
ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(), ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(),
N, H); N, H, zero_centered_gamma);
cudaDeviceSynchronize(); cudaDeviceSynchronize();
auto err = cudaGetLastError(); auto err = cudaGetLastError();
...@@ -248,7 +261,8 @@ std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288}, ...@@ -248,7 +261,8 @@ std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288},
class LNTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType, class LNTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType, transformer_engine::DType,
std::pair<size_t, size_t>>> {}; std::pair<size_t, size_t>,
bool>> {};
TEST_P(LNTestSuite, TestLN) { TEST_P(LNTestSuite, TestLN) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -257,10 +271,11 @@ TEST_P(LNTestSuite, TestLN) { ...@@ -257,10 +271,11 @@ TEST_P(LNTestSuite, TestLN) {
const DType input_type = std::get<0>(GetParam()); const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam()); const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam()); const auto size = std::get<2>(GetParam());
const bool zero_centered_gamma = std::get<3>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second); performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma);
); );
); );
} }
...@@ -271,11 +286,13 @@ INSTANTIATE_TEST_SUITE_P( ...@@ -271,11 +286,13 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Combine( ::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3), ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases)), ::testing::ValuesIn(test_cases),
::testing::Values(false, true)),
[](const testing::TestParamInfo<LNTestSuite::ParamType>& info) { [](const testing::TestParamInfo<LNTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" + std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" + test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" + std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second); std::to_string(std::get<2>(info.param).second) + "X" +
std::to_string(std::get<3>(info.param));
return name; return name;
}); });
...@@ -434,10 +434,12 @@ def test_export_gemm( ...@@ -434,10 +434,12 @@ def test_export_gemm(
@pytest.mark.parametrize("use_fp8", [False, True]) @pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("scale_factor", [448, 112]) @pytest.mark.parametrize("scale_factor", [448, 112])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_layernorm( def test_export_layernorm(
use_fp8: bool, use_fp8: bool,
scale_factor: float, scale_factor: float,
precision: torch.dtype precision: torch.dtype,
zero_centered_gamma: bool
): ):
# Skip FP8 tests on non-hopper devices # Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9: if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
...@@ -459,7 +461,8 @@ def test_export_layernorm( ...@@ -459,7 +461,8 @@ def test_export_layernorm(
inp, inp,
self.weight, self.weight,
self.bias, self.bias,
self.eps) self.eps,
zero_centered_gamma)
return ret return ret
class TestFP8_Layernorm(nn.Module): class TestFP8_Layernorm(nn.Module):
...@@ -482,7 +485,8 @@ def test_export_layernorm( ...@@ -482,7 +485,8 @@ def test_export_layernorm(
self.eps, self.eps,
self.meta, self.meta,
self.fp8_tensor, self.fp8_tensor,
self.fp8_type) self.fp8_type,
zero_centered_gamma)
ret = cast_from_fp8( ret = cast_from_fp8(
ret, ret,
...@@ -500,7 +504,7 @@ def test_export_layernorm( ...@@ -500,7 +504,7 @@ def test_export_layernorm(
do_export(model, inp, fname, use_fp8=use_fp8) do_export(model, inp, fname, use_fp8=use_fp8)
if precision not in (torch.bfloat16, ): if precision not in (torch.bfloat16, ):
# TODO: FP32 has a small threshold (1e-5) # TODO: FP32 has a small threshold (1e-5)
validate_result(fname, inp, model, atol=1e-3, is_fp8=use_fp8) validate_result(fname, inp, model, atol=4e-3, is_fp8=use_fp8)
@skip_FP8 @skip_FP8
...@@ -646,13 +650,15 @@ def test_export_linear( ...@@ -646,13 +650,15 @@ def test_export_linear(
(torch.float16, True), (torch.float16, True),
(torch.float16, False), (torch.float16, False),
]) ])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_layernorm_linear( def test_export_layernorm_linear(
scale_factor: float, scale_factor: float,
use_fp8: bool, use_fp8: bool,
use_bias: bool, use_bias: bool,
return_bias: bool, return_bias: bool,
return_layernorm_output: bool, return_layernorm_output: bool,
precision: torch.dtype precision: torch.dtype,
zero_centered_gamma: bool
): ):
# Skip FP8 tests on non-hopper devices # Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9: if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
...@@ -676,6 +682,7 @@ def test_export_layernorm_linear( ...@@ -676,6 +682,7 @@ def test_export_layernorm_linear(
return_bias=return_bias, return_bias=return_bias,
return_layernorm_output=return_layernorm_output, return_layernorm_output=return_layernorm_output,
params_dtype=precision, params_dtype=precision,
zero_centered_gamma=zero_centered_gamma,
).to(device='cuda') ).to(device='cuda')
if use_fp8: if use_fp8:
set_layer_scale(model, scale_factor) set_layer_scale(model, scale_factor)
...@@ -698,13 +705,15 @@ def test_export_layernorm_linear( ...@@ -698,13 +705,15 @@ def test_export_layernorm_linear(
(torch.float16, True), (torch.float16, True),
(torch.float16, False), (torch.float16, False),
]) ])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_layernorm_mlp( def test_export_layernorm_mlp(
scale_factor: float, scale_factor: float,
use_fp8: bool, use_fp8: bool,
use_bias: bool, use_bias: bool,
return_bias: bool, return_bias: bool,
return_layernorm_output: bool, return_layernorm_output: bool,
precision: torch.dtype precision: torch.dtype,
zero_centered_gamma: bool
): ):
# Skip FP8 tests on non-hopper devices # Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9: if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
...@@ -729,6 +738,7 @@ def test_export_layernorm_mlp( ...@@ -729,6 +738,7 @@ def test_export_layernorm_mlp(
return_bias=return_bias, return_bias=return_bias,
return_layernorm_output=return_layernorm_output, return_layernorm_output=return_layernorm_output,
params_dtype=precision, params_dtype=precision,
zero_centered_gamma=zero_centered_gamma,
).to(device='cuda') ).to(device='cuda')
if use_fp8: if use_fp8:
set_layer_scale(model, scale_factor) set_layer_scale(model, scale_factor)
...@@ -902,6 +912,7 @@ def test_export_multihead_attention( ...@@ -902,6 +912,7 @@ def test_export_multihead_attention(
@pytest.mark.parametrize("precision", [torch.float32, torch.float16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16])
@pytest.mark.parametrize("fuse_qkv_params", [False, True]) @pytest.mark.parametrize("fuse_qkv_params", [False, True])
@pytest.mark.parametrize("apply_query_key_layer_scaling", [True, False]) @pytest.mark.parametrize("apply_query_key_layer_scaling", [True, False])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_transformer_layer( def test_export_transformer_layer(
use_fp8: bool, use_fp8: bool,
use_mask: bool, use_mask: bool,
...@@ -909,7 +920,8 @@ def test_export_transformer_layer( ...@@ -909,7 +920,8 @@ def test_export_transformer_layer(
output_layernorm: bool, output_layernorm: bool,
precision: torch.dtype, precision: torch.dtype,
fuse_qkv_params: bool, fuse_qkv_params: bool,
apply_query_key_layer_scaling: bool apply_query_key_layer_scaling: bool,
zero_centered_gamma: bool
): ):
# Skip FP8 tests on non-hopper devices # Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9: if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
...@@ -947,7 +959,8 @@ def test_export_transformer_layer( ...@@ -947,7 +959,8 @@ def test_export_transformer_layer(
output_layernorm=output_layernorm, output_layernorm=output_layernorm,
params_dtype=precision, params_dtype=precision,
fuse_qkv_params=fuse_qkv_params, fuse_qkv_params=fuse_qkv_params,
apply_query_key_layer_scaling=apply_query_key_layer_scaling).to(device='cuda') apply_query_key_layer_scaling=apply_query_key_layer_scaling,
zero_centered_gamma=zero_centered_gamma).to(device='cuda')
do_export(model, inp, fname, use_fp8) do_export(model, inp, fname, use_fp8)
if not use_fp8: if not use_fp8:
validate_result(fname, inp, model, atol=1e-3) validate_result(fname, inp, model, atol=1e-3)
......
...@@ -37,7 +37,8 @@ param_types = [torch.float32, torch.bfloat16, torch.float16] ...@@ -37,7 +37,8 @@ param_types = [torch.float32, torch.bfloat16, torch.float16]
batch_sizes = [1, 2] batch_sizes = [1, 2]
skip_wgrad = [True, False] all_boolean = [True, False]
def _disable_wgrads(block): def _disable_wgrads(block):
...@@ -151,8 +152,9 @@ def _test_sanity_common(block, bs, dtype, config, skip_wgrad): ...@@ -151,8 +152,9 @@ def _test_sanity_common(block, bs, dtype, config, skip_wgrad):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_layernorm_linear(dtype, bs, model, skip_wgrad): @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_layernorm_linear(dtype, bs, model, skip_wgrad, zero_centered_gamma):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -164,6 +166,7 @@ def test_sanity_layernorm_linear(dtype, bs, model, skip_wgrad): ...@@ -164,6 +166,7 @@ def test_sanity_layernorm_linear(dtype, bs, model, skip_wgrad):
config.hidden_size * 3, config.hidden_size * 3,
eps=config.eps, eps=config.eps,
init_method=init_method, init_method=init_method,
zero_centered_gamma=zero_centered_gamma,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -174,7 +177,7 @@ def test_sanity_layernorm_linear(dtype, bs, model, skip_wgrad): ...@@ -174,7 +177,7 @@ def test_sanity_layernorm_linear(dtype, bs, model, skip_wgrad):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_linear(dtype, bs, model, skip_wgrad): def test_sanity_linear(dtype, bs, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
...@@ -195,8 +198,9 @@ def test_sanity_linear(dtype, bs, model, skip_wgrad): ...@@ -195,8 +198,9 @@ def test_sanity_linear(dtype, bs, model, skip_wgrad):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_layernorm_mlp(dtype, bs, model, skip_wgrad): @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_layernorm_mlp(dtype, bs, model, skip_wgrad, zero_centered_gamma):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -210,6 +214,7 @@ def test_sanity_layernorm_mlp(dtype, bs, model, skip_wgrad): ...@@ -210,6 +214,7 @@ def test_sanity_layernorm_mlp(dtype, bs, model, skip_wgrad):
eps=config.eps, eps=config.eps,
init_method=init_method, init_method=init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
zero_centered_gamma=zero_centered_gamma,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -220,8 +225,9 @@ def test_sanity_layernorm_mlp(dtype, bs, model, skip_wgrad): ...@@ -220,8 +225,9 @@ def test_sanity_layernorm_mlp(dtype, bs, model, skip_wgrad):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_gpt(dtype, bs, model, skip_wgrad): @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_gpt(dtype, bs, model, skip_wgrad, zero_centered_gamma):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -241,6 +247,7 @@ def test_sanity_gpt(dtype, bs, model, skip_wgrad): ...@@ -241,6 +247,7 @@ def test_sanity_gpt(dtype, bs, model, skip_wgrad):
kv_channels=config.embed, kv_channels=config.embed,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -252,8 +259,9 @@ def test_sanity_gpt(dtype, bs, model, skip_wgrad): ...@@ -252,8 +259,9 @@ def test_sanity_gpt(dtype, bs, model, skip_wgrad):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_bert(dtype, bs, model, skip_wgrad): @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_bert(dtype, bs, model, skip_wgrad, zero_centered_gamma):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -273,6 +281,7 @@ def test_sanity_bert(dtype, bs, model, skip_wgrad): ...@@ -273,6 +281,7 @@ def test_sanity_bert(dtype, bs, model, skip_wgrad):
kv_channels=config.embed, kv_channels=config.embed,
apply_residual_connection_post_layernorm=True, apply_residual_connection_post_layernorm=True,
output_layernorm=True, output_layernorm=True,
zero_centered_gamma=zero_centered_gamma,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -284,8 +293,9 @@ def test_sanity_bert(dtype, bs, model, skip_wgrad): ...@@ -284,8 +293,9 @@ def test_sanity_bert(dtype, bs, model, skip_wgrad):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_T5(dtype, bs, model, skip_wgrad): @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_T5(dtype, bs, model, skip_wgrad, zero_centered_gamma):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -306,6 +316,7 @@ def test_sanity_T5(dtype, bs, model, skip_wgrad): ...@@ -306,6 +316,7 @@ def test_sanity_T5(dtype, bs, model, skip_wgrad):
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
layer_type="decoder", layer_type="decoder",
zero_centered_gamma=zero_centered_gamma,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -317,7 +328,7 @@ def test_sanity_T5(dtype, bs, model, skip_wgrad): ...@@ -317,7 +328,7 @@ def test_sanity_T5(dtype, bs, model, skip_wgrad):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_amp_and_nvfuser(dtype, bs, model, skip_wgrad): def test_sanity_amp_and_nvfuser(dtype, bs, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
...@@ -347,7 +358,7 @@ def test_sanity_amp_and_nvfuser(dtype, bs, model, skip_wgrad): ...@@ -347,7 +358,7 @@ def test_sanity_amp_and_nvfuser(dtype, bs, model, skip_wgrad):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_drop_path(dtype, bs, model, skip_wgrad): def test_sanity_drop_path(dtype, bs, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
...@@ -380,7 +391,7 @@ def test_sanity_drop_path(dtype, bs, model, skip_wgrad): ...@@ -380,7 +391,7 @@ def test_sanity_drop_path(dtype, bs, model, skip_wgrad):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_fused_qkv_params(dtype, bs, model, skip_wgrad): def test_sanity_fused_qkv_params(dtype, bs, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
......
...@@ -18,6 +18,11 @@ extern "C" { ...@@ -18,6 +18,11 @@ extern "C" {
#endif #endif
/*! \brief Compute LayerNorm on the input. /*! \brief Compute LayerNorm on the input.
*
* The formula used:
* @f[
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}\gamma + \beta
* @f]
* *
* Calling this function with workspace and barrier set to empty tensor will not * Calling this function with workspace and barrier set to empty tensor will not
* perform the operation, but instead set the shape and type of the workspace * perform the operation, but instead set the shape and type of the workspace
...@@ -49,8 +54,51 @@ void nvte_layernorm_fwd(const NVTETensor x, ...@@ -49,8 +54,51 @@ void nvte_layernorm_fwd(const NVTETensor x,
NVTETensor workspace, NVTETensor workspace,
NVTETensor barrier); NVTETensor barrier);
/*! \brief Compute LayerNorm with zero-centered gamma on the input.
*
* The formula used:
* @f[
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}(1 + \gamma) + \beta
* @f]
*
* Calling this function with workspace and barrier set to empty tensor will not
* perform the operation, but instead set the shape and type of the workspace
* and barrier tensors to the required values.
*
* \param[in] x Input tensor of shape [N, H].
* \param[in] gamma Gamma tensor of shape [H].
* \param[in] beta Beta tensor of shape [H].
* \param[in] epsilon Value added to denominator for numerical stability.
* \param[in,out] z Output tensor of shape [N, H].
* \param[out] mu Mean of the input calculated over the last dimension.
* Shape: [N].
* \param[out] rsigma Inverse of the variance of the input calculated over
* the last dimension. Shape: [N].
* \param[in] stream CUDA stream used for the operation.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
*/
void nvte_layernorm1p_fwd(const NVTETensor x,
const NVTETensor gamma,
const NVTETensor beta,
const float epsilon,
NVTETensor z,
NVTETensor mu,
NVTETensor rsigma,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier);
/*! \brief Compute backward of LayerNorm. /*! \brief Compute backward of LayerNorm.
*
* This function computes the gradient of function:
* @f[
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}\gamma + \beta
* @f]
* with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$.
* *
* Calling this function with workspace, barrier, dgamma_part and dbeta_part set * Calling this function with workspace, barrier, dgamma_part and dbeta_part set
* to empty tensor will not perform the operation, but instead set the shape and type * to empty tensor will not perform the operation, but instead set the shape and type
...@@ -88,6 +136,49 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size ...@@ -88,6 +136,49 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
NVTETensor workspace, NVTETensor workspace,
NVTETensor barrier); NVTETensor barrier);
/*! \brief Compute backward of LayerNorm with zero-centered gamma.
*
* This function computes the gradient of function:
* @f[
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}(1 + \gamma) + \beta
* @f]
* with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$.
*
* Calling this function with workspace, barrier, dgamma_part and dbeta_part set
* to empty tensor will not perform the operation, but instead set the shape and type
* of these tensors to the required values.
*
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
* \param[in] mu Mean of the input calculated over the last dimension.
* Shape: [N].
* \param[in] rsigma Inverse of the variance of the input calculated over
* the last dimension. Shape: [N].
* \param[in] gamma Gamma tensor of shape [H].
* \param[out] dx Output gradient of shape [N, H].
* \param[out] dgamma Gradient for gamma tensor of shape [H].
* \param[out] dbeta Gradient for beta tensor of shape [H].
* \param[out] dgamma_part Storage for partial gamma gradient.
* \param[out] dbeta_part Storage for partial bias gradient.
* \param[in] stream CUDA stream used for the operation.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
*/
void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size
const NVTETensor x, // BxSxhidden_size
const NVTETensor mu, // BxS, FP32!
const NVTETensor rsigma, // BxS, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx,
NVTETensor dgamma,
NVTETensor dbeta,
NVTETensor dgamma_part,
NVTETensor dbeta_part,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -44,7 +44,8 @@ struct ParamsBase { ...@@ -44,7 +44,8 @@ struct ParamsBase {
, rs(nullptr) , rs(nullptr)
, gamma(nullptr) , gamma(nullptr)
, workspace(nullptr) , workspace(nullptr)
, barrier(nullptr) {} , barrier(nullptr)
, zero_centered_gamma(false) {}
// For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
...@@ -67,6 +68,9 @@ struct ParamsBase { ...@@ -67,6 +68,9 @@ struct ParamsBase {
// Multi-CTA sync barriers in gmem. // Multi-CTA sync barriers in gmem.
int *barrier; int *barrier;
// Whether gamma is centered around 0
bool zero_centered_gamma;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -151,7 +151,8 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -151,7 +151,8 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
cudaStream_t stream, cudaStream_t stream,
const int multiprocessorCount, const int multiprocessorCount,
Tensor* workspace, Tensor* workspace,
Tensor* barrier) { Tensor* barrier,
const bool zero_centered_gamma) {
const auto itype = x.data.dtype; const auto itype = x.data.dtype;
const auto wtype = gamma.data.dtype; const auto wtype = gamma.data.dtype;
const auto otype = z->data.dtype; const auto otype = z->data.dtype;
...@@ -208,6 +209,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -208,6 +209,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
params.amax = z->amax.dptr; params.amax = z->amax.dptr;
params.scale = z->scale.dptr; params.scale = z->scale.dptr;
params.fp8_out = fp8_out; params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma;
// Query the kernel-specific launch parameters. // Query the kernel-specific launch parameters.
launcher(launch_params, true); launcher(launch_params, true);
...@@ -261,7 +263,8 @@ void layernorm_bwd(const Tensor& dz, ...@@ -261,7 +263,8 @@ void layernorm_bwd(const Tensor& dz,
cudaStream_t stream, cudaStream_t stream,
const int multiprocessorCount, const int multiprocessorCount,
Tensor* workspace, Tensor* workspace,
Tensor* barrier Tensor* barrier,
const bool zero_centered_gamma
) { ) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -325,6 +328,7 @@ void layernorm_bwd(const Tensor& dz, ...@@ -325,6 +328,7 @@ void layernorm_bwd(const Tensor& dz,
params.dgamma = dgamma->data.dptr; params.dgamma = dgamma->data.dptr;
params.dbeta_part = dbeta_part->data.dptr; params.dbeta_part = dbeta_part->data.dptr;
params.dgamma_part = dgamma_part->data.dptr; params.dgamma_part = dgamma_part->data.dptr;
params.zero_centered_gamma = zero_centered_gamma;
// Query the kernel-specific launch parameters. // Query the kernel-specific launch parameters.
launcher(launch_params, true); launcher(launch_params, true);
...@@ -386,7 +390,8 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size ...@@ -386,7 +390,8 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
stream, stream,
multiprocessorCount, multiprocessorCount,
reinterpret_cast<Tensor*>(workspace), reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier)); reinterpret_cast<Tensor*>(barrier),
false);
} }
void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
...@@ -418,5 +423,66 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size ...@@ -418,5 +423,66 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
stream, stream,
multiprocessorCount, multiprocessorCount,
reinterpret_cast<Tensor*>(workspace), reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier)); reinterpret_cast<Tensor*>(barrier),
false);
}
void nvte_layernorm1p_fwd(const NVTETensor x, // BxSxhidden_size
const NVTETensor gamma, // hidden_size
const NVTETensor beta, // hidden_size
const float epsilon,
NVTETensor z,
NVTETensor mu,
NVTETensor rsigma,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm1p_fwd);
using namespace transformer_engine;
layernorm_fwd(*reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(gamma),
*reinterpret_cast<const Tensor*>(beta),
epsilon,
reinterpret_cast<Tensor*>(z),
reinterpret_cast<Tensor*>(mu),
reinterpret_cast<Tensor*>(rsigma),
stream,
multiprocessorCount,
reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier),
true);
}
void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size
const NVTETensor x, // BxSxhidden_size
const NVTETensor mu, // BxS, FP32!
const NVTETensor rsigma, // BxS, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx,
NVTETensor dgamma,
NVTETensor dbeta,
NVTETensor dgamma_part,
NVTETensor dbeta_part,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm1p_bwd);
using namespace transformer_engine;
layernorm_bwd(*reinterpret_cast<const Tensor*>(dz),
*reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(mu),
*reinterpret_cast<const Tensor*>(rsigma),
*reinterpret_cast<const Tensor*>(gamma),
reinterpret_cast<Tensor*>(dx),
reinterpret_cast<Tensor*>(dgamma),
reinterpret_cast<Tensor*>(dbeta),
reinterpret_cast<Tensor*>(dgamma_part),
reinterpret_cast<Tensor*>(dbeta_part),
stream,
multiprocessorCount,
reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier),
true);
} }
...@@ -100,9 +100,10 @@ void ln_bwd_tuned_kernel(layer_norm::BwdParams params) { ...@@ -100,9 +100,10 @@ void ln_bwd_tuned_kernel(layer_norm::BwdParams params) {
for ( int it = 0; it < LDGS; it++ ) { for ( int it = 0; it < LDGS; it++ ) {
#pragma unroll #pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) { for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_tmp = x[it].data.elt[jt]; const compute_t x_tmp = x[it].data.elt[jt];
compute_t y_tmp = rs_r * (x_tmp - mu_r); const compute_t y_tmp = rs_r * (x_tmp - mu_r);
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]); const compute_t dy_tmp_shift = (params.zero_centered_gamma) ? 1.0f : 0.f;
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) + dy_tmp_shift;
dy_tmp *= compute_t(dz[it].data.elt[jt]); dy_tmp *= compute_t(dz[it].data.elt[jt]);
compute_t dz_tmp = dz[it].data.elt[jt]; compute_t dz_tmp = dz[it].data.elt[jt];
...@@ -411,11 +412,12 @@ void ln_bwd_general_kernel(layer_norm::BwdParams params) { ...@@ -411,11 +412,12 @@ void ln_bwd_general_kernel(layer_norm::BwdParams params) {
dz.load_from_elts(params.dz, row * params.cols + col, params.cols - col); dz.load_from_elts(params.dz, row * params.cols + col, params.cols - col);
#pragma unroll #pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) { for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_ij = x.data.elt[jt]; const compute_t x_ij = x.data.elt[jt];
compute_t y_ij = rs * (x_ij - mu); const compute_t y_ij = rs * (x_ij - mu);
compute_t g_ij = gamma[it].data.elt[jt]; const compute_t g_ij_shift = (params.zero_centered_gamma) ? 1.0f : 0.f;
compute_t dz_ij = dz.data.elt[jt]; const compute_t g_ij = gamma[it].data.elt[jt] + g_ij_shift;
compute_t dy_ij = g_ij * dz_ij; const compute_t dz_ij = dz.data.elt[jt];
const compute_t dy_ij = g_ij * dz_ij;
y[it].data.elt[jt] = y_ij; y[it].data.elt[jt] = y_ij;
dy[it].data.elt[jt] = dy_ij; dy[it].data.elt[jt] = dy_ij;
......
...@@ -61,7 +61,7 @@ void ln_fwd_tuned_kernel(FwdParams params) { ...@@ -61,7 +61,7 @@ void ln_fwd_tuned_kernel(FwdParams params) {
Wvec beta[LDGS]; Wvec beta[LDGS];
index_t idx = c; index_t idx = c;
#pragma unroll #pragma unroll
for ( int it = 0; it < LDGS; it++ ) { for ( int it = 0; it < LDGS; ++it ) {
gamma[it].load_from(params.gamma, idx); gamma[it].load_from(params.gamma, idx);
beta[it].load_from(params.beta, idx); beta[it].load_from(params.beta, idx);
idx += VEC_COLS_PER_LDG; idx += VEC_COLS_PER_LDG;
...@@ -113,6 +113,9 @@ void ln_fwd_tuned_kernel(FwdParams params) { ...@@ -113,6 +113,9 @@ void ln_fwd_tuned_kernel(FwdParams params) {
for ( int jt = 0; jt < NUM_ELTS; jt++ ) { for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t y_ij = rs * (xf[it * NUM_ELTS + jt] - mu); compute_t y_ij = rs * (xf[it * NUM_ELTS + jt] - mu);
compute_t g_ij = gamma[it].data.elt[jt]; compute_t g_ij = gamma[it].data.elt[jt];
if (params.zero_centered_gamma) {
g_ij += 1;
}
compute_t b_ij = beta[it].data.elt[jt]; compute_t b_ij = beta[it].data.elt[jt];
compute_t temp_output = g_ij * y_ij + b_ij; compute_t temp_output = g_ij * y_ij + b_ij;
...@@ -187,7 +190,7 @@ void ln_fwd_general_kernel(FwdParams params) { ...@@ -187,7 +190,7 @@ void ln_fwd_general_kernel(FwdParams params) {
#pragma unroll #pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS; for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && col < params.cols; it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS ) { ++it, col += gdimn * NUM_ELTS ) {
Wvec gamma_in, beta_in; Wvec gamma_in, beta_in;
gamma_in.load_from_elts(params.gamma, col, params.cols - col); gamma_in.load_from_elts(params.gamma, col, params.cols - col);
beta_in.load_from_elts(params.beta, col, params.cols - col); beta_in.load_from_elts(params.beta, col, params.cols - col);
...@@ -269,6 +272,9 @@ void ln_fwd_general_kernel(FwdParams params) { ...@@ -269,6 +272,9 @@ void ln_fwd_general_kernel(FwdParams params) {
for ( int jt = 0; jt < NUM_ELTS; jt++ ) { for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t y_ij = rs * (x[it].data.elt[jt] - mu); compute_t y_ij = rs * (x[it].data.elt[jt] - mu);
compute_t g_ij = gamma[it].data.elt[jt]; compute_t g_ij = gamma[it].data.elt[jt];
if (params.zero_centered_gamma) {
g_ij += 1;
}
compute_t b_ij = beta[it].data.elt[jt]; compute_t b_ij = beta[it].data.elt[jt];
z.data.elt[jt] = g_ij * y_ij + b_ij; z.data.elt[jt] = g_ij * y_ij + b_ij;
} }
......
...@@ -255,6 +255,7 @@ def layernorm_fwd_fp8( ...@@ -255,6 +255,7 @@ def layernorm_fwd_fp8(
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType, otype: tex.DType,
sm_margin: int, sm_margin: int,
zero_centered_gamma: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""LayerNorm with FP8 output""" """LayerNorm with FP8 output"""
return tex.layernorm_fwd_fp8( return tex.layernorm_fwd_fp8(
...@@ -267,6 +268,7 @@ def layernorm_fwd_fp8( ...@@ -267,6 +268,7 @@ def layernorm_fwd_fp8(
fp8_meta_tensor.scale_inv[fp8_tensor], fp8_meta_tensor.scale_inv[fp8_tensor],
otype, otype,
sm_margin, sm_margin,
zero_centered_gamma
) )
...@@ -278,6 +280,7 @@ def layernorm_fwd_fp8_inf( ...@@ -278,6 +280,7 @@ def layernorm_fwd_fp8_inf(
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType, otype: tex.DType,
zero_centered_gamma,
) -> torch.Tensor: ) -> torch.Tensor:
"""LayerNorm with FP8 output. """LayerNorm with FP8 output.
...@@ -293,7 +296,8 @@ def layernorm_fwd_fp8_inf( ...@@ -293,7 +296,8 @@ def layernorm_fwd_fp8_inf(
fp8_meta_tensor.amax_history, fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv, fp8_meta_tensor.scale_inv,
fp8_tensor, fp8_tensor,
otype) otype,
zero_centered_gamma)
return ret return ret
...@@ -302,6 +306,7 @@ def layernorm_fwd_inf( ...@@ -302,6 +306,7 @@ def layernorm_fwd_inf(
weight: torch.Tensor, weight: torch.Tensor,
bias: torch.Tensor, bias: torch.Tensor,
eps: float, eps: float,
zero_centered_gamma: bool,
) -> torch.Tensor: ) -> torch.Tensor:
"""LayerNorm with FP8 output""" """LayerNorm with FP8 output"""
return torch.ops.tex_ts.layernorm_fwd_inf_ts( return torch.ops.tex_ts.layernorm_fwd_inf_ts(
...@@ -309,6 +314,7 @@ def layernorm_fwd_inf( ...@@ -309,6 +314,7 @@ def layernorm_fwd_inf(
weight, weight,
bias, bias,
eps, eps,
zero_centered_gamma,
) )
......
...@@ -378,7 +378,8 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz, ...@@ -378,7 +378,8 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
const at::Tensor &mu, const at::Tensor &mu,
const at::Tensor &rsigma, const at::Tensor &rsigma,
const at::Tensor &gamma, const at::Tensor &gamma,
const int sm_margin const int sm_margin,
const bool zero_centered_gamma
) { ) {
auto dx = at::empty_like(x); auto dx = at::empty_like(x);
auto dgamma = at::empty_like(gamma); auto dgamma = at::empty_like(gamma);
...@@ -395,7 +396,8 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz, ...@@ -395,7 +396,8 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
auto dbeta_cu = makeTransformerEngineTensor(dbeta); auto dbeta_cu = makeTransformerEngineTensor(dbeta);
// This call populates tensors with the required config. // This call populates tensors with the required config.
nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), const auto bwd_fun = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(),
dbeta_part.data(), at::cuda::getCurrentCUDAStream(), dbeta_part.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
...@@ -420,7 +422,7 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz, ...@@ -420,7 +422,7 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
dbeta_part.dtype()); dbeta_part.dtype());
// Actual call to bwd kernel. // Actual call to bwd kernel.
nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(),
dbeta_part.data(), at::cuda::getCurrentCUDAStream(), dbeta_part.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
...@@ -438,7 +440,8 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, ...@@ -438,7 +440,8 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
at::Tensor amax, at::Tensor amax,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype, transformer_engine::DType otype,
const int sm_margin const int sm_margin,
const bool zero_centered_gamma
) { ) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -461,7 +464,8 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, ...@@ -461,7 +464,8 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
transformer_engine::TensorWrapper workspace, barrier; transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config // This call populates workspace and barrier tensors with the required config
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data()); workspace.data(), barrier.data());
...@@ -480,7 +484,7 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, ...@@ -480,7 +484,7 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
barrier.dtype()); barrier.dtype());
// Actual call to fwd kernel // Actual call to fwd kernel
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data()); workspace.data(), barrier.data());
...@@ -496,12 +500,13 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, ...@@ -496,12 +500,13 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor scale, at::Tensor scale,
at::Tensor amax, at::Tensor amax,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype transformer_engine::DType otype,
const bool zero_centered_gamma
) { ) {
// This is a specialized version of layernorm_fwd_fp8, optimized for inference, // This is a specialized version of layernorm_fwd_fp8, optimized for inference,
// which only returns the normalized output. // which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd_fp8( std::vector<at::Tensor> out = layernorm_fwd_fp8(
input, weight, bias, eps, scale, amax, scale_inv, otype, 0); input, weight, bias, eps, scale, amax, scale_inv, otype, 0, zero_centered_gamma);
return out[0]; return out[0];
} }
...@@ -510,7 +515,8 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, ...@@ -510,7 +515,8 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
const at::Tensor &bias, const at::Tensor &bias,
float eps, float eps,
const int sm_margin const int sm_margin,
const bool zero_centered_gamma
) { ) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -531,7 +537,8 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, ...@@ -531,7 +537,8 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
transformer_engine::TensorWrapper workspace, barrier; transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config // This call populates workspace and barrier tensors with the required config
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data()); workspace.data(), barrier.data());
...@@ -550,7 +557,7 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, ...@@ -550,7 +557,7 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
barrier.dtype()); barrier.dtype());
// Actual call to fwd kernel // Actual call to fwd kernel
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data()); workspace.data(), barrier.data());
...@@ -562,11 +569,12 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, ...@@ -562,11 +569,12 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
at::Tensor layernorm_fwd_inf(const at::Tensor &input, at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
const at::Tensor &bias, const at::Tensor &bias,
float eps float eps,
const bool zero_centered_gamma
) { ) {
// This is a specialized version of layernorm_fwd, optimized for inference, // This is a specialized version of layernorm_fwd, optimized for inference,
// which only returns the normalized output. // which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd(input, weight, bias, eps, 0); std::vector<at::Tensor> out = layernorm_fwd(input, weight, bias, eps, 0, zero_centered_gamma);
return out[0]; return out[0];
} }
......
...@@ -85,7 +85,8 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz, ...@@ -85,7 +85,8 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
const at::Tensor &mu, const at::Tensor &mu,
const at::Tensor &rsigma, const at::Tensor &rsigma,
const at::Tensor &gamma, const at::Tensor &gamma,
const int sm_margin const int sm_margin,
const bool zero_centered_gamma
); );
...@@ -97,7 +98,8 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, ...@@ -97,7 +98,8 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
at::Tensor amax, at::Tensor amax,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype, transformer_engine::DType otype,
const int sm_margin const int sm_margin,
const bool zero_centered_gamma
); );
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
...@@ -107,20 +109,23 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, ...@@ -107,20 +109,23 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor scale, at::Tensor scale,
at::Tensor amax, at::Tensor amax,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype transformer_engine::DType otype,
const bool zero_centered_gamma
); );
std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
const at::Tensor &bias, const at::Tensor &bias,
float eps, float eps,
const int sm_margin const int sm_margin,
const bool zero_centered_gamma
); );
at::Tensor layernorm_fwd_inf(const at::Tensor &input, at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
const at::Tensor &bias, const at::Tensor &bias,
float eps float eps,
const bool zero_centered_gamma
); );
at::Tensor cast_to_fp8(const at::Tensor &input, at::Tensor cast_to_fp8(const at::Tensor &input,
......
...@@ -133,7 +133,8 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, ...@@ -133,7 +133,8 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
at::Tensor amax, at::Tensor amax,
at::Tensor scale_inv, at::Tensor scale_inv,
int64_t fp8_tensor, int64_t fp8_tensor,
int64_t otype) { int64_t otype,
const bool zero_centered_gamma) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype); transformer_engine::DType otype_arg = reverse_map_dtype(otype);
float eps_float = static_cast<float>(eps); float eps_float = static_cast<float>(eps);
...@@ -144,7 +145,8 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, ...@@ -144,7 +145,8 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
scale, scale,
amax, amax,
scale_inv, scale_inv,
otype_arg); otype_arg,
zero_centered_gamma);
return output; return output;
} }
...@@ -152,13 +154,15 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, ...@@ -152,13 +154,15 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input, at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
const at::Tensor &bias, const at::Tensor &bias,
double eps) { double eps,
const bool zero_centered_gamma) {
float eps_float = static_cast<float>(eps); float eps_float = static_cast<float>(eps);
at::Tensor output = layernorm_fwd_inf(input, at::Tensor output = layernorm_fwd_inf(input,
weight, weight,
bias, bias,
eps_float); eps_float,
zero_centered_gamma);
return output; return output;
} }
......
...@@ -668,6 +668,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -668,6 +668,7 @@ class _LayerNormLinear(torch.autograd.Function):
is_grad_enabled: bool, is_grad_enabled: bool,
fwd_ln_sm_margin: int, fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int, bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -698,6 +699,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -698,6 +699,7 @@ class _LayerNormLinear(torch.autograd.Function):
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
fwd_ln_sm_margin, fwd_ln_sm_margin,
zero_centered_gamma,
) )
else: else:
mu = rsigma = None mu = rsigma = None
...@@ -709,15 +711,16 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -709,15 +711,16 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
zero_centered_gamma,
) )
else: else:
if is_grad_enabled: if is_grad_enabled:
ln_out_return, mu, rsigma = tex.layernorm_fwd( ln_out_return, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
) )
else: else:
ln_out_return, mu, rsigma = layernorm_fwd_inf( ln_out_return, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
), None, None ), None, None
ln_out = cast_to_fp8( ln_out = cast_to_fp8(
...@@ -729,11 +732,11 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -729,11 +732,11 @@ class _LayerNormLinear(torch.autograd.Function):
else: else:
if is_grad_enabled: if is_grad_enabled:
ln_out, mu, rsigma = tex.layernorm_fwd( ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
) )
else: else:
ln_out, mu, rsigma = layernorm_fwd_inf( ln_out, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
), None, None ), None, None
ln_out_return = ln_out ln_out_return = ln_out
...@@ -831,6 +834,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -831,6 +834,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.return_layernorm_output = return_layernorm_output ctx.return_layernorm_output = return_layernorm_output
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
# Row Parallel Linear # Row Parallel Linear
if parallel_mode == "row" and sequence_parallel: if parallel_mode == "row" and sequence_parallel:
...@@ -997,7 +1001,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -997,7 +1001,8 @@ class _LayerNormLinear(torch.autograd.Function):
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out) d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
dxmat, dgamma, dbeta = tex.layernorm_bwd( dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
) )
if not ctx.use_bias: if not ctx.use_bias:
...@@ -1027,11 +1032,12 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -1027,11 +1032,12 @@ class _LayerNormLinear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
class LayerNormLinear(TransformerEngineBaseModule): class LayerNormLinear(TransformerEngineBaseModule):
""" r"""
Applies layer normalization followed by linear transformation to the incoming data. Applies layer normalization followed by linear transformation to the incoming data.
Parameters Parameters
...@@ -1057,6 +1063,13 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1057,6 +1063,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
module are exposed as `N` separate `torch.nn.parameter.Parameter`s each, module are exposed as `N` separate `torch.nn.parameter.Parameter`s each,
split along the first dimension, where `N` is the length of the argument split along the first dimension, where `N` is the length of the argument
and the strings contained are the names of the split parameters. and the strings contained are the names of the split parameters.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -1112,6 +1125,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1112,6 +1125,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
return_layernorm_output: bool = False, return_layernorm_output: bool = False,
skip_weight_param_allocation: bool = False, skip_weight_param_allocation: bool = False,
parameters_split: Optional[Tuple[str, ...]] = None, parameters_split: Optional[Tuple[str, ...]] = None,
zero_centered_gamma: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.in_features = in_features self.in_features = in_features
...@@ -1121,6 +1135,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1121,6 +1135,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_bias = return_bias self.return_bias = return_bias
self.return_layernorm_output = return_layernorm_output self.return_layernorm_output = return_layernorm_output
self.parameters_split = parameters_split self.parameters_split = parameters_split
self.zero_centered_gamma = zero_centered_gamma
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
...@@ -1256,7 +1271,10 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1256,7 +1271,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
if not self.zero_centered_gamma:
init.ones_(self.layer_norm_weight) init.ones_(self.layer_norm_weight)
else:
init.zeros_(self.layer_norm_weight)
init.zeros_(self.layer_norm_bias) init.zeros_(self.layer_norm_bias)
def forward( def forward(
...@@ -1339,6 +1357,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1339,6 +1357,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
torch.is_grad_enabled(), torch.is_grad_enabled(),
self.fwd_ln_sm_margin, self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
...@@ -1997,6 +2016,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1997,6 +2016,7 @@ class _LayerNormMLP(torch.autograd.Function):
is_grad_enabled: bool, is_grad_enabled: bool,
fwd_ln_sm_margin: int, fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int, bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -2026,6 +2046,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2026,6 +2046,7 @@ class _LayerNormMLP(torch.autograd.Function):
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
fwd_ln_sm_margin, fwd_ln_sm_margin,
zero_centered_gamma,
) )
else: else:
ln_out = layernorm_fwd_fp8_inf( ln_out = layernorm_fwd_fp8_inf(
...@@ -2036,10 +2057,11 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2036,10 +2057,11 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
zero_centered_gamma,
) )
else: else:
ln_out_return, mu, rsigma = tex.layernorm_fwd( ln_out_return, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
) )
ln_out = cast_to_fp8( ln_out = cast_to_fp8(
ln_out_return, ln_out_return,
...@@ -2050,11 +2072,11 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2050,11 +2072,11 @@ class _LayerNormMLP(torch.autograd.Function):
else: else:
if is_grad_enabled: if is_grad_enabled:
ln_out, mu, rsigma = tex.layernorm_fwd( ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
) )
else: else:
ln_out, mu, rsigma = layernorm_fwd_inf( ln_out, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
), None, None ), None, None
ln_out_return = ln_out ln_out_return = ln_out
...@@ -2225,6 +2247,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2225,6 +2247,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.return_layernorm_output = return_layernorm_output ctx.return_layernorm_output = return_layernorm_output
ctx.set_parallel_mode = set_parallel_mode ctx.set_parallel_mode = set_parallel_mode
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
# Row Parallel Linear # Row Parallel Linear
if set_parallel_mode and sequence_parallel: if set_parallel_mode and sequence_parallel:
...@@ -2518,7 +2541,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2518,7 +2541,8 @@ class _LayerNormMLP(torch.autograd.Function):
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out) d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
dxmat, dgamma, dbeta = tex.layernorm_bwd( dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
) )
if not ctx.use_bias: if not ctx.use_bias:
...@@ -2553,11 +2577,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2553,11 +2577,12 @@ class _LayerNormMLP(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
class LayerNormMLP(TransformerEngineBaseModule): class LayerNormMLP(TransformerEngineBaseModule):
""" r"""
Applies layer normalization on the input followed by the MLP module, consisting of Applies layer normalization on the input followed by the MLP module, consisting of
2 successive linear transformations, separated by the GeLU activation. 2 successive linear transformations, separated by the GeLU activation.
...@@ -2583,6 +2608,13 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2583,6 +2608,13 @@ class LayerNormMLP(TransformerEngineBaseModule):
together with the output of the linear transformation. together with the output of the linear transformation.
Example use case: residual connection for transformer module Example use case: residual connection for transformer module
is taken post layernorm. is taken post layernorm.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -2643,6 +2675,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2643,6 +2675,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
seq_length: Optional[int] = None, seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None, micro_batch_size: Optional[int] = None,
set_parallel_mode: bool = False, set_parallel_mode: bool = False,
zero_centered_gamma: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -2652,6 +2685,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2652,6 +2685,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.return_layernorm_output = return_layernorm_output self.return_layernorm_output = return_layernorm_output
self.bias_gelu_nvfusion = bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1"))) self.bias_gelu_nvfusion = bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1")))
self.set_parallel_mode = set_parallel_mode self.set_parallel_mode = set_parallel_mode
self.zero_centered_gamma = zero_centered_gamma
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
...@@ -2776,7 +2810,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2776,7 +2810,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
if not self.zero_centered_gamma:
init.ones_(self.layer_norm_weight) init.ones_(self.layer_norm_weight)
else:
init.zeros_(self.layer_norm_weight)
init.zeros_(self.layer_norm_bias) init.zeros_(self.layer_norm_bias)
def forward( def forward(
...@@ -2840,6 +2877,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2840,6 +2877,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
torch.is_grad_enabled(), torch.is_grad_enabled(),
self.fwd_ln_sm_margin, self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
...@@ -2870,6 +2908,7 @@ class _LayerNorm(torch.autograd.Function): ...@@ -2870,6 +2908,7 @@ class _LayerNorm(torch.autograd.Function):
eps: float, eps: float,
fwd_ln_sm_margin: int, fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int, bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -2877,10 +2916,13 @@ class _LayerNorm(torch.autograd.Function): ...@@ -2877,10 +2916,13 @@ class _LayerNorm(torch.autograd.Function):
assert inp.shape[-1] == in_features, "LayerNorm not possible" assert inp.shape[-1] == in_features, "LayerNorm not possible"
inputmat = inp.view((-1, in_features)) inputmat = inp.view((-1, in_features))
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin) ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight,
ln_bias, eps, fwd_ln_sm_margin,
zero_centered_gamma)
ctx.save_for_backward(inputmat, ln_weight, mu, rsigma) ctx.save_for_backward(inputmat, ln_weight, mu, rsigma)
ctx.inp_shape = inp.shape ctx.inp_shape = inp.shape
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
return ln_out.view_as(inp) return ln_out.view_as(inp)
@staticmethod @staticmethod
...@@ -2891,9 +2933,10 @@ class _LayerNorm(torch.autograd.Function): ...@@ -2891,9 +2933,10 @@ class _LayerNorm(torch.autograd.Function):
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
d_ln_out = grad_output.view(inputmat.shape) d_ln_out = grad_output.view(inputmat.shape)
dxmat, dgamma, dbeta = tex.layernorm_bwd( dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
) )
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None
class LayerNorm(torch.nn.Module): class LayerNorm(torch.nn.Module):
...@@ -2902,7 +2945,7 @@ class LayerNorm(torch.nn.Module): ...@@ -2902,7 +2945,7 @@ class LayerNorm(torch.nn.Module):
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__ the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
size :attr:`hidden_size` size :attr:`hidden_size`
...@@ -2919,6 +2962,13 @@ class LayerNorm(torch.nn.Module): ...@@ -2919,6 +2962,13 @@ class LayerNorm(torch.nn.Module):
it controls the type used to allocate the initial parameters. Useful when it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
""" """
def __init__( def __init__(
...@@ -2927,9 +2977,11 @@ class LayerNorm(torch.nn.Module): ...@@ -2927,9 +2977,11 @@ class LayerNorm(torch.nn.Module):
eps: float = 1e-5, eps: float = 1e-5,
sequence_parallel: bool = False, sequence_parallel: bool = False,
params_dtype: torch.dtype = torch.float32, params_dtype: torch.dtype = torch.float32,
zero_centered_gamma: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.zero_centered_gamma = zero_centered_gamma
self.weight = Parameter( self.weight = Parameter(
torch.empty( torch.empty(
hidden_size, hidden_size,
...@@ -2974,7 +3026,10 @@ class LayerNorm(torch.nn.Module): ...@@ -2974,7 +3026,10 @@ class LayerNorm(torch.nn.Module):
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
if not self.zero_centered_gamma:
init.ones_(self.weight) init.ones_(self.weight)
else:
init.zeros_(self.weight)
init.zeros_(self.bias) init.zeros_(self.bias)
...@@ -2993,4 +3048,5 @@ class LayerNorm(torch.nn.Module): ...@@ -2993,4 +3048,5 @@ class LayerNorm(torch.nn.Module):
self.eps, self.eps,
self.fwd_ln_sm_margin, self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma
) )
...@@ -156,17 +156,18 @@ def onnx_te_gemm( ...@@ -156,17 +156,18 @@ def onnx_te_gemm(
return output return output
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i") @symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i", "b")
def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax, scale_inv, fp8_tensor, otype): def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax,
scale_inv, fp8_tensor, otype, zero_centered_gamma):
"""ONNX graph for layernorm_fwd_fp8""" """ONNX graph for layernorm_fwd_fp8"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps) ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma)
fp8_ln = quantize(g, ln, scale_inv, fp8_tensor) fp8_ln = quantize(g, ln, scale_inv, fp8_tensor)
return fp8_ln return fp8_ln
@symbolic_helper.parse_args("v", "v", "v", "f") @symbolic_helper.parse_args("v", "v", "v", "f", "b")
def onnx_layernorm_fwd(g, inputs, weight, bias, eps): def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
"""ONNX graph for layernorm_fwd""" """ONNX graph for layernorm_fwd"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs)
...@@ -177,6 +178,9 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps): ...@@ -177,6 +178,9 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps):
# Normalization axis = 0, so normalized_shape uses all dims except dim = 0 # Normalization axis = 0, so normalized_shape uses all dims except dim = 0
normalized_shape = normalized_shape[1:] normalized_shape = normalized_shape[1:]
if zero_centered_gamma:
one = g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64, device="cuda"))
weight = g.op("Add", weight, one)
ln = torch.onnx.symbolic_opset9.layer_norm( ln = torch.onnx.symbolic_opset9.layer_norm(
g, g,
inputs, inputs,
......
...@@ -248,6 +248,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -248,6 +248,7 @@ class MultiHeadAttention(torch.nn.Module):
attention_type: str = "self", attention_type: str = "self",
set_parallel_mode: bool = False, set_parallel_mode: bool = False,
fuse_qkv_params: bool = False, fuse_qkv_params: bool = False,
zero_centered_gamma:bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_number = (layer_number,) self.layer_number = (layer_number,)
...@@ -293,6 +294,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -293,6 +294,7 @@ class MultiHeadAttention(torch.nn.Module):
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
return_layernorm_output=return_layernorm_output, return_layernorm_output=return_layernorm_output,
parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None, parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None,
zero_centered_gamma=zero_centered_gamma,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
else: else:
...@@ -317,6 +319,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -317,6 +319,7 @@ class MultiHeadAttention(torch.nn.Module):
return_bias=False, return_bias=False,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
return_layernorm_output=return_layernorm_output, return_layernorm_output=return_layernorm_output,
zero_centered_gamma=zero_centered_gamma,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
else: else:
...@@ -576,7 +579,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -576,7 +579,7 @@ class MultiHeadAttention(torch.nn.Module):
class TransformerLayer(torch.nn.Module): class TransformerLayer(torch.nn.Module):
""" r"""
TransformerLayer is made up of an attention block and a feedforward network (MLP). TransformerLayer is made up of an attention block and a feedforward network (MLP).
This standard layer is based on the paper "Attention Is All You Need". This standard layer is based on the paper "Attention Is All You Need".
...@@ -629,6 +632,13 @@ class TransformerLayer(torch.nn.Module): ...@@ -629,6 +632,13 @@ class TransformerLayer(torch.nn.Module):
:attr:`hidden_size` / :attr:`num_attention_heads` if `None`. :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
self_attn_mask_type: {'causal', 'padding'}, default = `causal` self_attn_mask_type: {'causal', 'padding'}, default = `causal`
type of attention mask passed into softmax operation. type of attention mask passed into softmax operation.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -703,6 +713,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -703,6 +713,7 @@ class TransformerLayer(torch.nn.Module):
drop_path_rate: float = 0.0, drop_path_rate: float = 0.0,
set_parallel_mode: bool = False, set_parallel_mode: bool = False,
fuse_qkv_params: bool = False, fuse_qkv_params: bool = False,
zero_centered_gamma: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -759,6 +770,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -759,6 +770,7 @@ class TransformerLayer(torch.nn.Module):
"return_layernorm_output": apply_residual_connection_post_layernorm, "return_layernorm_output": apply_residual_connection_post_layernorm,
"set_parallel_mode": set_parallel_mode, "set_parallel_mode": set_parallel_mode,
"fuse_qkv_params": fuse_qkv_params, "fuse_qkv_params": fuse_qkv_params,
"zero_centered_gamma": zero_centered_gamma
} }
self.self_attention = MultiHeadAttention( self.self_attention = MultiHeadAttention(
...@@ -799,6 +811,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -799,6 +811,7 @@ class TransformerLayer(torch.nn.Module):
seq_length=seq_length, seq_length=seq_length,
micro_batch_size=micro_batch_size, micro_batch_size=micro_batch_size,
set_parallel_mode=set_parallel_mode, set_parallel_mode=set_parallel_mode,
zero_centered_gamma=zero_centered_gamma,
) )
self.hidden_dropout = hidden_dropout self.hidden_dropout = hidden_dropout
...@@ -828,6 +841,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -828,6 +841,7 @@ class TransformerLayer(torch.nn.Module):
eps=layernorm_epsilon, eps=layernorm_epsilon,
sequence_parallel=self.sequence_parallel, sequence_parallel=self.sequence_parallel,
params_dtype=params_dtype, params_dtype=params_dtype,
zero_centered_gamma=zero_centered_gamma
) )
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment