Unverified Commit 7324fe2b authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Zero-centered gamma support in LayerNorm (LayerNorm1p) (#67)



* C++ implementation of LayerNorm1P
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Expose zero centered gamma to pyTorch
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix ONNX export
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix ONNX export and tests
Signed-off-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>

* Fix lint
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix backward handling - C++ part
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix for backward - Python side
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix FP8 path
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Reenable the pylint check
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix the NVTX marker
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Change in the bwd kernel
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
parent 2f643ada
......@@ -49,14 +49,17 @@ template <typename InputType, typename OutputType>
void compute_ref_output(const InputType *data, const InputType *gamma, const InputType *beta,
OutputType *output, const float *mu, const float *rsigma,
const size_t N, const size_t H,
float *amax, float scale) {
float *amax, float scale, const bool zero_centered_gamma) {
using compute_t = float;
compute_t current_max = -1e100;
for (size_t i = 0 ; i < N; ++i) {
for (size_t j = 0; j < H; ++j) {
compute_t current = static_cast<compute_t>(data[i * H + j]);
compute_t tmp = (current - mu[i]) * rsigma[i] * static_cast<compute_t>(gamma[j]) +
static_cast<compute_t>(beta[j]);
compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1;
}
compute_t tmp = (current - mu[i]) * rsigma[i] * g + static_cast<compute_t>(beta[j]);
output[i * H + j] = static_cast<OutputType>(tmp * scale);
current_max = fmaxf(current_max, fabsf(tmp));
}
......@@ -70,7 +73,8 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data,
const InputType *gamma,
InputType *data_grad,
InputType *gamma_grad, InputType *beta_grad,
const size_t N, const size_t H) {
const size_t N, const size_t H,
const bool zero_centered_gamma) {
using compute_t = float;
std::vector<compute_t> dgamma(H, 0.f);
std::vector<compute_t> dbeta(H, 0.f);
......@@ -81,7 +85,10 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data,
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = (x - mu[i]) * rsigma[i];
const compute_t g = static_cast<compute_t>(gamma[j]);
compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1;
}
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
dgamma[j] += y * dz;
......@@ -96,7 +103,10 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data,
for (size_t j = 0; j < H; ++j) {
const compute_t x = static_cast<compute_t>(data[i * H + j]);
const compute_t y = (x - mu[i]) * rsigma[i];
const compute_t g = static_cast<compute_t>(gamma[j]);
compute_t g = static_cast<compute_t>(gamma[j]);
if (zero_centered_gamma) {
g += 1;
}
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy);
......@@ -112,7 +122,7 @@ void compute_ref_backward(const OutputType *output_grad, const InputType *data,
}
template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H) {
void performTest(const size_t N, const size_t H, const bool zero_centered_gamma) {
if (sizeof(InputType) < sizeof(OutputType)) {
GTEST_SKIP() << "LN kernel does not support OutputType > InputType";
return;
......@@ -158,32 +168,34 @@ void performTest(const size_t N, const size_t H) {
// Forward kernel
float epsilon = 1e-5;
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon,
z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount,
workspace.data(), barrier.data());
auto fwd_function = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
fwd_function(input.data(), gamma.data(), beta.data(), epsilon,
z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount,
workspace.data(), barrier.data());
workspace = Tensor(workspace.shape(), workspace.dtype());
barrier = Tensor(barrier.shape(), barrier.dtype());
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon,
z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount,
workspace.data(), barrier.data());
fwd_function(input.data(), gamma.data(), beta.data(), epsilon,
z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount,
workspace.data(), barrier.data());
// Backward kernel
nvte_layernorm_bwd(dz.data(), input.data(),
mu.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), dbeta.data(),
dgamma_part.data(), dbeta_part.data(),
0, prop.multiProcessorCount,
workspace.data(), barrier.data());
auto bwd_function = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
bwd_function(dz.data(), input.data(),
mu.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), dbeta.data(),
dgamma_part.data(), dbeta_part.data(),
0, prop.multiProcessorCount,
workspace.data(), barrier.data());
workspace = Tensor(workspace.shape(), workspace.dtype());
barrier = Tensor(barrier.shape(), barrier.dtype());
dgamma_part = Tensor(dgamma_part.shape(), dgamma_part.dtype());
dbeta_part = Tensor(dbeta_part.shape(), dbeta_part.dtype());
nvte_layernorm_bwd(dz.data(), input.data(),
mu.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), dbeta.data(),
dgamma_part.data(), dbeta_part.data(),
0, prop.multiProcessorCount,
workspace.data(), barrier.data());
bwd_function(dz.data(), input.data(),
mu.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), dbeta.data(),
dgamma_part.data(), dbeta_part.data(),
0, prop.multiProcessorCount,
workspace.data(), barrier.data());
// Reference implementations
// use the GPU stats to tighten the tolerances
......@@ -201,12 +213,13 @@ void performTest(const size_t N, const size_t H) {
rsigma.cpu_dptr<float>(),
N, H,
&ref_amax,
ref_scale);
ref_scale,
zero_centered_gamma);
compute_ref_backward(dz.cpu_dptr<WeightType>(), input.cpu_dptr<InputType>(),
mu.cpu_dptr<float>(), rsigma.cpu_dptr<float>(),
gamma.cpu_dptr<WeightType>(),
ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(),
N, H);
N, H, zero_centered_gamma);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
......@@ -248,7 +261,8 @@ std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288},
class LNTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>>> {};
std::pair<size_t, size_t>,
bool>> {};
TEST_P(LNTestSuite, TestLN) {
using namespace transformer_engine;
......@@ -257,10 +271,11 @@ TEST_P(LNTestSuite, TestLN) {
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
const bool zero_centered_gamma = std::get<3>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second);
performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma);
);
);
}
......@@ -271,11 +286,13 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases)),
::testing::ValuesIn(test_cases),
::testing::Values(false, true)),
[](const testing::TestParamInfo<LNTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
std::to_string(std::get<2>(info.param).second) + "X" +
std::to_string(std::get<3>(info.param));
return name;
});
......@@ -434,10 +434,12 @@ def test_export_gemm(
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("scale_factor", [448, 112])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_layernorm(
use_fp8: bool,
scale_factor: float,
precision: torch.dtype
precision: torch.dtype,
zero_centered_gamma: bool
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
......@@ -459,7 +461,8 @@ def test_export_layernorm(
inp,
self.weight,
self.bias,
self.eps)
self.eps,
zero_centered_gamma)
return ret
class TestFP8_Layernorm(nn.Module):
......@@ -482,7 +485,8 @@ def test_export_layernorm(
self.eps,
self.meta,
self.fp8_tensor,
self.fp8_type)
self.fp8_type,
zero_centered_gamma)
ret = cast_from_fp8(
ret,
......@@ -500,7 +504,7 @@ def test_export_layernorm(
do_export(model, inp, fname, use_fp8=use_fp8)
if precision not in (torch.bfloat16, ):
# TODO: FP32 has a small threshold (1e-5)
validate_result(fname, inp, model, atol=1e-3, is_fp8=use_fp8)
validate_result(fname, inp, model, atol=4e-3, is_fp8=use_fp8)
@skip_FP8
......@@ -646,13 +650,15 @@ def test_export_linear(
(torch.float16, True),
(torch.float16, False),
])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_layernorm_linear(
scale_factor: float,
use_fp8: bool,
use_bias: bool,
return_bias: bool,
return_layernorm_output: bool,
precision: torch.dtype
precision: torch.dtype,
zero_centered_gamma: bool
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
......@@ -676,6 +682,7 @@ def test_export_layernorm_linear(
return_bias=return_bias,
return_layernorm_output=return_layernorm_output,
params_dtype=precision,
zero_centered_gamma=zero_centered_gamma,
).to(device='cuda')
if use_fp8:
set_layer_scale(model, scale_factor)
......@@ -698,13 +705,15 @@ def test_export_layernorm_linear(
(torch.float16, True),
(torch.float16, False),
])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_layernorm_mlp(
scale_factor: float,
use_fp8: bool,
use_bias: bool,
return_bias: bool,
return_layernorm_output: bool,
precision: torch.dtype
precision: torch.dtype,
zero_centered_gamma: bool
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
......@@ -729,6 +738,7 @@ def test_export_layernorm_mlp(
return_bias=return_bias,
return_layernorm_output=return_layernorm_output,
params_dtype=precision,
zero_centered_gamma=zero_centered_gamma,
).to(device='cuda')
if use_fp8:
set_layer_scale(model, scale_factor)
......@@ -902,6 +912,7 @@ def test_export_multihead_attention(
@pytest.mark.parametrize("precision", [torch.float32, torch.float16])
@pytest.mark.parametrize("fuse_qkv_params", [False, True])
@pytest.mark.parametrize("apply_query_key_layer_scaling", [True, False])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_transformer_layer(
use_fp8: bool,
use_mask: bool,
......@@ -909,7 +920,8 @@ def test_export_transformer_layer(
output_layernorm: bool,
precision: torch.dtype,
fuse_qkv_params: bool,
apply_query_key_layer_scaling: bool
apply_query_key_layer_scaling: bool,
zero_centered_gamma: bool
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
......@@ -947,7 +959,8 @@ def test_export_transformer_layer(
output_layernorm=output_layernorm,
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
apply_query_key_layer_scaling=apply_query_key_layer_scaling).to(device='cuda')
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
zero_centered_gamma=zero_centered_gamma).to(device='cuda')
do_export(model, inp, fname, use_fp8)
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3)
......
......@@ -37,7 +37,8 @@ param_types = [torch.float32, torch.bfloat16, torch.float16]
batch_sizes = [1, 2]
skip_wgrad = [True, False]
all_boolean = [True, False]
def _disable_wgrads(block):
......@@ -151,8 +152,9 @@ def _test_sanity_common(block, bs, dtype, config, skip_wgrad):
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_layernorm_linear(dtype, bs, model, skip_wgrad):
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_layernorm_linear(dtype, bs, model, skip_wgrad, zero_centered_gamma):
config = model_configs[model]
sigma = 0.023
......@@ -164,6 +166,7 @@ def test_sanity_layernorm_linear(dtype, bs, model, skip_wgrad):
config.hidden_size * 3,
eps=config.eps,
init_method=init_method,
zero_centered_gamma=zero_centered_gamma,
)
.to(dtype=dtype)
.cuda()
......@@ -174,7 +177,7 @@ def test_sanity_layernorm_linear(dtype, bs, model, skip_wgrad):
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_linear(dtype, bs, model, skip_wgrad):
config = model_configs[model]
......@@ -195,8 +198,9 @@ def test_sanity_linear(dtype, bs, model, skip_wgrad):
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_layernorm_mlp(dtype, bs, model, skip_wgrad):
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_layernorm_mlp(dtype, bs, model, skip_wgrad, zero_centered_gamma):
config = model_configs[model]
sigma = 0.023
......@@ -210,6 +214,7 @@ def test_sanity_layernorm_mlp(dtype, bs, model, skip_wgrad):
eps=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
zero_centered_gamma=zero_centered_gamma,
)
.to(dtype=dtype)
.cuda()
......@@ -220,8 +225,9 @@ def test_sanity_layernorm_mlp(dtype, bs, model, skip_wgrad):
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_gpt(dtype, bs, model, skip_wgrad):
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_gpt(dtype, bs, model, skip_wgrad, zero_centered_gamma):
config = model_configs[model]
sigma = 0.023
......@@ -241,6 +247,7 @@ def test_sanity_gpt(dtype, bs, model, skip_wgrad):
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
)
.to(dtype=dtype)
.cuda()
......@@ -252,8 +259,9 @@ def test_sanity_gpt(dtype, bs, model, skip_wgrad):
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_bert(dtype, bs, model, skip_wgrad):
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_bert(dtype, bs, model, skip_wgrad, zero_centered_gamma):
config = model_configs[model]
sigma = 0.023
......@@ -273,6 +281,7 @@ def test_sanity_bert(dtype, bs, model, skip_wgrad):
kv_channels=config.embed,
apply_residual_connection_post_layernorm=True,
output_layernorm=True,
zero_centered_gamma=zero_centered_gamma,
)
.to(dtype=dtype)
.cuda()
......@@ -284,8 +293,9 @@ def test_sanity_bert(dtype, bs, model, skip_wgrad):
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_T5(dtype, bs, model, skip_wgrad):
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_T5(dtype, bs, model, skip_wgrad, zero_centered_gamma):
config = model_configs[model]
sigma = 0.023
......@@ -306,6 +316,7 @@ def test_sanity_T5(dtype, bs, model, skip_wgrad):
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
layer_type="decoder",
zero_centered_gamma=zero_centered_gamma,
)
.to(dtype=dtype)
.cuda()
......@@ -317,7 +328,7 @@ def test_sanity_T5(dtype, bs, model, skip_wgrad):
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_amp_and_nvfuser(dtype, bs, model, skip_wgrad):
config = model_configs[model]
......@@ -347,7 +358,7 @@ def test_sanity_amp_and_nvfuser(dtype, bs, model, skip_wgrad):
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_drop_path(dtype, bs, model, skip_wgrad):
config = model_configs[model]
......@@ -380,7 +391,7 @@ def test_sanity_drop_path(dtype, bs, model, skip_wgrad):
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_fused_qkv_params(dtype, bs, model, skip_wgrad):
config = model_configs[model]
......
......@@ -18,6 +18,11 @@ extern "C" {
#endif
/*! \brief Compute LayerNorm on the input.
*
* The formula used:
* @f[
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}\gamma + \beta
* @f]
*
* Calling this function with workspace and barrier set to empty tensor will not
* perform the operation, but instead set the shape and type of the workspace
......@@ -49,8 +54,51 @@ void nvte_layernorm_fwd(const NVTETensor x,
NVTETensor workspace,
NVTETensor barrier);
/*! \brief Compute LayerNorm with zero-centered gamma on the input.
*
* The formula used:
* @f[
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}(1 + \gamma) + \beta
* @f]
*
* Calling this function with workspace and barrier set to empty tensor will not
* perform the operation, but instead set the shape and type of the workspace
* and barrier tensors to the required values.
*
* \param[in] x Input tensor of shape [N, H].
* \param[in] gamma Gamma tensor of shape [H].
* \param[in] beta Beta tensor of shape [H].
* \param[in] epsilon Value added to denominator for numerical stability.
* \param[in,out] z Output tensor of shape [N, H].
* \param[out] mu Mean of the input calculated over the last dimension.
* Shape: [N].
* \param[out] rsigma Inverse of the variance of the input calculated over
* the last dimension. Shape: [N].
* \param[in] stream CUDA stream used for the operation.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
*/
void nvte_layernorm1p_fwd(const NVTETensor x,
const NVTETensor gamma,
const NVTETensor beta,
const float epsilon,
NVTETensor z,
NVTETensor mu,
NVTETensor rsigma,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier);
/*! \brief Compute backward of LayerNorm.
*
* This function computes the gradient of function:
* @f[
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}\gamma + \beta
* @f]
* with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$.
*
* Calling this function with workspace, barrier, dgamma_part and dbeta_part set
* to empty tensor will not perform the operation, but instead set the shape and type
......@@ -88,6 +136,49 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
NVTETensor workspace,
NVTETensor barrier);
/*! \brief Compute backward of LayerNorm with zero-centered gamma.
*
* This function computes the gradient of function:
* @f[
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}(1 + \gamma) + \beta
* @f]
* with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$.
*
* Calling this function with workspace, barrier, dgamma_part and dbeta_part set
* to empty tensor will not perform the operation, but instead set the shape and type
* of these tensors to the required values.
*
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
* \param[in] mu Mean of the input calculated over the last dimension.
* Shape: [N].
* \param[in] rsigma Inverse of the variance of the input calculated over
* the last dimension. Shape: [N].
* \param[in] gamma Gamma tensor of shape [H].
* \param[out] dx Output gradient of shape [N, H].
* \param[out] dgamma Gradient for gamma tensor of shape [H].
* \param[out] dbeta Gradient for beta tensor of shape [H].
* \param[out] dgamma_part Storage for partial gamma gradient.
* \param[out] dbeta_part Storage for partial bias gradient.
* \param[in] stream CUDA stream used for the operation.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
*/
void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size
const NVTETensor x, // BxSxhidden_size
const NVTETensor mu, // BxS, FP32!
const NVTETensor rsigma, // BxS, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx,
NVTETensor dgamma,
NVTETensor dbeta,
NVTETensor dgamma_part,
NVTETensor dbeta_part,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -44,7 +44,8 @@ struct ParamsBase {
, rs(nullptr)
, gamma(nullptr)
, workspace(nullptr)
, barrier(nullptr) {}
, barrier(nullptr)
, zero_centered_gamma(false) {}
// For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
......@@ -67,6 +68,9 @@ struct ParamsBase {
// Multi-CTA sync barriers in gmem.
int *barrier;
// Whether gamma is centered around 0
bool zero_centered_gamma;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
......
......@@ -151,7 +151,8 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
cudaStream_t stream,
const int multiprocessorCount,
Tensor* workspace,
Tensor* barrier) {
Tensor* barrier,
const bool zero_centered_gamma) {
const auto itype = x.data.dtype;
const auto wtype = gamma.data.dtype;
const auto otype = z->data.dtype;
......@@ -208,6 +209,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
params.amax = z->amax.dptr;
params.scale = z->scale.dptr;
params.fp8_out = fp8_out;
params.zero_centered_gamma = zero_centered_gamma;
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
......@@ -261,7 +263,8 @@ void layernorm_bwd(const Tensor& dz,
cudaStream_t stream,
const int multiprocessorCount,
Tensor* workspace,
Tensor* barrier
Tensor* barrier,
const bool zero_centered_gamma
) {
using namespace transformer_engine;
......@@ -325,6 +328,7 @@ void layernorm_bwd(const Tensor& dz,
params.dgamma = dgamma->data.dptr;
params.dbeta_part = dbeta_part->data.dptr;
params.dgamma_part = dgamma_part->data.dptr;
params.zero_centered_gamma = zero_centered_gamma;
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
......@@ -386,7 +390,8 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
stream,
multiprocessorCount,
reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier));
reinterpret_cast<Tensor*>(barrier),
false);
}
void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
......@@ -418,5 +423,66 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
stream,
multiprocessorCount,
reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier));
reinterpret_cast<Tensor*>(barrier),
false);
}
void nvte_layernorm1p_fwd(const NVTETensor x, // BxSxhidden_size
const NVTETensor gamma, // hidden_size
const NVTETensor beta, // hidden_size
const float epsilon,
NVTETensor z,
NVTETensor mu,
NVTETensor rsigma,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm1p_fwd);
using namespace transformer_engine;
layernorm_fwd(*reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(gamma),
*reinterpret_cast<const Tensor*>(beta),
epsilon,
reinterpret_cast<Tensor*>(z),
reinterpret_cast<Tensor*>(mu),
reinterpret_cast<Tensor*>(rsigma),
stream,
multiprocessorCount,
reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier),
true);
}
void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size
const NVTETensor x, // BxSxhidden_size
const NVTETensor mu, // BxS, FP32!
const NVTETensor rsigma, // BxS, FP32!
const NVTETensor gamma, // hidden_size
NVTETensor dx,
NVTETensor dgamma,
NVTETensor dbeta,
NVTETensor dgamma_part,
NVTETensor dbeta_part,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm1p_bwd);
using namespace transformer_engine;
layernorm_bwd(*reinterpret_cast<const Tensor*>(dz),
*reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(mu),
*reinterpret_cast<const Tensor*>(rsigma),
*reinterpret_cast<const Tensor*>(gamma),
reinterpret_cast<Tensor*>(dx),
reinterpret_cast<Tensor*>(dgamma),
reinterpret_cast<Tensor*>(dbeta),
reinterpret_cast<Tensor*>(dgamma_part),
reinterpret_cast<Tensor*>(dbeta_part),
stream,
multiprocessorCount,
reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier),
true);
}
......@@ -100,9 +100,10 @@ void ln_bwd_tuned_kernel(layer_norm::BwdParams params) {
for ( int it = 0; it < LDGS; it++ ) {
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_tmp = x[it].data.elt[jt];
compute_t y_tmp = rs_r * (x_tmp - mu_r);
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]);
const compute_t x_tmp = x[it].data.elt[jt];
const compute_t y_tmp = rs_r * (x_tmp - mu_r);
const compute_t dy_tmp_shift = (params.zero_centered_gamma) ? 1.0f : 0.f;
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) + dy_tmp_shift;
dy_tmp *= compute_t(dz[it].data.elt[jt]);
compute_t dz_tmp = dz[it].data.elt[jt];
......@@ -411,11 +412,12 @@ void ln_bwd_general_kernel(layer_norm::BwdParams params) {
dz.load_from_elts(params.dz, row * params.cols + col, params.cols - col);
#pragma unroll
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_ij = x.data.elt[jt];
compute_t y_ij = rs * (x_ij - mu);
compute_t g_ij = gamma[it].data.elt[jt];
compute_t dz_ij = dz.data.elt[jt];
compute_t dy_ij = g_ij * dz_ij;
const compute_t x_ij = x.data.elt[jt];
const compute_t y_ij = rs * (x_ij - mu);
const compute_t g_ij_shift = (params.zero_centered_gamma) ? 1.0f : 0.f;
const compute_t g_ij = gamma[it].data.elt[jt] + g_ij_shift;
const compute_t dz_ij = dz.data.elt[jt];
const compute_t dy_ij = g_ij * dz_ij;
y[it].data.elt[jt] = y_ij;
dy[it].data.elt[jt] = dy_ij;
......
......@@ -61,7 +61,7 @@ void ln_fwd_tuned_kernel(FwdParams params) {
Wvec beta[LDGS];
index_t idx = c;
#pragma unroll
for ( int it = 0; it < LDGS; it++ ) {
for ( int it = 0; it < LDGS; ++it ) {
gamma[it].load_from(params.gamma, idx);
beta[it].load_from(params.beta, idx);
idx += VEC_COLS_PER_LDG;
......@@ -113,6 +113,9 @@ void ln_fwd_tuned_kernel(FwdParams params) {
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t y_ij = rs * (xf[it * NUM_ELTS + jt] - mu);
compute_t g_ij = gamma[it].data.elt[jt];
if (params.zero_centered_gamma) {
g_ij += 1;
}
compute_t b_ij = beta[it].data.elt[jt];
compute_t temp_output = g_ij * y_ij + b_ij;
......@@ -187,7 +190,7 @@ void ln_fwd_general_kernel(FwdParams params) {
#pragma unroll
for ( int it = 0, col = gidn * NUM_ELTS;
it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS ) {
++it, col += gdimn * NUM_ELTS ) {
Wvec gamma_in, beta_in;
gamma_in.load_from_elts(params.gamma, col, params.cols - col);
beta_in.load_from_elts(params.beta, col, params.cols - col);
......@@ -269,6 +272,9 @@ void ln_fwd_general_kernel(FwdParams params) {
for ( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t y_ij = rs * (x[it].data.elt[jt] - mu);
compute_t g_ij = gamma[it].data.elt[jt];
if (params.zero_centered_gamma) {
g_ij += 1;
}
compute_t b_ij = beta[it].data.elt[jt];
z.data.elt[jt] = g_ij * y_ij + b_ij;
}
......
......@@ -255,6 +255,7 @@ def layernorm_fwd_fp8(
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
sm_margin: int,
zero_centered_gamma: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""LayerNorm with FP8 output"""
return tex.layernorm_fwd_fp8(
......@@ -267,6 +268,7 @@ def layernorm_fwd_fp8(
fp8_meta_tensor.scale_inv[fp8_tensor],
otype,
sm_margin,
zero_centered_gamma
)
......@@ -278,6 +280,7 @@ def layernorm_fwd_fp8_inf(
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
zero_centered_gamma,
) -> torch.Tensor:
"""LayerNorm with FP8 output.
......@@ -293,7 +296,8 @@ def layernorm_fwd_fp8_inf(
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype)
otype,
zero_centered_gamma)
return ret
......@@ -302,6 +306,7 @@ def layernorm_fwd_inf(
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
zero_centered_gamma: bool,
) -> torch.Tensor:
"""LayerNorm with FP8 output"""
return torch.ops.tex_ts.layernorm_fwd_inf_ts(
......@@ -309,6 +314,7 @@ def layernorm_fwd_inf(
weight,
bias,
eps,
zero_centered_gamma,
)
......
......@@ -378,7 +378,8 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
const at::Tensor &mu,
const at::Tensor &rsigma,
const at::Tensor &gamma,
const int sm_margin
const int sm_margin,
const bool zero_centered_gamma
) {
auto dx = at::empty_like(x);
auto dgamma = at::empty_like(gamma);
......@@ -395,11 +396,12 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
auto dbeta_cu = makeTransformerEngineTensor(dbeta);
// This call populates tensors with the required config.
nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(),
dbeta_part.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
const auto bwd_fun = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(),
dbeta_part.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
// Alloc space for Tensors.
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -420,11 +422,11 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
dbeta_part.dtype());
// Actual call to bwd kernel.
nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(),
dbeta_part.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(),
dbeta_part.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
return { dx, dgamma, dbeta };
}
......@@ -438,7 +440,8 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin
const int sm_margin,
const bool zero_centered_gamma
) {
using namespace transformer_engine;
......@@ -461,10 +464,11 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
// Fill workspace and barrier
auto workspace_data = allocateSpace(workspace.shape(),
......@@ -480,10 +484,10 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
barrier.dtype());
// Actual call to fwd kernel
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
return {ln_out, mu, rsigma};
}
......@@ -496,12 +500,13 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
transformer_engine::DType otype,
const bool zero_centered_gamma
) {
// This is a specialized version of layernorm_fwd_fp8, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd_fp8(
input, weight, bias, eps, scale, amax, scale_inv, otype, 0);
input, weight, bias, eps, scale, amax, scale_inv, otype, 0, zero_centered_gamma);
return out[0];
}
......@@ -510,7 +515,8 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
const int sm_margin
const int sm_margin,
const bool zero_centered_gamma
) {
using namespace transformer_engine;
......@@ -531,10 +537,11 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
// Fill workspace and barrier
auto workspace_data = allocateSpace(workspace.shape(),
......@@ -550,10 +557,10 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
barrier.dtype());
// Actual call to fwd kernel
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
return {ln_out, mu, rsigma};
}
......@@ -562,11 +569,12 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps
float eps,
const bool zero_centered_gamma
) {
// This is a specialized version of layernorm_fwd, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd(input, weight, bias, eps, 0);
std::vector<at::Tensor> out = layernorm_fwd(input, weight, bias, eps, 0, zero_centered_gamma);
return out[0];
}
......
......@@ -85,7 +85,8 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
const at::Tensor &mu,
const at::Tensor &rsigma,
const at::Tensor &gamma,
const int sm_margin
const int sm_margin,
const bool zero_centered_gamma
);
......@@ -97,7 +98,8 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin
const int sm_margin,
const bool zero_centered_gamma
);
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
......@@ -107,20 +109,23 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
transformer_engine::DType otype,
const bool zero_centered_gamma
);
std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
const int sm_margin
const int sm_margin,
const bool zero_centered_gamma
);
at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps
float eps,
const bool zero_centered_gamma
);
at::Tensor cast_to_fp8(const at::Tensor &input,
......
......@@ -133,7 +133,8 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
int64_t otype,
const bool zero_centered_gamma) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
float eps_float = static_cast<float>(eps);
......@@ -144,7 +145,8 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
scale,
amax,
scale_inv,
otype_arg);
otype_arg,
zero_centered_gamma);
return output;
}
......@@ -152,13 +154,15 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
double eps) {
double eps,
const bool zero_centered_gamma) {
float eps_float = static_cast<float>(eps);
at::Tensor output = layernorm_fwd_inf(input,
weight,
bias,
eps_float);
eps_float,
zero_centered_gamma);
return output;
}
......
......@@ -668,6 +668,7 @@ class _LayerNormLinear(torch.autograd.Function):
is_grad_enabled: bool,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -698,6 +699,7 @@ class _LayerNormLinear(torch.autograd.Function):
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
fwd_ln_sm_margin,
zero_centered_gamma,
)
else:
mu = rsigma = None
......@@ -709,15 +711,16 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
zero_centered_gamma,
)
else:
if is_grad_enabled:
ln_out_return, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out_return, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps
inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
), None, None
ln_out = cast_to_fp8(
......@@ -729,11 +732,11 @@ class _LayerNormLinear(torch.autograd.Function):
else:
if is_grad_enabled:
ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps
inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
), None, None
ln_out_return = ln_out
......@@ -831,6 +834,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.tp_group = tp_group
ctx.return_layernorm_output = return_layernorm_output
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
# Row Parallel Linear
if parallel_mode == "row" and sequence_parallel:
......@@ -997,7 +1001,8 @@ class _LayerNormLinear(torch.autograd.Function):
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin
d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
if not ctx.use_bias:
......@@ -1027,11 +1032,12 @@ class _LayerNormLinear(torch.autograd.Function):
None,
None,
None,
None,
)
class LayerNormLinear(TransformerEngineBaseModule):
"""
r"""
Applies layer normalization followed by linear transformation to the incoming data.
Parameters
......@@ -1057,6 +1063,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
module are exposed as `N` separate `torch.nn.parameter.Parameter`s each,
split along the first dimension, where `N` is the length of the argument
and the strings contained are the names of the split parameters.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
Parallelism parameters
----------------------
......@@ -1112,6 +1125,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
return_layernorm_output: bool = False,
skip_weight_param_allocation: bool = False,
parameters_split: Optional[Tuple[str, ...]] = None,
zero_centered_gamma: bool = False,
) -> None:
super().__init__()
self.in_features = in_features
......@@ -1121,6 +1135,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_bias = return_bias
self.return_layernorm_output = return_layernorm_output
self.parameters_split = parameters_split
self.zero_centered_gamma = zero_centered_gamma
if tp_group is None:
self.tp_size = tp_size
......@@ -1256,7 +1271,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
init.ones_(self.layer_norm_weight)
if not self.zero_centered_gamma:
init.ones_(self.layer_norm_weight)
else:
init.zeros_(self.layer_norm_weight)
init.zeros_(self.layer_norm_bias)
def forward(
......@@ -1339,6 +1357,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
torch.is_grad_enabled(),
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
)
out = fwd_fn(*args)
......@@ -1997,6 +2016,7 @@ class _LayerNormMLP(torch.autograd.Function):
is_grad_enabled: bool,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -2026,6 +2046,7 @@ class _LayerNormMLP(torch.autograd.Function):
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
fwd_ln_sm_margin,
zero_centered_gamma,
)
else:
ln_out = layernorm_fwd_fp8_inf(
......@@ -2036,10 +2057,11 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
zero_centered_gamma,
)
else:
ln_out_return, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
ln_out = cast_to_fp8(
ln_out_return,
......@@ -2050,11 +2072,11 @@ class _LayerNormMLP(torch.autograd.Function):
else:
if is_grad_enabled:
ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps
inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
), None, None
ln_out_return = ln_out
......@@ -2225,6 +2247,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.return_layernorm_output = return_layernorm_output
ctx.set_parallel_mode = set_parallel_mode
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
# Row Parallel Linear
if set_parallel_mode and sequence_parallel:
......@@ -2518,7 +2541,8 @@ class _LayerNormMLP(torch.autograd.Function):
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin
d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
if not ctx.use_bias:
......@@ -2553,11 +2577,12 @@ class _LayerNormMLP(torch.autograd.Function):
None,
None,
None,
None,
)
class LayerNormMLP(TransformerEngineBaseModule):
"""
r"""
Applies layer normalization on the input followed by the MLP module, consisting of
2 successive linear transformations, separated by the GeLU activation.
......@@ -2583,6 +2608,13 @@ class LayerNormMLP(TransformerEngineBaseModule):
together with the output of the linear transformation.
Example use case: residual connection for transformer module
is taken post layernorm.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
Parallelism parameters
----------------------
......@@ -2643,6 +2675,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
set_parallel_mode: bool = False,
zero_centered_gamma: bool = False,
) -> None:
super().__init__()
......@@ -2652,6 +2685,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.return_layernorm_output = return_layernorm_output
self.bias_gelu_nvfusion = bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1")))
self.set_parallel_mode = set_parallel_mode
self.zero_centered_gamma = zero_centered_gamma
if tp_group is None:
self.tp_size = tp_size
......@@ -2776,7 +2810,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
init.ones_(self.layer_norm_weight)
if not self.zero_centered_gamma:
init.ones_(self.layer_norm_weight)
else:
init.zeros_(self.layer_norm_weight)
init.zeros_(self.layer_norm_bias)
def forward(
......@@ -2840,6 +2877,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
torch.is_grad_enabled(),
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
)
out = fwd_fn(*args)
......@@ -2870,6 +2908,7 @@ class _LayerNorm(torch.autograd.Function):
eps: float,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
) -> torch.Tensor:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -2877,10 +2916,13 @@ class _LayerNorm(torch.autograd.Function):
assert inp.shape[-1] == in_features, "LayerNorm not possible"
inputmat = inp.view((-1, in_features))
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin)
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight,
ln_bias, eps, fwd_ln_sm_margin,
zero_centered_gamma)
ctx.save_for_backward(inputmat, ln_weight, mu, rsigma)
ctx.inp_shape = inp.shape
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
return ln_out.view_as(inp)
@staticmethod
......@@ -2891,9 +2933,10 @@ class _LayerNorm(torch.autograd.Function):
grad_output = grad_output.contiguous()
d_ln_out = grad_output.view(inputmat.shape)
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin
d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None
class LayerNorm(torch.nn.Module):
......@@ -2902,7 +2945,7 @@ class LayerNorm(torch.nn.Module):
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
size :attr:`hidden_size`
......@@ -2919,6 +2962,13 @@ class LayerNorm(torch.nn.Module):
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
"""
def __init__(
......@@ -2927,9 +2977,11 @@ class LayerNorm(torch.nn.Module):
eps: float = 1e-5,
sequence_parallel: bool = False,
params_dtype: torch.dtype = torch.float32,
zero_centered_gamma: bool = False,
) -> None:
super().__init__()
self.eps = eps
self.zero_centered_gamma = zero_centered_gamma
self.weight = Parameter(
torch.empty(
hidden_size,
......@@ -2974,7 +3026,10 @@ class LayerNorm(torch.nn.Module):
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
init.ones_(self.weight)
if not self.zero_centered_gamma:
init.ones_(self.weight)
else:
init.zeros_(self.weight)
init.zeros_(self.bias)
......@@ -2993,4 +3048,5 @@ class LayerNorm(torch.nn.Module):
self.eps,
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma
)
......@@ -156,17 +156,18 @@ def onnx_te_gemm(
return output
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i")
def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax, scale_inv, fp8_tensor, otype):
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i", "b")
def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax,
scale_inv, fp8_tensor, otype, zero_centered_gamma):
"""ONNX graph for layernorm_fwd_fp8"""
# pylint: disable=unused-argument
ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps)
ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma)
fp8_ln = quantize(g, ln, scale_inv, fp8_tensor)
return fp8_ln
@symbolic_helper.parse_args("v", "v", "v", "f")
def onnx_layernorm_fwd(g, inputs, weight, bias, eps):
@symbolic_helper.parse_args("v", "v", "v", "f", "b")
def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
"""ONNX graph for layernorm_fwd"""
# pylint: disable=unused-argument
normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs)
......@@ -177,6 +178,9 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps):
# Normalization axis = 0, so normalized_shape uses all dims except dim = 0
normalized_shape = normalized_shape[1:]
if zero_centered_gamma:
one = g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64, device="cuda"))
weight = g.op("Add", weight, one)
ln = torch.onnx.symbolic_opset9.layer_norm(
g,
inputs,
......
......@@ -248,6 +248,7 @@ class MultiHeadAttention(torch.nn.Module):
attention_type: str = "self",
set_parallel_mode: bool = False,
fuse_qkv_params: bool = False,
zero_centered_gamma:bool = False,
) -> None:
super().__init__()
self.layer_number = (layer_number,)
......@@ -293,6 +294,7 @@ class MultiHeadAttention(torch.nn.Module):
parallel_mode=qkv_parallel_mode,
return_layernorm_output=return_layernorm_output,
parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None,
zero_centered_gamma=zero_centered_gamma,
**common_gemm_kwargs,
)
else:
......@@ -317,6 +319,7 @@ class MultiHeadAttention(torch.nn.Module):
return_bias=False,
parallel_mode=qkv_parallel_mode,
return_layernorm_output=return_layernorm_output,
zero_centered_gamma=zero_centered_gamma,
**common_gemm_kwargs,
)
else:
......@@ -576,7 +579,7 @@ class MultiHeadAttention(torch.nn.Module):
class TransformerLayer(torch.nn.Module):
"""
r"""
TransformerLayer is made up of an attention block and a feedforward network (MLP).
This standard layer is based on the paper "Attention Is All You Need".
......@@ -629,6 +632,13 @@ class TransformerLayer(torch.nn.Module):
:attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
self_attn_mask_type: {'causal', 'padding'}, default = `causal`
type of attention mask passed into softmax operation.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
Parallelism parameters
----------------------
......@@ -703,6 +713,7 @@ class TransformerLayer(torch.nn.Module):
drop_path_rate: float = 0.0,
set_parallel_mode: bool = False,
fuse_qkv_params: bool = False,
zero_centered_gamma: bool = False,
) -> None:
super().__init__()
......@@ -759,6 +770,7 @@ class TransformerLayer(torch.nn.Module):
"return_layernorm_output": apply_residual_connection_post_layernorm,
"set_parallel_mode": set_parallel_mode,
"fuse_qkv_params": fuse_qkv_params,
"zero_centered_gamma": zero_centered_gamma
}
self.self_attention = MultiHeadAttention(
......@@ -799,6 +811,7 @@ class TransformerLayer(torch.nn.Module):
seq_length=seq_length,
micro_batch_size=micro_batch_size,
set_parallel_mode=set_parallel_mode,
zero_centered_gamma=zero_centered_gamma,
)
self.hidden_dropout = hidden_dropout
......@@ -828,6 +841,7 @@ class TransformerLayer(torch.nn.Module):
eps=layernorm_epsilon,
sequence_parallel=self.sequence_parallel,
params_dtype=params_dtype,
zero_centered_gamma=zero_centered_gamma
)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment