Unverified Commit a372f3f4 authored by Vadim Gimpelson's avatar Vadim Gimpelson Committed by GitHub
Browse files

[MISC] Fix Tensor Parallelism for Quantized Mamba Models with n_groups=1 (#33257)


Signed-off-by: default avatarVadim Gimpelson <vadim.gimpelson@gmail.com>
parent 61e632ae
...@@ -17,7 +17,6 @@ from vllm.forward_context import ForwardContext, get_forward_context ...@@ -17,7 +17,6 @@ from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.abstract import MambaBase
...@@ -40,6 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -40,6 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, composed_weight_loader,
sharded_weight_loader, sharded_weight_loader,
) )
from vllm.model_executor.parameter import BasevLLMParameter
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
...@@ -280,13 +280,6 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -280,13 +280,6 @@ class MambaMixer2(MambaBase, CustomOp):
"then num_groups must equal 1." "then num_groups must equal 1."
) )
assert (
(n_groups % self.tp_size == 0) or self.tp_size == 1 or quant_config is None
), (
"Tensor parallel currently supported for quantized models only "
"if tensor parallel world size divides num groups."
)
self.ssm_state_size = ssm_state_size self.ssm_state_size = ssm_state_size
self.conv_kernel_size = conv_kernel_size self.conv_kernel_size = conv_kernel_size
self.activation = activation self.activation = activation
...@@ -308,121 +301,94 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -308,121 +301,94 @@ class MambaMixer2(MambaBase, CustomOp):
self.groups_ssm_state_size = self.n_groups * self.ssm_state_size self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size
if n_groups % self.tp_size == 0: # Use ColumnParallelLinear with custom weight loaders for both cases:
self.conv1d = MergedColumnParallelLinear( # - When n_groups % tp_size == 0: standard sharding without duplication
input_size=conv_kernel_size, # - When n_groups == 1: groups are duplicated across TP ranks
output_sizes=[ # The custom weight loader handles both cases correctly.
intermediate_size,
self.groups_ssm_state_size, self.conv1d = ColumnParallelLinear(
self.groups_ssm_state_size, input_size=conv_kernel_size,
], output_size=self.conv_dim,
bias=use_conv_bias, bias=use_conv_bias,
quant_config=None, quant_config=None,
prefix=f"{prefix}.conv1d", prefix=f"{prefix}.conv1d",
) )
self.in_proj = MergedColumnParallelLinear( self.in_proj = ColumnParallelLinear(
input_size=hidden_size, input_size=hidden_size,
output_sizes=[ output_size=intermediate_size + self.conv_dim + self.num_heads,
intermediate_size, bias=use_bias,
intermediate_size, quant_config=quant_config,
self.groups_ssm_state_size, prefix=f"{prefix}.in_proj",
self.groups_ssm_state_size, )
self.num_heads,
],
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)
else:
# This is the n_groups == 1 case,
# where we need to duplicate groups if TP>1.
self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size,
output_size=self.conv_dim,
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)
self.in_proj = ColumnParallelLinear( # Configure shard settings for the custom weight loader:
input_size=hidden_size, # - group_shard_settings handles group duplication when n_groups == 1
output_size=intermediate_size + self.conv_dim + self.num_heads, # - When n_groups % tp_size == 0, extra=0 and duplicate_groups=False
bias=use_bias, group_shard_settings = (
quant_config=quant_config, self.groups_ssm_state_size, # expected model size
prefix=f"{prefix}.in_proj", (self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned
) n_groups == 1, # duplicate groups when n_groups == 1
)
intermediate_settings = (intermediate_size, 0, False)
head_settings = (self.num_heads, 0, False)
# Apply custom weight loaders for conv1d (bias and weight)
delattr(self.conv1d.bias, "weight_loader")
set_weight_attrs(
self.conv1d.bias,
{
"weight_loader": mamba_v2_sharded_weight_loader(
[
intermediate_settings,
group_shard_settings,
group_shard_settings,
],
self.tp_size,
tp_rank,
)
},
)
# - because in_proj is a concatenation of 3 weights, we delattr(self.conv1d.weight, "weight_loader")
# need to interleave them before sharding set_weight_attrs(
# - use the custom weight loader mamba_v2_sharded_weight_loader self.conv1d.weight,
# for conv1d.bias, covn1d.weight and in_proj.weight {
# - need to set these settings, to assign the groups "weight_loader": mamba_v2_sharded_weight_loader(
# to the head shards [
group_shard_settings = ( intermediate_settings,
self.groups_ssm_state_size, # expected model size group_shard_settings,
(self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned group_shard_settings,
n_groups == 1, # if there was only one group ],
) self.tp_size,
intermediate_settings = (intermediate_size, 0, False) tp_rank,
head_settings = (self.num_heads, 0, False) )
},
# - the weight already has a "weight_loader" attribute )
# which set_weight_attrs will raise if we do not
# delete before trying to override it
# - ditto for the other two weights below
delattr(self.conv1d.bias, "weight_loader")
set_weight_attrs(
self.conv1d.bias,
{
"weight_loader": mamba_v2_sharded_weight_loader(
[
intermediate_settings,
group_shard_settings,
group_shard_settings,
],
self.tp_size,
tp_rank,
)
},
)
delattr(self.conv1d.weight, "weight_loader") # Create the custom weight loader for in_proj
set_weight_attrs( mamba_loader = mamba_v2_sharded_weight_loader(
self.conv1d.weight, [
{ intermediate_settings, # for gate
"weight_loader": mamba_v2_sharded_weight_loader( intermediate_settings,
[ group_shard_settings,
intermediate_settings, group_shard_settings,
group_shard_settings, head_settings, # for dt
group_shard_settings, ],
], self.tp_size,
self.tp_size, tp_rank,
tp_rank, )
)
},
)
if quant_config is None: # Apply the custom weight loader to in_proj.weight
# - quant layers do not have a weight loader # Works for both non-quantized (Parameter) and quantized
delattr(self.in_proj.weight, "weight_loader") # (ModelWeightParameter which extends BasevLLMParameter)
set_weight_attrs( if isinstance(self.in_proj.weight, BasevLLMParameter):
self.in_proj.weight, # For BasevLLMParameter subclasses (quantized layers like FP8)
{ self.in_proj.weight.weight_loader = mamba_loader
"weight_loader": mamba_v2_sharded_weight_loader( else:
[ # For standard Parameter (non-quantized layers)
intermediate_settings, # for gate delattr(self.in_proj.weight, "weight_loader")
intermediate_settings, set_weight_attrs(self.in_proj.weight, {"weight_loader": mamba_loader})
group_shard_settings,
group_shard_settings,
head_settings, # for dt
],
self.tp_size,
tp_rank,
)
},
)
# unsqueeze to fit conv1d weights shape into the linear weights shape. # unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in # Can't do this in `weight_loader` since it already exists in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment