Unverified Commit 27fcfe7b authored by tomeras91's avatar tomeras91 Committed by GitHub
Browse files

[Mamba] Support TP>1 with quantization for mamba2 mixer in case `n_groups % tp_size == 0` (#24593)


Signed-off-by: default avatarTomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: default avatartomeras91 <57313761+tomeras91@users.noreply.github.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 68dbde5d
...@@ -19,6 +19,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, ...@@ -19,6 +19,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata, from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
...@@ -261,12 +262,14 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -261,12 +262,14 @@ class MambaMixer2(MambaBase, CustomOp):
), "Tensor parallel world size must divide num heads." ), "Tensor parallel world size must divide num heads."
assert (n_groups % self.tp_size) == 0 or n_groups == 1, ( assert (n_groups % self.tp_size) == 0 or n_groups == 1, (
"If tensor parallel world size does not divide num_heads, " "If tensor parallel world size does not divide num_groups, "
"then num_groups must equal 1.") "then num_groups must equal 1.")
assert ( assert (n_groups % self.tp_size == 0) or self.tp_size == 1 or \
self.tp_size == 1 or quant_config is None quant_config is None, (
), "Tensor parallel currently not supported for quantized models." "Tensor parallel currently supported for quantized models only "
"if tensor parallel world size divides num groups."
)
self.ssm_state_size = ssm_state_size self.ssm_state_size = ssm_state_size
self.conv_kernel_size = conv_kernel_size self.conv_kernel_size = conv_kernel_size
...@@ -285,94 +288,101 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -285,94 +288,101 @@ class MambaMixer2(MambaBase, CustomOp):
n_groups, self.tp_size) n_groups, self.tp_size)
self.n_groups = n_groups + groups self.n_groups = n_groups + groups
self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
self.conv1d = ColumnParallelLinear( self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size
input_size=conv_kernel_size,
output_size=self.conv_dim, if n_groups % self.tp_size == 0:
bias=use_conv_bias, self.conv1d = MergedColumnParallelLinear(
quant_config=None, input_size=conv_kernel_size,
prefix=f"{prefix}.conv1d", output_sizes=[
) intermediate_size,
# unsqueeze to fit conv1d weights shape into the linear weights shape. self.groups_ssm_state_size,
# Can't do this in `weight_loader` since it already exists in self.groups_ssm_state_size,
# `ColumnParallelLinear` and `set_weight_attrs` ],
# doesn't allow to override it bias=use_conv_bias,
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) quant_config=None,
prefix=f"{prefix}.conv1d",
)
self.in_proj = ColumnParallelLinear( self.in_proj = MergedColumnParallelLinear(
input_size=hidden_size, input_size=hidden_size,
output_size=intermediate_size + self.conv_dim + self.num_heads, output_sizes=[
bias=use_bias, intermediate_size,
quant_config=quant_config, intermediate_size,
prefix=f"{prefix}.in_proj", self.groups_ssm_state_size,
) self.groups_ssm_state_size,
self.num_heads,
],
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)
else:
# This is the n_groups == 1 case,
# where we need to duplicate groups if TP>1.
self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size,
output_size=self.conv_dim,
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)
# - because in_proj is a concatenation of 3 weights, we self.in_proj = ColumnParallelLinear(
# need to interleave them before sharding input_size=hidden_size,
# - use the custom weight loader mamba_v2_sharded_weight_loader output_size=intermediate_size + self.conv_dim + self.num_heads,
# for conv1d.bias, covn1d.weight and in_proj.weight bias=use_bias,
# - need to set these settings, to assign the groups to the head shards quant_config=quant_config,
group_shard_settings = ( prefix=f"{prefix}.in_proj",
self.n_groups * self.ssm_state_size, # expected model size )
(self.n_groups - n_groups) *
self.ssm_state_size, # extra dims assigned
n_groups == 1, # if there was only one group
)
intermediate_settings = (intermediate_size, 0, False)
head_settings = (self.num_heads, 0, False)
# - the weight already has a "weight_loader" attribute
# which set_weight_attrs will raise if we do not
# delete before trying to override it
# - ditto for the other two weights below
delattr(self.conv1d.bias, "weight_loader")
set_weight_attrs(
self.conv1d.bias,
{
"weight_loader":
mamba_v2_sharded_weight_loader(
[
intermediate_settings,
group_shard_settings,
group_shard_settings,
],
self.tp_size,
tp_rank,
)
},
)
delattr(self.conv1d.weight, "weight_loader") # - because in_proj is a concatenation of 3 weights, we
set_weight_attrs( # need to interleave them before sharding
self.conv1d.weight, # - use the custom weight loader mamba_v2_sharded_weight_loader
{ # for conv1d.bias, covn1d.weight and in_proj.weight
"weight_loader": # - need to set these settings, to assign the groups
mamba_v2_sharded_weight_loader( # to the head shards
[ group_shard_settings = (
intermediate_settings, self.groups_ssm_state_size, # expected model size
group_shard_settings, (self.n_groups - n_groups) *
group_shard_settings, self.ssm_state_size, # extra dims assigned
], n_groups == 1, # if there was only one group
self.tp_size, )
tp_rank, intermediate_settings = (intermediate_size, 0, False)
) head_settings = (self.num_heads, 0, False)
},
) # - the weight already has a "weight_loader" attribute
# which set_weight_attrs will raise if we do not
# delete before trying to override it
# - ditto for the other two weights below
delattr(self.conv1d.bias, "weight_loader")
set_weight_attrs(
self.conv1d.bias,
{
"weight_loader":
mamba_v2_sharded_weight_loader(
[
intermediate_settings,
group_shard_settings,
group_shard_settings,
],
self.tp_size,
tp_rank,
)
},
)
if quant_config is None: delattr(self.conv1d.weight, "weight_loader")
# - quant layers do not have a weight loader
delattr(self.in_proj.weight, "weight_loader")
set_weight_attrs( set_weight_attrs(
self.in_proj.weight, self.conv1d.weight,
{ {
"weight_loader": "weight_loader":
mamba_v2_sharded_weight_loader( mamba_v2_sharded_weight_loader(
[ [
intermediate_settings, # for gate
intermediate_settings, intermediate_settings,
group_shard_settings, group_shard_settings,
group_shard_settings, group_shard_settings,
head_settings, # for dt
], ],
self.tp_size, self.tp_size,
tp_rank, tp_rank,
...@@ -380,6 +390,33 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -380,6 +390,33 @@ class MambaMixer2(MambaBase, CustomOp):
}, },
) )
if quant_config is None:
# - quant layers do not have a weight loader
delattr(self.in_proj.weight, "weight_loader")
set_weight_attrs(
self.in_proj.weight,
{
"weight_loader":
mamba_v2_sharded_weight_loader(
[
intermediate_settings, # for gate
intermediate_settings,
group_shard_settings,
group_shard_settings,
head_settings, # for dt
],
self.tp_size,
tp_rank,
)
},
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `MergedColumnParallelLinear`,
# and `set_weight_attrs` doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
# - these are TPed by heads to reduce the size of the # - these are TPed by heads to reduce the size of the
# temporal shape # temporal shape
self.A = nn.Parameter( self.A = nn.Parameter(
...@@ -498,8 +535,6 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -498,8 +535,6 @@ class MambaMixer2(MambaBase, CustomOp):
chunk_indices_p = mamba2_metadata.chunk_indices chunk_indices_p = mamba2_metadata.chunk_indices
chunk_offsets_p = mamba2_metadata.chunk_offsets chunk_offsets_p = mamba2_metadata.chunk_offsets
groups_time_state_size = self.n_groups * self.ssm_state_size
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states) projected_states, _ = self.in_proj(hidden_states)
...@@ -524,8 +559,8 @@ class MambaMixer2(MambaBase, CustomOp): ...@@ -524,8 +559,8 @@ class MambaMixer2(MambaBase, CustomOp):
hidden_states_B_C, hidden_states_B_C,
[ [
self.intermediate_size // self.tp_size, self.intermediate_size // self.tp_size,
groups_time_state_size // self.tp_size, self.groups_ssm_state_size // self.tp_size,
groups_time_state_size // self.tp_size, self.groups_ssm_state_size // self.tp_size,
], ],
dim=-1, dim=-1,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment