Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ef7eefe1
Unverified
Commit
ef7eefe1
authored
Sep 18, 2025
by
Tao He
Committed by
GitHub
Sep 18, 2025
Browse files
[Qwen] Add fp8 checkpoint support for qwen3-next. (#25079)
Signed-off-by:
Tao He
<
linzhu.ht@alibaba-inc.com
>
parent
350c94de
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
21 deletions
+22
-21
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+17
-18
vllm/model_executor/models/qwen3_next_mtp.py
vllm/model_executor/models/qwen3_next_mtp.py
+5
-3
No files found.
vllm/model_executor/models/qwen3_next.py
View file @
ef7eefe1
...
@@ -30,7 +30,6 @@ from vllm.model_executor.layers.layernorm import (
...
@@ -30,7 +30,6 @@ from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm
as
Qwen3NextRMSNorm
)
GemmaRMSNorm
as
Qwen3NextRMSNorm
)
# yapf: enable
# yapf: enable
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
...
@@ -254,12 +253,20 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
...
@@ -254,12 +253,20 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
# projection of the input hidden states
# projection of the input hidden states
self
.
projection_size_qkvz
=
self
.
key_dim
*
2
+
self
.
value_dim
*
2
self
.
projection_size_qkvz
=
self
.
key_dim
*
2
+
self
.
value_dim
*
2
self
.
projection_size_ba
=
self
.
num_v_heads
*
2
self
.
projection_size_ba
=
self
.
num_v_heads
*
2
self
.
in_proj
=
Merged
ColumnParallelLinear
(
self
.
in_proj
_qkvz
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
input_size
=
self
.
hidden_size
,
output_size
s
=
[
self
.
projection_size_qkvz
,
self
.
projection_size_ba
],
output_size
=
self
.
projection_size_qkvz
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj"
,
prefix
=
f
"
{
prefix
}
.in_proj_qkvz"
,
)
# ba_proj doesn't support blockwise fp8 quantization.
self
.
in_proj_ba
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
=
self
.
projection_size_ba
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_ba"
,
)
)
query_key_settings
=
(
self
.
key_dim
,
0
,
False
)
query_key_settings
=
(
self
.
key_dim
,
0
,
False
)
...
@@ -420,19 +427,14 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
...
@@ -420,19 +427,14 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
ssm_state
=
self_kv_cache
[
1
]
ssm_state
=
self_kv_cache
[
1
]
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_accepted_tokens
=
attn_metadata
.
num_accepted_tokens
num_accepted_tokens
=
attn_metadata
.
num_accepted_tokens
# 1. Set up dimensions for reshapes later
projected_states
,
_
=
self
.
in_proj
(
hidden_states
[:
num_actual_tokens
])
if
spec_token_masks
is
not
None
:
if
spec_token_masks
is
not
None
:
spec_token_masks
=
spec_token_masks
[:
num_actual_tokens
]
spec_token_masks
=
spec_token_masks
[:
num_actual_tokens
]
projected_states_qkvz
,
projected_states_ba
=
torch
.
split
(
projected_states
,
# 1. Set up dimensions for reshapes later
[
projected_states_qkvz
,
_
=
self
.
in_proj_qkvz
(
self
.
projection_size_qkvz
//
self
.
tp_size
,
hidden_states
[:
num_actual_tokens
])
self
.
projection_size_ba
//
self
.
tp_size
projected_states_ba
,
_
=
self
.
in_proj_ba
(
],
hidden_states
[:
num_actual_tokens
])
dim
=-
1
,
)
query
,
key
,
value
,
z
,
b
,
a
=
self
.
fix_query_key_value_ordering
(
query
,
key
,
value
,
z
,
b
,
a
=
self
.
fix_query_key_value_ordering
(
projected_states_qkvz
,
projected_states_ba
)
projected_states_qkvz
,
projected_states_ba
)
query
,
key
,
value
=
map
(
lambda
x
:
rearrange
(
x
,
'l p d -> l (p d)'
),
query
,
key
,
value
=
map
(
lambda
x
:
rearrange
(
x
,
'l p d -> l (p d)'
),
...
@@ -976,8 +978,6 @@ class Qwen3NextModel(nn.Module):
...
@@ -976,8 +978,6 @@ class Qwen3NextModel(nn.Module):
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"in_proj"
,
"in_proj_qkvz"
,
0
),
(
"in_proj"
,
"in_proj_ba"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
...
@@ -1055,7 +1055,6 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
...
@@ -1055,7 +1055,6 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
"v_proj"
,
"v_proj"
,
],
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
],
"in_proj"
:
[
"in_proj_qkvz"
,
"in_proj_ba"
],
}
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
vllm/model_executor/models/qwen3_next_mtp.py
View file @
ef7eefe1
...
@@ -63,7 +63,9 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
...
@@ -63,7 +63,9 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
self
.
config
.
hidden_size
,
self
.
config
.
hidden_size
,
gather_output
=
True
,
gather_output
=
True
,
bias
=
False
,
bias
=
False
,
return_bias
=
False
)
return_bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
'
{
prefix
}
.fc'
)
self
.
layers
=
torch
.
nn
.
ModuleList
(
self
.
layers
=
torch
.
nn
.
ModuleList
(
Qwen3NextDecoderLayer
(
Qwen3NextDecoderLayer
(
...
@@ -72,7 +74,7 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
...
@@ -72,7 +74,7 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
model_config
=
model_config
,
model_config
=
model_config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
'
{
prefix
}
.layers.
{
self
.
mtp_start_layer_idx
+
idx
}
'
,
prefix
=
f
'
{
prefix
}
.layers.
{
idx
}
'
,
)
for
idx
in
range
(
self
.
num_mtp_layers
))
)
for
idx
in
range
(
self
.
num_mtp_layers
))
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
...
@@ -233,7 +235,7 @@ class Qwen3NextMTP(nn.Module, SupportsPP):
...
@@ -233,7 +235,7 @@ class Qwen3NextMTP(nn.Module, SupportsPP):
self
.
config
=
config
self
.
config
=
config
self
.
model
=
Qwen3NextMultiTokenPredictor
(
vllm_config
=
vllm_config
,
self
.
model
=
Qwen3NextMultiTokenPredictor
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
=
maybe_prefix
(
prefix
,
"m
odel
"
))
prefix
,
"m
tp
"
))
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment