Unverified Commit 41425476 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] Fix ONNX export errors (#2406)



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 8ef8285c
......@@ -356,7 +356,9 @@ def onnx_layernorm(
)
if normalization == "RMSNorm":
ln_out = torch.nn.functional.rms_norm(inp, inp.shape[-1:], ln_weight, eps)
variance = inp.pow(2).mean(-1, keepdim=True)
ln_out = inp * torch.rsqrt(variance + eps)
ln_out = ln_out * ln_weight
else:
ln_out = torch.nn.functional.layer_norm(
inp, inp.shape[-1:], ln_weight, layer_norm_bias, eps
......
......@@ -249,4 +249,6 @@ class RMSNorm(BasicOperation):
) -> torch.Tensor:
"""Every operand in this function has a defined ONNX translation."""
weight = self.weight + 1 if self.zero_centered_gamma else self.weight
return torch.nn.functional.rms_norm(input_, input_.shape[-1:], weight, self.eps)
variance = input_.pow(2).mean(-1, keepdim=True)
normalized = input_ * torch.rsqrt(variance + self.eps)
return normalized * weight
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment