Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fca20841
Unverified
Commit
fca20841
authored
Feb 22, 2025
by
Yu Chin Fabian Lim
Committed by
GitHub
Feb 22, 2025
Browse files
Correction to TP logic for Mamba Mixer 2 when Num Groups not divisible by TP Size (#13660)
parent
da31b533
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
7 deletions
+21
-7
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+21
-7
No files found.
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
fca20841
...
@@ -133,7 +133,8 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int):
...
@@ -133,7 +133,8 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int):
if
ngroups
%
tp_size
==
0
:
if
ngroups
%
tp_size
==
0
:
return
0
return
0
return
tp_size
-
ngroups
%
tp_size
# for n_groups == 1, this is exactly tp_size - n_groups
return
tp_size
-
ngroups
def
mamba_v2_sharded_weight_loader
(
def
mamba_v2_sharded_weight_loader
(
...
@@ -153,7 +154,7 @@ def mamba_v2_sharded_weight_loader(
...
@@ -153,7 +154,7 @@ def mamba_v2_sharded_weight_loader(
boundary
,
loaded_boundary
=
0
,
0
boundary
,
loaded_boundary
=
0
,
0
# - iterate over the shard specs
# - iterate over the shard specs
for
full_dim
,
extra
,
ratio
in
shard_spec
:
for
full_dim
,
extra
,
duplicate_groups
in
shard_spec
:
# - full dim is the model dim (before TP).
# - full dim is the model dim (before TP).
# - extra > 0, means there is expected overall increase
# - extra > 0, means there is expected overall increase
# of dimensions. This is so because of replication.
# of dimensions. This is so because of replication.
...
@@ -167,7 +168,12 @@ def mamba_v2_sharded_weight_loader(
...
@@ -167,7 +168,12 @@ def mamba_v2_sharded_weight_loader(
# - compute the rank into the loaded shard.
# - compute the rank into the loaded shard.
# - if there is replication, different TP shards will
# - if there is replication, different TP shards will
# take from the same rank.
# take from the same rank.
rank
=
tp_rank
//
ratio
if
duplicate_groups
:
# NOTE: currently we only support duplication
# in the case where num_groups == 1
rank
=
0
else
:
rank
=
tp_rank
# - leftmost boundary index into loaded weight.
# - leftmost boundary index into loaded weight.
loaded_skip
=
rank
*
shard_size
loaded_skip
=
rank
*
shard_size
...
@@ -233,12 +239,21 @@ class MambaMixer2(CustomOp):
...
@@ -233,12 +239,21 @@ class MambaMixer2(CustomOp):
# - HOWEVER IF, world_size DOES NOT divide groups, then we need
# - HOWEVER IF, world_size DOES NOT divide groups, then we need
# to allocate extra space in the shard, such that groups
# to allocate extra space in the shard, such that groups
# may be replicated to follow the head shard.
# may be replicated to follow the head shard.
# - NOTE: currently for the world size DOES NOT divide groups
# case, we only support the case when n_groups == 1
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
assert
num_heads
%
self
.
tp_size
==
0
,
\
assert
num_heads
%
self
.
tp_size
==
0
,
\
"Tensor parallel world size must divide num heads."
"Tensor parallel world size must divide num heads."
assert
(
n_groups
%
self
.
tp_size
)
==
0
or
n_groups
==
1
,
\
(
"If tensor parallel world size does not divide num_heads, "
"then num_groups must equal 1."
)
self
.
ssm_state_size
=
ssm_state_size
self
.
ssm_state_size
=
ssm_state_size
self
.
activation
=
activation
self
.
activation
=
activation
...
@@ -284,11 +299,10 @@ class MambaMixer2(CustomOp):
...
@@ -284,11 +299,10 @@ class MambaMixer2(CustomOp):
self
.
n_groups
*
self
.
ssm_state_size
,
# expected model size
self
.
n_groups
*
self
.
ssm_state_size
,
# expected model size
(
self
.
n_groups
-
n_groups
)
*
(
self
.
n_groups
-
n_groups
)
*
self
.
ssm_state_size
,
# extra dims assigned
self
.
ssm_state_size
,
# extra dims assigned
self
.
num_heads
//
n_groups
==
1
,
# if there was only one group
n_groups
,
# ratio for mapping back to original group
)
)
intermediate_settings
=
(
intermediate_size
,
0
,
1
)
intermediate_settings
=
(
intermediate_size
,
0
,
False
)
head_setings
=
(
self
.
num_heads
,
0
,
1
)
head_setings
=
(
self
.
num_heads
,
0
,
False
)
# - the weight already has a "weight_loader" attribute
# - the weight already has a "weight_loader" attribute
# which set_weight_attrs will raise if we do not
# which set_weight_attrs will raise if we do not
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment