Unverified Commit 99f40677 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Fix return_bias option in LayerNormLinear and LayerNormMLP (#1569)



* Do not apply bias when apply_bias is False
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Bwd fix for LNMLP and tests
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix for the dbias calculation
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Improve tests and cleaning the logic
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Tightened test tolerances a little
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Revert "Tightened test tolerances a little"

This reverts commit 2e20a92c884a84759006541adc1d638ab91dde62.
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Update tests/pytorch/test_numerics.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>

* Fix the Gelu Aux type
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove use_fc1_bias option
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent bee4649c
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from collections import OrderedDict from collections import OrderedDict
import math import math
import os import os
from typing import Dict, List, Optional from typing import Dict, List, Tuple, Optional
import pytest import pytest
import copy import copy
import random import random
...@@ -331,9 +331,9 @@ class TorchLayerNormLinear(nn.Module): ...@@ -331,9 +331,9 @@ class TorchLayerNormLinear(nn.Module):
in_features: int, in_features: int,
out_features: int, out_features: int,
eps: float, eps: float,
bias: bool = True,
normalization: str = "LayerNorm", normalization: str = "LayerNorm",
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
bias: bool = True,
): ):
super().__init__() super().__init__()
if normalization == "LayerNorm": if normalization == "LayerNorm":
...@@ -347,7 +347,7 @@ class TorchLayerNormLinear(nn.Module): ...@@ -347,7 +347,7 @@ class TorchLayerNormLinear(nn.Module):
else: else:
raise RuntimeError("Unsupported normalization") raise RuntimeError("Unsupported normalization")
self.linear = nn.Linear(in_features, out_features) self.linear = nn.Linear(in_features, out_features, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(self.layernorm(x)) return self.linear(self.layernorm(x))
...@@ -447,6 +447,7 @@ class TorchLayerNormMLP(nn.Module): ...@@ -447,6 +447,7 @@ class TorchLayerNormMLP(nn.Module):
eps: float = 1e-5, eps: float = 1e-5,
activation="gelu", activation="gelu",
normalization: str = "LayerNorm", normalization: str = "LayerNorm",
bias: bool = True,
): ):
super().__init__() super().__init__()
if normalization == "LayerNorm": if normalization == "LayerNorm":
...@@ -462,8 +463,8 @@ class TorchLayerNormMLP(nn.Module): ...@@ -462,8 +463,8 @@ class TorchLayerNormMLP(nn.Module):
fc1_output_features = ffn_hidden_size fc1_output_features = ffn_hidden_size
self.gelu = _supported_act[activation] self.gelu = _supported_act[activation]
self.fc1 = nn.Linear(hidden_size, fc1_output_features) self.fc1 = nn.Linear(hidden_size, fc1_output_features, bias=bias)
self.fc2 = nn.Linear(ffn_hidden_size, hidden_size) self.fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=bias)
def forward(self, x): def forward(self, x):
t = self.gelu(self.fc1(self.ln(x))) t = self.gelu(self.fc1(self.ln(x)))
...@@ -1039,6 +1040,8 @@ def _test_granular_accuracy(block, bs, dtype, config): ...@@ -1039,6 +1040,8 @@ def _test_granular_accuracy(block, bs, dtype, config):
inp_hidden_states.retain_grad() inp_hidden_states.retain_grad()
out = block(inp_hidden_states) out = block(inp_hidden_states)
if isinstance(out, (List, Tuple)):
out = out[0]
loss = out.sum() loss = out.sum()
loss.backward() loss.backward()
...@@ -1117,32 +1120,53 @@ def test_dpa_accuracy(dtype, bs, model): ...@@ -1117,32 +1120,53 @@ def test_dpa_accuracy(dtype, bs, model):
assert_allclose(te_output, torch_output, atol=5e-2, rtol=1e-2) assert_allclose(te_output, torch_output, atol=5e-2, rtol=1e-2)
class TestReturnBiasModule(nn.Module):
def __init__(self, mod, **kwargs):
super().__init__()
self.te_module = mod(**kwargs)
self.return_bias = kwargs["return_bias"]
self.bias = kwargs["bias"]
def forward(self, x):
if self.return_bias:
out, bias = self.te_module(x)
if self.bias:
out = out + bias
return out
return self.te_module(x)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("model", ["small"])
def test_linear_accuracy(dtype, bs, model): @pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_linear_accuracy(dtype, bs, model, return_bias, bias):
config = model_configs[model] config = model_configs[model]
te_linear = Linear( te_linear = TestReturnBiasModule(
config.hidden_size, Linear,
4 * config.hidden_size, in_features=config.hidden_size,
bias=True, out_features=4 * config.hidden_size,
params_dtype=dtype, params_dtype=dtype,
return_bias=return_bias,
bias=bias,
device="cuda", device="cuda",
).eval() )
torch_linear = torch.nn.Linear( torch_linear = torch.nn.Linear(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
bias=True, bias=bias,
device="cuda", device="cuda",
dtype=dtype, dtype=dtype,
).eval() )
# Share params # Share params
with torch.no_grad(): with torch.no_grad():
torch_linear.weight = Parameter(te_linear.weight.clone()) torch_linear.weight = Parameter(te_linear.te_module.weight.clone())
torch_linear.bias = Parameter(te_linear.bias.clone()) if bias:
torch_linear.bias = Parameter(te_linear.te_module.bias.clone())
te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config) te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_linear, bs, dtype, config) torch_outputs = _test_granular_accuracy(torch_linear, bs, dtype, config)
...@@ -1265,41 +1289,51 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): ...@@ -1265,41 +1289,51 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centered_gamma): @pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_layernorm_linear_accuracy(
dtype, bs, model, normalization, zero_centered_gamma, return_bias, bias
):
config = model_configs[model] config = model_configs[model]
te_ln_linear = LayerNormLinear( te_ln_linear = TestReturnBiasModule(
config.hidden_size, LayerNormLinear,
4 * config.hidden_size, in_features=config.hidden_size,
config.eps, out_features=4 * config.hidden_size,
bias=True, eps=config.eps,
normalization=normalization, normalization=normalization,
params_dtype=dtype, params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
return_bias=return_bias,
bias=bias,
device="cuda", device="cuda",
).eval() )
torch_ln_linear = ( torch_ln_linear = (
TorchLayerNormLinear( TorchLayerNormLinear(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.eps, config.eps,
bias=True,
normalization=normalization, normalization=normalization,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
bias=bias,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
.eval()
) )
# Share params # Share params
with torch.no_grad(): with torch.no_grad():
torch_ln_linear.layernorm.weight = Parameter(te_ln_linear.layer_norm_weight.clone()) torch_ln_linear.layernorm.weight = Parameter(
te_ln_linear.te_module.layer_norm_weight.clone()
)
if normalization != "RMSNorm": if normalization != "RMSNorm":
torch_ln_linear.layernorm.bias = Parameter(te_ln_linear.layer_norm_bias.clone()) torch_ln_linear.layernorm.bias = Parameter(
torch_ln_linear.linear.weight = Parameter(te_ln_linear.weight.clone()) te_ln_linear.te_module.layer_norm_bias.clone()
torch_ln_linear.linear.bias = Parameter(te_ln_linear.bias.clone()) )
torch_ln_linear.linear.weight = Parameter(te_ln_linear.te_module.weight.clone())
if bias:
torch_ln_linear.linear.bias = Parameter(te_ln_linear.te_module.bias.clone())
te_outputs = _test_granular_accuracy(te_ln_linear, bs, dtype, config) te_outputs = _test_granular_accuracy(te_ln_linear, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config) torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)
...@@ -1339,17 +1373,22 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere ...@@ -1339,17 +1373,22 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere
@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("activation", all_activations) @pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): @pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, return_bias, bias):
config = model_configs[model] config = model_configs[model]
te_ln_mlp = LayerNormMLP( te_ln_mlp = TestReturnBiasModule(
config.hidden_size, LayerNormMLP,
4 * config.hidden_size, hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
activation=activation, activation=activation,
normalization=normalization, normalization=normalization,
params_dtype=dtype, params_dtype=dtype,
return_bias=return_bias,
bias=bias,
device="cuda", device="cuda",
).eval() )
torch_ln_mlp = ( torch_ln_mlp = (
TorchLayerNormMLP( TorchLayerNormMLP(
...@@ -1357,21 +1396,22 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): ...@@ -1357,21 +1396,22 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
4 * config.hidden_size, 4 * config.hidden_size,
activation=activation, activation=activation,
normalization=normalization, normalization=normalization,
bias=bias,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
.eval()
) )
# Share params # Share params
with torch.no_grad(): with torch.no_grad():
torch_ln_mlp.ln.weight = Parameter(te_ln_mlp.layer_norm_weight.clone()) torch_ln_mlp.ln.weight = Parameter(te_ln_mlp.te_module.layer_norm_weight.clone())
if normalization != "RMSNorm": if normalization != "RMSNorm":
torch_ln_mlp.ln.bias = Parameter(te_ln_mlp.layer_norm_bias.clone()) torch_ln_mlp.ln.bias = Parameter(te_ln_mlp.te_module.layer_norm_bias.clone())
torch_ln_mlp.fc1.weight = Parameter(te_ln_mlp.fc1_weight.clone()) torch_ln_mlp.fc1.weight = Parameter(te_ln_mlp.te_module.fc1_weight.clone())
torch_ln_mlp.fc1.bias = Parameter(te_ln_mlp.fc1_bias.clone()) torch_ln_mlp.fc2.weight = Parameter(te_ln_mlp.te_module.fc2_weight.clone())
torch_ln_mlp.fc2.weight = Parameter(te_ln_mlp.fc2_weight.clone()) if bias:
torch_ln_mlp.fc2.bias = Parameter(te_ln_mlp.fc2_bias.clone()) torch_ln_mlp.fc1.bias = Parameter(te_ln_mlp.te_module.fc1_bias.clone())
torch_ln_mlp.fc2.bias = Parameter(te_ln_mlp.te_module.fc2_bias.clone())
te_outputs = _test_granular_accuracy(te_ln_mlp, bs, dtype, config) te_outputs = _test_granular_accuracy(te_ln_mlp, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_ln_mlp, bs, dtype, config) torch_outputs = _test_granular_accuracy(torch_ln_mlp, bs, dtype, config)
......
...@@ -351,6 +351,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -351,6 +351,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
&pre_gelu_out, sizeof(pre_gelu_out))); &pre_gelu_out, sizeof(pre_gelu_out)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat))); operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
} }
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE,
......
...@@ -36,6 +36,10 @@ namespace transformer_engine::pytorch { ...@@ -36,6 +36,10 @@ namespace transformer_engine::pytorch {
namespace detail { namespace detail {
bool is_low_precision(const DType type) {
return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2;
}
std::vector<size_t> getGemmOutputShape(const NVTEShape& A_shape, const bool transa, std::vector<size_t> getGemmOutputShape(const NVTEShape& A_shape, const bool transa,
const NVTEShape& B_shape, const bool transb) { const NVTEShape& B_shape, const bool transb) {
// Flatten outer dims to get 2D matrices // Flatten outer dims to get 2D matrices
...@@ -96,6 +100,9 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans ...@@ -96,6 +100,9 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
TensorWrapper A_tensor = makeTransformerEngineTensor(A, none); TensorWrapper A_tensor = makeTransformerEngineTensor(A, none);
TensorWrapper B_tensor = makeTransformerEngineTensor(B, none); TensorWrapper B_tensor = makeTransformerEngineTensor(B, none);
const bool low_precision =
detail::is_low_precision(A_tensor.dtype()) || detail::is_low_precision(B_tensor.dtype());
// Check tensor dimensions // Check tensor dimensions
const auto& A_shape = A_tensor.shape(); const auto& A_shape = A_tensor.shape();
const auto& B_shape = B_tensor.shape(); const auto& B_shape = B_tensor.shape();
...@@ -137,7 +144,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans ...@@ -137,7 +144,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
// Activation input tensor // Activation input tensor
MaybeTensor pre_gelu_out = std::nullopt; MaybeTensor pre_gelu_out = std::nullopt;
DType gelu_type = bias_type; DType gelu_type = low_precision ? bias_type : D_tensor.dtype();
if (gelu) { if (gelu) {
if (!grad) { if (!grad) {
auto dtype = GetATenDType(gelu_type); auto dtype = GetATenDType(gelu_type);
......
...@@ -79,7 +79,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -79,7 +79,6 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias: Union[torch.Tensor, None], ln_bias: Union[torch.Tensor, None],
weight: torch.Tensor, weight: torch.Tensor,
bias: torch.Tensor, bias: torch.Tensor,
use_bias: bool,
eps: float, eps: float,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
fp8: bool, fp8: bool,
...@@ -422,7 +421,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -422,7 +421,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias ctx.use_bias = bias is not None
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp_shape ctx.inp_shape = inp_shape
...@@ -756,10 +755,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -756,10 +755,6 @@ class _LayerNormLinear(torch.autograd.Function):
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme # TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
clear_tensor_data(ln_out_total) clear_tensor_data(ln_out_total)
# Don't return grad bias if not needed
if not ctx.use_bias:
grad_bias = None
# Synchronize tensor parallel communication # Synchronize tensor parallel communication
if ln_out_total_work is not None: if ln_out_total_work is not None:
ln_out_total_work.wait() ln_out_total_work.wait()
...@@ -841,7 +836,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -841,7 +836,6 @@ class _LayerNormLinear(torch.autograd.Function):
dbeta, dbeta,
wgrad, wgrad,
grad_bias, grad_bias,
None, # use_bias
None, # eps None, # eps
None, # is_first_microbatch None, # is_first_microbatch
None, # fp8 None, # fp8
...@@ -1344,8 +1338,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1344,8 +1338,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.layer_norm_weight, self.layer_norm_weight,
self.layer_norm_bias, self.layer_norm_bias,
weight_tensor, weight_tensor,
bias_tensor, bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
self.apply_bias and not self.gemm_bias_unfused_add,
self.eps, self.eps,
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
......
...@@ -139,10 +139,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -139,10 +139,8 @@ class _LayerNormMLP(torch.autograd.Function):
ln_bias: torch.Tensor, ln_bias: torch.Tensor,
fc1_weight: torch.Tensor, fc1_weight: torch.Tensor,
fc1_bias: torch.Tensor, fc1_bias: torch.Tensor,
use_fc1_bias: bool,
fc2_weight: torch.Tensor, fc2_weight: torch.Tensor,
fc2_bias: torch.Tensor, fc2_bias: torch.Tensor,
use_fc2_bias: bool,
eps: float, eps: float,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
fp8: bool, fp8: bool,
...@@ -367,7 +365,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -367,7 +365,7 @@ class _LayerNormMLP(torch.autograd.Function):
# FC1 GEMM # FC1 GEMM
# There are 2 fussions possible: # There are 2 fusions possible:
# - gemm_gelu_fusion - default for full precision, optional for fp8 - need to turn on gemm_gelu_fusion, # - gemm_gelu_fusion - default for full precision, optional for fp8 - need to turn on gemm_gelu_fusion,
# - bias_gelu_fusion - only for full precision. # - bias_gelu_fusion - only for full precision.
# If both gemm_gelu_fusion and bias_gelu_fusion are enabled, only bias_gelu_fusion will be performer # If both gemm_gelu_fusion and bias_gelu_fusion are enabled, only bias_gelu_fusion will be performer
...@@ -452,8 +450,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -452,8 +450,7 @@ class _LayerNormMLP(torch.autograd.Function):
) )
if not is_grad_enabled: if not is_grad_enabled:
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out) clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)
else:
if is_grad_enabled:
if cpu_offloading: if cpu_offloading:
if fp8 and fc1_weight_final is not None: if fp8 and fc1_weight_final is not None:
set_offloading_param(fc1_weight_final, "weight_offloading", True) set_offloading_param(fc1_weight_final, "weight_offloading", True)
...@@ -536,9 +533,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -536,9 +533,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
ctx.use_fc1_bias = use_fc1_bias ctx.use_bias = fc2_bias is not None
ctx.use_fc2_bias = use_fc2_bias
ctx.use_bias = ctx.use_fc1_bias
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp_shape ctx.inp_shape = inp_shape
...@@ -773,14 +768,13 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -773,14 +768,13 @@ class _LayerNormMLP(torch.autograd.Function):
quantization_params=None, # wgrad in high precision quantization_params=None, # wgrad in high precision
layout="NT", layout="NT",
grad=True, grad=True,
bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None, bias=fc2_bias if fc2_bias_grad is None else None,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
) )
if fc2_bias_grad is None: if fc2_bias_grad is None:
fc2_bias_grad = fc2_bias_grad_ fc2_bias_grad = fc2_bias_grad_
del fc2_bias_grad_
clear_tensor_data(act_out) clear_tensor_data(act_out)
# bias computation # bias computation
...@@ -1045,11 +1039,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1045,11 +1039,9 @@ class _LayerNormMLP(torch.autograd.Function):
dgamma, dgamma,
dbeta, dbeta,
fc1_wgrad, fc1_wgrad,
fc1_bias_grad if ctx.use_fc1_bias else None, fc1_bias_grad if fc1_bias is not None else None,
None, # use_fc1_bias
fc2_wgrad, # pylint: disable=possibly-used-before-assignment fc2_wgrad, # pylint: disable=possibly-used-before-assignment
fc2_bias_grad if ctx.use_fc2_bias else None, fc2_bias_grad,
None, # use_fc2_bias
None, # eps None, # eps
None, # is_first_microbatch None, # is_first_microbatch
None, # fp8 None, # fp8
...@@ -1469,10 +1461,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1469,10 +1461,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.layer_norm_bias, self.layer_norm_bias,
fc1_weight, fc1_weight,
fc1_bias, fc1_bias,
self.use_bias,
fc2_weight, fc2_weight,
fc2_bias, fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None,
self.apply_bias and not self.gemm_bias_unfused_add,
self.eps, self.eps,
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment