Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6cd43ae5
Commit
6cd43ae5
authored
Jan 23, 2026
by
zhuwenwen
Browse files
fix qwen3-next nn layout
parent
88411543
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
55 additions
and
10 deletions
+55
-10
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+48
-8
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+7
-2
No files found.
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
6cd43ae5
...
@@ -36,6 +36,7 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -36,6 +36,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
from
vllm.v1.attention.backends.mamba2_attn
import
Mamba2AttentionMetadata
from
vllm.v1.attention.backends.mamba2_attn
import
Mamba2AttentionMetadata
import
vllm.envs
as
envs
# Added by the IBM Team, 2024
# Added by the IBM Team, 2024
...
@@ -160,6 +161,15 @@ def mamba_v2_sharded_weight_loader(
...
@@ -160,6 +161,15 @@ def mamba_v2_sharded_weight_loader(
# - track boundary of (sharded) param, and loaded_weight, respectively
# - track boundary of (sharded) param, and loaded_weight, respectively
boundary
,
loaded_boundary
=
0
,
0
boundary
,
loaded_boundary
=
0
,
0
if
envs
.
VLLM_USE_NN
:
loaded_total_dim
=
sum
(
full_dim
-
extra
for
full_dim
,
extra
,
_
in
shard_spec
)
param_out_axis
=
0
if
param
.
dim
()
==
1
else
(
param
.
dim
()
-
1
)
loaded_out_axis
=
0
if
(
loaded_weight
.
dim
()
>
1
and
loaded_weight
.
shape
[
-
1
]
==
loaded_total_dim
and
loaded_weight
.
shape
[
0
]
!=
loaded_total_dim
):
loaded_out_axis
=
loaded_weight
.
dim
()
-
1
# - iterate over the shard specs
# - iterate over the shard specs
for
full_dim
,
extra
,
duplicate_groups
in
shard_spec
:
for
full_dim
,
extra
,
duplicate_groups
in
shard_spec
:
# - full dim is the model dim (before TP).
# - full dim is the model dim (before TP).
...
@@ -190,12 +200,38 @@ def mamba_v2_sharded_weight_loader(
...
@@ -190,12 +200,38 @@ def mamba_v2_sharded_weight_loader(
# - the ignore is for a mundane mypy error as it does not
# - the ignore is for a mundane mypy error as it does not
# seem to handle slices well.
# seem to handle slices well.
# https://github.com/python/mypy/issues/2410
# https://github.com/python/mypy/issues/2410
param
.
data
[
if
envs
.
VLLM_USE_NN
:
boundary
:(
boundary
+
take
),
if
take
>
0
:
...
# type: ignore[misc]
param_slice
=
param
.
data
.
narrow
(
param_out_axis
,
boundary
,
take
)
]
=
loaded_weight
[
loaded_start_idx
:(
loaded_start_idx
+
loaded_slice
=
loaded_weight
.
narrow
(
loaded_out_axis
,
take
)
# type: ignore[misc]
loaded_start_idx
,
take
)
]
# type: ignore[misc]
if
(
param_slice
.
dim
()
==
loaded_slice
.
dim
()
+
1
and
param_slice
.
shape
[
1
]
==
1
):
loaded_slice
=
loaded_slice
.
unsqueeze
(
1
)
elif
(
loaded_slice
.
dim
()
==
param_slice
.
dim
()
+
1
and
loaded_slice
.
shape
[
1
]
==
1
):
loaded_slice
=
loaded_slice
.
squeeze
(
1
)
if
param_slice
.
shape
!=
loaded_slice
.
shape
:
loaded_slice
=
loaded_slice
.
permute
(
*
reversed
(
range
(
loaded_slice
.
dim
())))
if
param_slice
.
shape
!=
loaded_slice
.
shape
:
raise
RuntimeError
(
"mamba_v2_sharded_weight_loader shape mismatch: "
f
"param_slice=
{
tuple
(
param_slice
.
shape
)
}
"
f
"loaded_slice=
{
tuple
(
loaded_slice
.
shape
)
}
"
f
"(param_out_axis=
{
param_out_axis
}
, "
f
"loaded_out_axis=
{
loaded_out_axis
}
)"
)
param_slice
.
copy_
(
loaded_slice
)
else
:
param
.
data
[
boundary
:(
boundary
+
take
),
...
# type: ignore[misc]
]
=
loaded_weight
[
loaded_start_idx
:(
loaded_start_idx
+
take
)
# type: ignore[misc]
]
# type: ignore[misc]
# move indexing boundaries
# move indexing boundaries
boundary
+=
shard_size
boundary
+=
shard_size
...
@@ -522,8 +558,12 @@ class MambaMixer2(MambaBase, CustomOp):
...
@@ -522,8 +558,12 @@ class MambaMixer2(MambaBase, CustomOp):
dim
=-
1
,
dim
=-
1
,
)
)
conv_weights
=
self
.
conv1d
.
weight
.
view
(
self
.
conv1d
.
weight
.
size
(
0
),
if
envs
.
VLLM_USE_NN
:
self
.
conv1d
.
weight
.
size
(
2
))
conv_weights
=
self
.
conv1d
.
weight
.
squeeze
(
1
).
transpose
(
0
,
1
).
contiguous
()
else
:
conv_weights
=
self
.
conv1d
.
weight
.
view
(
self
.
conv1d
.
weight
.
size
(
0
),
self
.
conv1d
.
weight
.
size
(
2
))
# - get hidden_states, B and C after depthwise convolution.
# - get hidden_states, B and C after depthwise convolution.
split_hidden_states_B_C_fn
=
lambda
hidden_states_B_C
:
torch
.
split
(
split_hidden_states_B_C_fn
=
lambda
hidden_states_B_C
:
torch
.
split
(
...
...
vllm/model_executor/models/qwen3_next.py
View file @
6cd43ae5
...
@@ -63,6 +63,7 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
...
@@ -63,6 +63,7 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
import
vllm.envs
as
envs
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -432,8 +433,12 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
...
@@ -432,8 +433,12 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
mixed_qkv
=
torch
.
cat
((
query
,
key
,
value
),
dim
=-
1
)
mixed_qkv
=
torch
.
cat
((
query
,
key
,
value
),
dim
=-
1
)
# 2. Convolution sequence transformation
# 2. Convolution sequence transformation
conv_weights
=
self
.
conv1d
.
weight
.
view
(
self
.
conv1d
.
weight
.
size
(
0
),
if
envs
.
VLLM_USE_NN
:
self
.
conv1d
.
weight
.
size
(
2
))
conv_weights
=
self
.
conv1d
.
weight
.
squeeze
(
1
).
transpose
(
0
,
1
).
contiguous
()
else
:
conv_weights
=
self
.
conv1d
.
weight
.
view
(
self
.
conv1d
.
weight
.
size
(
0
),
self
.
conv1d
.
weight
.
size
(
2
))
if
spec_sequence_masks
is
not
None
:
if
spec_sequence_masks
is
not
None
:
if
(
attn_metadata
.
num_prefills
==
0
if
(
attn_metadata
.
num_prefills
==
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment