Unverified Commit 1958bda9 authored by Mengqing Cao's avatar Mengqing Cao Committed by GitHub
Browse files

[Misc][Model][Refactor] Pass the prefix into Linear layers (#28259)


Signed-off-by: default avatarMengqingCao <cmq0113@163.com>
parent 7bdb42b2
...@@ -62,14 +62,23 @@ from .utils import ( ...@@ -62,14 +62,23 @@ from .utils import (
class PersimmonMLP(nn.Module): class PersimmonMLP(nn.Module):
def __init__( def __init__(
self, config: PersimmonConfig, quant_config: QuantizationConfig | None = None self,
config: PersimmonConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.dense_h_to_4h = ColumnParallelLinear( self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size, config.intermediate_size, quant_config=quant_config config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.dense_h_to_4h",
) )
self.dense_4h_to_h = RowParallelLinear( self.dense_4h_to_h = RowParallelLinear(
config.intermediate_size, config.hidden_size, quant_config=quant_config config.intermediate_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h",
) )
self.act = get_act_fn(config.hidden_act) self.act = get_act_fn(config.hidden_act)
...@@ -110,12 +119,14 @@ class PersimmonAttention(nn.Module): ...@@ -110,12 +119,14 @@ class PersimmonAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
) )
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.dense",
) )
self.is_qk_layernorm = config.qk_layernorm self.is_qk_layernorm = config.qk_layernorm
...@@ -192,7 +203,11 @@ class PersimmonDecoderLayer(nn.Module): ...@@ -192,7 +203,11 @@ class PersimmonDecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
) )
self.mlp = PersimmonMLP(config, quant_config=quant_config) self.mlp = PersimmonMLP(
config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = nn.LayerNorm( self.input_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps config.hidden_size, eps=config.layer_norm_eps
) )
......
...@@ -99,11 +99,13 @@ class PhiAttention(nn.Module): ...@@ -99,11 +99,13 @@ class PhiAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
) )
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
self.hidden_size, self.hidden_size,
self.hidden_size, self.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.dense",
) )
scaling = self.head_size**-0.5 scaling = self.head_size**-0.5
...@@ -148,7 +150,10 @@ class PhiAttention(nn.Module): ...@@ -148,7 +150,10 @@ class PhiAttention(nn.Module):
class PhiMLP(nn.Module): class PhiMLP(nn.Module):
def __init__( def __init__(
self, config: PhiConfig, quant_config: QuantizationConfig | None = None self,
config: PhiConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -159,11 +164,13 @@ class PhiMLP(nn.Module): ...@@ -159,11 +164,13 @@ class PhiMLP(nn.Module):
config.hidden_size, config.hidden_size,
n_inner, n_inner,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc1",
) )
self.fc2 = RowParallelLinear( self.fc2 = RowParallelLinear(
n_inner, n_inner,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc2",
) )
self.act = get_act_fn(config.hidden_act) self.act = get_act_fn(config.hidden_act)
...@@ -189,7 +196,7 @@ class PhiLayer(nn.Module): ...@@ -189,7 +196,7 @@ class PhiLayer(nn.Module):
self.self_attn = PhiAttention( self.self_attn = PhiAttention(
config, cache_config, quant_config, prefix=f"{prefix}.self_attn" config, cache_config, quant_config, prefix=f"{prefix}.self_attn"
) )
self.mlp = PhiMLP(config, quant_config) self.mlp = PhiMLP(config, quant_config, prefix=f"{prefix}.mlp")
def forward( def forward(
self, self,
......
...@@ -343,12 +343,14 @@ class PhiMoEAttention(nn.Module): ...@@ -343,12 +343,14 @@ class PhiMoEAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj",
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
......
...@@ -567,12 +567,14 @@ class Plamo2AttentionMixer(nn.Module): ...@@ -567,12 +567,14 @@ class Plamo2AttentionMixer(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
config.hidden_size, config.hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj",
) )
self.rope_theta = config.rope_theta if hasattr(config, "rope_theta") else 10000 self.rope_theta = config.rope_theta if hasattr(config, "rope_theta") else 10000
......
...@@ -102,12 +102,14 @@ class QWenAttention(nn.Module): ...@@ -102,12 +102,14 @@ class QWenAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_attn",
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_proj",
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
......
...@@ -75,7 +75,12 @@ class Zamba2LoRA(nn.Module): ...@@ -75,7 +75,12 @@ class Zamba2LoRA(nn.Module):
super().__init__() super().__init__()
self.A = ColumnParallelLinear( self.A = ColumnParallelLinear(
input_dim, rank, bias=False, quant_config=quant_config, gather_output=True input_dim,
rank,
bias=False,
quant_config=quant_config,
gather_output=True,
prefix=f"{prefix}.A",
) )
if isinstance(output_dim, list): if isinstance(output_dim, list):
...@@ -150,12 +155,14 @@ class Zamba2Attention(nn.Module): ...@@ -150,12 +155,14 @@ class Zamba2Attention(nn.Module):
self.total_num_attention_heads, self.total_num_attention_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.attention_hidden_size, self.attention_hidden_size,
config.hidden_size, config.hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj",
) )
# Even though in Zamba2 weights are shared between attention layers, KV # Even though in Zamba2 weights are shared between attention layers, KV
...@@ -197,18 +204,21 @@ class Zamba2Attention(nn.Module): ...@@ -197,18 +204,21 @@ class Zamba2Attention(nn.Module):
config.adapter_rank, config.adapter_rank,
self.attention_hidden_size, self.attention_hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.linear_q_adapter",
) )
linear_k_adapter = Zamba2LoRA( linear_k_adapter = Zamba2LoRA(
self.attention_hidden_size, self.attention_hidden_size,
config.adapter_rank, config.adapter_rank,
self.attention_hidden_size, self.attention_hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.linear_k_adapter",
) )
linear_v_adapter = Zamba2LoRA( linear_v_adapter = Zamba2LoRA(
self.attention_hidden_size, self.attention_hidden_size,
config.adapter_rank, config.adapter_rank,
self.attention_hidden_size, self.attention_hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.linear_v_adapter",
) )
else: else:
linear_q_adapter = nn.Identity() linear_q_adapter = nn.Identity()
...@@ -312,6 +322,7 @@ class Zamba2MLP(nn.Module): ...@@ -312,6 +322,7 @@ class Zamba2MLP(nn.Module):
2 * [self.intermediate_size], # 2x for gate and input projections 2 * [self.intermediate_size], # 2x for gate and input projections
bias=self.config.add_bias_linear, bias=self.config.add_bias_linear,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
...@@ -319,6 +330,7 @@ class Zamba2MLP(nn.Module): ...@@ -319,6 +330,7 @@ class Zamba2MLP(nn.Module):
self.hidden_size, self.hidden_size,
bias=self.config.add_bias_linear, bias=self.config.add_bias_linear,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.down_proj",
) )
# Only allow GELU activations # Only allow GELU activations
...@@ -418,6 +430,7 @@ class Zamba2AttentionDecoderLayer(nn.Module): ...@@ -418,6 +430,7 @@ class Zamba2AttentionDecoderLayer(nn.Module):
bare_block_idx=bare_block_idx, bare_block_idx=bare_block_idx,
num_hybrid_layers=num_hybrid_layers, num_hybrid_layers=num_hybrid_layers,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
) )
# Initialize layer normalizations # Initialize layer normalizations
...@@ -599,6 +612,7 @@ class Zamba2HybridLayer(nn.Module): ...@@ -599,6 +612,7 @@ class Zamba2HybridLayer(nn.Module):
config.hidden_size, config.hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.linear",
) )
self.mamba_decoder = Zamba2MambaDecoderLayer( self.mamba_decoder = Zamba2MambaDecoderLayer(
config, config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment