Unverified Commit a88ce94b authored by Linkun's avatar Linkun Committed by GitHub
Browse files

[IR][RmsNorm] pass None if not has_weight (#38961)


Signed-off-by: default avatarLinkun Chen <github@lkchen.net>
Signed-off-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent 2a36d8fb
......@@ -241,8 +241,12 @@ class RMSNorm(CustomOp):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
if residual is None:
# TODO(luka): address the weight=None passing issue more generally
return ir.ops.rms_norm(
x, self.weight.data, self.variance_epsilon, self.variance_size_override
x,
self.weight.data if self.has_weight else None,
self.variance_epsilon,
self.variance_size_override,
)
return self.forward_static(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment