Unverified Commit 9ec82579 authored by cloud11665's avatar cloud11665 Committed by GitHub
Browse files

[Model] Add module name prefixes to gemma3 (#15889)


Signed-off-by: default avatarBartholomew Sabat <bartek@recursal.ai>
Co-authored-by: default avatarBartholomew Sabat <bartek@recursal.ai>
parent 38327cf4
...@@ -59,16 +59,23 @@ class Gemma3MLP(nn.Module): ...@@ -59,16 +59,23 @@ class Gemma3MLP(nn.Module):
intermediate_size: int, intermediate_size: int,
hidden_activation: str, hidden_activation: str,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False, bias=False,
quant_config=quant_config) quant_config=quant_config,
self.down_proj = RowParallelLinear(intermediate_size, prefix=f"{prefix}.down_proj",
hidden_size, )
bias=False,
quant_config=quant_config)
if hidden_activation != "gelu_pytorch_tanh": if hidden_activation != "gelu_pytorch_tanh":
raise ValueError( raise ValueError(
"Gemma3 uses `gelu_pytorch_tanh` as the hidden activation " "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
...@@ -125,12 +132,14 @@ class Gemma3Attention(nn.Module): ...@@ -125,12 +132,14 @@ class Gemma3Attention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=config.attention_bias, bias=config.attention_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=config.attention_bias, bias=config.attention_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj",
) )
self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
...@@ -293,6 +302,7 @@ class Gemma3DecoderLayer(nn.Module): ...@@ -293,6 +302,7 @@ class Gemma3DecoderLayer(nn.Module):
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_activation=config.hidden_activation, hidden_activation=config.hidden_activation,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp",
) )
self.input_layernorm = GemmaRMSNorm(config.hidden_size, self.input_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -344,6 +354,7 @@ class Gemma3Model(nn.Module): ...@@ -344,6 +354,7 @@ class Gemma3Model(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
prefix=f"{prefix}.embed_tokens",
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment