Unverified Commit 4a0e0be2 authored by cao1zhg's avatar cao1zhg Committed by GitHub
Browse files

[bugfix] fix norm type error in qwen3_next model (#10322)


Co-authored-by: default avatarcaoyizhong.cyz <caoyizhong.cyz@alibaba-inc.com>
Co-authored-by: default avatarYi Zhang <1109276519@qq.com>
parent 64f296f8
...@@ -518,24 +518,10 @@ class Qwen3HybridLinearDecoderLayer(nn.Module): ...@@ -518,24 +518,10 @@ class Qwen3HybridLinearDecoderLayer(nn.Module):
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
) )
if getattr( self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False) self.post_attention_layernorm = GemmaRMSNorm(
): config.hidden_size, eps=config.rms_norm_eps
logger.warning_once( )
"Using Gemma RMSNorm for input normalization and post attn normalization."
)
self.input_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_attention_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
else:
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.layer_communicator = LayerCommunicator( self.layer_communicator = LayerCommunicator(
layer_scatter_modes=self.layer_scatter_modes, layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm, input_layernorm=self.input_layernorm,
...@@ -685,23 +671,10 @@ class Qwen3HybridAttentionDecoderLayer(nn.Module): ...@@ -685,23 +671,10 @@ class Qwen3HybridAttentionDecoderLayer(nn.Module):
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
) )
if getattr( self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False) self.post_attention_layernorm = GemmaRMSNorm(
): config.hidden_size, eps=config.rms_norm_eps
logger.warning_once( )
"Using Gemma RMSNorm for input normalization and post attn normalization."
)
self.input_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_attention_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
else:
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
...@@ -844,13 +817,7 @@ class Qwen3NextModel(nn.Module): ...@@ -844,13 +817,7 @@ class Qwen3NextModel(nn.Module):
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
) )
if getattr( self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False)
):
logger.warning_once("Using Gemma RMSNorm for final normalization.")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.infer_count = 0 self.infer_count = 0
def forward( def forward(
......
...@@ -54,15 +54,7 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM): ...@@ -54,15 +54,7 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
# (1) do not use_dedicated_mtp_embeddings provided in ckpt since not provided and directly use the target model embeddings # (1) do not use_dedicated_mtp_embeddings provided in ckpt since not provided and directly use the target model embeddings
# (2) hardcode bias=False since not provided # (2) hardcode bias=False since not provided
self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False) self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
if getattr( RMSNorm_cls = GemmaRMSNorm
config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False)
):
logger.warning_once(
"Using Gemma RMSNorm for input normalization and post attn normalization."
)
RMSNorm_cls = GemmaRMSNorm
else:
RMSNorm_cls = RMSNorm
self.pre_fc_norm_embedding = RMSNorm_cls( self.pre_fc_norm_embedding = RMSNorm_cls(
config.hidden_size, config.rms_norm_eps config.hidden_size, config.rms_norm_eps
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment