"vscode:/vscode.git/clone" did not exist on "de0fabbc5c84e6771d70b92014ae06fe82654ff0"
Unverified Commit 41425476 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] Fix ONNX export errors (#2406)



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 8ef8285c
...@@ -356,7 +356,9 @@ def onnx_layernorm( ...@@ -356,7 +356,9 @@ def onnx_layernorm(
) )
if normalization == "RMSNorm": if normalization == "RMSNorm":
ln_out = torch.nn.functional.rms_norm(inp, inp.shape[-1:], ln_weight, eps) variance = inp.pow(2).mean(-1, keepdim=True)
ln_out = inp * torch.rsqrt(variance + eps)
ln_out = ln_out * ln_weight
else: else:
ln_out = torch.nn.functional.layer_norm( ln_out = torch.nn.functional.layer_norm(
inp, inp.shape[-1:], ln_weight, layer_norm_bias, eps inp, inp.shape[-1:], ln_weight, layer_norm_bias, eps
......
...@@ -249,4 +249,6 @@ class RMSNorm(BasicOperation): ...@@ -249,4 +249,6 @@ class RMSNorm(BasicOperation):
) -> torch.Tensor: ) -> torch.Tensor:
"""Every operand in this function has a defined ONNX translation.""" """Every operand in this function has a defined ONNX translation."""
weight = self.weight + 1 if self.zero_centered_gamma else self.weight weight = self.weight + 1 if self.zero_centered_gamma else self.weight
return torch.nn.functional.rms_norm(input_, input_.shape[-1:], weight, self.eps) variance = input_.pow(2).mean(-1, keepdim=True)
normalized = input_ * torch.rsqrt(variance + self.eps)
return normalized * weight
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment