Unverified Commit 604b9eae authored by Vadim Gimpelson's avatar Vadim Gimpelson Committed by GitHub
Browse files

[BUGFIX] Fix accuracy regression for NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4 with TP>1 (#34476)


Signed-off-by: default avatarVadim Gimpelson <vadim.gimpelson@gmail.com>
parent 50dbd6c9
...@@ -17,6 +17,7 @@ from vllm.forward_context import ForwardContext, get_forward_context ...@@ -17,6 +17,7 @@ from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp, PluggableLayer from vllm.model_executor.custom_op import CustomOp, PluggableLayer
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.abstract import MambaBase
...@@ -301,10 +302,35 @@ class MambaMixer2(MambaBase, PluggableLayer): ...@@ -301,10 +302,35 @@ class MambaMixer2(MambaBase, PluggableLayer):
self.groups_ssm_state_size = self.n_groups * self.ssm_state_size self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size
# Use ColumnParallelLinear with custom weight loaders for both cases: if n_groups % self.tp_size == 0:
# - When n_groups % tp_size == 0: standard sharding without duplication self.conv1d = MergedColumnParallelLinear(
# - When n_groups == 1: groups are duplicated across TP ranks input_size=conv_kernel_size,
# The custom weight loader handles both cases correctly. output_sizes=[
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
],
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)
self.in_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[
intermediate_size,
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
self.num_heads,
],
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)
else:
# This is the n_groups == 1 case,
# where we need to duplicate groups if TP>1.
self.conv1d = ColumnParallelLinear( self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size, input_size=conv_kernel_size,
...@@ -322,18 +348,24 @@ class MambaMixer2(MambaBase, PluggableLayer): ...@@ -322,18 +348,24 @@ class MambaMixer2(MambaBase, PluggableLayer):
prefix=f"{prefix}.in_proj", prefix=f"{prefix}.in_proj",
) )
# Configure shard settings for the custom weight loader: # - because in_proj is a concatenation of 3 weights, we
# - group_shard_settings handles group duplication when n_groups == 1 # need to interleave them before sharding
# - When n_groups % tp_size == 0, extra=0 and duplicate_groups=False # - use the custom weight loader mamba_v2_sharded_weight_loader
# for conv1d.bias, covn1d.weight and in_proj.weight
# - need to set these settings, to assign the groups
# to the head shards
group_shard_settings = ( group_shard_settings = (
self.groups_ssm_state_size, # expected model size self.groups_ssm_state_size, # expected model size
(self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned (self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned
n_groups == 1, # duplicate groups when n_groups == 1 n_groups == 1, # if there was only one group
) )
intermediate_settings = (intermediate_size, 0, False) intermediate_settings = (intermediate_size, 0, False)
head_settings = (self.num_heads, 0, False) head_settings = (self.num_heads, 0, False)
# Apply custom weight loaders for conv1d (bias and weight) # - the weight already has a "weight_loader" attribute
# which set_weight_attrs will raise if we do not
# delete before trying to override it
# - ditto for the other two weights below
delattr(self.conv1d.bias, "weight_loader") delattr(self.conv1d.bias, "weight_loader")
set_weight_attrs( set_weight_attrs(
self.conv1d.bias, self.conv1d.bias,
...@@ -366,7 +398,8 @@ class MambaMixer2(MambaBase, PluggableLayer): ...@@ -366,7 +398,8 @@ class MambaMixer2(MambaBase, PluggableLayer):
}, },
) )
# Create the custom weight loader for in_proj # Create the custom weight loader for Mamba sharding with group
# replication. This handles the interleaved projections correctly.
mamba_loader = mamba_v2_sharded_weight_loader( mamba_loader = mamba_v2_sharded_weight_loader(
[ [
intermediate_settings, # for gate intermediate_settings, # for gate
...@@ -384,6 +417,7 @@ class MambaMixer2(MambaBase, PluggableLayer): ...@@ -384,6 +417,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
# (ModelWeightParameter which extends BasevLLMParameter) # (ModelWeightParameter which extends BasevLLMParameter)
if isinstance(self.in_proj.weight, BasevLLMParameter): if isinstance(self.in_proj.weight, BasevLLMParameter):
# For BasevLLMParameter subclasses (quantized layers like FP8) # For BasevLLMParameter subclasses (quantized layers like FP8)
# These have a weight_loader property that can be directly set
self.in_proj.weight.weight_loader = mamba_loader self.in_proj.weight.weight_loader = mamba_loader
else: else:
# For standard Parameter (non-quantized layers) # For standard Parameter (non-quantized layers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment