Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8fbe3f30
Unverified
Commit
8fbe3f30
authored
Mar 20, 2026
by
Jee Jee Li
Committed by
GitHub
Mar 20, 2026
Browse files
[Bugfix][LoRA] Fix Qwen35 LoRA (#36976)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
ea2c148f
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
257 additions
and
46 deletions
+257
-46
.buildkite/test_areas/lora.yaml
.buildkite/test_areas/lora.yaml
+3
-2
tests/lora/conftest.py
tests/lora/conftest.py
+5
-0
tests/lora/test_qwen35_densemoel_lora.py
tests/lora/test_qwen35_densemoel_lora.py
+132
-0
vllm/model_executor/models/qwen3_5.py
vllm/model_executor/models/qwen3_5.py
+100
-23
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+17
-21
No files found.
.buildkite/test_areas/lora.yaml
View file @
8fbe3f30
...
@@ -8,7 +8,7 @@ steps:
...
@@ -8,7 +8,7 @@ steps:
-
vllm/lora
-
vllm/lora
-
tests/lora
-
tests/lora
commands
:
commands
:
-
pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_llm_with_multi_loras.py --ignore=lora/test_olmoe_tp.py --ignore=lora/test_deepseekv2_tp.py --ignore=lora/test_gptoss_tp.py --ignore=lora/test_qwen3moe_tp.py
-
pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_llm_with_multi_loras.py --ignore=lora/test_olmoe_tp.py --ignore=lora/test_deepseekv2_tp.py --ignore=lora/test_gptoss_tp.py --ignore=lora/test_qwen3moe_tp.py
--ignore=lora/test_qwen35_densemoel_lora.py
parallelism
:
4
parallelism
:
4
...
@@ -31,3 +31,4 @@ steps:
...
@@ -31,3 +31,4 @@ steps:
-
pytest -v -s -x lora/test_llm_with_multi_loras.py
-
pytest -v -s -x lora/test_llm_with_multi_loras.py
-
pytest -v -s -x lora/test_olmoe_tp.py
-
pytest -v -s -x lora/test_olmoe_tp.py
-
pytest -v -s -x lora/test_gptoss_tp.py
-
pytest -v -s -x lora/test_gptoss_tp.py
-
pytest -v -s -x lora/test_qwen35_densemoel_lora.py
\ No newline at end of file
tests/lora/conftest.py
View file @
8fbe3f30
...
@@ -294,6 +294,11 @@ def whisper_lora_files():
...
@@ -294,6 +294,11 @@ def whisper_lora_files():
return
snapshot_download
(
repo_id
=
"chengyili2005/whisper-small-mandarin-lora"
)
return
snapshot_download
(
repo_id
=
"chengyili2005/whisper-small-mandarin-lora"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
qwen35_dense_model_lora_files
():
return
snapshot_download
(
repo_id
=
"jeeejeee/qwen35-4b-text-only-sql-lora"
)
@
pytest
.
fixture
@
pytest
.
fixture
def
reset_default_device
():
def
reset_default_device
():
"""
"""
...
...
tests/lora/test_qwen35_densemoel_lora.py
0 → 100644
View file @
8fbe3f30
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
transformers
import
AutoTokenizer
import
vllm
import
vllm.config
from
vllm.lora.request
import
LoRARequest
from
..utils
import
create_new_process_for_each_test
,
multi_gpu_test
MODEL_PATH
=
"Qwen/Qwen3.5-4B"
PROMPT_TEMPLATE
=
"""Write a SQL query for the given database.
\n
Schema:
\n
Tables:
\n
- stadium(Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average)
\n
- singer(Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male)
\n
- concert(concert_ID, concert_Name, Theme, Stadium_ID, Year)
\n
- singer_in_concert(concert_ID, Singer_ID)
\n\n
Question:
\n
{query}"""
# noqa: E501
EXPECTED_LORA_OUTPUT
=
[
"SELECT count(*) FROM singer"
,
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'"
,
"SELECT name FROM stadium WHERE stadium_id NOT IN (SELECT stadium_id FROM concert)"
,
]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MODEL_PATH
,
trust_remote_code
=
True
)
def
do_sample
(
llm
:
vllm
.
LLM
,
lora_path
:
str
,
lora_id
:
int
)
->
list
[
str
]:
prompts
=
[
PROMPT_TEMPLATE
.
format
(
query
=
"How many singers do we have?"
),
PROMPT_TEMPLATE
.
format
(
query
=
(
"What is the average, minimum, and maximum "
"age of all singers from France?"
)
),
PROMPT_TEMPLATE
.
format
(
query
=
(
"What are the names of the stadiums without any concerts?"
)
),
]
input_templates
=
[]
for
prmpt
in
prompts
:
messages
=
[{
"role"
:
"user"
,
"content"
:
prmpt
}]
prompt
=
tokenizer
.
apply_chat_template
(
messages
,
tokenize
=
False
,
add_generation_prompt
=
True
,
enable_thinking
=
False
,
# disable thinking
)
input_templates
.
append
(
prompt
)
sampling_params
=
vllm
.
SamplingParams
(
temperature
=
0
,
max_tokens
=
512
)
outputs
=
llm
.
generate
(
input_templates
,
sampling_params
,
lora_request
=
LoRARequest
(
str
(
lora_id
),
lora_id
,
lora_path
)
if
lora_id
else
None
,
)
generated_texts
:
list
[
str
]
=
[]
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
.
strip
()
generated_texts
.
append
(
generated_text
)
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
return
generated_texts
@
create_new_process_for_each_test
()
def
test_qwen35_dense_model_lora
(
qwen35_dense_model_lora_files
):
llm
=
vllm
.
LLM
(
MODEL_PATH
,
max_model_len
=
512
,
enable_lora
=
True
,
max_loras
=
2
,
max_num_seqs
=
16
,
max_lora_rank
=
8
,
trust_remote_code
=
True
,
)
output1
=
do_sample
(
llm
,
qwen35_dense_model_lora_files
,
lora_id
=
1
)
for
i
in
range
(
len
(
EXPECTED_LORA_OUTPUT
)):
assert
output1
[
i
]
==
EXPECTED_LORA_OUTPUT
[
i
]
output2
=
do_sample
(
llm
,
qwen35_dense_model_lora_files
,
lora_id
=
2
)
for
i
in
range
(
len
(
EXPECTED_LORA_OUTPUT
)):
assert
output2
[
i
]
==
EXPECTED_LORA_OUTPUT
[
i
]
@
multi_gpu_test
(
num_gpus
=
4
)
def
test_qwen35_dense_model_lora_tp4
(
qwen35_dense_model_lora_files
):
llm
=
vllm
.
LLM
(
MODEL_PATH
,
max_model_len
=
1024
,
enable_lora
=
True
,
max_loras
=
2
,
max_lora_rank
=
8
,
max_num_seqs
=
16
,
tensor_parallel_size
=
4
,
trust_remote_code
=
True
,
fully_sharded_loras
=
False
,
compilation_config
=
vllm
.
config
.
CompilationConfig
(
# Avoid OOM
cudagraph_specialize_lora
=
False
,
),
)
output1
=
do_sample
(
llm
,
qwen35_dense_model_lora_files
,
lora_id
=
1
)
print
(
output1
)
for
i
in
range
(
len
(
EXPECTED_LORA_OUTPUT
)):
assert
output1
[
i
]
==
EXPECTED_LORA_OUTPUT
[
i
]
output2
=
do_sample
(
llm
,
qwen35_dense_model_lora_files
,
lora_id
=
2
)
for
i
in
range
(
len
(
EXPECTED_LORA_OUTPUT
)):
assert
output2
[
i
]
==
EXPECTED_LORA_OUTPUT
[
i
]
@
multi_gpu_test
(
num_gpus
=
4
)
def
test_qwen35_dense_model_lora_tp4_fully_sharded_loras
(
qwen35_dense_model_lora_files
):
llm
=
vllm
.
LLM
(
MODEL_PATH
,
max_model_len
=
512
,
enable_lora
=
True
,
max_loras
=
2
,
max_lora_rank
=
8
,
tensor_parallel_size
=
4
,
trust_remote_code
=
True
,
fully_sharded_loras
=
True
,
gpu_memory_utilization
=
0.8
,
compilation_config
=
vllm
.
config
.
CompilationConfig
(
# Avoid OOM
cudagraph_specialize_lora
=
False
,
),
)
output1
=
do_sample
(
llm
,
qwen35_dense_model_lora_files
,
lora_id
=
1
)
for
i
in
range
(
len
(
EXPECTED_LORA_OUTPUT
)):
assert
output1
[
i
]
==
EXPECTED_LORA_OUTPUT
[
i
]
output2
=
do_sample
(
llm
,
qwen35_dense_model_lora_files
,
lora_id
=
2
)
for
i
in
range
(
len
(
EXPECTED_LORA_OUTPUT
)):
assert
output2
[
i
]
==
EXPECTED_LORA_OUTPUT
[
i
]
vllm/model_executor/models/qwen3_5.py
View file @
8fbe3f30
...
@@ -32,9 +32,7 @@ from einops import rearrange
...
@@ -32,9 +32,7 @@ from einops import rearrange
from
torch
import
nn
from
torch
import
nn
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
from
vllm.config
import
VllmConfig
VllmConfig
,
)
from
vllm.distributed
import
(
from
vllm.distributed
import
(
get_pp_group
,
get_pp_group
,
)
)
...
@@ -42,7 +40,10 @@ from vllm.logger import init_logger
...
@@ -42,7 +40,10 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.layernorm
import
(
from
vllm.model_executor.layers.layernorm
import
(
GemmaRMSNorm
as
Qwen3_5RMSNorm
,
GemmaRMSNorm
as
Qwen3_5RMSNorm
,
)
)
from
vllm.model_executor.layers.linear
import
MergedColumnParallelLinear
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateCopyFunc
,
MambaStateCopyFunc
,
...
@@ -130,6 +131,40 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
...
@@ -130,6 +131,40 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
"Qwen3.5 Series dont need to fix query key value ordering"
"Qwen3.5 Series dont need to fix query key value ordering"
)
)
def
__init__
(
self
,
config
:
Qwen3_5Config
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
create_in_proj_qkvz
=
vllm_config
.
lora_config
is
None
super
().
__init__
(
config
,
vllm_config
=
vllm_config
,
prefix
=
prefix
,
create_in_proj_qkvz
=
create_in_proj_qkvz
,
)
if
vllm_config
.
lora_config
is
not
None
:
# Separate in_proj_qkv (Q,K,V) and in_proj_z for LoRA compatibility.
# Use MergedColumnParallelLinear for in_proj_qkv because GDN can have
# linear_num_key_heads != linear_num_value_heads (e.g. 16 vs 32), so
# output sizes [key_dim, key_dim, value_dim] are not representable
# with a single QKVParallelLinear (which ties K and V head counts).
self
.
in_proj_qkv
=
MergedColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_sizes
=
[
self
.
key_dim
,
self
.
key_dim
,
self
.
value_dim
],
bias
=
False
,
quant_config
=
vllm_config
.
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_qkv"
,
)
self
.
in_proj_z
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
=
self
.
value_dim
,
bias
=
False
,
quant_config
=
vllm_config
.
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_z"
,
)
def
create_qkvz_proj
(
def
create_qkvz_proj
(
self
,
self
,
hidden_size
:
int
,
hidden_size
:
int
,
...
@@ -180,6 +215,12 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
...
@@ -180,6 +215,12 @@ class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
# ============================================================
# ============================================================
# Part 1: Input Projection
# Part 1: Input Projection
# ============================================================
# ============================================================
if
hasattr
(
self
,
"in_proj_qkv"
):
# LoRA path: separate in_proj_qkv and in_proj_z
mixed_qkv
,
_
=
self
.
in_proj_qkv
(
hidden_states
)
ba
,
_
=
self
.
in_proj_ba
(
hidden_states
)
z
,
_
=
self
.
in_proj_z
(
hidden_states
)
else
:
mixed_qkvz
,
ba
=
torch
.
ops
.
vllm
.
gdn_in_proj
(
mixed_qkvz
,
ba
=
torch
.
ops
.
vllm
.
gdn_in_proj
(
hidden_states
,
hidden_states
,
sum
(
self
.
in_proj_qkvz
.
output_sizes
)
//
self
.
tp_size
,
sum
(
self
.
in_proj_qkvz
.
output_sizes
)
//
self
.
tp_size
,
...
@@ -240,18 +281,14 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
...
@@ -240,18 +281,14 @@ class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
model_config
=
vllm_config
.
model_config
model_config
=
vllm_config
.
model_config
cache_config
=
vllm_config
.
cache_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
speculative_config
=
vllm_config
.
speculative_config
self
.
layer_type
=
layer_type
self
.
layer_type
=
layer_type
self
.
layer_idx
=
extract_layer_index
(
prefix
)
self
.
layer_idx
=
extract_layer_index
(
prefix
)
if
self
.
layer_type
==
"linear_attention"
:
if
self
.
layer_type
==
"linear_attention"
:
self
.
linear_attn
=
Qwen3_5GatedDeltaNet
(
self
.
linear_attn
=
Qwen3_5GatedDeltaNet
(
config
,
config
=
config
,
model_config
=
model_config
,
vllm_config
=
vllm_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
speculative_config
=
speculative_config
,
prefix
=
f
"
{
prefix
}
.linear_attn"
,
prefix
=
f
"
{
prefix
}
.linear_attn"
,
)
)
elif
self
.
layer_type
==
"full_attention"
:
elif
self
.
layer_type
==
"full_attention"
:
...
@@ -331,6 +368,7 @@ class Qwen3_5Model(Qwen3NextModel):
...
@@ -331,6 +368,7 @@ class Qwen3_5Model(Qwen3NextModel):
self
.
num_redundant_experts
=
eplb_config
.
num_redundant_experts
self
.
num_redundant_experts
=
eplb_config
.
num_redundant_experts
self
.
config
=
config
self
.
config
=
config
self
.
enable_lora
=
vllm_config
.
lora_config
is
not
None
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
...
@@ -396,13 +434,25 @@ class Qwen3_5Model(Qwen3NextModel):
...
@@ -396,13 +434,25 @@ class Qwen3_5Model(Qwen3NextModel):
# mlp
# mlp
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
# GDN
(
"in_proj_qkvz"
,
"in_proj_qkv"
,
(
0
,
1
,
2
)),
(
"in_proj_qkvz"
,
"in_proj_z"
,
3
),
(
"in_proj_ba"
,
"in_proj_b"
,
0
),
(
"in_proj_ba"
,
"in_proj_b"
,
0
),
(
"in_proj_ba"
,
"in_proj_a"
,
1
),
(
"in_proj_ba"
,
"in_proj_a"
,
1
),
]
]
if
self
.
enable_lora
:
stacked_params_mapping
.
extend
(
[
(
"in_proj_qkv"
,
"in_proj_qkv"
,
(
0
,
1
,
2
)),
(
"in_proj_z"
,
"in_proj_z"
,
0
),
]
)
else
:
stacked_params_mapping
.
extend
(
[
(
"in_proj_qkvz"
,
"in_proj_qkv"
,
(
0
,
1
,
2
)),
(
"in_proj_qkvz"
,
"in_proj_z"
,
3
),
]
)
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
loaded_params
:
set
[
str
]
=
set
()
expert_params_mapping
=
self
.
get_expert_mapping
()
expert_params_mapping
=
self
.
get_expert_mapping
()
...
@@ -450,6 +500,9 @@ class Qwen3_5Model(Qwen3NextModel):
...
@@ -450,6 +500,9 @@ class Qwen3_5Model(Qwen3NextModel):
continue
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
if
param_name
==
"in_proj_z"
and
self
.
enable_lora
:
weight_loader
(
param
,
loaded_weight
)
else
:
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
...
@@ -580,6 +633,15 @@ class Qwen3_5ForCausalLMBase(
...
@@ -580,6 +633,15 @@ class Qwen3_5ForCausalLMBase(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
)
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
)
)
)
# When LoRA is enabled, GDN uses separate in_proj_qkv and in_proj_z
# instead of merged in_proj_qkvz; pack mapping must match.
if
vllm_config
.
lora_config
:
base
=
getattr
(
Qwen3_5ForCausalLMBase
,
"packed_modules_mapping"
,
{})
self
.
packed_modules_mapping
=
{
k
:
list
(
v
)
for
k
,
v
in
base
.
items
()}
self
.
packed_modules_mapping
.
pop
(
"in_proj_qkvz"
,
None
)
self
.
packed_modules_mapping
[
"in_proj_qkv"
]
=
[
"in_proj_qkv"
]
self
.
packed_modules_mapping
[
"in_proj_z"
]
=
[
"in_proj_z"
]
if
get_pp_group
().
is_last_rank
:
if
get_pp_group
().
is_last_rank
:
if
config
.
tie_word_embeddings
:
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
self
.
lm_head
=
self
.
model
.
embed_tokens
...
@@ -672,6 +734,7 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
...
@@ -672,6 +734,7 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
"model"
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
"model"
):
# protocols have not __init__ method, so we need to use nn.Module.__init__
# protocols have not __init__ method, so we need to use nn.Module.__init__
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__init__
(
self
)
self
.
update_packed_mapping
(
enable_lora
=
vllm_config
.
lora_config
is
not
None
)
config
:
Qwen3_5Config
=
vllm_config
.
model_config
.
hf_config
config
:
Qwen3_5Config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
...
@@ -699,6 +762,16 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
...
@@ -699,6 +762,16 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
self
.
language_model
.
make_empty_intermediate_tensors
self
.
language_model
.
make_empty_intermediate_tensors
)
)
def
update_packed_mapping
(
self
,
enable_lora
:
bool
):
# When LoRA is enabled, GDN uses separate in_proj_qkv and in_proj_z
if
enable_lora
:
base
=
getattr
(
Qwen3_5ForConditionalGeneration
,
"packed_modules_mapping"
,
{}
)
self
.
packed_modules_mapping
=
{
k
:
list
(
v
)
for
k
,
v
in
base
.
items
()}
self
.
packed_modules_mapping
.
pop
(
"in_proj_qkvz"
,
None
)
self
.
packed_modules_mapping
[
"in_proj_qkv"
]
=
[
"in_proj_qkv"
]
def
embed_input_ids
(
def
embed_input_ids
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -879,9 +952,13 @@ class Qwen3_5_MoeMixtureOfExperts(MixtureOfExperts):
...
@@ -879,9 +952,13 @@ class Qwen3_5_MoeMixtureOfExperts(MixtureOfExperts):
class
Qwen3_5MoeForConditionalGeneration
(
class
Qwen3_5MoeForConditionalGeneration
(
Qwen3_5ForConditionalGeneration
,
Qwen3_5_MoeMixtureOfExperts
Qwen3_5ForConditionalGeneration
,
Qwen3_5_MoeMixtureOfExperts
):
):
# For MoE LoRA weights loading
is_3d_moe_weight
:
bool
=
True
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
"model"
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
"model"
):
# protocols have not __init__ method, so we need to use nn.Module.__init__
# protocols have not __init__ method, so we need to use nn.Module.__init__
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__init__
(
self
)
self
.
update_packed_mapping
(
enable_lora
=
vllm_config
.
lora_config
is
not
None
)
config
:
Qwen3_5MoeConfig
=
vllm_config
.
model_config
.
hf_config
config
:
Qwen3_5MoeConfig
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
...
...
vllm/model_executor/models/qwen3_next.py
View file @
8fbe3f30
...
@@ -15,7 +15,6 @@ from vllm.compilation.decorators import support_torch_compile
...
@@ -15,7 +15,6 @@ from vllm.compilation.decorators import support_torch_compile
from
vllm.config
import
(
from
vllm.config
import
(
CacheConfig
,
CacheConfig
,
ModelConfig
,
ModelConfig
,
SpeculativeConfig
,
VllmConfig
,
VllmConfig
,
get_current_vllm_config
,
get_current_vllm_config
,
)
)
...
@@ -401,11 +400,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
...
@@ -401,11 +400,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
Qwen3NextConfig
,
config
:
Qwen3NextConfig
,
model_config
:
ModelConfig
|
None
=
None
,
vllm_config
:
VllmConfig
,
cache_config
:
CacheConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
speculative_config
:
SpeculativeConfig
|
None
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
create_in_proj_qkvz
:
bool
=
True
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
...
@@ -432,10 +429,10 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
...
@@ -432,10 +429,10 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
)
)
self
.
config
=
config
self
.
config
=
config
self
.
model_config
=
model_config
self
.
model_config
=
vllm_config
.
model_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
quant_config
=
quant_config
quant_config
=
vllm_config
.
quant_config
self
.
speculative_config
=
speculative_config
self
.
speculative_config
=
vllm_config
.
speculative_config
self
.
num_spec
=
(
self
.
num_spec
=
(
self
.
speculative_config
.
num_speculative_tokens
self
.
speculative_config
.
num_speculative_tokens
if
self
.
speculative_config
if
self
.
speculative_config
...
@@ -455,6 +452,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
...
@@ -455,6 +452,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
# projection of the input hidden states
# projection of the input hidden states
# Qwen3-Next and Qwen3.5 has a different qkv_proj layout,
# Qwen3-Next and Qwen3.5 has a different qkv_proj layout,
# we need to create qkvz_proj adaptively here.
# we need to create qkvz_proj adaptively here.
# When create_in_proj_qkvz is False (e.g. LoRA enabled in Qwen3.5),
# the subclass creates in_proj_qkv and in_proj_z separately.
if
create_in_proj_qkvz
:
self
.
in_proj_qkvz
=
self
.
create_qkvz_proj
(
self
.
in_proj_qkvz
=
self
.
create_qkvz_proj
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
key_dim
=
self
.
key_dim
,
key_dim
=
self
.
key_dim
,
...
@@ -1207,7 +1207,6 @@ class Qwen3NextDecoderLayer(nn.Module):
...
@@ -1207,7 +1207,6 @@ class Qwen3NextDecoderLayer(nn.Module):
model_config
=
vllm_config
.
model_config
model_config
=
vllm_config
.
model_config
cache_config
=
vllm_config
.
cache_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
speculative_config
=
vllm_config
.
speculative_config
self
.
layer_type
=
layer_type
self
.
layer_type
=
layer_type
self
.
layer_idx
=
extract_layer_index
(
prefix
)
self
.
layer_idx
=
extract_layer_index
(
prefix
)
...
@@ -1215,10 +1214,7 @@ class Qwen3NextDecoderLayer(nn.Module):
...
@@ -1215,10 +1214,7 @@ class Qwen3NextDecoderLayer(nn.Module):
if
self
.
layer_type
==
"linear_attention"
:
if
self
.
layer_type
==
"linear_attention"
:
self
.
linear_attn
=
Qwen3NextGatedDeltaNet
(
self
.
linear_attn
=
Qwen3NextGatedDeltaNet
(
config
,
config
,
model_config
=
model_config
,
vllm_config
=
vllm_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
speculative_config
=
speculative_config
,
prefix
=
f
"
{
prefix
}
.linear_attn"
,
prefix
=
f
"
{
prefix
}
.linear_attn"
,
)
)
elif
self
.
layer_type
==
"full_attention"
:
elif
self
.
layer_type
==
"full_attention"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment