Unverified Commit 6467b635 authored by Tib's avatar Tib Committed by GitHub
Browse files

[Bugfix] Add missing activation attr to RMSNormGated (#35423)


Signed-off-by: default avatartibG <naps@qubes.milou>
Co-authored-by: default avatartibG <naps@qubes.milou>
parent 9c3fe993
......@@ -510,6 +510,7 @@ class RMSNormGated(CustomOp):
norm_before_gate: bool = False,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
activation: str = "swish",
):
"""Initialize RMSNormGated.
......@@ -524,10 +525,12 @@ class RMSNormGated(CustomOp):
If False and z is provided: out = norm(x * silu(z))
device: Device to create parameters on
dtype: Data type for parameters
activation: Activation function name for gating
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.activation = activation
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter("bias", None)
self.group_size = group_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment