Unverified Commit 1958bda9 authored by Mengqing Cao's avatar Mengqing Cao Committed by GitHub
Browse files

[Misc][Model][Refactor] Pass the prefix into Linear layers (#28259)


Signed-off-by: default avatarMengqingCao <cmq0113@163.com>
parent 7bdb42b2
......@@ -62,14 +62,23 @@ from .utils import (
class PersimmonMLP(nn.Module):
def __init__(
self, config: PersimmonConfig, quant_config: QuantizationConfig | None = None
self,
config: PersimmonConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size, config.intermediate_size, quant_config=quant_config
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.dense_h_to_4h",
)
self.dense_4h_to_h = RowParallelLinear(
config.intermediate_size, config.hidden_size, quant_config=quant_config
config.intermediate_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h",
)
self.act = get_act_fn(config.hidden_act)
......@@ -110,12 +119,14 @@ class PersimmonAttention(nn.Module):
self.total_num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
self.is_qk_layernorm = config.qk_layernorm
......@@ -192,7 +203,11 @@ class PersimmonDecoderLayer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = PersimmonMLP(config, quant_config=quant_config)
self.mlp = PersimmonMLP(
config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
......
......@@ -99,11 +99,13 @@ class PhiAttention(nn.Module):
self.total_num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
scaling = self.head_size**-0.5
......@@ -148,7 +150,10 @@ class PhiAttention(nn.Module):
class PhiMLP(nn.Module):
def __init__(
self, config: PhiConfig, quant_config: QuantizationConfig | None = None
self,
config: PhiConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
......@@ -159,11 +164,13 @@ class PhiMLP(nn.Module):
config.hidden_size,
n_inner,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
n_inner,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
self.act = get_act_fn(config.hidden_act)
......@@ -189,7 +196,7 @@ class PhiLayer(nn.Module):
self.self_attn = PhiAttention(
config, cache_config, quant_config, prefix=f"{prefix}.self_attn"
)
self.mlp = PhiMLP(config, quant_config)
self.mlp = PhiMLP(config, quant_config, prefix=f"{prefix}.mlp")
def forward(
self,
......
......@@ -343,12 +343,14 @@ class PhiMoEAttention(nn.Module):
self.total_num_kv_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
......
......@@ -567,12 +567,14 @@ class Plamo2AttentionMixer(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rope_theta = config.rope_theta if hasattr(config, "rope_theta") else 10000
......
......@@ -102,12 +102,14 @@ class QWenAttention(nn.Module):
self.total_num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_attn",
)
self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
self.scaling = self.head_dim**-0.5
......
......@@ -75,7 +75,12 @@ class Zamba2LoRA(nn.Module):
super().__init__()
self.A = ColumnParallelLinear(
input_dim, rank, bias=False, quant_config=quant_config, gather_output=True
input_dim,
rank,
bias=False,
quant_config=quant_config,
gather_output=True,
prefix=f"{prefix}.A",
)
if isinstance(output_dim, list):
......@@ -150,12 +155,14 @@ class Zamba2Attention(nn.Module):
self.total_num_attention_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.attention_hidden_size,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
# Even though in Zamba2 weights are shared between attention layers, KV
......@@ -197,18 +204,21 @@ class Zamba2Attention(nn.Module):
config.adapter_rank,
self.attention_hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.linear_q_adapter",
)
linear_k_adapter = Zamba2LoRA(
self.attention_hidden_size,
config.adapter_rank,
self.attention_hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.linear_k_adapter",
)
linear_v_adapter = Zamba2LoRA(
self.attention_hidden_size,
config.adapter_rank,
self.attention_hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.linear_v_adapter",
)
else:
linear_q_adapter = nn.Identity()
......@@ -312,6 +322,7 @@ class Zamba2MLP(nn.Module):
2 * [self.intermediate_size], # 2x for gate and input projections
bias=self.config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
......@@ -319,6 +330,7 @@ class Zamba2MLP(nn.Module):
self.hidden_size,
bias=self.config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
# Only allow GELU activations
......@@ -418,6 +430,7 @@ class Zamba2AttentionDecoderLayer(nn.Module):
bare_block_idx=bare_block_idx,
num_hybrid_layers=num_hybrid_layers,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)
# Initialize layer normalizations
......@@ -599,6 +612,7 @@ class Zamba2HybridLayer(nn.Module):
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.linear",
)
self.mamba_decoder = Zamba2MambaDecoderLayer(
config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment