Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f3505904
Commit
f3505904
authored
Mar 27, 2026
by
yuanyuan
Browse files
porting some qwen3.5 fp8 block quant bugfix
parent
f28b6574
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
139 additions
and
207 deletions
+139
-207
.gitignore
.gitignore
+2
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+46
-14
vllm/model_executor/models/qwen3_5.py
vllm/model_executor/models/qwen3_5.py
+62
-176
vllm/model_executor/models/qwen3_5_mtp.py
vllm/model_executor/models/qwen3_5_mtp.py
+2
-4
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+27
-13
No files found.
.gitignore
View file @
f3505904
...
...
@@ -238,3 +238,5 @@ ep_kernels_workspace/
vllm/grpc/vllm_engine_pb2.py
vllm/grpc/vllm_engine_pb2_grpc.py
vllm/grpc/vllm_engine_pb2.pyi
vllm/version.py
\ No newline at end of file
vllm/model_executor/layers/linear.py
View file @
f3505904
...
...
@@ -748,8 +748,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
int
|
None
=
None
,
loaded_shard_id
:
tuple
[
int
,
...]
|
int
|
None
=
None
,
):
if
isinstance
(
loaded_shard_id
,
tuple
):
raise
NotImplementedError
(
"Shard id with multiple indices is not supported in weight_loader, "
"please use weight_loader_v2 instead."
)
# Special case for GGUF
# initialize GGUF param after we know the quantize type
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
...
...
@@ -781,7 +786,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for per-tensor scale to load scalar into fused array.
needs_scalar_to_array
=
getattr
(
param
,
"needs_scalar_to_array"
,
False
)
if
loaded_shard_id
is
None
:
if
loaded_shard_id
is
None
or
isinstance
(
loaded_shard_id
,
tuple
)
:
# Loaded weight is already fused on disk (mlp).
# (e.g., Phi-3's gate_up_proj).
if
output_dim
is
None
:
...
...
@@ -793,10 +798,20 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
return
output_sizes
=
(
self
.
output_sizes
[
loaded_shard_id
[
0
]
:
loaded_shard_id
[
-
1
]
+
1
]
if
loaded_shard_id
is
not
None
else
self
.
output_sizes
)
current_shard_offset
=
0
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
if
use_bitsandbytes_4bit
and
isinstance
(
loaded_shard_id
,
tuple
):
raise
NotImplementedError
(
"Shard id with multiple indices is not supported "
"for BNB quantization yet."
)
shard_offsets
:
list
[
tuple
[
int
,
int
,
int
]]
=
[]
for
i
,
output_size
in
enumerate
(
self
.
output_sizes
):
for
i
,
output_size
in
enumerate
(
output_sizes
):
shard_offsets
.
append
((
i
,
current_shard_offset
,
output_size
))
current_shard_offset
+=
output_size
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
...
...
@@ -838,15 +853,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
shard_offset
//=
self
.
tp_size
shard_size
//=
self
.
tp_size
if
isinstance
(
param
,
BlockQuantScaleParameter
):
weight_block_size
=
getattr
(
self
,
"weight_block_size"
,
None
)
shard_size
,
shard_offset
=
adjust_block_scale_shard
(
weight_block_size
,
shard_size
,
shard_offset
)
shard_offset
//=
self
.
tp_size
shard_size
//=
self
.
tp_size
# Special case for quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
...
...
@@ -901,7 +916,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data
.
copy_
(
loaded_weight
)
def
_load_fused_module_from_checkpoint
(
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
,
output_sizes
:
list
[
int
]
|
None
=
None
,
):
"""
Handle special case for models where MLP layers are already
...
...
@@ -915,7 +933,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
current_shard_offset
=
0
shard_offsets
:
list
[
tuple
[
int
,
int
,
int
]]
=
[]
for
i
,
output_size
in
enumerate
(
self
.
output_sizes
):
output_sizes
=
output_sizes
or
self
.
output_sizes
for
i
,
output_size
in
enumerate
(
output_sizes
):
shard_offsets
.
append
((
i
,
current_shard_offset
,
output_size
))
current_shard_offset
+=
output_size
...
...
@@ -940,17 +959,30 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
int
|
None
=
None
,
loaded_shard_id
:
tuple
[
int
,
...]
|
int
|
None
=
None
,
):
if
loaded_shard_id
is
None
:
if
loaded_shard_id
is
None
or
isinstance
(
loaded_shard_id
,
tuple
)
:
if
isinstance
(
param
,
PerTensorScaleParameter
):
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
0
)
return
elif
type
(
param
)
in
(
RowvLLMParameter
,
BasevLLMParameter
):
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
)
return
output_sizes
=
(
[
self
.
output_sizes
[
idx
]
for
idx
in
loaded_shard_id
]
if
loaded_shard_id
else
None
)
if
isinstance
(
param
,
BlockQuantScaleParameter
):
weight_block_size
=
getattr
(
self
,
"weight_block_size"
,
None
)
output_sizes
=
[
adjust_block_scale_shard
(
weight_block_size
,
size
,
0
)[
0
]
for
size
in
(
output_sizes
or
self
.
output_sizes
)
]
# TODO: @dsikka - move to parameter.py
self
.
_load_fused_module_from_checkpoint
(
param
,
loaded_weight
)
self
.
_load_fused_module_from_checkpoint
(
param
,
loaded_weight
,
output_sizes
=
output_sizes
)
return
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
...
...
@@ -958,15 +990,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
shard_offset
//=
self
.
tp_size
shard_size
//=
self
.
tp_size
if
isinstance
(
param
,
BlockQuantScaleParameter
):
weight_block_size
=
getattr
(
self
,
"weight_block_size"
,
None
)
shard_size
,
shard_offset
=
adjust_block_scale_shard
(
weight_block_size
,
shard_size
,
shard_offset
)
shard_offset
//=
self
.
tp_size
shard_size
//=
self
.
tp_size
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
loaded_shard_id
,
...
...
vllm/model_executor/models/qwen3_5.py
View file @
f3505904
...
...
@@ -30,44 +30,20 @@ from collections.abc import Callable, Iterable
import
torch
from
einops
import
rearrange
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.models.qwen3_5.configuration_qwen3_5
import
(
Qwen3_5Config
,
Qwen3_5TextConfig
,
)
from
transformers.models.qwen3_5_moe.configuration_qwen3_5_moe
import
(
Qwen3_5MoeConfig
,
Qwen3_5MoeTextConfig
,
)
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
SpeculativeConfig
,
VllmConfig
,
get_current_vllm_config
,
)
from
vllm.distributed
import
(
divide
,
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
(
GemmaRMSNorm
as
Qwen3_5RMSNorm
,
)
from
vllm.model_executor.layers.layernorm
import
RMSNormGated
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.linear
import
MergedColumnParallelLinear
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
(
mamba_v2_sharded_weight_loader
,
)
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateCopyFunc
,
MambaStateCopyFuncCalculator
,
...
...
@@ -81,12 +57,18 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
sharded_weight_loader
,
maybe_remap_kv_scale_name
,
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
transformers.models.qwen3_5.configuration_qwen3_5
import
(
Qwen3_5Config
,
Qwen3_5TextConfig
,
)
from
transformers.models.qwen3_5_moe.configuration_qwen3_5_moe
import
(
Qwen3_5MoeConfig
,
Qwen3_5MoeTextConfig
,
)
from
.interfaces
import
(
HasInnerState
,
...
...
@@ -138,151 +120,29 @@ class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo):
class
Qwen3_5GatedDeltaNet
(
Qwen3NextGatedDeltaNet
):
def
__init__
(
self
,
config
:
Qwen3_5TextConfig
|
Qwen3_5MoeTextConfig
,
model_config
:
ModelConfig
|
None
=
None
,
cache_config
:
CacheConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
speculative_config
:
SpeculativeConfig
|
None
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
(
Qwen3NextGatedDeltaNet
,
self
).
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
hidden_size
=
config
.
hidden_size
self
.
num_v_heads
=
config
.
linear_num_value_heads
self
.
num_k_heads
=
config
.
linear_num_key_heads
self
.
head_k_dim
=
config
.
linear_key_head_dim
self
.
head_v_dim
=
config
.
linear_value_head_dim
self
.
key_dim
=
self
.
head_k_dim
*
self
.
num_k_heads
self
.
value_dim
=
self
.
head_v_dim
*
self
.
num_v_heads
self
.
conv_kernel_size
=
config
.
linear_conv_kernel_dim
self
.
layer_idx
=
extract_layer_index
(
prefix
)
self
.
activation
=
config
.
hidden_act
self
.
act
=
ACT2FN
[
config
.
hidden_act
]
self
.
layer_norm_epsilon
=
config
.
rms_norm_eps
self
.
prefix
=
prefix
self
.
config
=
config
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
quant_config
=
quant_config
self
.
speculative_config
=
speculative_config
self
.
num_spec
=
(
self
.
speculative_config
.
num_speculative_tokens
if
self
.
speculative_config
else
0
)
# QKV
self
.
conv_dim
=
self
.
key_dim
*
2
+
self
.
value_dim
self
.
conv1d
=
ColumnParallelLinear
(
input_size
=
self
.
conv_kernel_size
,
output_size
=
self
.
conv_dim
,
bias
=
False
,
prefix
=
f
"
{
prefix
}
.conv1d"
,
)
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
self
.
in_proj_qkv
=
MergedColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_sizes
=
[
self
.
key_dim
,
self
.
key_dim
,
self
.
value_dim
],
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_qkv"
,
)
self
.
in_proj_z
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
=
self
.
value_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_z"
,
)
self
.
in_proj_b
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
=
self
.
num_v_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_ba"
,
)
self
.
in_proj_a
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
=
self
.
num_v_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_a"
,
)
query_key_settings
=
(
self
.
key_dim
,
0
,
False
)
value_settings
=
(
self
.
value_dim
,
0
,
False
)
delattr
(
self
.
conv1d
.
weight
,
"weight_loader"
)
set_weight_attrs
(
self
.
conv1d
.
weight
,
{
"weight_loader"
:
mamba_v2_sharded_weight_loader
(
[
query_key_settings
,
query_key_settings
,
value_settings
,
],
self
.
tp_size
,
self
.
tp_rank
,
)
},
)
# selective projection used to make dt, B and C input dependant
# time step projection (discretization)
# instantiate once and copy inv_dt in init_weights of PretrainedModel
self
.
dt_bias
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_v_heads
//
self
.
tp_size
),
)
self
.
A_log
=
nn
.
Parameter
(
torch
.
empty
(
divide
(
self
.
num_v_heads
,
self
.
tp_size
),
)
)
set_weight_attrs
(
self
.
A_log
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
set_weight_attrs
(
self
.
dt_bias
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
self
.
norm
=
RMSNormGated
(
self
.
head_v_dim
,
eps
=
self
.
layer_norm_epsilon
,
group_size
=
None
,
norm_before_gate
=
True
,
device
=
current_platform
.
current_device
(),
dtype
=
config
.
dtype
,
)
self
.
out_proj
=
RowParallelLinear
(
self
.
value_dim
,
self
.
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.out_proj"
,
)
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
def
fix_query_key_value_ordering
(
self
,
mixed_qkv
,
z
,
b
,
a
,
mixed_qkvz
:
torch
.
Tensor
,
mixed_ba
:
torch
.
Tensor
,
):
raise
NotImplementedError
(
"Qwen3.5 Series dont need to fix query key value ordering"
)
def
create_qkvz_proj
(
self
,
hidden_size
:
int
,
key_dim
:
int
,
value_dim
:
int
,
quant_config
:
QuantizationConfig
|
None
,
prefix
:
str
,
)
->
MergedColumnParallelLinear
:
return
MergedColumnParallelLinear
(
input_size
=
hidden_size
,
output_sizes
=
[
key_dim
,
key_dim
,
value_dim
,
value_dim
],
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
prefix
,
)
def
forward
(
self
,
...
...
@@ -300,11 +160,13 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
# ============================================================
# Part 1: Input Projection
# ============================================================
mixed_qkv
,
_
=
self
.
in_proj_qkv
(
hidden_states
)
z
,
_
=
self
.
in_proj_z
(
hidden_states
)
mixed_qkvz
,
_
=
self
.
in_proj_qkvz
(
hidden_states
)
qkv_size
=
(
self
.
key_dim
*
2
+
self
.
value_dim
)
//
self
.
tp_size
z_size
=
self
.
value_dim
//
self
.
tp_size
mixed_qkv
,
z
=
mixed_qkvz
.
split
([
qkv_size
,
z_size
],
dim
=-
1
)
z
=
z
.
reshape
(
z
.
size
(
0
),
-
1
,
self
.
head_v_dim
)
b
,
_
=
self
.
in_proj_b
(
hidden_states
)
a
,
_
=
self
.
in_proj_a
(
hidden_states
)
b
a
,
_
=
self
.
in_proj_b
a
(
hidden_states
)
b
,
a
=
ba
.
chunk
(
2
,
dim
=-
1
)
b
=
b
.
contiguous
()
a
=
a
.
contiguous
()
...
...
@@ -411,7 +273,6 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
1
,
1
,
config
.
hidden_size
,
dtype
=
config
.
dtype
,
),
)
self
.
ffn_layer_scale
=
torch
.
nn
.
Parameter
(
...
...
@@ -419,7 +280,6 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
1
,
1
,
config
.
hidden_size
,
dtype
=
config
.
dtype
,
),
)
...
...
@@ -503,11 +363,18 @@ class Qwen3_5Model(Qwen3NextModel):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# self attention
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
# mlp
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
# GDN
(
"in_proj_qkvz"
,
"in_proj_qkv"
,
(
0
,
1
,
2
)),
(
"in_proj_qkvz"
,
"in_proj_z"
,
3
),
(
"in_proj_ba"
,
"in_proj_b"
,
0
),
(
"in_proj_ba"
,
"in_proj_a"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
...
...
@@ -528,6 +395,12 @@ class Qwen3_5Model(Qwen3NextModel):
if
name
.
startswith
(
"mtp."
):
continue
# Remapping the name of FP8 kv-scale.
if
name
.
endswith
(
"scale"
):
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
"experts.gate_up_proj"
in
name
or
"experts.down_proj"
in
name
:
is_fused_expert
=
True
...
...
@@ -654,6 +527,9 @@ class Qwen3_5ForCausalLMBase(
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
],
# GDN fused projections.
"in_proj_qkvz"
:
[
"in_proj_qkv"
,
"in_proj_z"
],
"in_proj_ba"
:
[
"in_proj_b"
,
"in_proj_a"
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
@@ -751,7 +627,15 @@ class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLMBase, QwenNextMixtureOfExperts):
dummy_inputs
=
Qwen3VLDummyInputsBuilder
,
)
class
Qwen3_5ForConditionalGeneration
(
Qwen3VLForConditionalGeneration
,
IsHybrid
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
# Qwen3.5 does not support multimodal pruning (EVS).
supports_multimodal_pruning
=
False
packed_modules_mapping
=
Qwen3VLForConditionalGeneration
.
packed_modules_mapping
|
{
"in_proj_qkvz"
:
[
"in_proj_qkv"
,
"in_proj_z"
],
"in_proj_ba"
:
[
"in_proj_b"
,
"in_proj_a"
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
"model"
):
# protocols have not __init__ method, so we need to use nn.Module.__init__
nn
.
Module
.
__init__
(
self
)
config
:
Qwen3_5Config
=
vllm_config
.
model_config
.
hf_config
...
...
@@ -868,7 +752,9 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
vllm_config
:
"VllmConfig"
,
)
->
tuple
[
torch
.
dtype
,
torch
.
dtype
]:
return
MambaStateDtypeCalculator
.
gated_delta_net_state_dtype
(
vllm_config
.
model_config
.
dtype
,
vllm_config
.
cache_config
.
mamba_cache_dtype
vllm_config
.
model_config
.
dtype
,
vllm_config
.
cache_config
.
mamba_cache_dtype
,
vllm_config
.
cache_config
.
mamba_ssm_cache_dtype
,
)
@
classmethod
...
...
@@ -957,7 +843,7 @@ class Qwen3_5_MoeMixtureOfExperts(MixtureOfExperts):
class
Qwen3_5MoeForConditionalGeneration
(
Qwen3_5ForConditionalGeneration
,
Qwen3_5_MoeMixtureOfExperts
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
"
model
"
):
# protocols have not __init__ method, so we need to use nn.Module.__init__
nn
.
Module
.
__init__
(
self
)
config
:
Qwen3_5MoeConfig
=
vllm_config
.
model_config
.
hf_config
...
...
vllm/model_executor/models/qwen3_5_mtp.py
View file @
f3505904
...
...
@@ -7,10 +7,6 @@ from collections.abc import Callable, Iterable
import
torch
from
torch
import
nn
from
transformers.models.qwen3_5.configuration_qwen3_5
import
Qwen3_5TextConfig
from
transformers.models.qwen3_5_moe.configuration_qwen3_5_moe
import
(
Qwen3_5MoeTextConfig
,
)
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
VllmConfig
...
...
@@ -27,6 +23,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.models.qwen3_5
import
Qwen3_5DecoderLayer
,
Qwen3_5RMSNorm
from
vllm.model_executor.models.qwen3_next
import
QwenNextMixtureOfExperts
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs.qwen3_5
import
Qwen3_5TextConfig
from
vllm.transformers_utils.configs.qwen3_5_moe
import
Qwen3_5MoeTextConfig
from
.interfaces
import
(
MultiModalEmbeddings
,
...
...
vllm/model_executor/models/qwen3_next.py
View file @
f3505904
...
...
@@ -40,6 +40,7 @@ from vllm.model_executor.layers.layernorm import (
from
vllm.model_executor.layers.layernorm
import
RMSNormGated
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
...
...
@@ -292,19 +293,19 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
# projection of the input hidden states
self
.
projection_size_qkvz
=
self
.
key_dim
*
2
+
self
.
value_dim
*
2
self
.
projection_size_ba
=
self
.
num_v_heads
*
2
self
.
in_proj_qkvz
=
ColumnParallelLinear
(
input
_size
=
self
.
hidden_size
,
output_size
=
self
.
projection_size_qkvz
,
bias
=
False
,
# Qwen3-Next and Qwen3.5 has a different qkv_proj layout,
# we need to create qkvz_proj adaptively here.
self
.
in_proj_qkvz
=
self
.
create_qkvz_proj
(
hidden
_size
=
self
.
hidden_size
,
key_dim
=
self
.
key_dim
,
value_dim
=
self
.
value_dim
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_qkvz"
,
)
# ba_proj doesn't support blockwise fp8 quantization.
self
.
in_proj_ba
=
ColumnParallelLinear
(
self
.
in_proj_ba
=
Merged
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
=
self
.
projection_size_ba
,
output_size
s
=
[
self
.
num_v_heads
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_ba"
,
...
...
@@ -351,7 +352,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
group_size
=
None
,
norm_before_gate
=
True
,
device
=
current_platform
.
current_device
(),
dtype
=
config
.
dtype
,
)
self
.
out_proj
=
RowParallelLinear
(
...
...
@@ -368,10 +368,26 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
def
create_qkvz_proj
(
self
,
hidden_size
:
int
,
key_dim
:
int
,
value_dim
:
int
,
quant_config
:
QuantizationConfig
|
None
,
prefix
:
str
,
)
->
MergedColumnParallelLinear
:
return
MergedColumnParallelLinear
(
input_size
=
hidden_size
,
output_sizes
=
[
sum
((
key_dim
,
key_dim
,
value_dim
)),
value_dim
],
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_qkvz"
,
)
def
fix_query_key_value_ordering
(
self
,
mixed_qkvz
,
mixed_ba
,
mixed_qkvz
:
torch
.
Tensor
,
mixed_ba
:
torch
.
Tensor
,
):
"""
Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
...
...
@@ -893,7 +909,6 @@ class Qwen3NextDecoderLayer(nn.Module):
1
,
1
,
config
.
hidden_size
,
dtype
=
config
.
dtype
,
),
)
self
.
ffn_layer_scale
=
torch
.
nn
.
Parameter
(
...
...
@@ -901,7 +916,6 @@ class Qwen3NextDecoderLayer(nn.Module):
1
,
1
,
config
.
hidden_size
,
dtype
=
config
.
dtype
,
),
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment