"docs/vscode:/vscode.git/clone" did not exist on "96e90fdeb3c4ebacfe24513556afccb918722b7c"
Unverified Commit 89361181 authored by Tao He's avatar Tao He Committed by GitHub
Browse files

[Qwen][Bugfix] Fixes sigmoid activation in torch impl of RMSNormGated. (#40245)


Signed-off-by: default avatarTao He <linzhu.ht@alibaba-inc.com>
parent 67ed01c3
......@@ -478,9 +478,12 @@ class RMSNormGated(CustomOp):
weight = self.weight.float()
z = z.float() if z is not None else None
assert self.activation in ["silu", "sigmoid", "swish"]
act_fn = F.sigmoid if self.activation == "sigmoid" else F.silu
# Apply gating before normalization if needed
if z is not None and not self.norm_before_gate:
x = x * F.silu(z)
x = x * act_fn(z)
# RMS Normalization
if self.group_size is None:
......@@ -499,7 +502,7 @@ class RMSNormGated(CustomOp):
# Apply gating after normalization if needed
if z is not None and self.norm_before_gate:
out = out * F.silu(z)
out = out * act_fn(z)
return out.to(orig_dtype)
......
......@@ -357,11 +357,19 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
output_gate_type = getattr(config, "output_gate_type", "silu")
if output_gate_type == "swish":
output_gate_type = "silu"
assert output_gate_type in ["silu", "swish", "sigmoid"], (
f"unsupported {output_gate_type=}"
)
self.norm = RMSNormGated(
self.head_v_dim,
eps=self.layer_norm_epsilon,
group_size=None,
norm_before_gate=True,
activation=output_gate_type,
device=current_platform.current_device(),
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment