Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1d65283e
Unverified
Commit
1d65283e
authored
Feb 17, 2026
by
Jiangyun Zhu
Committed by
GitHub
Feb 17, 2026
Browse files
Revert "[Models] Fuse Qwen3.5 GDN's qkvz_proj and ba_proj" (#34683)
parent
c464b573
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
182 additions
and
87 deletions
+182
-87
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+6
-28
vllm/model_executor/models/qwen3_5.py
vllm/model_executor/models/qwen3_5.py
+166
-32
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+10
-27
No files found.
vllm/model_executor/layers/linear.py
View file @
1d65283e
...
...
@@ -685,13 +685,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
tuple
[
int
,
...]
|
int
|
None
=
None
,
loaded_shard_id
:
int
|
None
=
None
,
):
if
isinstance
(
loaded_shard_id
,
tuple
):
raise
NotImplementedError
(
"Shard id with multiple indices is not supported in weight_loader, "
"please use weight_loader_v2 instead."
)
# Special case for GGUF
# initialize GGUF param after we know the quantize type
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
...
...
@@ -830,10 +825,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data
.
copy_
(
loaded_weight
)
def
_load_fused_module_from_checkpoint
(
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
,
output_sizes
:
list
[
int
]
|
None
=
None
,
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
):
"""
Handle special case for models where MLP layers are already
...
...
@@ -847,8 +839,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
current_shard_offset
=
0
shard_offsets
:
list
[
tuple
[
int
,
int
,
int
]]
=
[]
output_sizes
=
output_sizes
or
self
.
output_sizes
for
i
,
output_size
in
enumerate
(
output_sizes
):
for
i
,
output_size
in
enumerate
(
self
.
output_sizes
):
shard_offsets
.
append
((
i
,
current_shard_offset
,
output_size
))
current_shard_offset
+=
output_size
...
...
@@ -873,30 +864,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
tuple
[
int
,
...]
|
int
|
None
=
None
,
loaded_shard_id
:
int
|
None
=
None
,
):
if
loaded_shard_id
is
None
or
isinstance
(
loaded_shard_id
,
tuple
)
:
if
loaded_shard_id
is
None
:
if
isinstance
(
param
,
PerTensorScaleParameter
):
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
0
)
return
elif
type
(
param
)
in
(
RowvLLMParameter
,
BasevLLMParameter
):
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
)
return
output_sizes
=
(
[
self
.
output_sizes
[
idx
]
for
idx
in
loaded_shard_id
]
if
loaded_shard_id
else
None
)
if
isinstance
(
param
,
BlockQuantScaleParameter
):
weight_block_size
=
getattr
(
self
,
"weight_block_size"
,
None
)
output_sizes
=
[
adjust_block_scale_shard
(
weight_block_size
,
size
,
0
)[
0
]
for
size
in
(
output_sizes
or
self
.
output_sizes
)
]
# TODO: @dsikka - move to parameter.py
self
.
_load_fused_module_from_checkpoint
(
param
,
loaded_weight
,
output_sizes
=
output_sizes
)
self
.
_load_fused_module_from_checkpoint
(
param
,
loaded_weight
)
return
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
...
...
vllm/model_executor/models/qwen3_5.py
View file @
1d65283e
...
...
@@ -30,20 +30,36 @@ from collections.abc import Callable, Iterable
import
torch
from
einops
import
rearrange
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
SpeculativeConfig
,
VllmConfig
,
get_current_vllm_config
,
)
from
vllm.distributed
import
(
divide
,
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
(
GemmaRMSNorm
as
Qwen3_5RMSNorm
,
)
from
vllm.model_executor.layers.linear
import
MergedColumnParallelLinear
from
vllm.model_executor.layers.layernorm
import
RMSNormGated
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
(
mamba_v2_sharded_weight_loader
,
)
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateCopyFunc
,
MambaStateCopyFuncCalculator
,
...
...
@@ -57,8 +73,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
sharded_weight_loader
,
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs.qwen3_5
import
(
Qwen3_5Config
,
...
...
@@ -80,6 +99,7 @@ from .interfaces import (
)
from
.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3NextMLP
from
.qwen3_next
import
(
ChunkGatedDeltaRule
,
Qwen3NextAttention
,
Qwen3NextDecoderLayer
,
Qwen3NextGatedDeltaNet
,
...
...
@@ -119,29 +139,152 @@ class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo):
class
Qwen3_5GatedDeltaNet
(
Qwen3NextGatedDeltaNet
):
def
fix_query_key_value_ordering
(
def
__init__
(
self
,
mixed_qkvz
:
torch
.
Tensor
,
mixed_ba
:
torch
.
Tensor
,
):
raise
NotImplementedError
(
"Qwen3.5 Series dont need to fix query key value ordering"
config
:
Qwen3_5TextConfig
|
Qwen3_5MoeTextConfig
,
model_config
:
ModelConfig
|
None
=
None
,
cache_config
:
CacheConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
speculative_config
:
SpeculativeConfig
|
None
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
(
Qwen3NextGatedDeltaNet
,
self
).
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
hidden_size
=
config
.
hidden_size
self
.
num_v_heads
=
config
.
linear_num_value_heads
self
.
num_k_heads
=
config
.
linear_num_key_heads
self
.
head_k_dim
=
config
.
linear_key_head_dim
self
.
head_v_dim
=
config
.
linear_value_head_dim
self
.
key_dim
=
self
.
head_k_dim
*
self
.
num_k_heads
self
.
value_dim
=
self
.
head_v_dim
*
self
.
num_v_heads
self
.
conv_kernel_size
=
config
.
linear_conv_kernel_dim
self
.
layer_idx
=
extract_layer_index
(
prefix
)
self
.
activation
=
config
.
hidden_act
self
.
act
=
ACT2FN
[
config
.
hidden_act
]
self
.
layer_norm_epsilon
=
config
.
rms_norm_eps
self
.
prefix
=
prefix
self
.
config
=
config
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
quant_config
=
quant_config
self
.
speculative_config
=
speculative_config
self
.
num_spec
=
(
self
.
speculative_config
.
num_speculative_tokens
if
self
.
speculative_config
else
0
)
def
create_qkvz_proj
(
self
,
hidden_size
:
int
,
key_dim
:
int
,
value_dim
:
int
,
quant_config
:
QuantizationConfig
|
None
,
prefix
:
str
,
)
->
MergedColumnParallelLinear
:
return
MergedColumnParallelLinear
(
input_size
=
hidden_size
,
output_sizes
=
[
key_dim
,
key_dim
,
value_dim
,
value_dim
],
# QKV
self
.
conv_dim
=
self
.
key_dim
*
2
+
self
.
value_dim
self
.
conv1d
=
ColumnParallelLinear
(
input_size
=
self
.
conv_kernel_size
,
output_size
=
self
.
conv_dim
,
bias
=
False
,
prefix
=
f
"
{
prefix
}
.conv1d"
,
)
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
self
.
in_proj_qkv
=
MergedColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_sizes
=
[
self
.
key_dim
,
self
.
key_dim
,
self
.
value_dim
],
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_qkv"
,
)
self
.
in_proj_z
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
=
self
.
value_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_z"
,
)
self
.
in_proj_b
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
=
self
.
num_v_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_b"
,
)
self
.
in_proj_a
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
=
self
.
num_v_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
prefix
,
prefix
=
f
"
{
prefix
}
.in_proj_a"
,
)
query_key_settings
=
(
self
.
key_dim
,
0
,
False
)
value_settings
=
(
self
.
value_dim
,
0
,
False
)
delattr
(
self
.
conv1d
.
weight
,
"weight_loader"
)
set_weight_attrs
(
self
.
conv1d
.
weight
,
{
"weight_loader"
:
mamba_v2_sharded_weight_loader
(
[
query_key_settings
,
query_key_settings
,
value_settings
,
],
self
.
tp_size
,
self
.
tp_rank
,
)
},
)
# selective projection used to make dt, B and C input dependant
# time step projection (discretization)
# instantiate once and copy inv_dt in init_weights of PretrainedModel
self
.
dt_bias
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_v_heads
//
self
.
tp_size
),
)
self
.
A_log
=
nn
.
Parameter
(
torch
.
empty
(
divide
(
self
.
num_v_heads
,
self
.
tp_size
),
)
)
set_weight_attrs
(
self
.
A_log
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
set_weight_attrs
(
self
.
dt_bias
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
self
.
norm
=
RMSNormGated
(
self
.
head_v_dim
,
eps
=
self
.
layer_norm_epsilon
,
group_size
=
None
,
norm_before_gate
=
True
,
device
=
current_platform
.
current_device
(),
dtype
=
config
.
dtype
,
)
self
.
out_proj
=
RowParallelLinear
(
self
.
value_dim
,
self
.
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.out_proj"
,
)
self
.
chunk_gated_delta_rule
=
ChunkGatedDeltaRule
()
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
def
fix_query_key_value_ordering
(
self
,
mixed_qkv
,
z
,
b
,
a
,
):
raise
NotImplementedError
(
"Qwen3.5 Series dont need to fix query key value ordering"
)
def
forward
(
...
...
@@ -160,13 +303,11 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
# ============================================================
# Part 1: Input Projection
# ============================================================
mixed_qkvz
,
_
=
self
.
in_proj_qkvz
(
hidden_states
)
qkv_size
=
(
self
.
key_dim
*
2
+
self
.
value_dim
)
//
self
.
tp_size
z_size
=
self
.
value_dim
//
self
.
tp_size
mixed_qkv
,
z
=
mixed_qkvz
.
split
([
qkv_size
,
z_size
],
dim
=-
1
)
mixed_qkv
,
_
=
self
.
in_proj_qkv
(
hidden_states
)
z
,
_
=
self
.
in_proj_z
(
hidden_states
)
z
=
z
.
reshape
(
z
.
size
(
0
),
-
1
,
self
.
head_v_dim
)
b
a
,
_
=
self
.
in_proj_b
a
(
hidden_states
)
b
,
a
=
ba
.
chunk
(
2
,
dim
=-
1
)
b
,
_
=
self
.
in_proj_b
(
hidden_states
)
a
,
_
=
self
.
in_proj_a
(
hidden_states
)
b
=
b
.
contiguous
()
a
=
a
.
contiguous
()
...
...
@@ -365,18 +506,11 @@ class Qwen3_5Model(Qwen3NextModel):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# self attention
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
# mlp
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
# GDN
(
"in_proj_qkvz"
,
"in_proj_qkv"
,
(
0
,
1
,
2
)),
(
"in_proj_qkvz"
,
"in_proj_z"
,
3
),
(
"in_proj_ba"
,
"in_proj_b"
,
0
),
(
"in_proj_ba"
,
"in_proj_a"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
...
...
vllm/model_executor/models/qwen3_next.py
View file @
1d65283e
...
...
@@ -44,7 +44,6 @@ from vllm.model_executor.layers.layernorm import (
from
vllm.model_executor.layers.layernorm
import
RMSNormGated
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
...
...
@@ -407,19 +406,19 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
# projection of the input hidden states
# Qwen3-Next and Qwen3.5 has a different qkv_proj layout,
# we need to create qkvz_proj adaptively here.
self
.
in_proj_qkvz
=
self
.
create_qkvz_proj
(
hidden
_size
=
self
.
hidden_size
,
key_dim
=
self
.
key_dim
,
value_dim
=
self
.
value_dim
,
self
.
projection_size_qkvz
=
self
.
key_dim
*
2
+
self
.
value_dim
*
2
self
.
projection_size_ba
=
self
.
num_v_heads
*
2
self
.
in_proj_qkvz
=
ColumnParallelLinear
(
input
_size
=
self
.
hidden_size
,
output_size
=
self
.
projection_size_qkvz
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_qkvz"
,
)
# ba_proj doesn't support blockwise fp8 quantization.
self
.
in_proj_ba
=
Merged
ColumnParallelLinear
(
self
.
in_proj_ba
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
s
=
[
self
.
num_v_heads
]
*
2
,
output_size
=
self
.
projection_size_ba
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_ba"
,
...
...
@@ -485,26 +484,10 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
def
create_qkvz_proj
(
self
,
hidden_size
:
int
,
key_dim
:
int
,
value_dim
:
int
,
quant_config
:
QuantizationConfig
|
None
,
prefix
:
str
,
)
->
MergedColumnParallelLinear
:
return
MergedColumnParallelLinear
(
input_size
=
hidden_size
,
output_sizes
=
[
sum
((
key_dim
,
key_dim
,
value_dim
)),
value_dim
],
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_qkvz"
,
)
def
fix_query_key_value_ordering
(
self
,
mixed_qkvz
:
torch
.
Tensor
,
mixed_ba
:
torch
.
Tensor
,
mixed_qkvz
,
mixed_ba
,
):
"""
Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment