Unverified Commit f9756d9e authored by Rohit Dwivedula's avatar Rohit Dwivedula Committed by GitHub
Browse files

Adds: extra_repr for RMSNorm layers in most models (#32204)

* adds: extra_repr() to RMSNorm layers in multiple models

* adds: extra_repr for deprecated models as well

* formatting as per style guide
parent b8e5cd53
......@@ -76,6 +76,9 @@ class ChameleonRMSNorm(nn.Module):
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
ALL_LAYERNORM_LAYERS.append(ChameleonRMSNorm)
......
......@@ -239,6 +239,9 @@ class ClvpRMSNorm(nn.Module):
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class ClvpRotaryPositionalEmbedding(nn.Module):
"""
......
......@@ -250,6 +250,9 @@ class MegaRMSNorm(nn.Module):
input * torch.rsqrt(mean_square + self.eps)
return input
def extra_repr(self):
return f"{self.num_features}, eps={self.eps}, affine={self.affine}"
class MegaScaleNorm(nn.Module):
"""
......
......@@ -62,6 +62,9 @@ class OpenLlamaRMSNorm(nn.Module):
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class OpenLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
......
......@@ -180,6 +180,9 @@ class GemmaRMSNorm(nn.Module):
output = output * (1.0 + self.weight.float())
return output.type_as(x)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)
......
......@@ -68,6 +68,9 @@ class GemmaRMSNorm(nn.Module):
output = output * (1.0 + self.weight.float())
return output.type_as(x)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)
......
......@@ -70,6 +70,9 @@ class Gemma2RMSNorm(nn.Module):
output = output * (1.0 + self.weight.float())
return output.type_as(x)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
class Gemma2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
......
......@@ -431,6 +431,9 @@ class IdeficsRMSNorm(nn.Module):
return self.weight * hidden_states
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
ALL_LAYERNORM_LAYERS.append(IdeficsRMSNorm)
......
......@@ -676,6 +676,9 @@ class Idefics2RMSNorm(nn.Module):
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Idefics2PerceiverAttention(nn.Module):
def __init__(self, config, layer_idx: Optional[int] = None) -> None:
......
......@@ -178,6 +178,9 @@ class JambaRMSNorm(nn.Module):
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
......
......@@ -374,6 +374,9 @@ class JetMoeRMSNorm(nn.Module):
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with Gemma->JetMoe
class JetMoeRotaryEmbedding(nn.Module):
......
......@@ -71,6 +71,9 @@ class LlamaRMSNorm(nn.Module):
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
......
......@@ -73,6 +73,9 @@ class MistralRMSNorm(nn.Module):
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class MistralRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
......
......@@ -163,6 +163,9 @@ class MixtralRMSNorm(nn.Module):
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral
# TODO @longjie no longer copied from Mistral after static cache
......
......@@ -74,6 +74,9 @@ class Phi3RMSNorm(nn.Module):
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3
class Phi3RotaryEmbedding(nn.Module):
......
......@@ -78,6 +78,9 @@ class Qwen2RMSNorm(nn.Module):
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
class Qwen2RotaryEmbedding(nn.Module):
......
......@@ -154,6 +154,9 @@ class Qwen2MoeRMSNorm(nn.Module):
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
class Qwen2MoeRotaryEmbedding(nn.Module):
......
......@@ -59,6 +59,9 @@ class RecurrentGemmaRMSNorm(nn.Module):
output = output * (1.0 + self.weight.float())
return output.type_as(x)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
ALL_LAYERNORM_LAYERS.append(RecurrentGemmaRMSNorm)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment