Unverified Commit 6bd35bf9 authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

Fix FP32 LayerNorm ONNX export (#313)



* Fix FP32 LayerNorm ONNX export

When running inference use a fwd method that is registered with torchscript.
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* Bug fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 11c5d588
......@@ -620,20 +620,13 @@ def test_export_layernorm(
class Test_Layernorm(nn.Module):
def __init__(self) -> None:
super().__init__()
normalized_shape = torch.Size(inp.shape[1:])
self.weight = torch.randn(*normalized_shape, device="cuda",
dtype=torch.float if fake_bf16_io else precision)
self.bias = torch.zeros(*normalized_shape, device="cuda",
dtype=torch.float if fake_bf16_io else precision)
self.eps = 1e-6 # An arbitrary small value
eps = 1e-6 # An arbitrary small value
dtype = torch.float if fake_bf16_io else precision
self.ln = te.LayerNorm(inp_shape[1], eps, params_dtype=dtype,
zero_centered_gamma=False).eval().cuda()
def forward(self, inp):
ret = texcpp.layernorm_fwd_inf(
inp,
self.weight,
self.bias,
self.eps,
zero_centered_gamma)
ret = self.ln(inp)
return ret
class TestFP8_Layernorm(nn.Module):
......
......@@ -11,7 +11,9 @@ from torch.nn.parameter import Parameter
from torch.nn import init
import transformer_engine_extensions as tex
from ..cpp_extensions import (
layernorm_fwd_inf,
)
__all__ = ["LayerNorm"]
......@@ -29,6 +31,7 @@ class _LayerNorm(torch.autograd.Function):
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
is_grad_enabled: bool,
) -> torch.Tensor:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -36,13 +39,16 @@ class _LayerNorm(torch.autograd.Function):
assert inp.shape[-1] == in_features, "LayerNorm not possible"
inputmat = inp.view((-1, in_features))
if is_grad_enabled:
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight,
ln_bias, eps, fwd_ln_sm_margin,
zero_centered_gamma)
ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma)
ctx.save_for_backward(inputmat, ln_weight, mu, rsigma)
ctx.inp_shape = inp.shape
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
else:
ln_out, mu, rsigma = layernorm_fwd_inf(inputmat, ln_weight,
ln_bias, eps, zero_centered_gamma), None, None
return ln_out.view_as(inp)
@staticmethod
......@@ -56,7 +62,7 @@ class _LayerNorm(torch.autograd.Function):
d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None, None
class LayerNorm(torch.nn.Module):
......@@ -162,12 +168,22 @@ class LayerNorm(torch.nn.Module):
if hasattr(self, "layer_norm_bias"):
setattr(self, "bias", self.layer_norm_bias)
return _LayerNorm.apply(
if torch.is_grad_enabled():
fwd_fn = _LayerNorm.apply
args = []
else:
fwd_fn = _LayerNorm.forward
args = [None]
args += (
inp,
self.weight,
self.bias,
self.eps,
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma
self.zero_centered_gamma,
torch.is_grad_enabled()
)
return fwd_fn(*args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment