Unverified Commit 1958bda9 authored by Mengqing Cao's avatar Mengqing Cao Committed by GitHub
Browse files

[Misc][Model][Refactor] Pass the prefix into Linear layers (#28259)


Signed-off-by: default avatarMengqingCao <cmq0113@163.com>
parent 7bdb42b2
......@@ -75,7 +75,11 @@ class ArcticMLP(nn.Module):
)
self.w13 = MergedColumnParallelLinear(
self.hidden_size, [self.ffn_dim] * 2, bias=False, quant_config=quant_config
self.hidden_size,
[self.ffn_dim] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.w13",
)
self.w2 = RowParallelLinear(
self.ffn_dim,
......@@ -83,6 +87,7 @@ class ArcticMLP(nn.Module):
bias=False,
reduce_results=reduce_results,
quant_config=quant_config,
prefix=f"{prefix}.w2",
)
if config.hidden_act != "silu":
raise ValueError(
......@@ -297,6 +302,7 @@ class ArcticAttention(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
......@@ -304,6 +310,7 @@ class ArcticAttention(nn.Module):
bias=False,
reduce_results=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
......
......@@ -98,13 +98,22 @@ class BaiChuanMLP(nn.Module):
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, quant_config=quant_config
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
......@@ -152,12 +161,14 @@ class BaiChuanAttention(nn.Module):
self.total_num_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.W_pack",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
# Create the alibi slopes and slice them.
if self.position_embedding == "ALIBI":
......@@ -235,6 +246,7 @@ class BaiChuanDecoderLayer(nn.Module):
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
......
......@@ -60,6 +60,7 @@ class BambaMLP(nn.Module):
config: BambaConfig,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
......@@ -67,12 +68,14 @@ class BambaMLP(nn.Module):
output_sizes=[config.intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
input_size=config.intermediate_size,
output_size=config.hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if config.hidden_act != "silu":
raise ValueError(
......@@ -118,7 +121,9 @@ class BambaMixerDecoderLayer(nn.Module):
prefix=f"{prefix}.mixer",
)
self.feed_forward = BambaMLP(config, quant_config=quant_config)
self.feed_forward = BambaMLP(
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -202,12 +207,14 @@ class BambaAttentionDecoderLayer(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.attn = Attention(
......@@ -219,7 +226,9 @@ class BambaAttentionDecoderLayer(nn.Module):
prefix=f"{prefix}.attn",
)
self.feed_forward = BambaMLP(config, quant_config=quant_config)
self.feed_forward = BambaMLP(
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......
......@@ -108,12 +108,14 @@ class BloomAttention(nn.Module):
self.total_num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
# Create the alibi slopes and slice them.
......@@ -152,6 +154,7 @@ class BloomMLP(nn.Module):
self,
config: BloomConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
......@@ -159,12 +162,14 @@ class BloomMLP(nn.Module):
hidden_size,
4 * hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.dense_h_to_4h",
)
self.gelu_impl = get_act_fn("gelu")
self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size,
hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
......@@ -192,7 +197,7 @@ class BloomBlock(nn.Module):
self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon
)
self.mlp = BloomMLP(config, quant_config)
self.mlp = BloomMLP(config, quant_config, prefix=f"{prefix}.mlp")
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm
)
......
......@@ -227,6 +227,7 @@ class ChameleonMLP(nn.Module):
hidden_act: str,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
......@@ -234,12 +235,14 @@ class ChameleonMLP(nn.Module):
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
......@@ -299,12 +302,14 @@ class ChameleonAttention(nn.Module):
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
......@@ -393,6 +398,7 @@ class ChameleonDecoderLayer(nn.Module):
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
......@@ -462,6 +468,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
......
......@@ -209,12 +209,14 @@ class DbrxAttention(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.Wqkv",
)
self.out_proj = RowParallelLinear(
self.d_model,
self.d_model,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
......
......@@ -82,7 +82,11 @@ class DeepseekMLP(nn.Module):
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
......@@ -90,6 +94,7 @@ class DeepseekMLP(nn.Module):
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
......@@ -239,6 +244,7 @@ class DeepseekAttention(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
......@@ -246,6 +252,7 @@ class DeepseekAttention(nn.Module):
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
......
......@@ -240,6 +240,7 @@ class Dots1Attention(nn.Module):
self.total_num_kv_heads,
bias=attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
......@@ -247,6 +248,7 @@ class Dots1Attention(nn.Module):
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
......
......@@ -137,6 +137,7 @@ class FalconAttention(nn.Module):
bias=config.bias,
skip_bias_add=True,
quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
......@@ -153,6 +154,7 @@ class FalconAttention(nn.Module):
skip_bias_add=True,
quant_config=quant_config,
reduce_results=self.reduce_row_parallel_results,
prefix=f"{prefix}.dense",
)
self.use_rotary = config.rotary
......@@ -227,6 +229,7 @@ class FalconMLP(nn.Module):
self,
config: FalconConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
......@@ -237,6 +240,7 @@ class FalconMLP(nn.Module):
bias=config.bias,
skip_bias_add=True,
quant_config=quant_config,
prefix=f"{prefix}.dense_h_to_4h",
)
self.act = get_act_fn("gelu")
self.reduce_row_parallel_results = not (
......@@ -249,6 +253,7 @@ class FalconMLP(nn.Module):
skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
......@@ -275,7 +280,7 @@ class FalconDecoderLayer(nn.Module):
self.self_attention = FalconAttention(
config, cache_config, quant_config, prefix=f"{prefix}.self_attention"
)
self.mlp = FalconMLP(config, quant_config)
self.mlp = FalconMLP(config, quant_config, prefix=f"{prefix}.mlp")
self.config = config
if not hasattr(config, "num_ln_in_parallel_attn"):
......
......@@ -59,6 +59,7 @@ class FalconH1MLP(nn.Module):
config: FalconH1Config,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
......@@ -66,12 +67,14 @@ class FalconH1MLP(nn.Module):
output_sizes=[config.intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
input_size=config.intermediate_size,
output_size=config.hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.tp_size = get_tensor_model_parallel_world_size()
self.intermediate_size = config.intermediate_size
......@@ -365,7 +368,7 @@ class FalconH1ParallelHybrid(nn.Module):
self.attention_in_multiplier = config.attention_in_multiplier
self.attn_out_multiplier = config.attention_out_multiplier
self.feed_forward = FalconH1MLP(config)
self.feed_forward = FalconH1MLP(config, prefix=f"{prefix}.feed_forward")
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......
......@@ -66,13 +66,22 @@ class Gemma2MLP(nn.Module):
hidden_act: str,
hidden_activation: str,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, quant_config=quant_config
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
raise ValueError(
......@@ -134,12 +143,14 @@ class Gemma2Attention(nn.Module):
self.total_num_kv_heads,
bias=config.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=config.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
......@@ -208,6 +219,7 @@ class Gemma2DecoderLayer(nn.Module):
hidden_act=config.hidden_act,
hidden_activation=config.hidden_activation,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(
......
......@@ -78,12 +78,14 @@ class GPTJAttention(nn.Module):
self.total_num_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
config.hidden_size,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
tp_world_size = get_tensor_model_parallel_world_size()
......@@ -130,6 +132,7 @@ class GPTJMLP(nn.Module):
intermediate_size: int,
config: GPTJConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.n_embd
......@@ -137,11 +140,13 @@ class GPTJMLP(nn.Module):
hidden_size,
intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.fc_in",
)
self.fc_out = RowParallelLinear(
intermediate_size,
hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.fc_out",
)
self.act = get_act_fn(config.activation_function)
......@@ -166,7 +171,7 @@ class GPTJBlock(nn.Module):
self.attn = GPTJAttention(
config, cache_config, quant_config, prefix=f"{prefix}.attn"
)
self.mlp = GPTJMLP(inner_dim, config, quant_config)
self.mlp = GPTJMLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
def forward(
self,
......
......@@ -80,12 +80,14 @@ class GPTNeoXAttention(nn.Module):
self.total_num_heads,
bias=self.bias,
quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
)
self.dense = RowParallelLinear(
config.hidden_size,
config.hidden_size,
bias=self.bias,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
scaling = self.head_size**-0.5
rotary_dim = int(self.head_size * config.rotary_pct)
......@@ -125,17 +127,20 @@ class GPTNeoXMLP(nn.Module):
self,
config: GPTNeoXConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.dense_h_to_4h",
)
self.dense_4h_to_h = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h",
)
self.act = get_act_fn(config.hidden_act)
......
......@@ -107,12 +107,14 @@ class JAISAttention(nn.Module):
total_num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_attn",
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
tp_rank = get_tensor_model_parallel_rank()
......@@ -147,6 +149,7 @@ class JAISMLP(nn.Module):
intermediate_size: int,
config: JAISConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
......@@ -156,6 +159,7 @@ class JAISMLP(nn.Module):
intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_fc",
)
self.c_fc2 = (
ColumnParallelLinear(
......@@ -163,6 +167,7 @@ class JAISMLP(nn.Module):
intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_fc2",
)
if self.swiglu
else None
......@@ -172,6 +177,7 @@ class JAISMLP(nn.Module):
hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
self.act = SwiGLUActivation()
......@@ -206,7 +212,7 @@ class JAISBlock(nn.Module):
config, cache_config, quant_config, prefix=f"{prefix}.attn"
)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = JAISMLP(inner_dim, config, quant_config)
self.mlp = JAISMLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
def forward(
self,
......
......@@ -220,12 +220,14 @@ class JambaAttentionDecoderLayer(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.attn = Attention(
......
......@@ -191,13 +191,22 @@ class MiniCPMMLP(nn.Module):
hidden_act: str,
hidden_act_param: float,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, quant_config=quant_config
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act == "silu":
self.act_fn = SiluAndMul()
......@@ -259,12 +268,14 @@ class MiniCPMAttention(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
......
......@@ -96,6 +96,7 @@ class MiniCPM3Attention(nn.Module):
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj",
)
self.kv_a_proj_with_mqa = ReplicatedLinear(
......@@ -103,6 +104,7 @@ class MiniCPM3Attention(nn.Module):
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa",
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
......@@ -110,6 +112,7 @@ class MiniCPM3Attention(nn.Module):
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj",
)
# O projection.
self.o_proj = RowParallelLinear(
......@@ -117,6 +120,7 @@ class MiniCPM3Attention(nn.Module):
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
......
......@@ -83,6 +83,7 @@ class MPTAttention(nn.Module):
self.total_num_kv_heads,
bias=not config.no_bias,
quant_config=quant_config,
prefix=f"{prefix}.Wqkv",
)
if self.qk_ln:
self.q_ln = nn.LayerNorm(self.d_model)
......@@ -92,6 +93,7 @@ class MPTAttention(nn.Module):
self.d_model,
bias=not config.no_bias,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
tp_world_size = get_tensor_model_parallel_world_size()
......@@ -152,6 +154,7 @@ class MPTMLP(nn.Module):
self,
config: MptConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.d_model
......@@ -162,6 +165,7 @@ class MPTMLP(nn.Module):
intermediate_size,
bias=not config.no_bias,
quant_config=quant_config,
prefix=f"{prefix}.up_proj",
)
self.act = get_act_fn("gelu")
self.down_proj = RowParallelLinear(
......@@ -169,6 +173,7 @@ class MPTMLP(nn.Module):
hidden_size,
bias=not config.no_bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
......@@ -193,7 +198,7 @@ class MPTBlock(nn.Module):
config, cache_config, quant_config, prefix=f"{prefix}.attn"
)
self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MPTMLP(config, quant_config)
self.ffn = MPTMLP(config, quant_config, prefix=f"{prefix}.ffn")
def forward(
self,
......
......@@ -158,6 +158,7 @@ class OlmoeAttention(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.tp_size = tp_size
self.tp_rank = get_tensor_model_parallel_rank()
......@@ -168,6 +169,7 @@ class OlmoeAttention(nn.Module):
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
......
......@@ -52,13 +52,22 @@ class OrionMLP(nn.Module):
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, quant_config=quant_config
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
......@@ -116,12 +125,14 @@ class OrionAttention(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
......@@ -183,6 +194,7 @@ class OrionDecoderLayer(nn.Module):
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment