Unverified Commit 99f40677 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Fix return_bias option in LayerNormLinear and LayerNormMLP (#1569)



* Do not apply bias when apply_bias is False
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Bwd fix for LNMLP and tests
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix for the dbias calculation
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Improve tests and cleaning the logic
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Tightened test tolerances a little
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Revert "Tightened test tolerances a little"

This reverts commit 2e20a92c884a84759006541adc1d638ab91dde62.
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Update tests/pytorch/test_numerics.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>

* Fix the Gelu Aux type
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove use_fc1_bias option
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent bee4649c
......@@ -5,7 +5,7 @@
from collections import OrderedDict
import math
import os
from typing import Dict, List, Optional
from typing import Dict, List, Tuple, Optional
import pytest
import copy
import random
......@@ -331,9 +331,9 @@ class TorchLayerNormLinear(nn.Module):
in_features: int,
out_features: int,
eps: float,
bias: bool = True,
normalization: str = "LayerNorm",
zero_centered_gamma: bool = False,
bias: bool = True,
):
super().__init__()
if normalization == "LayerNorm":
......@@ -347,7 +347,7 @@ class TorchLayerNormLinear(nn.Module):
else:
raise RuntimeError("Unsupported normalization")
self.linear = nn.Linear(in_features, out_features)
self.linear = nn.Linear(in_features, out_features, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(self.layernorm(x))
......@@ -447,6 +447,7 @@ class TorchLayerNormMLP(nn.Module):
eps: float = 1e-5,
activation="gelu",
normalization: str = "LayerNorm",
bias: bool = True,
):
super().__init__()
if normalization == "LayerNorm":
......@@ -462,8 +463,8 @@ class TorchLayerNormMLP(nn.Module):
fc1_output_features = ffn_hidden_size
self.gelu = _supported_act[activation]
self.fc1 = nn.Linear(hidden_size, fc1_output_features)
self.fc2 = nn.Linear(ffn_hidden_size, hidden_size)
self.fc1 = nn.Linear(hidden_size, fc1_output_features, bias=bias)
self.fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=bias)
def forward(self, x):
t = self.gelu(self.fc1(self.ln(x)))
......@@ -1039,6 +1040,8 @@ def _test_granular_accuracy(block, bs, dtype, config):
inp_hidden_states.retain_grad()
out = block(inp_hidden_states)
if isinstance(out, (List, Tuple)):
out = out[0]
loss = out.sum()
loss.backward()
......@@ -1117,32 +1120,53 @@ def test_dpa_accuracy(dtype, bs, model):
assert_allclose(te_output, torch_output, atol=5e-2, rtol=1e-2)
class TestReturnBiasModule(nn.Module):
def __init__(self, mod, **kwargs):
super().__init__()
self.te_module = mod(**kwargs)
self.return_bias = kwargs["return_bias"]
self.bias = kwargs["bias"]
def forward(self, x):
if self.return_bias:
out, bias = self.te_module(x)
if self.bias:
out = out + bias
return out
return self.te_module(x)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"])
def test_linear_accuracy(dtype, bs, model):
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_linear_accuracy(dtype, bs, model, return_bias, bias):
config = model_configs[model]
te_linear = Linear(
config.hidden_size,
4 * config.hidden_size,
bias=True,
te_linear = TestReturnBiasModule(
Linear,
in_features=config.hidden_size,
out_features=4 * config.hidden_size,
params_dtype=dtype,
return_bias=return_bias,
bias=bias,
device="cuda",
).eval()
)
torch_linear = torch.nn.Linear(
config.hidden_size,
4 * config.hidden_size,
bias=True,
bias=bias,
device="cuda",
dtype=dtype,
).eval()
)
# Share params
with torch.no_grad():
torch_linear.weight = Parameter(te_linear.weight.clone())
torch_linear.bias = Parameter(te_linear.bias.clone())
torch_linear.weight = Parameter(te_linear.te_module.weight.clone())
if bias:
torch_linear.bias = Parameter(te_linear.te_module.bias.clone())
te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_linear, bs, dtype, config)
......@@ -1265,41 +1289,51 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centered_gamma):
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_layernorm_linear_accuracy(
dtype, bs, model, normalization, zero_centered_gamma, return_bias, bias
):
config = model_configs[model]
te_ln_linear = LayerNormLinear(
config.hidden_size,
4 * config.hidden_size,
config.eps,
bias=True,
te_ln_linear = TestReturnBiasModule(
LayerNormLinear,
in_features=config.hidden_size,
out_features=4 * config.hidden_size,
eps=config.eps,
normalization=normalization,
params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
return_bias=return_bias,
bias=bias,
device="cuda",
).eval()
)
torch_ln_linear = (
TorchLayerNormLinear(
config.hidden_size,
4 * config.hidden_size,
config.eps,
bias=True,
normalization=normalization,
zero_centered_gamma=zero_centered_gamma,
bias=bias,
)
.to(dtype=dtype)
.cuda()
.eval()
)
# Share params
with torch.no_grad():
torch_ln_linear.layernorm.weight = Parameter(te_ln_linear.layer_norm_weight.clone())
torch_ln_linear.layernorm.weight = Parameter(
te_ln_linear.te_module.layer_norm_weight.clone()
)
if normalization != "RMSNorm":
torch_ln_linear.layernorm.bias = Parameter(te_ln_linear.layer_norm_bias.clone())
torch_ln_linear.linear.weight = Parameter(te_ln_linear.weight.clone())
torch_ln_linear.linear.bias = Parameter(te_ln_linear.bias.clone())
torch_ln_linear.layernorm.bias = Parameter(
te_ln_linear.te_module.layer_norm_bias.clone()
)
torch_ln_linear.linear.weight = Parameter(te_ln_linear.te_module.weight.clone())
if bias:
torch_ln_linear.linear.bias = Parameter(te_ln_linear.te_module.bias.clone())
te_outputs = _test_granular_accuracy(te_ln_linear, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)
......@@ -1339,17 +1373,22 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
@pytest.mark.parametrize("return_bias", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, return_bias, bias):
config = model_configs[model]
te_ln_mlp = LayerNormMLP(
config.hidden_size,
4 * config.hidden_size,
te_ln_mlp = TestReturnBiasModule(
LayerNormMLP,
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
activation=activation,
normalization=normalization,
params_dtype=dtype,
return_bias=return_bias,
bias=bias,
device="cuda",
).eval()
)
torch_ln_mlp = (
TorchLayerNormMLP(
......@@ -1357,21 +1396,22 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
4 * config.hidden_size,
activation=activation,
normalization=normalization,
bias=bias,
)
.to(dtype=dtype)
.cuda()
.eval()
)
# Share params
with torch.no_grad():
torch_ln_mlp.ln.weight = Parameter(te_ln_mlp.layer_norm_weight.clone())
torch_ln_mlp.ln.weight = Parameter(te_ln_mlp.te_module.layer_norm_weight.clone())
if normalization != "RMSNorm":
torch_ln_mlp.ln.bias = Parameter(te_ln_mlp.layer_norm_bias.clone())
torch_ln_mlp.fc1.weight = Parameter(te_ln_mlp.fc1_weight.clone())
torch_ln_mlp.fc1.bias = Parameter(te_ln_mlp.fc1_bias.clone())
torch_ln_mlp.fc2.weight = Parameter(te_ln_mlp.fc2_weight.clone())
torch_ln_mlp.fc2.bias = Parameter(te_ln_mlp.fc2_bias.clone())
torch_ln_mlp.ln.bias = Parameter(te_ln_mlp.te_module.layer_norm_bias.clone())
torch_ln_mlp.fc1.weight = Parameter(te_ln_mlp.te_module.fc1_weight.clone())
torch_ln_mlp.fc2.weight = Parameter(te_ln_mlp.te_module.fc2_weight.clone())
if bias:
torch_ln_mlp.fc1.bias = Parameter(te_ln_mlp.te_module.fc1_bias.clone())
torch_ln_mlp.fc2.bias = Parameter(te_ln_mlp.te_module.fc2_bias.clone())
te_outputs = _test_granular_accuracy(te_ln_mlp, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_ln_mlp, bs, dtype, config)
......
......@@ -351,6 +351,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
&pre_gelu_out, sizeof(pre_gelu_out)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
}
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE,
......
......@@ -36,6 +36,10 @@ namespace transformer_engine::pytorch {
namespace detail {
bool is_low_precision(const DType type) {
return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2;
}
std::vector<size_t> getGemmOutputShape(const NVTEShape& A_shape, const bool transa,
const NVTEShape& B_shape, const bool transb) {
// Flatten outer dims to get 2D matrices
......@@ -96,6 +100,9 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
TensorWrapper A_tensor = makeTransformerEngineTensor(A, none);
TensorWrapper B_tensor = makeTransformerEngineTensor(B, none);
const bool low_precision =
detail::is_low_precision(A_tensor.dtype()) || detail::is_low_precision(B_tensor.dtype());
// Check tensor dimensions
const auto& A_shape = A_tensor.shape();
const auto& B_shape = B_tensor.shape();
......@@ -137,7 +144,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
// Activation input tensor
MaybeTensor pre_gelu_out = std::nullopt;
DType gelu_type = bias_type;
DType gelu_type = low_precision ? bias_type : D_tensor.dtype();
if (gelu) {
if (!grad) {
auto dtype = GetATenDType(gelu_type);
......
......@@ -79,7 +79,6 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias: Union[torch.Tensor, None],
weight: torch.Tensor,
bias: torch.Tensor,
use_bias: bool,
eps: float,
is_first_microbatch: Union[bool, None],
fp8: bool,
......@@ -422,7 +421,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias
ctx.use_bias = bias is not None
ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp_shape
......@@ -756,10 +755,6 @@ class _LayerNormLinear(torch.autograd.Function):
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
clear_tensor_data(ln_out_total)
# Don't return grad bias if not needed
if not ctx.use_bias:
grad_bias = None
# Synchronize tensor parallel communication
if ln_out_total_work is not None:
ln_out_total_work.wait()
......@@ -841,7 +836,6 @@ class _LayerNormLinear(torch.autograd.Function):
dbeta,
wgrad,
grad_bias,
None, # use_bias
None, # eps
None, # is_first_microbatch
None, # fp8
......@@ -1344,8 +1338,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor,
self.apply_bias and not self.gemm_bias_unfused_add,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
self.eps,
is_first_microbatch,
self.fp8,
......
......@@ -139,10 +139,8 @@ class _LayerNormMLP(torch.autograd.Function):
ln_bias: torch.Tensor,
fc1_weight: torch.Tensor,
fc1_bias: torch.Tensor,
use_fc1_bias: bool,
fc2_weight: torch.Tensor,
fc2_bias: torch.Tensor,
use_fc2_bias: bool,
eps: float,
is_first_microbatch: Union[bool, None],
fp8: bool,
......@@ -367,7 +365,7 @@ class _LayerNormMLP(torch.autograd.Function):
# FC1 GEMM
# There are 2 fussions possible:
# There are 2 fusions possible:
# - gemm_gelu_fusion - default for full precision, optional for fp8 - need to turn on gemm_gelu_fusion,
# - bias_gelu_fusion - only for full precision.
# If both gemm_gelu_fusion and bias_gelu_fusion are enabled, only bias_gelu_fusion will be performer
......@@ -452,8 +450,7 @@ class _LayerNormMLP(torch.autograd.Function):
)
if not is_grad_enabled:
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)
if is_grad_enabled:
else:
if cpu_offloading:
if fp8 and fc1_weight_final is not None:
set_offloading_param(fc1_weight_final, "weight_offloading", True)
......@@ -536,9 +533,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
ctx.use_fc1_bias = use_fc1_bias
ctx.use_fc2_bias = use_fc2_bias
ctx.use_bias = ctx.use_fc1_bias
ctx.use_bias = fc2_bias is not None
ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp_shape
......@@ -773,14 +768,13 @@ class _LayerNormMLP(torch.autograd.Function):
quantization_params=None, # wgrad in high precision
layout="NT",
grad=True,
bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None,
bias=fc2_bias if fc2_bias_grad is None else None,
accumulate=accumulate_wgrad_into_param_main_grad,
use_split_accumulator=_2X_ACC_WGRAD,
out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
if fc2_bias_grad is None:
fc2_bias_grad = fc2_bias_grad_
del fc2_bias_grad_
clear_tensor_data(act_out)
# bias computation
......@@ -1045,11 +1039,9 @@ class _LayerNormMLP(torch.autograd.Function):
dgamma,
dbeta,
fc1_wgrad,
fc1_bias_grad if ctx.use_fc1_bias else None,
None, # use_fc1_bias
fc1_bias_grad if fc1_bias is not None else None,
fc2_wgrad, # pylint: disable=possibly-used-before-assignment
fc2_bias_grad if ctx.use_fc2_bias else None,
None, # use_fc2_bias
fc2_bias_grad,
None, # eps
None, # is_first_microbatch
None, # fp8
......@@ -1469,10 +1461,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.layer_norm_bias,
fc1_weight,
fc1_bias,
self.use_bias,
fc2_weight,
fc2_bias,
self.apply_bias and not self.gemm_bias_unfused_add,
fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None,
self.eps,
is_first_microbatch,
self.fp8,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment