Unverified Commit a372f3f4 authored by Vadim Gimpelson's avatar Vadim Gimpelson Committed by GitHub
Browse files

[MISC] Fix Tensor Parallelism for Quantized Mamba Models with n_groups=1 (#33257)


Signed-off-by: default avatarVadim Gimpelson <vadim.gimpelson@gmail.com>
parent 61e632ae
...@@ -17,7 +17,6 @@ from vllm.forward_context import ForwardContext, get_forward_context ...@@ -17,7 +17,6 @@ from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.abstract import MambaBase
...@@ -40,6 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -40,6 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, composed_weight_loader,
sharded_weight_loader, sharded_weight_loader,
) )
from vllm.model_executor.parameter import BasevLLMParameter
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
...@@ -280,13 +280,6 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -280,13 +280,6 @@ class MambaMixer2(MambaBase, CustomOp):
"then num_groups must equal 1." "then num_groups must equal 1."
) )
assert (
(n_groups % self.tp_size == 0) or self.tp_size == 1 or quant_config is None
), (
"Tensor parallel currently supported for quantized models only "
"if tensor parallel world size divides num groups."
)
self.ssm_state_size = ssm_state_size self.ssm_state_size = ssm_state_size
self.conv_kernel_size = conv_kernel_size self.conv_kernel_size = conv_kernel_size
self.activation = activation self.activation = activation
...@@ -308,35 +301,10 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -308,35 +301,10 @@ class MambaMixer2(MambaBase, CustomOp):
self.groups_ssm_state_size = self.n_groups * self.ssm_state_size self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size
if n_groups % self.tp_size == 0: # Use ColumnParallelLinear with custom weight loaders for both cases:
self.conv1d = MergedColumnParallelLinear( # - When n_groups % tp_size == 0: standard sharding without duplication
input_size=conv_kernel_size, # - When n_groups == 1: groups are duplicated across TP ranks
output_sizes=[ # The custom weight loader handles both cases correctly.
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
],
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)
self.in_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[
intermediate_size,
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
self.num_heads,
],
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)
else:
# This is the n_groups == 1 case,
# where we need to duplicate groups if TP>1.
self.conv1d = ColumnParallelLinear( self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size, input_size=conv_kernel_size,
...@@ -354,24 +322,18 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -354,24 +322,18 @@ class MambaMixer2(MambaBase, CustomOp):
prefix=f"{prefix}.in_proj", prefix=f"{prefix}.in_proj",
) )
# - because in_proj is a concatenation of 3 weights, we # Configure shard settings for the custom weight loader:
# need to interleave them before sharding # - group_shard_settings handles group duplication when n_groups == 1
# - use the custom weight loader mamba_v2_sharded_weight_loader # - When n_groups % tp_size == 0, extra=0 and duplicate_groups=False
# for conv1d.bias, covn1d.weight and in_proj.weight
# - need to set these settings, to assign the groups
# to the head shards
group_shard_settings = ( group_shard_settings = (
self.groups_ssm_state_size, # expected model size self.groups_ssm_state_size, # expected model size
(self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned (self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned
n_groups == 1, # if there was only one group n_groups == 1, # duplicate groups when n_groups == 1
) )
intermediate_settings = (intermediate_size, 0, False) intermediate_settings = (intermediate_size, 0, False)
head_settings = (self.num_heads, 0, False) head_settings = (self.num_heads, 0, False)
# - the weight already has a "weight_loader" attribute # Apply custom weight loaders for conv1d (bias and weight)
# which set_weight_attrs will raise if we do not
# delete before trying to override it
# - ditto for the other two weights below
delattr(self.conv1d.bias, "weight_loader") delattr(self.conv1d.bias, "weight_loader")
set_weight_attrs( set_weight_attrs(
self.conv1d.bias, self.conv1d.bias,
...@@ -404,13 +366,8 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -404,13 +366,8 @@ class MambaMixer2(MambaBase, CustomOp):
}, },
) )
if quant_config is None: # Create the custom weight loader for in_proj
# - quant layers do not have a weight loader mamba_loader = mamba_v2_sharded_weight_loader(
delattr(self.in_proj.weight, "weight_loader")
set_weight_attrs(
self.in_proj.weight,
{
"weight_loader": mamba_v2_sharded_weight_loader(
[ [
intermediate_settings, # for gate intermediate_settings, # for gate
intermediate_settings, intermediate_settings,
...@@ -421,8 +378,17 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -421,8 +378,17 @@ class MambaMixer2(MambaBase, CustomOp):
self.tp_size, self.tp_size,
tp_rank, tp_rank,
) )
},
) # Apply the custom weight loader to in_proj.weight
# Works for both non-quantized (Parameter) and quantized
# (ModelWeightParameter which extends BasevLLMParameter)
if isinstance(self.in_proj.weight, BasevLLMParameter):
# For BasevLLMParameter subclasses (quantized layers like FP8)
self.in_proj.weight.weight_loader = mamba_loader
else:
# For standard Parameter (non-quantized layers)
delattr(self.in_proj.weight, "weight_loader")
set_weight_attrs(self.in_proj.weight, {"weight_loader": mamba_loader})
# unsqueeze to fit conv1d weights shape into the linear weights shape. # unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in # Can't do this in `weight_loader` since it already exists in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment