Commit 4799ca5f authored by laibao's avatar laibao
Browse files

fix: 修复 RMSNormGated 输出类型不一致问题

parent d146a231
......@@ -458,16 +458,18 @@ class RMSNormGated(CustomOp):
- norm_before_gate=True: out = norm(x) * silu(z)
- norm_before_gate=False: out = norm(x * silu(z))
"""
input_dtype = x.dtype
# Apply gating before normalization if needed
if z is not None and not self.norm_before_gate:
x = x * F.silu(z)
x = torch.mul(x, F.silu(z))
# RMS Normalization
if self.group_size is None:
# Standard RMS norm across the last dimension
variance = x.pow(2).mean(dim=-1, keepdim=True)
x_normed = x * torch.rsqrt(variance + self.eps)
out = x_normed * self.weight
out = torch.mul(x_normed, self.weight)
else:
# Group RMS norm
from einops import rearrange
......@@ -475,11 +477,14 @@ class RMSNormGated(CustomOp):
x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
variance = x_group.pow(2).mean(dim=-1, keepdim=True)
x_normed = x_group * torch.rsqrt(variance + self.eps)
out = rearrange(x_normed, "... g d -> ... (g d)") * self.weight
out = torch.mul(rearrange(x_normed, "... g d -> ... (g d)"), self.weight)
# Apply gating after normalization if needed
if z is not None and self.norm_before_gate:
out = out * F.silu(z)
out = torch.mul(out, F.silu(z))
if out.dtype != input_dtype:
out = out.to(input_dtype)
return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment