Unverified Commit 3cc2c1d2 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

AMP support for LN and RMSNorm (#371)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 88c88654
...@@ -15,6 +15,8 @@ from transformer_engine.pytorch import ( ...@@ -15,6 +15,8 @@ from transformer_engine.pytorch import (
Linear, Linear,
LayerNormMLP, LayerNormMLP,
TransformerLayer, TransformerLayer,
RMSNorm,
LayerNorm,
) )
from transformer_engine.common import recipe from transformer_engine.common import recipe
...@@ -308,6 +310,50 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_d ...@@ -308,6 +310,50 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_d
torch.cuda.synchronize() torch.cuda.synchronize()
def _test_sanity_normalization_amp(block, bs, dtype, config, skip_wgrad, skip_dgrad):
if skip_dgrad and skip_wgrad:
pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")
te_inp = torch.randn(
config.seq_len, bs, config.hidden_size, requires_grad=True
).cuda()
te_inp.retain_grad()
with torch.autocast(device_type="cuda", enabled=True, dtype=dtype):
te_out = block(te_inp)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
assert te_out.dtype == dtype, "AMP wrong output type."
assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type."
for name, p in block.named_parameters():
if p.requires_grad:
assert p.grad.dtype == torch.float32, f"AMP wrong wgrad type for {name}."
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_normalization_amp(dtype, bs, model, skip_wgrad, skip_dgrad, normalization):
config = model_configs[model]
module = RMSNorm if normalization == "RMSNorm" else LayerNorm
block = (
module(
config.hidden_size,
eps=config.eps,
)
.to(dtype=torch.float32)
.cuda()
)
_test_sanity_normalization_amp(block, bs, dtype, config, skip_wgrad, skip_dgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
......
...@@ -11,10 +11,12 @@ from torch.nn.parameter import Parameter ...@@ -11,10 +11,12 @@ from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from .base import TransformerEngineBaseModule
from ..cpp_extensions import ( from ..cpp_extensions import (
layernorm_fwd_inf, layernorm_fwd_inf,
) )
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..utils import cast_if_needed
__all__ = ["LayerNorm"] __all__ = ["LayerNorm"]
...@@ -33,6 +35,7 @@ class _LayerNorm(torch.autograd.Function): ...@@ -33,6 +35,7 @@ class _LayerNorm(torch.autograd.Function):
bwd_ln_sm_margin: int, bwd_ln_sm_margin: int,
zero_centered_gamma: bool, zero_centered_gamma: bool,
is_grad_enabled: bool, is_grad_enabled: bool,
activation_dtype: torch.dtype,
) -> torch.Tensor: ) -> torch.Tensor:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -40,6 +43,11 @@ class _LayerNorm(torch.autograd.Function): ...@@ -40,6 +43,11 @@ class _LayerNorm(torch.autograd.Function):
assert inp.shape[-1] == in_features, "LayerNorm not possible" assert inp.shape[-1] == in_features, "LayerNorm not possible"
inputmat = inp.view((-1, in_features)) inputmat = inp.view((-1, in_features))
# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
ln_weight = cast_if_needed(ln_weight, activation_dtype)
ln_bias = cast_if_needed(ln_bias, activation_dtype)
if is_grad_enabled: if is_grad_enabled:
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight,
ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma) ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma)
...@@ -63,7 +71,7 @@ class _LayerNorm(torch.autograd.Function): ...@@ -63,7 +71,7 @@ class _LayerNorm(torch.autograd.Function):
d_ln_out, inputmat, mu, rsigma, ln_weight, d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
) )
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None, None return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None, None, None
class LayerNorm(torch.nn.Module): class LayerNorm(torch.nn.Module):
...@@ -170,6 +178,9 @@ class LayerNorm(torch.nn.Module): ...@@ -170,6 +178,9 @@ class LayerNorm(torch.nn.Module):
if hasattr(self, "layer_norm_bias"): if hasattr(self, "layer_norm_bias"):
setattr(self, "bias", self.layer_norm_bias) setattr(self, "bias", self.layer_norm_bias)
# Set the activation type for AMP.
TransformerEngineBaseModule.set_activation_dtype(self, inp)
if torch.is_grad_enabled(): if torch.is_grad_enabled():
fwd_fn = _LayerNorm.apply fwd_fn = _LayerNorm.apply
args = [] args = []
...@@ -185,7 +196,8 @@ class LayerNorm(torch.nn.Module): ...@@ -185,7 +196,8 @@ class LayerNorm(torch.nn.Module):
self.fwd_ln_sm_margin, self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
torch.is_grad_enabled() torch.is_grad_enabled(),
self.activation_dtype,
) )
return fwd_fn(*args) return fwd_fn(*args)
...@@ -10,8 +10,10 @@ import torch ...@@ -10,8 +10,10 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
from .base import TransformerEngineBaseModule
from .. import cpp_extensions as tex from .. import cpp_extensions as tex
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..utils import cast_if_needed
__all__ = ["RMSNorm"] __all__ = ["RMSNorm"]
...@@ -30,6 +32,7 @@ class _RMSNorm(torch.autograd.Function): ...@@ -30,6 +32,7 @@ class _RMSNorm(torch.autograd.Function):
bwd_rmsnorm_sm_margin: int, bwd_rmsnorm_sm_margin: int,
zero_centered_gamma: bool, zero_centered_gamma: bool,
is_grad_enabled: bool, is_grad_enabled: bool,
activation_dtype: torch.dtype,
) -> torch.Tensor: ) -> torch.Tensor:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = rmsnorm_weight.numel() in_features = rmsnorm_weight.numel()
...@@ -37,6 +40,10 @@ class _RMSNorm(torch.autograd.Function): ...@@ -37,6 +40,10 @@ class _RMSNorm(torch.autograd.Function):
assert inp.shape[-1] == in_features, "RMSNorm not possible" assert inp.shape[-1] == in_features, "RMSNorm not possible"
inputmat = inp.view((-1, in_features)) inputmat = inp.view((-1, in_features))
# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
rmsnorm_weight = cast_if_needed(rmsnorm_weight, activation_dtype)
if is_grad_enabled: if is_grad_enabled:
rmsnorm_out, rsigma = tex.rmsnorm_fwd(inputmat, rmsnorm_weight, rmsnorm_out, rsigma = tex.rmsnorm_fwd(inputmat, rmsnorm_weight,
eps, fwd_rmsnorm_sm_margin, eps, fwd_rmsnorm_sm_margin,
...@@ -70,6 +77,7 @@ class _RMSNorm(torch.autograd.Function): ...@@ -70,6 +77,7 @@ class _RMSNorm(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -148,6 +156,10 @@ class RMSNorm(torch.nn.Module): ...@@ -148,6 +156,10 @@ class RMSNorm(torch.nn.Module):
@no_torch_dynamo @no_torch_dynamo
def forward(self, inp: torch.Tensor) -> torch.Tensor: def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""RMSNorm FWD""" """RMSNorm FWD"""
# Set the activation type for AMP.
TransformerEngineBaseModule.set_activation_dtype(self, inp)
if torch.is_grad_enabled(): if torch.is_grad_enabled():
fwd_fn = _RMSNorm.apply fwd_fn = _RMSNorm.apply
args = [] args = []
...@@ -162,7 +174,8 @@ class RMSNorm(torch.nn.Module): ...@@ -162,7 +174,8 @@ class RMSNorm(torch.nn.Module):
self.fwd_rmsnorm_sm_margin, self.fwd_rmsnorm_sm_margin,
self.bwd_rmsnorm_sm_margin, self.bwd_rmsnorm_sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
torch.is_grad_enabled() torch.is_grad_enabled(),
self.activation_dtype,
) )
return fwd_fn(*args) return fwd_fn(*args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment