"...git@developer.sourcefind.cn:tsoc/superbenchmark.git" did not exist on "26373edb788f695251be7b4a3fcdc6f0c9d7733a"
Unverified Commit 6bd35bf9 authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

Fix FP32 LayerNorm ONNX export (#313)



* Fix FP32 LayerNorm ONNX export

When running inference use a fwd method that is registered with torchscript.
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* Bug fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 11c5d588
...@@ -620,20 +620,13 @@ def test_export_layernorm( ...@@ -620,20 +620,13 @@ def test_export_layernorm(
class Test_Layernorm(nn.Module): class Test_Layernorm(nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
normalized_shape = torch.Size(inp.shape[1:]) eps = 1e-6 # An arbitrary small value
self.weight = torch.randn(*normalized_shape, device="cuda", dtype = torch.float if fake_bf16_io else precision
dtype=torch.float if fake_bf16_io else precision) self.ln = te.LayerNorm(inp_shape[1], eps, params_dtype=dtype,
self.bias = torch.zeros(*normalized_shape, device="cuda", zero_centered_gamma=False).eval().cuda()
dtype=torch.float if fake_bf16_io else precision)
self.eps = 1e-6 # An arbitrary small value
def forward(self, inp): def forward(self, inp):
ret = texcpp.layernorm_fwd_inf( ret = self.ln(inp)
inp,
self.weight,
self.bias,
self.eps,
zero_centered_gamma)
return ret return ret
class TestFP8_Layernorm(nn.Module): class TestFP8_Layernorm(nn.Module):
......
...@@ -11,7 +11,9 @@ from torch.nn.parameter import Parameter ...@@ -11,7 +11,9 @@ from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from ..cpp_extensions import (
layernorm_fwd_inf,
)
__all__ = ["LayerNorm"] __all__ = ["LayerNorm"]
...@@ -29,6 +31,7 @@ class _LayerNorm(torch.autograd.Function): ...@@ -29,6 +31,7 @@ class _LayerNorm(torch.autograd.Function):
fwd_ln_sm_margin: int, fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int, bwd_ln_sm_margin: int,
zero_centered_gamma: bool, zero_centered_gamma: bool,
is_grad_enabled: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -36,13 +39,16 @@ class _LayerNorm(torch.autograd.Function): ...@@ -36,13 +39,16 @@ class _LayerNorm(torch.autograd.Function):
assert inp.shape[-1] == in_features, "LayerNorm not possible" assert inp.shape[-1] == in_features, "LayerNorm not possible"
inputmat = inp.view((-1, in_features)) inputmat = inp.view((-1, in_features))
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, if is_grad_enabled:
ln_bias, eps, fwd_ln_sm_margin, ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight,
zero_centered_gamma) ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma)
ctx.save_for_backward(inputmat, ln_weight, mu, rsigma) ctx.save_for_backward(inputmat, ln_weight, mu, rsigma)
ctx.inp_shape = inp.shape ctx.inp_shape = inp.shape
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma ctx.zero_centered_gamma = zero_centered_gamma
else:
ln_out, mu, rsigma = layernorm_fwd_inf(inputmat, ln_weight,
ln_bias, eps, zero_centered_gamma), None, None
return ln_out.view_as(inp) return ln_out.view_as(inp)
@staticmethod @staticmethod
...@@ -56,7 +62,7 @@ class _LayerNorm(torch.autograd.Function): ...@@ -56,7 +62,7 @@ class _LayerNorm(torch.autograd.Function):
d_ln_out, inputmat, mu, rsigma, ln_weight, d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
) )
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None, None
class LayerNorm(torch.nn.Module): class LayerNorm(torch.nn.Module):
...@@ -162,12 +168,22 @@ class LayerNorm(torch.nn.Module): ...@@ -162,12 +168,22 @@ class LayerNorm(torch.nn.Module):
if hasattr(self, "layer_norm_bias"): if hasattr(self, "layer_norm_bias"):
setattr(self, "bias", self.layer_norm_bias) setattr(self, "bias", self.layer_norm_bias)
return _LayerNorm.apply( if torch.is_grad_enabled():
fwd_fn = _LayerNorm.apply
args = []
else:
fwd_fn = _LayerNorm.forward
args = [None]
args += (
inp, inp,
self.weight, self.weight,
self.bias, self.bias,
self.eps, self.eps,
self.fwd_ln_sm_margin, self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma self.zero_centered_gamma,
torch.is_grad_enabled()
) )
return fwd_fn(*args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment