Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f3505904
Commit
f3505904
authored
Mar 27, 2026
by
yuanyuan
Browse files
porting some qwen3.5 fp8 block quant bugfix
parent
f28b6574
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
139 additions
and
207 deletions
+139
-207
.gitignore
.gitignore
+2
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+46
-14
vllm/model_executor/models/qwen3_5.py
vllm/model_executor/models/qwen3_5.py
+62
-176
vllm/model_executor/models/qwen3_5_mtp.py
vllm/model_executor/models/qwen3_5_mtp.py
+2
-4
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+27
-13
No files found.
.gitignore
View file @
f3505904
...
@@ -238,3 +238,5 @@ ep_kernels_workspace/
...
@@ -238,3 +238,5 @@ ep_kernels_workspace/
vllm/grpc/vllm_engine_pb2.py
vllm/grpc/vllm_engine_pb2.py
vllm/grpc/vllm_engine_pb2_grpc.py
vllm/grpc/vllm_engine_pb2_grpc.py
vllm/grpc/vllm_engine_pb2.pyi
vllm/grpc/vllm_engine_pb2.pyi
vllm/version.py
\ No newline at end of file
vllm/model_executor/layers/linear.py
View file @
f3505904
...
@@ -748,8 +748,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -748,8 +748,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self
,
self
,
param
:
Parameter
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
int
|
None
=
None
,
loaded_shard_id
:
tuple
[
int
,
...]
|
int
|
None
=
None
,
):
):
if
isinstance
(
loaded_shard_id
,
tuple
):
raise
NotImplementedError
(
"Shard id with multiple indices is not supported in weight_loader, "
"please use weight_loader_v2 instead."
)
# Special case for GGUF
# Special case for GGUF
# initialize GGUF param after we know the quantize type
# initialize GGUF param after we know the quantize type
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
...
@@ -781,7 +786,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -781,7 +786,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for per-tensor scale to load scalar into fused array.
# Special case for per-tensor scale to load scalar into fused array.
needs_scalar_to_array
=
getattr
(
param
,
"needs_scalar_to_array"
,
False
)
needs_scalar_to_array
=
getattr
(
param
,
"needs_scalar_to_array"
,
False
)
if
loaded_shard_id
is
None
:
if
loaded_shard_id
is
None
or
isinstance
(
loaded_shard_id
,
tuple
)
:
# Loaded weight is already fused on disk (mlp).
# Loaded weight is already fused on disk (mlp).
# (e.g., Phi-3's gate_up_proj).
# (e.g., Phi-3's gate_up_proj).
if
output_dim
is
None
:
if
output_dim
is
None
:
...
@@ -793,10 +798,20 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -793,10 +798,20 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert
param_data
.
shape
==
loaded_weight
.
shape
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
param_data
.
copy_
(
loaded_weight
)
return
return
output_sizes
=
(
self
.
output_sizes
[
loaded_shard_id
[
0
]
:
loaded_shard_id
[
-
1
]
+
1
]
if
loaded_shard_id
is
not
None
else
self
.
output_sizes
)
current_shard_offset
=
0
current_shard_offset
=
0
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
if
use_bitsandbytes_4bit
and
isinstance
(
loaded_shard_id
,
tuple
):
raise
NotImplementedError
(
"Shard id with multiple indices is not supported "
"for BNB quantization yet."
)
shard_offsets
:
list
[
tuple
[
int
,
int
,
int
]]
=
[]
shard_offsets
:
list
[
tuple
[
int
,
int
,
int
]]
=
[]
for
i
,
output_size
in
enumerate
(
self
.
output_sizes
):
for
i
,
output_size
in
enumerate
(
output_sizes
):
shard_offsets
.
append
((
i
,
current_shard_offset
,
output_size
))
shard_offsets
.
append
((
i
,
current_shard_offset
,
output_size
))
current_shard_offset
+=
output_size
current_shard_offset
+=
output_size
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
...
@@ -838,15 +853,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -838,15 +853,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
shard_offset
//=
self
.
tp_size
shard_size
//=
self
.
tp_size
if
isinstance
(
param
,
BlockQuantScaleParameter
):
if
isinstance
(
param
,
BlockQuantScaleParameter
):
weight_block_size
=
getattr
(
self
,
"weight_block_size"
,
None
)
weight_block_size
=
getattr
(
self
,
"weight_block_size"
,
None
)
shard_size
,
shard_offset
=
adjust_block_scale_shard
(
shard_size
,
shard_offset
=
adjust_block_scale_shard
(
weight_block_size
,
shard_size
,
shard_offset
weight_block_size
,
shard_size
,
shard_offset
)
)
shard_offset
//=
self
.
tp_size
shard_size
//=
self
.
tp_size
# Special case for quantization.
# Special case for quantization.
# If quantized, we need to adjust the offset and size to account
# If quantized, we need to adjust the offset and size to account
# for the packing.
# for the packing.
...
@@ -901,7 +916,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -901,7 +916,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data
.
copy_
(
loaded_weight
)
param_data
.
copy_
(
loaded_weight
)
def
_load_fused_module_from_checkpoint
(
def
_load_fused_module_from_checkpoint
(
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
,
output_sizes
:
list
[
int
]
|
None
=
None
,
):
):
"""
"""
Handle special case for models where MLP layers are already
Handle special case for models where MLP layers are already
...
@@ -915,7 +933,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -915,7 +933,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
current_shard_offset
=
0
current_shard_offset
=
0
shard_offsets
:
list
[
tuple
[
int
,
int
,
int
]]
=
[]
shard_offsets
:
list
[
tuple
[
int
,
int
,
int
]]
=
[]
for
i
,
output_size
in
enumerate
(
self
.
output_sizes
):
output_sizes
=
output_sizes
or
self
.
output_sizes
for
i
,
output_size
in
enumerate
(
output_sizes
):
shard_offsets
.
append
((
i
,
current_shard_offset
,
output_size
))
shard_offsets
.
append
((
i
,
current_shard_offset
,
output_size
))
current_shard_offset
+=
output_size
current_shard_offset
+=
output_size
...
@@ -940,17 +959,30 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -940,17 +959,30 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self
,
self
,
param
:
BasevLLMParameter
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
int
|
None
=
None
,
loaded_shard_id
:
tuple
[
int
,
...]
|
int
|
None
=
None
,
):
):
if
loaded_shard_id
is
None
:
if
loaded_shard_id
is
None
or
isinstance
(
loaded_shard_id
,
tuple
)
:
if
isinstance
(
param
,
PerTensorScaleParameter
):
if
isinstance
(
param
,
PerTensorScaleParameter
):
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
0
)
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
0
)
return
return
elif
type
(
param
)
in
(
RowvLLMParameter
,
BasevLLMParameter
):
elif
type
(
param
)
in
(
RowvLLMParameter
,
BasevLLMParameter
):
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
)
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
)
return
return
output_sizes
=
(
[
self
.
output_sizes
[
idx
]
for
idx
in
loaded_shard_id
]
if
loaded_shard_id
else
None
)
if
isinstance
(
param
,
BlockQuantScaleParameter
):
weight_block_size
=
getattr
(
self
,
"weight_block_size"
,
None
)
output_sizes
=
[
adjust_block_scale_shard
(
weight_block_size
,
size
,
0
)[
0
]
for
size
in
(
output_sizes
or
self
.
output_sizes
)
]
# TODO: @dsikka - move to parameter.py
# TODO: @dsikka - move to parameter.py
self
.
_load_fused_module_from_checkpoint
(
param
,
loaded_weight
)
self
.
_load_fused_module_from_checkpoint
(
param
,
loaded_weight
,
output_sizes
=
output_sizes
)
return
return
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
...
@@ -958,15 +990,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -958,15 +990,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
shard_offset
//=
self
.
tp_size
shard_size
//=
self
.
tp_size
if
isinstance
(
param
,
BlockQuantScaleParameter
):
if
isinstance
(
param
,
BlockQuantScaleParameter
):
weight_block_size
=
getattr
(
self
,
"weight_block_size"
,
None
)
weight_block_size
=
getattr
(
self
,
"weight_block_size"
,
None
)
shard_size
,
shard_offset
=
adjust_block_scale_shard
(
shard_size
,
shard_offset
=
adjust_block_scale_shard
(
weight_block_size
,
shard_size
,
shard_offset
weight_block_size
,
shard_size
,
shard_offset
)
)
shard_offset
//=
self
.
tp_size
shard_size
//=
self
.
tp_size
param
.
load_merged_column_weight
(
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
loaded_weight
=
loaded_weight
,
shard_id
=
loaded_shard_id
,
shard_id
=
loaded_shard_id
,
...
...
vllm/model_executor/models/qwen3_5.py
View file @
f3505904
...
@@ -30,44 +30,20 @@ from collections.abc import Callable, Iterable
...
@@ -30,44 +30,20 @@ from collections.abc import Callable, Iterable
import
torch
import
torch
from
einops
import
rearrange
from
einops
import
rearrange
from
torch
import
nn
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.models.qwen3_5.configuration_qwen3_5
import
(
Qwen3_5Config
,
Qwen3_5TextConfig
,
)
from
transformers.models.qwen3_5_moe.configuration_qwen3_5_moe
import
(
Qwen3_5MoeConfig
,
Qwen3_5MoeTextConfig
,
)
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
SpeculativeConfig
,
VllmConfig
,
VllmConfig
,
get_current_vllm_config
,
)
)
from
vllm.distributed
import
(
from
vllm.distributed
import
(
divide
,
get_pp_group
,
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
(
from
vllm.model_executor.layers.layernorm
import
(
GemmaRMSNorm
as
Qwen3_5RMSNorm
,
GemmaRMSNorm
as
Qwen3_5RMSNorm
,
)
)
from
vllm.model_executor.layers.layernorm
import
RMSNormGated
from
vllm.model_executor.layers.linear
import
MergedColumnParallelLinear
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
(
mamba_v2_sharded_weight_loader
,
)
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateCopyFunc
,
MambaStateCopyFunc
,
MambaStateCopyFuncCalculator
,
MambaStateCopyFuncCalculator
,
...
@@ -81,12 +57,18 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -81,12 +57,18 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
default_weight_loader
,
sharded_weight_loader
,
maybe_remap_kv_scale_name
,
)
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
transformers.models.qwen3_5.configuration_qwen3_5
import
(
Qwen3_5Config
,
Qwen3_5TextConfig
,
)
from
transformers.models.qwen3_5_moe.configuration_qwen3_5_moe
import
(
Qwen3_5MoeConfig
,
Qwen3_5MoeTextConfig
,
)
from
.interfaces
import
(
from
.interfaces
import
(
HasInnerState
,
HasInnerState
,
...
@@ -138,151 +120,29 @@ class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo):
...
@@ -138,151 +120,29 @@ class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo):
class
Qwen3_5GatedDeltaNet
(
Qwen3NextGatedDeltaNet
):
class
Qwen3_5GatedDeltaNet
(
Qwen3NextGatedDeltaNet
):
def
__init__
(
self
,
config
:
Qwen3_5TextConfig
|
Qwen3_5MoeTextConfig
,
model_config
:
ModelConfig
|
None
=
None
,
cache_config
:
CacheConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
speculative_config
:
SpeculativeConfig
|
None
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
(
Qwen3NextGatedDeltaNet
,
self
).
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
hidden_size
=
config
.
hidden_size
self
.
num_v_heads
=
config
.
linear_num_value_heads
self
.
num_k_heads
=
config
.
linear_num_key_heads
self
.
head_k_dim
=
config
.
linear_key_head_dim
self
.
head_v_dim
=
config
.
linear_value_head_dim
self
.
key_dim
=
self
.
head_k_dim
*
self
.
num_k_heads
self
.
value_dim
=
self
.
head_v_dim
*
self
.
num_v_heads
self
.
conv_kernel_size
=
config
.
linear_conv_kernel_dim
self
.
layer_idx
=
extract_layer_index
(
prefix
)
self
.
activation
=
config
.
hidden_act
self
.
act
=
ACT2FN
[
config
.
hidden_act
]
self
.
layer_norm_epsilon
=
config
.
rms_norm_eps
self
.
prefix
=
prefix
self
.
config
=
config
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
quant_config
=
quant_config
self
.
speculative_config
=
speculative_config
self
.
num_spec
=
(
self
.
speculative_config
.
num_speculative_tokens
if
self
.
speculative_config
else
0
)
# QKV
self
.
conv_dim
=
self
.
key_dim
*
2
+
self
.
value_dim
self
.
conv1d
=
ColumnParallelLinear
(
input_size
=
self
.
conv_kernel_size
,
output_size
=
self
.
conv_dim
,
bias
=
False
,
prefix
=
f
"
{
prefix
}
.conv1d"
,
)
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
self
.
in_proj_qkv
=
MergedColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_sizes
=
[
self
.
key_dim
,
self
.
key_dim
,
self
.
value_dim
],
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_qkv"
,
)
self
.
in_proj_z
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
=
self
.
value_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_z"
,
)
self
.
in_proj_b
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
=
self
.
num_v_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_ba"
,
)
self
.
in_proj_a
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
=
self
.
num_v_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_a"
,
)
query_key_settings
=
(
self
.
key_dim
,
0
,
False
)
value_settings
=
(
self
.
value_dim
,
0
,
False
)
delattr
(
self
.
conv1d
.
weight
,
"weight_loader"
)
set_weight_attrs
(
self
.
conv1d
.
weight
,
{
"weight_loader"
:
mamba_v2_sharded_weight_loader
(
[
query_key_settings
,
query_key_settings
,
value_settings
,
],
self
.
tp_size
,
self
.
tp_rank
,
)
},
)
# selective projection used to make dt, B and C input dependant
# time step projection (discretization)
# instantiate once and copy inv_dt in init_weights of PretrainedModel
self
.
dt_bias
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_v_heads
//
self
.
tp_size
),
)
self
.
A_log
=
nn
.
Parameter
(
torch
.
empty
(
divide
(
self
.
num_v_heads
,
self
.
tp_size
),
)
)
set_weight_attrs
(
self
.
A_log
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
set_weight_attrs
(
self
.
dt_bias
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
self
.
norm
=
RMSNormGated
(
self
.
head_v_dim
,
eps
=
self
.
layer_norm_epsilon
,
group_size
=
None
,
norm_before_gate
=
True
,
device
=
current_platform
.
current_device
(),
dtype
=
config
.
dtype
,
)
self
.
out_proj
=
RowParallelLinear
(
self
.
value_dim
,
self
.
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.out_proj"
,
)
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
def
fix_query_key_value_ordering
(
def
fix_query_key_value_ordering
(
self
,
self
,
mixed_qkv
,
mixed_qkvz
:
torch
.
Tensor
,
z
,
mixed_ba
:
torch
.
Tensor
,
b
,
a
,
):
):
raise
NotImplementedError
(
raise
NotImplementedError
(
"Qwen3.5 Series dont need to fix query key value ordering"
"Qwen3.5 Series dont need to fix query key value ordering"
)
)
def
create_qkvz_proj
(
self
,
hidden_size
:
int
,
key_dim
:
int
,
value_dim
:
int
,
quant_config
:
QuantizationConfig
|
None
,
prefix
:
str
,
)
->
MergedColumnParallelLinear
:
return
MergedColumnParallelLinear
(
input_size
=
hidden_size
,
output_sizes
=
[
key_dim
,
key_dim
,
value_dim
,
value_dim
],
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
prefix
,
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -300,11 +160,13 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
...
@@ -300,11 +160,13 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
# ============================================================
# ============================================================
# Part 1: Input Projection
# Part 1: Input Projection
# ============================================================
# ============================================================
mixed_qkv
,
_
=
self
.
in_proj_qkv
(
hidden_states
)
mixed_qkvz
,
_
=
self
.
in_proj_qkvz
(
hidden_states
)
z
,
_
=
self
.
in_proj_z
(
hidden_states
)
qkv_size
=
(
self
.
key_dim
*
2
+
self
.
value_dim
)
//
self
.
tp_size
z_size
=
self
.
value_dim
//
self
.
tp_size
mixed_qkv
,
z
=
mixed_qkvz
.
split
([
qkv_size
,
z_size
],
dim
=-
1
)
z
=
z
.
reshape
(
z
.
size
(
0
),
-
1
,
self
.
head_v_dim
)
z
=
z
.
reshape
(
z
.
size
(
0
),
-
1
,
self
.
head_v_dim
)
b
,
_
=
self
.
in_proj_b
(
hidden_states
)
b
a
,
_
=
self
.
in_proj_b
a
(
hidden_states
)
a
,
_
=
self
.
in_proj_a
(
hidden_states
)
b
,
a
=
ba
.
chunk
(
2
,
dim
=-
1
)
b
=
b
.
contiguous
()
b
=
b
.
contiguous
()
a
=
a
.
contiguous
()
a
=
a
.
contiguous
()
...
@@ -411,7 +273,6 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
...
@@ -411,7 +273,6 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
1
,
1
,
1
,
1
,
config
.
hidden_size
,
config
.
hidden_size
,
dtype
=
config
.
dtype
,
),
),
)
)
self
.
ffn_layer_scale
=
torch
.
nn
.
Parameter
(
self
.
ffn_layer_scale
=
torch
.
nn
.
Parameter
(
...
@@ -419,7 +280,6 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
...
@@ -419,7 +280,6 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
1
,
1
,
1
,
1
,
config
.
hidden_size
,
config
.
hidden_size
,
dtype
=
config
.
dtype
,
),
),
)
)
...
@@ -503,11 +363,18 @@ class Qwen3_5Model(Qwen3NextModel):
...
@@ -503,11 +363,18 @@ class Qwen3_5Model(Qwen3NextModel):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
# self attention
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
# mlp
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
# GDN
(
"in_proj_qkvz"
,
"in_proj_qkv"
,
(
0
,
1
,
2
)),
(
"in_proj_qkvz"
,
"in_proj_z"
,
3
),
(
"in_proj_ba"
,
"in_proj_b"
,
0
),
(
"in_proj_ba"
,
"in_proj_a"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
...
@@ -528,6 +395,12 @@ class Qwen3_5Model(Qwen3NextModel):
...
@@ -528,6 +395,12 @@ class Qwen3_5Model(Qwen3NextModel):
if
name
.
startswith
(
"mtp."
):
if
name
.
startswith
(
"mtp."
):
continue
continue
# Remapping the name of FP8 kv-scale.
if
name
.
endswith
(
"scale"
):
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
"experts.gate_up_proj"
in
name
or
"experts.down_proj"
in
name
:
if
"experts.gate_up_proj"
in
name
or
"experts.down_proj"
in
name
:
is_fused_expert
=
True
is_fused_expert
=
True
...
@@ -654,6 +527,9 @@ class Qwen3_5ForCausalLMBase(
...
@@ -654,6 +527,9 @@ class Qwen3_5ForCausalLMBase(
"v_proj"
,
"v_proj"
,
],
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
],
# GDN fused projections.
"in_proj_qkvz"
:
[
"in_proj_qkv"
,
"in_proj_z"
],
"in_proj_ba"
:
[
"in_proj_b"
,
"in_proj_a"
],
}
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
@@ -751,7 +627,15 @@ class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLMBase, QwenNextMixtureOfExperts):
...
@@ -751,7 +627,15 @@ class Qwen3_5MoeForCausalLM(Qwen3_5ForCausalLMBase, QwenNextMixtureOfExperts):
dummy_inputs
=
Qwen3VLDummyInputsBuilder
,
dummy_inputs
=
Qwen3VLDummyInputsBuilder
,
)
)
class
Qwen3_5ForConditionalGeneration
(
Qwen3VLForConditionalGeneration
,
IsHybrid
):
class
Qwen3_5ForConditionalGeneration
(
Qwen3VLForConditionalGeneration
,
IsHybrid
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
# Qwen3.5 does not support multimodal pruning (EVS).
supports_multimodal_pruning
=
False
packed_modules_mapping
=
Qwen3VLForConditionalGeneration
.
packed_modules_mapping
|
{
"in_proj_qkvz"
:
[
"in_proj_qkv"
,
"in_proj_z"
],
"in_proj_ba"
:
[
"in_proj_b"
,
"in_proj_a"
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
"model"
):
# protocols have not __init__ method, so we need to use nn.Module.__init__
# protocols have not __init__ method, so we need to use nn.Module.__init__
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__init__
(
self
)
config
:
Qwen3_5Config
=
vllm_config
.
model_config
.
hf_config
config
:
Qwen3_5Config
=
vllm_config
.
model_config
.
hf_config
...
@@ -868,7 +752,9 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
...
@@ -868,7 +752,9 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
vllm_config
:
"VllmConfig"
,
vllm_config
:
"VllmConfig"
,
)
->
tuple
[
torch
.
dtype
,
torch
.
dtype
]:
)
->
tuple
[
torch
.
dtype
,
torch
.
dtype
]:
return
MambaStateDtypeCalculator
.
gated_delta_net_state_dtype
(
return
MambaStateDtypeCalculator
.
gated_delta_net_state_dtype
(
vllm_config
.
model_config
.
dtype
,
vllm_config
.
cache_config
.
mamba_cache_dtype
vllm_config
.
model_config
.
dtype
,
vllm_config
.
cache_config
.
mamba_cache_dtype
,
vllm_config
.
cache_config
.
mamba_ssm_cache_dtype
,
)
)
@
classmethod
@
classmethod
...
@@ -957,7 +843,7 @@ class Qwen3_5_MoeMixtureOfExperts(MixtureOfExperts):
...
@@ -957,7 +843,7 @@ class Qwen3_5_MoeMixtureOfExperts(MixtureOfExperts):
class
Qwen3_5MoeForConditionalGeneration
(
class
Qwen3_5MoeForConditionalGeneration
(
Qwen3_5ForConditionalGeneration
,
Qwen3_5_MoeMixtureOfExperts
Qwen3_5ForConditionalGeneration
,
Qwen3_5_MoeMixtureOfExperts
):
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
"
model
"
):
# protocols have not __init__ method, so we need to use nn.Module.__init__
# protocols have not __init__ method, so we need to use nn.Module.__init__
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__init__
(
self
)
config
:
Qwen3_5MoeConfig
=
vllm_config
.
model_config
.
hf_config
config
:
Qwen3_5MoeConfig
=
vllm_config
.
model_config
.
hf_config
...
...
vllm/model_executor/models/qwen3_5_mtp.py
View file @
f3505904
...
@@ -7,10 +7,6 @@ from collections.abc import Callable, Iterable
...
@@ -7,10 +7,6 @@ from collections.abc import Callable, Iterable
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers.models.qwen3_5.configuration_qwen3_5
import
Qwen3_5TextConfig
from
transformers.models.qwen3_5_moe.configuration_qwen3_5_moe
import
(
Qwen3_5MoeTextConfig
,
)
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
...
@@ -27,6 +23,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -27,6 +23,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.models.qwen3_5
import
Qwen3_5DecoderLayer
,
Qwen3_5RMSNorm
from
vllm.model_executor.models.qwen3_5
import
Qwen3_5DecoderLayer
,
Qwen3_5RMSNorm
from
vllm.model_executor.models.qwen3_next
import
QwenNextMixtureOfExperts
from
vllm.model_executor.models.qwen3_next
import
QwenNextMixtureOfExperts
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs.qwen3_5
import
Qwen3_5TextConfig
from
vllm.transformers_utils.configs.qwen3_5_moe
import
Qwen3_5MoeTextConfig
from
.interfaces
import
(
from
.interfaces
import
(
MultiModalEmbeddings
,
MultiModalEmbeddings
,
...
...
vllm/model_executor/models/qwen3_next.py
View file @
f3505904
...
@@ -40,6 +40,7 @@ from vllm.model_executor.layers.layernorm import (
...
@@ -40,6 +40,7 @@ from vllm.model_executor.layers.layernorm import (
from
vllm.model_executor.layers.layernorm
import
RMSNormGated
from
vllm.model_executor.layers.layernorm
import
RMSNormGated
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
ColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
,
RowParallelLinear
,
...
@@ -292,19 +293,19 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
...
@@ -292,19 +293,19 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
# projection of the input hidden states
# projection of the input hidden states
self
.
projection_size_qkvz
=
self
.
key_dim
*
2
+
self
.
value_dim
*
2
# Qwen3-Next and Qwen3.5 has a different qkv_proj layout,
self
.
projection_size_ba
=
self
.
num_v_heads
*
2
# we need to create qkvz_proj adaptively here.
self
.
in_proj_qkvz
=
ColumnParallelLinear
(
self
.
in_proj_qkvz
=
self
.
create_qkvz_proj
(
input
_size
=
self
.
hidden_size
,
hidden
_size
=
self
.
hidden_size
,
output_size
=
self
.
projection_size_qkvz
,
key_dim
=
self
.
key_dim
,
bias
=
False
,
value_dim
=
self
.
value_dim
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_qkvz"
,
prefix
=
f
"
{
prefix
}
.in_proj_qkvz"
,
)
)
# ba_proj doesn't support blockwise fp8 quantization.
# ba_proj doesn't support blockwise fp8 quantization.
self
.
in_proj_ba
=
ColumnParallelLinear
(
self
.
in_proj_ba
=
Merged
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
input_size
=
self
.
hidden_size
,
output_size
=
self
.
projection_size_ba
,
output_size
s
=
[
self
.
num_v_heads
]
*
2
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_ba"
,
prefix
=
f
"
{
prefix
}
.in_proj_ba"
,
...
@@ -351,7 +352,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
...
@@ -351,7 +352,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
group_size
=
None
,
group_size
=
None
,
norm_before_gate
=
True
,
norm_before_gate
=
True
,
device
=
current_platform
.
current_device
(),
device
=
current_platform
.
current_device
(),
dtype
=
config
.
dtype
,
)
)
self
.
out_proj
=
RowParallelLinear
(
self
.
out_proj
=
RowParallelLinear
(
...
@@ -368,10 +368,26 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
...
@@ -368,10 +368,26 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
compilation_config
.
static_forward_context
[
prefix
]
=
self
def
create_qkvz_proj
(
self
,
hidden_size
:
int
,
key_dim
:
int
,
value_dim
:
int
,
quant_config
:
QuantizationConfig
|
None
,
prefix
:
str
,
)
->
MergedColumnParallelLinear
:
return
MergedColumnParallelLinear
(
input_size
=
hidden_size
,
output_sizes
=
[
sum
((
key_dim
,
key_dim
,
value_dim
)),
value_dim
],
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_qkvz"
,
)
def
fix_query_key_value_ordering
(
def
fix_query_key_value_ordering
(
self
,
self
,
mixed_qkvz
,
mixed_qkvz
:
torch
.
Tensor
,
mixed_ba
,
mixed_ba
:
torch
.
Tensor
,
):
):
"""
"""
Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
...
@@ -893,7 +909,6 @@ class Qwen3NextDecoderLayer(nn.Module):
...
@@ -893,7 +909,6 @@ class Qwen3NextDecoderLayer(nn.Module):
1
,
1
,
1
,
1
,
config
.
hidden_size
,
config
.
hidden_size
,
dtype
=
config
.
dtype
,
),
),
)
)
self
.
ffn_layer_scale
=
torch
.
nn
.
Parameter
(
self
.
ffn_layer_scale
=
torch
.
nn
.
Parameter
(
...
@@ -901,7 +916,6 @@ class Qwen3NextDecoderLayer(nn.Module):
...
@@ -901,7 +916,6 @@ class Qwen3NextDecoderLayer(nn.Module):
1
,
1
,
1
,
1
,
config
.
hidden_size
,
config
.
hidden_size
,
dtype
=
config
.
dtype
,
),
),
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment