Unverified Commit 01be5b48 authored by Rohit Dwivedula's avatar Rohit Dwivedula Committed by GitHub
Browse files

adds: extra_repr() to MambaRMSNorm to include hidden size / size of weights in the layer (#32171)

* adds: extra_repr() to MambaRMSNorm to include the hidden size of the layer

* style fix with ruff:
parent c85510f9
......@@ -327,6 +327,9 @@ class MambaRMSNorm(nn.Module):
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{self.weight.shape[0]}, eps={self.variance_epsilon}"
class MambaBlock(nn.Module):
def __init__(self, config, layer_idx):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment