Unverified Commit 4e95ec11 authored by Ajay Anubolu's avatar Ajay Anubolu Committed by GitHub
Browse files

[Bugfix] Fix Qwen3-Next in_proj_ba weight sharding with TP > 1 (#36242)


Signed-off-by: default avatarAjAnubolu <anuboluajay@gmail.com>
parent 179547d6
...@@ -145,6 +145,24 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet): ...@@ -145,6 +145,24 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
prefix=prefix, prefix=prefix,
) )
def create_ba_proj(
self,
hidden_size: int,
num_v_heads: int,
quant_config: QuantizationConfig | None,
prefix: str,
) -> MergedColumnParallelLinear:
# Qwen3.5 has separate in_proj_b and in_proj_a weights in the
# checkpoint, which are loaded into the fused in_proj_ba parameter
# via stacked_params_mapping with shard_id 0 and 1 respectively.
return MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[num_v_heads] * 2,
bias=False,
quant_config=quant_config,
prefix=prefix,
)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -412,12 +412,11 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ...@@ -412,12 +412,11 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
prefix=f"{prefix}.in_proj_qkvz", prefix=f"{prefix}.in_proj_qkvz",
) )
# ba_proj doesn't support blockwise fp8 quantization. # ba_proj doesn't support blockwise fp8 quantization.
# # in_proj_ba is defined as MergedColumnParallelLinear for # Qwen3-Next and Qwen3.5 have different in_proj_ba checkpoint
# compatibility with Qwen3_5. # layouts, so we use a factory method to create the projection.
self.in_proj_ba = MergedColumnParallelLinear( self.in_proj_ba = self.create_ba_proj(
input_size=self.hidden_size, hidden_size=self.hidden_size,
output_sizes=[self.num_v_heads] * 2, num_v_heads=self.num_v_heads,
bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.in_proj_ba", prefix=f"{prefix}.in_proj_ba",
) )
...@@ -497,6 +496,28 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ...@@ -497,6 +496,28 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
prefix=prefix, prefix=prefix,
) )
def create_ba_proj(
self,
hidden_size: int,
num_v_heads: int,
quant_config: QuantizationConfig | None,
prefix: str,
) -> MergedColumnParallelLinear:
# Qwen3-Next stores in_proj_ba as a single fused weight with an
# interleaved GQA layout: [b_g0, a_g0, b_g1, a_g1, ...] where
# each group corresponds to a key-head group. We must use a single
# output shard so that ColumnParallel sharding preserves this
# interleaved structure across TP ranks.
# Qwen3.5 overrides this to use [num_v_heads, num_v_heads] since
# its checkpoint has separate in_proj_b and in_proj_a weights.
return MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[num_v_heads * 2],
bias=False,
quant_config=quant_config,
prefix=prefix,
)
def fix_query_key_value_ordering( def fix_query_key_value_ordering(
self, self,
mixed_qkvz: torch.Tensor, mixed_qkvz: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment