Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
35006c0f
Commit
35006c0f
authored
Feb 24, 2026
by
zhuwenwen
Browse files
Merge branch '0.15.1-dev-qwen3_5' into 'v0.15.1-dev'
Support qwen3 5 See merge request dcutoolkit/deeplearing/vllm!438
parents
133e783f
4dc838d3
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
1971 additions
and
27 deletions
+1971
-27
csrc/sampler.cu
csrc/sampler.cu
+18
-5
docs/models/supported_models.md
docs/models/supported_models.md
+10
-8
tests/models/registry.py
tests/models/registry.py
+21
-1
vllm/config/model.py
vllm/config/model.py
+6
-4
vllm/config/speculative.py
vllm/config/speculative.py
+11
-0
vllm/model_executor/layers/mamba/abstract.py
vllm/model_executor/layers/mamba/abstract.py
+2
-1
vllm/model_executor/layers/mamba/mamba_utils.py
vllm/model_executor/layers/mamba/mamba_utils.py
+6
-2
vllm/model_executor/models/config.py
vllm/model_executor/models/config.py
+27
-0
vllm/model_executor/models/qwen3_5.py
vllm/model_executor/models/qwen3_5.py
+993
-0
vllm/model_executor/models/qwen3_5_mtp.py
vllm/model_executor/models/qwen3_5_mtp.py
+447
-0
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+6
-6
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+10
-0
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+8
-0
vllm/transformers_utils/configs/qwen3_5.py
vllm/transformers_utils/configs/qwen3_5.py
+193
-0
vllm/transformers_utils/configs/qwen3_5_moe.py
vllm/transformers_utils/configs/qwen3_5_moe.py
+205
-0
vllm/transformers_utils/model_arch_config_convertor.py
vllm/transformers_utils/model_arch_config_convertor.py
+6
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+2
-0
No files found.
csrc/sampler.cu
View file @
35006c0f
...
@@ -215,7 +215,11 @@ __device__ bool processHistogramStep(
...
@@ -215,7 +215,11 @@ __device__ bool processHistogramStep(
// Compute the prefix sum.
// Compute the prefix sum.
int
prefixSum
{
0
},
totalSum
{
0
};
int
prefixSum
{
0
},
totalSum
{
0
};
#ifndef USE_ROCM
using
Scan
=
cub
::
BlockScan
<
int
,
kNumThreadsPerBlock
>
;
using
Scan
=
cub
::
BlockScan
<
int
,
kNumThreadsPerBlock
>
;
#else:
using
Scan
=
hipcub
::
BlockScan
<
int
,
kNumThreadsPerBlock
>
;
#endif
Scan
(
smemFinal
.
histo
.
scan
).
ExclusiveSum
(
binCount
,
prefixSum
,
totalSum
);
Scan
(
smemFinal
.
histo
.
scan
).
ExclusiveSum
(
binCount
,
prefixSum
,
totalSum
);
// Update the histogram with the prefix sums.
// Update the histogram with the prefix sums.
...
@@ -334,13 +338,22 @@ static __device__ void topKPerRowJob(const int* indices, const float* logits,
...
@@ -334,13 +338,22 @@ static __device__ void topKPerRowJob(const int* indices, const float* logits,
static
constexpr
int
kNumFinalItemsPerThread
=
static
constexpr
int
kNumFinalItemsPerThread
=
kNumFinalItems
/
kNumThreadsPerBlock
;
kNumFinalItems
/
kNumThreadsPerBlock
;
// The class to sort the elements during the final pass.
// The class to sort the elements during the final pass.
#ifndef USE_ROCM
using
FinalSort
=
cub
::
BlockRadixSort
<
float
,
kNumThreadsPerBlock
,
using
FinalSort
=
cub
::
BlockRadixSort
<
float
,
kNumThreadsPerBlock
,
kNumFinalItemsPerThread
,
int
>
;
kNumFinalItemsPerThread
,
int
>
;
#else
using
FinalSort
=
hipcub
::
BlockRadixSort
<
float
,
kNumThreadsPerBlock
,
kNumFinalItemsPerThread
,
int
>
;
#endif
using
FinalSortTempStorage
=
using
FinalSortTempStorage
=
std
::
conditional_t
<
useRadixSort
,
typename
FinalSort
::
TempStorage
,
int
>
;
std
::
conditional_t
<
useRadixSort
,
typename
FinalSort
::
TempStorage
,
int
>
;
// The class to compute the inclusive prefix-sum over the histogram.
// The class to compute the inclusive prefix-sum over the histogram.
#ifndef USE_ROCM
using
Scan
=
cub
::
BlockScan
<
int
,
kNumThreadsPerBlock
>
;
using
Scan
=
cub
::
BlockScan
<
int
,
kNumThreadsPerBlock
>
;
#else
using
Scan
=
hipcub
::
BlockScan
<
int
,
kNumThreadsPerBlock
>
;
#endif
// The structure to store the final items (for the final pass).
// The structure to store the final items (for the final pass).
struct
FinalItems
{
struct
FinalItems
{
// Shared memory to store the indices for the final pass.
// Shared memory to store the indices for the final pass.
...
...
docs/models/supported_models.md
View file @
35006c0f
...
@@ -717,6 +717,8 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
...
@@ -717,6 +717,8 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
`Qwen2VLForConditionalGeneration`
| QVQ, Qwen2-VL | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/QVQ-72B-Preview`
,
`Qwen/Qwen2-VL-7B-Instruct`
,
`Qwen/Qwen2-VL-72B-Instruct`
, etc. | ✅︎ | ✅︎ |
|
`Qwen2VLForConditionalGeneration`
| QVQ, Qwen2-VL | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/QVQ-72B-Preview`
,
`Qwen/Qwen2-VL-7B-Instruct`
,
`Qwen/Qwen2-VL-72B-Instruct`
, etc. | ✅︎ | ✅︎ |
|
`Qwen2_5_VLForConditionalGeneration`
| Qwen2.5-VL | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/Qwen2.5-VL-3B-Instruct`
,
`Qwen/Qwen2.5-VL-72B-Instruct`
, etc. | ✅︎ | ✅︎ |
|
`Qwen2_5_VLForConditionalGeneration`
| Qwen2.5-VL | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/Qwen2.5-VL-3B-Instruct`
,
`Qwen/Qwen2.5-VL-72B-Instruct`
, etc. | ✅︎ | ✅︎ |
|
`Qwen2_5OmniThinkerForConditionalGeneration`
| Qwen2.5-Omni | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
+ A
<sup>
+
</sup>
|
`Qwen/Qwen2.5-Omni-3B`
,
`Qwen/Qwen2.5-Omni-7B`
| ✅︎ | ✅︎ |
|
`Qwen2_5OmniThinkerForConditionalGeneration`
| Qwen2.5-Omni | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
+ A
<sup>
+
</sup>
|
`Qwen/Qwen2.5-Omni-3B`
,
`Qwen/Qwen2.5-Omni-7B`
| ✅︎ | ✅︎ |
|
`Qwen3_5ForConditionalGeneration`
| Qwen3.5 | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/Qwen3.5-9B-Instruct`
, etc. | ✅︎ | ✅︎ |
|
`Qwen3_5MoeForConditionalGeneration`
| Qwen3.5-MOE | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/Qwen3.5-35B-A3B-Instruct`
, etc. | ✅︎ | ✅︎ |
|
`Qwen3VLForConditionalGeneration`
| Qwen3-VL | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/Qwen3-VL-4B-Instruct`
, etc. | ✅︎ | ✅︎ |
|
`Qwen3VLForConditionalGeneration`
| Qwen3-VL | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/Qwen3-VL-4B-Instruct`
, etc. | ✅︎ | ✅︎ |
|
`Qwen3VLMoeForConditionalGeneration`
| Qwen3-VL-MOE | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/Qwen3-VL-30B-A3B-Instruct`
, etc. | ✅︎ | ✅︎ |
|
`Qwen3VLMoeForConditionalGeneration`
| Qwen3-VL-MOE | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/Qwen3-VL-30B-A3B-Instruct`
, etc. | ✅︎ | ✅︎ |
|
`Qwen3OmniMoeThinkerForConditionalGeneration`
| Qwen3-Omni | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
+ A
<sup>
+
</sup>
|
`Qwen/Qwen3-Omni-30B-A3B-Instruct`
,
`Qwen/Qwen3-Omni-30B-A3B-Thinking`
| ✅︎ | ✅︎ |
|
`Qwen3OmniMoeThinkerForConditionalGeneration`
| Qwen3-Omni | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
+ A
<sup>
+
</sup>
|
`Qwen/Qwen3-Omni-30B-A3B-Instruct`
,
`Qwen/Qwen3-Omni-30B-A3B-Thinking`
| ✅︎ | ✅︎ |
...
...
tests/models/registry.py
View file @
35006c0f
...
@@ -943,6 +943,26 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -943,6 +943,26 @@ _MULTIMODAL_EXAMPLE_MODELS = {
max_model_len
=
4096
,
max_model_len
=
4096
,
min_transformers_version
=
"4.57"
,
min_transformers_version
=
"4.57"
,
),
),
"Qwen3_5ForConditionalGeneration"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3.5-9B-Instruct"
),
max_model_len
=
4096
,
min_transformers_version
=
"5.1.0"
,
),
"Qwen3_5MoeForConditionalGeneration"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3.5-35B-A3B-Instruct"
),
max_model_len
=
4096
,
min_transformers_version
=
"5.1.0"
,
),
"Qwen3_5MTP"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3.5-9B-Instruct"
),
speculative_model
=
"Qwen/Qwen3.5-9B-Instruct"
,
min_transformers_version
=
"5.1.0"
,
),
"Qwen3_5MoeMTP"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3.5-35B-A3B-Instruct"
),
speculative_model
=
"Qwen/Qwen3.5-35B-A3B-Instruct"
,
min_transformers_version
=
"5.1.0"
,
),
"Qwen3OmniMoeForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen3OmniMoeForConditionalGeneration"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3-Omni-30B-A3B-Instruct"
),
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3-Omni-30B-A3B-Instruct"
),
max_model_len
=
4096
,
max_model_len
=
4096
,
...
...
vllm/config/model.py
View file @
35006c0f
...
@@ -1250,7 +1250,9 @@ class ModelConfig:
...
@@ -1250,7 +1250,9 @@ class ModelConfig:
return
sum
(
t
==
1
for
t
in
attn_type_list
[
start
:
end
])
return
sum
(
t
==
1
for
t
in
attn_type_list
[
start
:
end
])
# Hybrid model Qwen3Next
# Hybrid model Qwen3Next
layer_types_value
=
getattr
(
self
.
hf_config
,
"layer_types"
,
None
)
# layer_types_value = getattr(self.hf_config, "layer_types", None)
# Hybrid model Qwen3Next Qwen3.5 Series
layer_types_value
=
getattr
(
self
.
hf_text_config
,
"layer_types"
,
None
)
if
layer_types_value
is
not
None
:
if
layer_types_value
is
not
None
:
if
block_type
==
"attention"
:
if
block_type
==
"attention"
:
return
sum
(
return
sum
(
...
...
vllm/config/speculative.py
View file @
35006c0f
...
@@ -37,6 +37,7 @@ MTPModelTypes = Literal[
...
@@ -37,6 +37,7 @@ MTPModelTypes = Literal[
"ernie_mtp"
,
"ernie_mtp"
,
"exaone_moe_mtp"
,
"exaone_moe_mtp"
,
"qwen3_next_mtp"
,
"qwen3_next_mtp"
,
"qwen3_5_mtp"
,
"longcat_flash_mtp"
,
"longcat_flash_mtp"
,
"mtp"
,
"mtp"
,
"pangu_ultra_moe_mtp"
,
"pangu_ultra_moe_mtp"
,
...
@@ -246,6 +247,16 @@ class SpeculativeConfig:
...
@@ -246,6 +247,16 @@ class SpeculativeConfig:
{
"n_predict"
:
n_predict
,
"architectures"
:
[
"ExaoneMoeMTP"
]}
{
"n_predict"
:
n_predict
,
"architectures"
:
[
"ExaoneMoeMTP"
]}
)
)
if
hf_config
.
model_type
in
(
"qwen3_5"
,
"qwen3_5_moe"
):
is_moe
=
hf_config
.
model_type
==
"qwen3_5_moe"
hf_config
.
model_type
=
"qwen3_5_mtp"
n_predict
=
getattr
(
hf_config
,
"mtp_num_hidden_layers"
,
None
)
hf_config
.
update
(
{
"n_predict"
:
n_predict
,
"architectures"
:
[
"Qwen3_5MoeMTP"
if
is_moe
else
"Qwen3_5MTP"
],
}
)
if
hf_config
.
model_type
==
"longcat_flash"
:
if
hf_config
.
model_type
==
"longcat_flash"
:
hf_config
.
model_type
=
"longcat_flash_mtp"
hf_config
.
model_type
=
"longcat_flash_mtp"
n_predict
=
getattr
(
hf_config
,
"num_nextn_predict_layers"
,
1
)
n_predict
=
getattr
(
hf_config
,
"num_nextn_predict_layers"
,
1
)
...
...
vllm/model_executor/layers/mamba/abstract.py
View file @
35006c0f
...
@@ -43,7 +43,8 @@ class MambaBase(AttentionLayerBase):
...
@@ -43,7 +43,8 @@ class MambaBase(AttentionLayerBase):
def
get_kv_cache_spec
(
self
,
vllm_config
:
VllmConfig
)
->
KVCacheSpec
|
None
:
def
get_kv_cache_spec
(
self
,
vllm_config
:
VllmConfig
)
->
KVCacheSpec
|
None
:
if
(
if
(
vllm_config
.
speculative_config
is
not
None
vllm_config
.
speculative_config
is
not
None
and
vllm_config
.
model_config
.
hf_config
.
model_type
not
in
[
"qwen3_next"
]
and
vllm_config
.
model_config
.
hf_config
.
model_type
not
in
[
"qwen3_next"
,
"qwen3_5"
,
"qwen3_5_moe"
]
):
):
raise
NotImplementedError
(
raise
NotImplementedError
(
"Mamba with speculative decoding is not supported yet."
"Mamba with speculative decoding is not supported yet."
...
...
vllm/model_executor/layers/mamba/mamba_utils.py
View file @
35006c0f
...
@@ -80,9 +80,13 @@ class MambaStateDtypeCalculator:
...
@@ -80,9 +80,13 @@ class MambaStateDtypeCalculator:
cls
,
cls
,
model_dtype
:
ModelDType
|
torch
.
dtype
,
model_dtype
:
ModelDType
|
torch
.
dtype
,
mamba_cache_dtype
:
MambaDType
,
mamba_cache_dtype
:
MambaDType
,
mamba_ssm_cache_dtype
:
MambaDType
=
"auto"
,
)
->
tuple
[
torch
.
dtype
,
torch
.
dtype
]:
)
->
tuple
[
torch
.
dtype
,
torch
.
dtype
]:
state_dtype
=
get_kv_cache_torch_dtype
(
mamba_cache_dtype
,
model_dtype
)
# state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
return
(
state_dtype
,
state_dtype
)
# return (state_dtype, state_dtype)
return
cls
.
_mamba_state_dtype
(
model_dtype
,
mamba_cache_dtype
,
mamba_ssm_cache_dtype
)
@
classmethod
@
classmethod
def
kda_state_dtype
(
def
kda_state_dtype
(
...
...
vllm/model_executor/models/config.py
View file @
35006c0f
...
@@ -581,6 +581,31 @@ class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
...
@@ -581,6 +581,31 @@ class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
)
)
cache_config
.
mamba_ssm_cache_dtype
=
mamba_ssm_cache_dtype
cache_config
.
mamba_ssm_cache_dtype
=
mamba_ssm_cache_dtype
class
Qwen3_5ForConditionalGenerationConfig
(
VerifyAndUpdateConfig
):
@
staticmethod
def
verify_and_update_config
(
vllm_config
:
"VllmConfig"
)
->
None
:
"""Update mamba_ssm_cache_dtype for Qwen3.5 models when set to 'auto'
(or not explicitly set), to the value specified in the HF config's
mamba_ssm_dtype field. Warn if the user explicitly overrides it to a
different value.
"""
cache_config
=
vllm_config
.
cache_config
hf_text_config
=
vllm_config
.
model_config
.
hf_text_config
mamba_ssm_dtype
=
getattr
(
hf_text_config
,
"mamba_ssm_dtype"
,
None
)
if
cache_config
.
mamba_ssm_cache_dtype
==
"auto"
:
if
mamba_ssm_dtype
is
not
None
:
cache_config
.
mamba_ssm_cache_dtype
=
mamba_ssm_dtype
elif
(
mamba_ssm_dtype
is
not
None
and
cache_config
.
mamba_ssm_cache_dtype
!=
mamba_ssm_dtype
):
logger
.
warning
(
"Qwen3.5 model specifies mamba_ssm_dtype='%s' in its config, "
"but --mamba-ssm-cache-dtype='%s' was passed. "
"Using the user-specified value."
,
mamba_ssm_dtype
,
cache_config
.
mamba_ssm_cache_dtype
,
)
MODELS_CONFIG_MAP
:
dict
[
str
,
type
[
VerifyAndUpdateConfig
]]
=
{
MODELS_CONFIG_MAP
:
dict
[
str
,
type
[
VerifyAndUpdateConfig
]]
=
{
"GteModel"
:
SnowflakeGteNewModelConfig
,
"GteModel"
:
SnowflakeGteNewModelConfig
,
...
@@ -603,4 +628,6 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
...
@@ -603,4 +628,6 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"FalconMambaForCausalLM"
:
MambaModelConfig
,
"FalconMambaForCausalLM"
:
MambaModelConfig
,
"DeepseekV32ForCausalLM"
:
DeepseekV32ForCausalLM
,
"DeepseekV32ForCausalLM"
:
DeepseekV32ForCausalLM
,
"NemotronHForCausalLM"
:
NemotronHForCausalLMConfig
,
"NemotronHForCausalLM"
:
NemotronHForCausalLMConfig
,
"Qwen3_5ForConditionalGeneration"
:
Qwen3_5ForConditionalGenerationConfig
,
"Qwen3_5MoeForConditionalGeneration"
:
Qwen3_5ForConditionalGenerationConfig
,
}
}
vllm/model_executor/models/qwen3_5.py
0 → 100644
View file @
35006c0f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The vLLM team.
# Copyright 2025 The Qwen Team.
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3.5 Series compatible with HuggingFace weights."""
import
typing
from
collections.abc
import
Callable
,
Iterable
import
torch
from
einops
import
rearrange
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.models.qwen3_5.configuration_qwen3_5
import
(
Qwen3_5Config
,
Qwen3_5TextConfig
,
)
from
transformers.models.qwen3_5_moe.configuration_qwen3_5_moe
import
(
Qwen3_5MoeConfig
,
Qwen3_5MoeTextConfig
,
)
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
SpeculativeConfig
,
VllmConfig
,
get_current_vllm_config
,
)
from
vllm.distributed
import
(
divide
,
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
(
GemmaRMSNorm
as
Qwen3_5RMSNorm
,
)
from
vllm.model_executor.layers.layernorm
import
RMSNormGated
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
(
mamba_v2_sharded_weight_loader
,
)
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateCopyFunc
,
MambaStateCopyFuncCalculator
,
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
,
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
sharded_weight_loader
,
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
(
HasInnerState
,
IsHybrid
,
MixtureOfExperts
,
MultiModalEmbeddings
,
SupportsLoRA
,
SupportsPP
,
_require_is_multimodal
,
)
from
.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3NextMLP
from
.qwen3_next
import
(
Qwen3NextAttention
,
Qwen3NextDecoderLayer
,
Qwen3NextGatedDeltaNet
,
Qwen3NextModel
,
Qwen3NextSparseMoeBlock
,
QwenNextMixtureOfExperts
,
)
from
.qwen3_vl
import
(
Qwen3_VisionTransformer
,
Qwen3VLDummyInputsBuilder
,
Qwen3VLForConditionalGeneration
,
Qwen3VLMultiModalProcessor
,
Qwen3VLProcessingInfo
,
)
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
_merge_multimodal_embeddings
,
extract_layer_index
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
,
)
logger
=
init_logger
(
__name__
)
class
Qwen3_5ProcessingInfo
(
Qwen3VLProcessingInfo
):
def
get_hf_config
(
self
):
return
self
.
ctx
.
get_hf_config
(
Qwen3_5Config
)
class
Qwen3_5MoeProcessingInfo
(
Qwen3VLProcessingInfo
):
def
get_hf_config
(
self
):
return
self
.
ctx
.
get_hf_config
(
Qwen3_5MoeConfig
)
class
Qwen3_5GatedDeltaNet
(
Qwen3NextGatedDeltaNet
):
def
__init__
(
self
,
config
:
Qwen3_5TextConfig
|
Qwen3_5MoeTextConfig
,
model_config
:
ModelConfig
|
None
=
None
,
cache_config
:
CacheConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
speculative_config
:
SpeculativeConfig
|
None
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
(
Qwen3NextGatedDeltaNet
,
self
).
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
hidden_size
=
config
.
hidden_size
self
.
num_v_heads
=
config
.
linear_num_value_heads
self
.
num_k_heads
=
config
.
linear_num_key_heads
self
.
head_k_dim
=
config
.
linear_key_head_dim
self
.
head_v_dim
=
config
.
linear_value_head_dim
self
.
key_dim
=
self
.
head_k_dim
*
self
.
num_k_heads
self
.
value_dim
=
self
.
head_v_dim
*
self
.
num_v_heads
self
.
conv_kernel_size
=
config
.
linear_conv_kernel_dim
self
.
layer_idx
=
extract_layer_index
(
prefix
)
self
.
activation
=
config
.
hidden_act
self
.
act
=
ACT2FN
[
config
.
hidden_act
]
self
.
layer_norm_epsilon
=
config
.
rms_norm_eps
self
.
prefix
=
prefix
self
.
config
=
config
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
quant_config
=
quant_config
self
.
speculative_config
=
speculative_config
self
.
num_spec
=
(
self
.
speculative_config
.
num_speculative_tokens
if
self
.
speculative_config
else
0
)
# QKV
self
.
conv_dim
=
self
.
key_dim
*
2
+
self
.
value_dim
self
.
conv1d
=
ColumnParallelLinear
(
input_size
=
self
.
conv_kernel_size
,
output_size
=
self
.
conv_dim
,
bias
=
False
,
prefix
=
f
"
{
prefix
}
.conv1d"
,
)
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
self
.
in_proj_qkv
=
MergedColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_sizes
=
[
self
.
key_dim
,
self
.
key_dim
,
self
.
value_dim
],
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_qkv"
,
)
self
.
in_proj_z
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
=
self
.
value_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_z"
,
)
self
.
in_proj_b
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
=
self
.
num_v_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_ba"
,
)
self
.
in_proj_a
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
=
self
.
num_v_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_a"
,
)
query_key_settings
=
(
self
.
key_dim
,
0
,
False
)
value_settings
=
(
self
.
value_dim
,
0
,
False
)
delattr
(
self
.
conv1d
.
weight
,
"weight_loader"
)
set_weight_attrs
(
self
.
conv1d
.
weight
,
{
"weight_loader"
:
mamba_v2_sharded_weight_loader
(
[
query_key_settings
,
query_key_settings
,
value_settings
,
],
self
.
tp_size
,
self
.
tp_rank
,
)
},
)
# selective projection used to make dt, B and C input dependant
# time step projection (discretization)
# instantiate once and copy inv_dt in init_weights of PretrainedModel
self
.
dt_bias
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_v_heads
//
self
.
tp_size
),
)
self
.
A_log
=
nn
.
Parameter
(
torch
.
empty
(
divide
(
self
.
num_v_heads
,
self
.
tp_size
),
)
)
set_weight_attrs
(
self
.
A_log
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
set_weight_attrs
(
self
.
dt_bias
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
self
.
norm
=
RMSNormGated
(
self
.
head_v_dim
,
eps
=
self
.
layer_norm_epsilon
,
group_size
=
None
,
norm_before_gate
=
True
,
device
=
current_platform
.
current_device
(),
dtype
=
config
.
dtype
,
)
self
.
out_proj
=
RowParallelLinear
(
self
.
value_dim
,
self
.
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.out_proj"
,
)
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
def
fix_query_key_value_ordering
(
self
,
mixed_qkv
,
z
,
b
,
a
,
):
raise
NotImplementedError
(
"Qwen3.5 Series dont need to fix query key value ordering"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
):
"""
Forward pass with three parts:
1. Input projection
2. Core attention (custom op)
3. Output projection
"""
num_tokens
=
hidden_states
.
size
(
0
)
# ============================================================
# Part 1: Input Projection
# ============================================================
mixed_qkv
,
_
=
self
.
in_proj_qkv
(
hidden_states
)
z
,
_
=
self
.
in_proj_z
(
hidden_states
)
z
=
z
.
reshape
(
z
.
size
(
0
),
-
1
,
self
.
head_v_dim
)
b
,
_
=
self
.
in_proj_b
(
hidden_states
)
a
,
_
=
self
.
in_proj_a
(
hidden_states
)
b
=
b
.
contiguous
()
a
=
a
.
contiguous
()
# ============================================================
# Part 2: Core Attention (Custom Op)
# ============================================================
# Note: we should not use torch.empty here like other attention backends,
# see discussions in https://github.com/vllm-project/vllm/pull/28182
core_attn_out
=
torch
.
zeros
(
(
num_tokens
,
self
.
num_v_heads
//
self
.
tp_size
,
self
.
head_v_dim
),
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
)
torch
.
ops
.
vllm
.
gdn_attention_core
(
mixed_qkv
,
b
,
a
,
core_attn_out
,
self
.
prefix
,
)
# ============================================================
# Part 3: Output Projection
# ============================================================
z_shape_og
=
z
.
shape
# Reshape input data into 2D tensor
core_attn_out
=
core_attn_out
.
reshape
(
-
1
,
core_attn_out
.
shape
[
-
1
])
z
=
z
.
reshape
(
-
1
,
z
.
shape
[
-
1
])
core_attn_out
=
self
.
norm
(
core_attn_out
,
z
)
core_attn_out
=
core_attn_out
.
reshape
(
z_shape_og
)
core_attn_out
=
rearrange
(
core_attn_out
,
"... h d -> ... (h d)"
)
output
[:
num_tokens
],
_
=
self
.
out_proj
(
core_attn_out
)
class
Qwen3_5DecoderLayer
(
Qwen3NextDecoderLayer
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
layer_type
:
str
,
prefix
:
str
=
""
,
)
->
None
:
super
(
Qwen3NextDecoderLayer
,
self
).
__init__
()
config
=
vllm_config
.
model_config
.
hf_text_config
model_config
=
vllm_config
.
model_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
speculative_config
=
vllm_config
.
speculative_config
self
.
layer_type
=
layer_type
self
.
layer_idx
=
extract_layer_index
(
prefix
)
if
self
.
layer_type
==
"linear_attention"
:
self
.
linear_attn
=
Qwen3_5GatedDeltaNet
(
config
,
model_config
=
model_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
speculative_config
=
speculative_config
,
prefix
=
f
"
{
prefix
}
.linear_attn"
,
)
elif
self
.
layer_type
==
"full_attention"
:
self
.
self_attn
=
Qwen3NextAttention
(
config
,
model_config
=
model_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
else
:
raise
ValueError
(
f
"Invalid layer_type
{
self
.
layer_type
}
"
)
# NOTE: Determine the MLP type based on the model type
# Qwen3.5 use all layers for MLP / Qwen3.5-MoE use sparse MoE blocks
if
config
.
model_type
==
"qwen3_5_moe_text"
:
self
.
mlp
=
Qwen3NextSparseMoeBlock
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
elif
config
.
model_type
==
"qwen3_5_text"
:
self
.
mlp
=
Qwen3NextMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
else
:
raise
ValueError
(
f
"Invalid model_type
{
config
.
model_type
}
"
)
self
.
input_layernorm
=
Qwen3_5RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
Qwen3_5RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
layer_scale
=
getattr
(
config
,
"layer_scale"
,
False
)
if
self
.
layer_scale
:
self
.
attn_layer_scale
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
config
.
hidden_size
,
dtype
=
config
.
dtype
,
),
)
self
.
ffn_layer_scale
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
config
.
hidden_size
,
dtype
=
config
.
dtype
,
),
)
@
support_torch_compile
(
dynamic_arg_dims
=
{
"input_ids"
:
0
,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions"
:
-
1
,
"intermediate_tensors"
:
0
,
"inputs_embeds"
:
0
,
}
)
class
Qwen3_5Model
(
Qwen3NextModel
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
(
Qwen3NextModel
,
self
).
__init__
()
config
:
Qwen3_5TextConfig
|
Qwen3_5MoeTextConfig
=
(
vllm_config
.
model_config
.
hf_text_config
)
parallel_config
=
vllm_config
.
parallel_config
eplb_config
=
parallel_config
.
eplb_config
self
.
num_redundant_experts
=
eplb_config
.
num_redundant_experts
self
.
config
=
config
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
vocab_size
,
config
.
hidden_size
,
)
def
get_layer
(
prefix
:
str
):
return
Qwen3_5DecoderLayer
(
vllm_config
,
layer_type
=
config
.
layer_types
[
extract_layer_index
(
prefix
)],
prefix
=
prefix
,
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
get_layer
,
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
make_empty_intermediate_tensors
=
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
)
if
get_pp_group
().
is_last_rank
:
self
.
norm
=
Qwen3_5RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
()
def
load_fused_expert_weights
(
self
,
name
:
str
,
params_dict
:
dict
,
loaded_weight
:
torch
.
Tensor
,
shard_id
:
str
,
num_experts
:
int
,
)
->
bool
:
param
=
params_dict
[
name
]
weight_loader
=
typing
.
cast
(
Callable
[...,
bool
],
param
.
weight_loader
)
loaded_local_expert
=
False
for
expert_id
in
range
(
num_experts
):
curr_expert_weight
=
loaded_weight
[
expert_id
]
success
=
weight_loader
(
param
,
curr_expert_weight
,
name
,
shard_id
,
expert_id
,
return_success
=
True
,
)
if
success
:
loaded_local_expert
=
True
return
loaded_local_expert
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
expert_params_mapping
=
self
.
get_expert_mapping
()
is_fused_expert
=
False
fused_expert_params_mapping
=
[
(
"experts.w13_weight"
,
"experts.gate_up_proj"
,
0
,
"w1"
),
(
"experts.w2_weight"
,
"experts.down_proj"
,
0
,
"w2"
),
]
num_experts
=
(
self
.
config
.
num_experts
if
hasattr
(
self
.
config
,
"num_experts"
)
else
0
)
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
name
.
startswith
(
"mtp."
):
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
"experts.gate_up_proj"
in
name
or
"experts.down_proj"
in
name
:
is_fused_expert
=
True
expert_params_mapping
=
fused_expert_params_mapping
if
weight_name
not
in
name
:
continue
if
"mlp.experts"
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
# name = apply_attn_prefix(name, params_dict)
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
is_expert_weight
=
False
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
is_expert_weight
=
True
name_mapped
=
name
.
replace
(
weight_name
,
param_name
)
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name_mapped
,
self
):
continue
if
is_fused_expert
:
# qwen3.5 no need to transpose
# loaded_weight = loaded_weight.transpose(-1, -2)
if
"experts.gate_up_proj"
in
name
:
loaded_weight
=
loaded_weight
.
chunk
(
2
,
dim
=-
2
)
success_w1
=
self
.
load_fused_expert_weights
(
name_mapped
,
params_dict
,
loaded_weight
[
0
],
"w1"
,
num_experts
,
)
success_w3
=
self
.
load_fused_expert_weights
(
name_mapped
,
params_dict
,
loaded_weight
[
1
],
"w3"
,
num_experts
,
)
success
=
success_w1
and
success_w3
else
:
# down_proj
success
=
self
.
load_fused_expert_weights
(
name_mapped
,
params_dict
,
loaded_weight
,
shard_id
,
num_experts
,
)
if
success
:
name
=
name_mapped
break
else
:
# Skip loading extra bias for GPTQ models.
if
(
name_mapped
.
endswith
(
".bias"
)
or
name_mapped
.
endswith
(
"_bias"
)
)
and
name_mapped
not
in
params_dict
:
continue
param
=
params_dict
[
name_mapped
]
weight_loader
=
param
.
weight_loader
success
=
weight_loader
(
param
,
loaded_weight
,
name_mapped
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
return_success
=
True
,
)
if
success
:
name
=
name_mapped
break
else
:
if
is_expert_weight
:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
if
name
not
in
params_dict
:
logger
.
warning_once
(
f
"Parameter
{
name
}
not found in params_dict, skip loading"
)
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
Qwen3_5ForCausalLMBase
(
nn
.
Module
,
HasInnerState
,
SupportsLoRA
,
SupportsPP
,
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
config
=
vllm_config
.
model_config
.
hf_text_config
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
cache_config
=
vllm_config
.
cache_config
scheduler_config
=
vllm_config
.
scheduler_config
if
cache_config
.
mamba_cache_mode
==
"all"
:
raise
NotImplementedError
(
"Qwen3.5 currently does not support 'all' prefix caching, "
"please use '--mamba-cache-mode=align' instead"
)
self
.
quant_config
=
vllm_config
.
quant_config
super
().
__init__
()
self
.
config
=
config
self
.
scheduler_config
=
scheduler_config
self
.
model
=
Qwen3_5Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
)
)
if
get_pp_group
().
is_last_rank
:
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
embed_input_ids
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
:
object
,
):
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
None
:
return
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
[
"mtp."
],
)
return
loader
.
load_weights
(
weights
)
class
Qwen3_5ForCausalLM
(
Qwen3_5ForCausalLMBase
):
pass
class
Qwen3_5MoeForCausalLM
(
Qwen3_5ForCausalLMBase
,
QwenNextMixtureOfExperts
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
# set MoE hyperparameters
self
.
set_moe_parameters
()
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
self
.
model
.
get_expert_mapping
()
########################################################
# Qwen3_5-Dense
########################################################
@
MULTIMODAL_REGISTRY
.
register_processor
(
Qwen3VLMultiModalProcessor
,
info
=
Qwen3_5ProcessingInfo
,
dummy_inputs
=
Qwen3VLDummyInputsBuilder
,
)
class
Qwen3_5ForConditionalGeneration
(
Qwen3VLForConditionalGeneration
,
IsHybrid
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
# protocols have not __init__ method, so we need to use nn.Module.__init__
nn
.
Module
.
__init__
(
self
)
config
:
Qwen3_5Config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
use_data_parallel
=
multimodal_config
.
mm_encoder_tp_mode
==
"data"
self
.
video_pruning_rate
=
multimodal_config
.
video_pruning_rate
self
.
is_multimodal_pruning_enabled
=
(
multimodal_config
.
is_multimodal_pruning_enabled
()
)
with
self
.
_mark_tower_model
(
vllm_config
,
{
"image"
,
"video"
}):
self
.
visual
=
Qwen3_VisionTransformer
(
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
)
with
self
.
_mark_language_model
(
vllm_config
):
self
.
language_model
=
Qwen3_5ForCausalLM
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
)
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
handle_oov_mm_token
:
bool
=
False
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
_embed_text_input_ids
(
input_ids
,
self
.
language_model
.
embed_input_ids
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
if
multimodal_embeddings
is
None
or
len
(
multimodal_embeddings
)
==
0
:
return
inputs_embeds
is_multimodal
=
_require_is_multimodal
(
is_multimodal
)
inputs_embeds
=
_merge_multimodal_embeddings
(
inputs_embeds
=
inputs_embeds
,
multimodal_embeddings
=
multimodal_embeddings
,
is_multimodal
=
is_multimodal
,
)
return
inputs_embeds
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
:
object
,
)
->
torch
.
Tensor
|
IntermediateTensors
:
"""Run forward pass for Qwen3.5.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen3VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
intermediate_tensors: Intermediate tensors from previous pipeline
stages.
inputs_embeds: Pre-computed input embeddings.
**kwargs: Additional keyword arguments including:
- pixel_values: Pixel values to be fed to a model.
`None` if no images are passed.
- image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in
LLM. `None` if no images are passed.
- pixel_values_videos: Pixel values of videos to be fed to a
model. `None` if no videos are passed.
- video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in
LLM. `None` if no videos are passed.
"""
if
intermediate_tensors
is
not
None
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
[
"mtp."
],
)
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
@
classmethod
def
get_mamba_state_dtype_from_config
(
cls
,
vllm_config
:
"VllmConfig"
,
)
->
tuple
[
torch
.
dtype
,
torch
.
dtype
]:
return
MambaStateDtypeCalculator
.
gated_delta_net_state_dtype
(
vllm_config
.
model_config
.
dtype
,
vllm_config
.
cache_config
.
mamba_cache_dtype
)
@
classmethod
def
get_mamba_state_shape_from_config
(
cls
,
vllm_config
:
"VllmConfig"
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
]]:
parallel_config
=
vllm_config
.
parallel_config
hf_config
=
vllm_config
.
model_config
.
hf_text_config
tp_size
=
parallel_config
.
tensor_parallel_size
num_spec
=
(
vllm_config
.
speculative_config
.
num_speculative_tokens
if
vllm_config
.
speculative_config
else
0
)
return
MambaStateShapeCalculator
.
gated_delta_net_state_shape
(
tp_size
,
hf_config
.
linear_num_key_heads
,
hf_config
.
linear_num_value_heads
,
hf_config
.
linear_key_head_dim
,
hf_config
.
linear_value_head_dim
,
hf_config
.
linear_conv_kernel_dim
,
num_spec
,
)
@
classmethod
def
get_mamba_state_copy_func
(
cls
)
->
tuple
[
MambaStateCopyFunc
,
MambaStateCopyFunc
]:
return
MambaStateCopyFuncCalculator
.
gated_delta_net_state_copy_func
()
########################################################
# Qwen3_5-MoE
########################################################
class
Qwen3_5_MoeMixtureOfExperts
(
MixtureOfExperts
):
def
update_physical_experts_metadata
(
self
,
num_physical_experts
:
int
,
num_local_physical_experts
:
int
,
)
->
None
:
assert
self
.
num_local_physical_experts
==
num_local_physical_experts
self
.
num_physical_experts
=
num_physical_experts
self
.
num_local_physical_experts
=
num_local_physical_experts
self
.
num_redundant_experts
=
num_physical_experts
-
self
.
num_logical_experts
for
layer
in
self
.
language_model
.
model
.
layers
:
if
isinstance
(
layer
.
mlp
,
Qwen3NextSparseMoeBlock
):
moe
=
layer
.
mlp
moe
.
n_local_physical_experts
=
num_local_physical_experts
moe
.
n_physical_experts
=
num_physical_experts
moe
.
n_redundant_experts
=
self
.
num_redundant_experts
moe
.
experts
.
update_expert_map
()
def
set_moe_parameters
(
self
):
self
.
expert_weights
=
[]
self
.
moe_layers
=
[]
example_moe
=
None
for
layer
in
self
.
language_model
.
model
.
layers
:
if
isinstance
(
layer
,
Qwen3_5DecoderLayer
)
and
isinstance
(
layer
.
mlp
,
Qwen3NextSparseMoeBlock
):
example_moe
=
layer
.
mlp
self
.
moe_layers
.
append
(
layer
.
mlp
.
experts
)
if
example_moe
is
None
:
raise
RuntimeError
(
"No Qwen3_5 layer found in the language_model.model.layers."
)
# Set MoE hyperparameters
self
.
num_moe_layers
=
len
(
self
.
moe_layers
)
self
.
num_expert_groups
=
1
self
.
num_shared_experts
=
0
self
.
num_logical_experts
=
example_moe
.
n_logical_experts
self
.
num_physical_experts
=
example_moe
.
n_physical_experts
self
.
num_local_physical_experts
=
example_moe
.
n_local_physical_experts
self
.
num_routed_experts
=
example_moe
.
n_routed_experts
self
.
num_redundant_experts
=
example_moe
.
n_redundant_experts
@
MULTIMODAL_REGISTRY
.
register_processor
(
Qwen3VLMultiModalProcessor
,
info
=
Qwen3_5MoeProcessingInfo
,
dummy_inputs
=
Qwen3VLDummyInputsBuilder
,
)
class
Qwen3_5MoeForConditionalGeneration
(
Qwen3_5ForConditionalGeneration
,
Qwen3_5_MoeMixtureOfExperts
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
# protocols have not __init__ method, so we need to use nn.Module.__init__
nn
.
Module
.
__init__
(
self
)
config
:
Qwen3_5MoeConfig
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
use_data_parallel
=
multimodal_config
.
mm_encoder_tp_mode
==
"data"
self
.
video_pruning_rate
=
multimodal_config
.
video_pruning_rate
self
.
is_multimodal_pruning_enabled
=
(
multimodal_config
.
is_multimodal_pruning_enabled
()
)
with
self
.
_mark_tower_model
(
vllm_config
,
{
"image"
,
"video"
}):
self
.
visual
=
Qwen3_VisionTransformer
(
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
)
with
self
.
_mark_language_model
(
vllm_config
):
self
.
language_model
=
Qwen3_5MoeForCausalLM
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
)
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
# set MoE hyperparameters
self
.
set_moe_parameters
()
\ No newline at end of file
vllm/model_executor/models/qwen3_5_mtp.py
0 → 100644
View file @
35006c0f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Qwen3_5 MTP model."""
import
typing
from
collections.abc
import
Callable
,
Iterable
import
torch
from
torch
import
nn
from
transformers.models.qwen3_5.configuration_qwen3_5
import
Qwen3_5TextConfig
from
transformers.models.qwen3_5_moe.configuration_qwen3_5_moe
import
(
Qwen3_5MoeTextConfig
,
)
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
VllmConfig
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.qwen3_5
import
Qwen3_5DecoderLayer
,
Qwen3_5RMSNorm
from
vllm.model_executor.models.qwen3_next
import
QwenNextMixtureOfExperts
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsMultiModal
,
_require_is_multimodal
,
)
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
_merge_multimodal_embeddings
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
maybe_prefix
,
)
logger
=
init_logger
(
__name__
)
@
support_torch_compile
(
dynamic_arg_dims
=
{
"input_ids"
:
0
,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions"
:
-
1
,
"intermediate_tensors"
:
0
,
"inputs_embeds"
:
0
,
"hidden_states"
:
0
,
}
)
class
Qwen3_5MultiTokenPredictor
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
model_config
=
vllm_config
.
model_config
quant_config
=
vllm_config
.
quant_config
config
:
Qwen3_5TextConfig
|
Qwen3_5MoeTextConfig
=
model_config
.
hf_text_config
self
.
config
=
config
self
.
vocab_size
=
config
.
vocab_size
self
.
mtp_start_layer_idx
=
config
.
num_hidden_layers
self
.
num_mtp_layers
=
getattr
(
config
,
"mtp_num_hidden_layers"
,
1
)
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
vocab_size
,
config
.
hidden_size
,
)
self
.
fc
=
ColumnParallelLinear
(
self
.
config
.
hidden_size
*
2
,
self
.
config
.
hidden_size
,
gather_output
=
True
,
bias
=
False
,
return_bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc"
,
)
self
.
layers
=
torch
.
nn
.
ModuleList
(
Qwen3_5DecoderLayer
(
vllm_config
,
layer_type
=
"full_attention"
,
prefix
=
f
"
{
prefix
}
.layers.
{
idx
}
"
,
)
for
idx
in
range
(
self
.
num_mtp_layers
)
)
self
.
make_empty_intermediate_tensors
=
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
)
self
.
norm
=
Qwen3_5RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
pre_fc_norm_hidden
=
Qwen3_5RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
pre_fc_norm_embedding
=
Qwen3_5RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
spec_step_idx
:
int
=
0
,
)
->
torch
.
Tensor
:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_input_ids
(
input_ids
)
assert
hidden_states
.
shape
[
-
1
]
==
inputs_embeds
.
shape
[
-
1
]
inputs_embeds
=
self
.
pre_fc_norm_embedding
(
inputs_embeds
)
hidden_states
=
self
.
pre_fc_norm_hidden
(
hidden_states
)
hidden_states
=
torch
.
cat
([
inputs_embeds
,
hidden_states
],
dim
=-
1
)
hidden_states
=
self
.
fc
(
hidden_states
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
current_step_idx
=
spec_step_idx
%
self
.
num_mtp_layers
hidden_states
,
residual
=
self
.
layers
[
current_step_idx
](
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
,
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
(
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
}
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
def
load_fused_expert_weights
(
self
,
name
:
str
,
params_dict
:
dict
,
loaded_weight
:
torch
.
Tensor
,
shard_id
:
str
,
num_experts
:
int
,
)
->
bool
:
param
=
params_dict
[
name
]
weight_loader
=
typing
.
cast
(
Callable
[...,
bool
],
param
.
weight_loader
)
loaded_local_expert
=
False
for
expert_id
in
range
(
num_experts
):
curr_expert_weight
=
loaded_weight
[
expert_id
]
success
=
weight_loader
(
param
,
curr_expert_weight
,
name
,
shard_id
,
expert_id
,
return_success
=
True
,
)
if
success
:
loaded_local_expert
=
True
return
loaded_local_expert
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
self
,
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
if
hasattr
(
self
.
config
,
"num_experts"
)
else
0
,
)
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
is_fused_expert
=
False
fused_expert_params_mapping
=
[
(
"experts.w13_weight"
,
"experts.gate_up_proj"
,
0
,
"w1"
),
(
"experts.w2_weight"
,
"experts.down_proj"
,
0
,
"w2"
),
]
num_experts
=
(
self
.
config
.
num_experts
if
hasattr
(
self
.
config
,
"num_experts"
)
else
0
)
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
"experts.gate_up_proj"
in
name
or
"experts.down_proj"
in
name
:
is_fused_expert
=
True
expert_params_mapping
=
fused_expert_params_mapping
if
weight_name
not
in
name
:
continue
if
"mlp.experts"
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
is_expert_weight
=
False
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
is_expert_weight
=
True
name_mapped
=
name
.
replace
(
weight_name
,
param_name
)
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name_mapped
,
self
):
continue
if
is_fused_expert
:
# qwen3.5 no need to transpose
# loaded_weight = loaded_weight.transpose(-1, -2)
if
"experts.gate_up_proj"
in
name
:
loaded_weight
=
loaded_weight
.
chunk
(
2
,
dim
=-
2
)
success_w1
=
self
.
load_fused_expert_weights
(
name_mapped
,
params_dict
,
loaded_weight
[
0
],
"w1"
,
num_experts
,
)
success_w3
=
self
.
load_fused_expert_weights
(
name_mapped
,
params_dict
,
loaded_weight
[
1
],
"w3"
,
num_experts
,
)
success
=
success_w1
and
success_w3
else
:
# down_proj
success
=
self
.
load_fused_expert_weights
(
name_mapped
,
params_dict
,
loaded_weight
,
shard_id
,
num_experts
,
)
if
success
:
name
=
name_mapped
break
else
:
# Skip loading extra bias for GPTQ models.
if
(
name_mapped
.
endswith
(
".bias"
)
or
name_mapped
.
endswith
(
"_bias"
)
)
and
name_mapped
not
in
params_dict
:
continue
param
=
params_dict
[
name_mapped
]
weight_loader
=
param
.
weight_loader
success
=
weight_loader
(
param
,
loaded_weight
,
name_mapped
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
return_success
=
True
,
)
if
success
:
name
=
name_mapped
break
else
:
if
is_expert_weight
:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
if
name
not
in
params_dict
:
logger
.
warning_once
(
f
"Parameter
{
name
}
not found in params_dict, skip loading"
)
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
@
support_torch_compile
(
dynamic_arg_dims
=
{
"input_ids"
:
0
,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions"
:
-
1
,
"intermediate_tensors"
:
0
,
"inputs_embeds"
:
0
,
"hidden_states"
:
0
,
}
)
class
Qwen3_5MTP
(
nn
.
Module
,
SupportsMultiModal
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"up_proj"
,
"down_proj"
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
config
=
vllm_config
.
model_config
.
hf_text_config
self
.
vllm_config
=
vllm_config
cache_config
=
vllm_config
.
cache_config
if
cache_config
.
mamba_cache_mode
==
"all"
:
raise
NotImplementedError
(
"Qwen3_5MTP currently does not support 'all' prefix caching, "
"please use '--mamba-cache-mode=align' instead"
)
self
.
quant_config
=
vllm_config
.
quant_config
super
().
__init__
()
self
.
config
=
config
self
.
model
=
Qwen3_5MultiTokenPredictor
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"mtp"
)
)
if
get_pp_group
().
is_last_rank
:
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
handle_oov_mm_token
:
bool
=
False
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
_embed_text_input_ids
(
input_ids
,
self
.
model
.
embed_input_ids
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
if
multimodal_embeddings
is
None
or
len
(
multimodal_embeddings
)
==
0
:
return
inputs_embeds
is_multimodal
=
_require_is_multimodal
(
is_multimodal
)
inputs_embeds
=
_merge_multimodal_embeddings
(
inputs_embeds
=
inputs_embeds
,
multimodal_embeddings
=
multimodal_embeddings
,
is_multimodal
=
is_multimodal
,
)
return
inputs_embeds
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
:
object
,
):
hidden_states
=
self
.
model
(
input_ids
,
positions
,
hidden_states
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
spec_step_idx
:
int
=
0
,
)
->
torch
.
Tensor
|
None
:
return
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
def
remap_weight_names
(
weights
):
for
name
,
weight
in
weights
:
if
name
.
startswith
(
"mtp."
):
name
=
name
.
replace
(
"mtp."
,
"model."
)
elif
any
(
key
in
name
for
key
in
[
"embed_tokens"
,
"lm_head"
]):
if
"embed_tokens"
in
name
:
name
=
name
.
replace
(
"language_model."
,
""
)
else
:
continue
yield
name
,
weight
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
remap_weight_names
(
weights
))
class
Qwen3_5MoeMTP
(
Qwen3_5MTP
,
QwenNextMixtureOfExperts
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
self
.
set_moe_parameters
()
\ No newline at end of file
vllm/model_executor/models/qwen3_next.py
View file @
35006c0f
...
@@ -106,7 +106,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
...
@@ -106,7 +106,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_
text_
config
parallel_config
=
vllm_config
.
parallel_config
parallel_config
=
vllm_config
.
parallel_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
...
@@ -177,7 +177,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
...
@@ -177,7 +177,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
renormalize
=
getattr
(
config
,
"
norm_topk_prob
"
,
True
)
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
,
prefix
=
f
"
{
prefix
}
.experts"
,
enable_eplb
=
self
.
enable_eplb
,
enable_eplb
=
self
.
enable_eplb
,
...
@@ -970,7 +970,7 @@ class Qwen3NextModel(nn.Module):
...
@@ -970,7 +970,7 @@ class Qwen3NextModel(nn.Module):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
:
Qwen3NextConfig
=
vllm_config
.
model_config
.
hf_config
config
:
Qwen3NextConfig
=
vllm_config
.
model_config
.
hf_
text_
config
parallel_config
=
vllm_config
.
parallel_config
parallel_config
=
vllm_config
.
parallel_config
eplb_config
=
parallel_config
.
eplb_config
eplb_config
=
parallel_config
.
eplb_config
...
@@ -1047,7 +1047,7 @@ class Qwen3NextModel(nn.Module):
...
@@ -1047,7 +1047,7 @@ class Qwen3NextModel(nn.Module):
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
,
num_experts
=
getattr
(
self
.
config
,
"
num_experts
"
,
0
)
,
num_redundant_experts
=
self
.
num_redundant_experts
,
num_redundant_experts
=
self
.
num_redundant_experts
,
)
)
...
@@ -1206,7 +1206,7 @@ class Qwen3NextForCausalLM(
...
@@ -1206,7 +1206,7 @@ class Qwen3NextForCausalLM(
}
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_
text_
config
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
self
.
model_config
=
vllm_config
.
model_config
cache_config
=
vllm_config
.
cache_config
cache_config
=
vllm_config
.
cache_config
...
@@ -1270,7 +1270,7 @@ class Qwen3NextForCausalLM(
...
@@ -1270,7 +1270,7 @@ class Qwen3NextForCausalLM(
cls
,
vllm_config
:
"VllmConfig"
cls
,
vllm_config
:
"VllmConfig"
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
]]:
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
]]:
parallel_config
=
vllm_config
.
parallel_config
parallel_config
=
vllm_config
.
parallel_config
hf_config
=
vllm_config
.
model_config
.
hf_config
hf_config
=
vllm_config
.
model_config
.
hf_
text_
config
tp_size
=
parallel_config
.
tensor_parallel_size
tp_size
=
parallel_config
.
tensor_parallel_size
num_spec
=
(
num_spec
=
(
vllm_config
.
speculative_config
.
num_speculative_tokens
vllm_config
.
speculative_config
.
num_speculative_tokens
...
...
vllm/model_executor/models/registry.py
View file @
35006c0f
...
@@ -438,6 +438,14 @@ _MULTIMODAL_MODELS = {
...
@@ -438,6 +438,14 @@ _MULTIMODAL_MODELS = {
"qwen3_vl_moe"
,
"qwen3_vl_moe"
,
"Qwen3VLMoeForConditionalGeneration"
,
"Qwen3VLMoeForConditionalGeneration"
,
),
),
"Qwen3_5ForConditionalGeneration"
:
(
"qwen3_5"
,
"Qwen3_5ForConditionalGeneration"
,
),
"Qwen3_5MoeForConditionalGeneration"
:
(
"qwen3_5"
,
"Qwen3_5MoeForConditionalGeneration"
,
),
"SkyworkR1VChatModel"
:
(
"skyworkr1v"
,
"SkyworkR1VChatModel"
),
"SkyworkR1VChatModel"
:
(
"skyworkr1v"
,
"SkyworkR1VChatModel"
),
"Step3VLForConditionalGeneration"
:
(
"step3_vl"
,
"Step3VLForConditionalGeneration"
),
# noqa: E501
"Step3VLForConditionalGeneration"
:
(
"step3_vl"
,
"Step3VLForConditionalGeneration"
),
# noqa: E501
"TarsierForConditionalGeneration"
:
(
"tarsier"
,
"TarsierForConditionalGeneration"
),
# noqa: E501
"TarsierForConditionalGeneration"
:
(
"tarsier"
,
"TarsierForConditionalGeneration"
),
# noqa: E501
...
@@ -480,6 +488,8 @@ _SPECULATIVE_DECODING_MODELS = {
...
@@ -480,6 +488,8 @@ _SPECULATIVE_DECODING_MODELS = {
"OpenPanguMTPModel"
:
(
"openpangu_mtp"
,
"OpenPanguMTP"
),
"OpenPanguMTPModel"
:
(
"openpangu_mtp"
,
"OpenPanguMTP"
),
"Qwen3NextMTP"
:
(
"qwen3_next_mtp"
,
"Qwen3NextMTP"
),
"Qwen3NextMTP"
:
(
"qwen3_next_mtp"
,
"Qwen3NextMTP"
),
"Step3p5MTP"
:
(
"step3p5_mtp"
,
"Step3p5MTP"
),
"Step3p5MTP"
:
(
"step3p5_mtp"
,
"Step3p5MTP"
),
"Qwen3_5MTP"
:
(
"qwen3_5_mtp"
,
"Qwen3_5MTP"
),
"Qwen3_5MoeMTP"
:
(
"qwen3_5_mtp"
,
"Qwen3_5MoeMTP"
),
# Temporarily disabled.
# Temporarily disabled.
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
# "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
# "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
...
...
vllm/transformers_utils/configs/__init__.py
View file @
35006c0f
...
@@ -53,6 +53,10 @@ _CLASS_TO_MODULE: dict[str, str] = {
...
@@ -53,6 +53,10 @@ _CLASS_TO_MODULE: dict[str, str] = {
"Step3p5Config"
:
"vllm.transformers_utils.configs.step3p5"
,
"Step3p5Config"
:
"vllm.transformers_utils.configs.step3p5"
,
"Qwen3ASRConfig"
:
"vllm.transformers_utils.configs.qwen3_asr"
,
"Qwen3ASRConfig"
:
"vllm.transformers_utils.configs.qwen3_asr"
,
"Qwen3NextConfig"
:
"vllm.transformers_utils.configs.qwen3_next"
,
"Qwen3NextConfig"
:
"vllm.transformers_utils.configs.qwen3_next"
,
"Qwen3_5Config"
:
"vllm.transformers_utils.configs.qwen3_5"
,
"Qwen3_5TextConfig"
:
"vllm.transformers_utils.configs.qwen3_5"
,
"Qwen3_5MoeConfig"
:
"vllm.transformers_utils.configs.qwen3_5_moe"
,
"Qwen3_5MoeTextConfig"
:
"vllm.transformers_utils.configs.qwen3_5_moe"
,
"Tarsier2Config"
:
"vllm.transformers_utils.configs.tarsier2"
,
"Tarsier2Config"
:
"vllm.transformers_utils.configs.tarsier2"
,
# Special case: DeepseekV3Config is from HuggingFace Transformers
# Special case: DeepseekV3Config is from HuggingFace Transformers
"DeepseekV3Config"
:
"transformers"
,
"DeepseekV3Config"
:
"transformers"
,
...
@@ -95,6 +99,10 @@ __all__ = [
...
@@ -95,6 +99,10 @@ __all__ = [
"Step3p5Config"
,
"Step3p5Config"
,
"Qwen3ASRConfig"
,
"Qwen3ASRConfig"
,
"Qwen3NextConfig"
,
"Qwen3NextConfig"
,
"Qwen3_5Config"
,
"Qwen3_5TextConfig"
,
"Qwen3_5MoeConfig"
,
"Qwen3_5MoeTextConfig"
,
"Tarsier2Config"
,
"Tarsier2Config"
,
]
]
...
...
vllm/transformers_utils/configs/qwen3_5.py
0 → 100644
View file @
35006c0f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The Qwen Team and The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen3.5 model configuration"""
from
transformers.configuration_utils
import
PretrainedConfig
,
layer_type_validation
class
Qwen3_5TextConfig
(
PretrainedConfig
):
model_type
=
"qwen3_5_text"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
base_model_tp_plan
=
{
"layers.*.self_attn.q_proj"
:
"colwise"
,
"layers.*.self_attn.k_proj"
:
"colwise"
,
"layers.*.self_attn.v_proj"
:
"colwise"
,
"layers.*.self_attn.o_proj"
:
"rowwise"
,
"layers.*.mlp.gate_proj"
:
"colwise"
,
"layers.*.mlp.up_proj"
:
"colwise"
,
"layers.*.mlp.down_proj"
:
"rowwise"
,
}
base_model_pp_plan
=
{
"embed_tokens"
:
([
"input_ids"
],
[
"inputs_embeds"
]),
"layers"
:
([
"hidden_states"
,
"attention_mask"
],
[
"hidden_states"
]),
"norm"
:
([
"hidden_states"
],
[
"hidden_states"
]),
}
base_config_key
=
"text_config"
def
__init__
(
self
,
vocab_size
=
248320
,
hidden_size
=
4096
,
intermediate_size
=
12288
,
num_hidden_layers
=
32
,
num_attention_heads
=
16
,
num_key_value_heads
=
4
,
hidden_act
=
"silu"
,
max_position_embeddings
=
32768
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
use_cache
=
True
,
tie_word_embeddings
=
False
,
rope_parameters
=
None
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
head_dim
=
256
,
linear_conv_kernel_dim
=
4
,
linear_key_head_dim
=
128
,
linear_value_head_dim
=
128
,
linear_num_key_heads
=
16
,
linear_num_value_heads
=
32
,
layer_types
=
None
,
pad_token_id
=
None
,
bos_token_id
=
None
,
eos_token_id
=
None
,
**
kwargs
,
):
kwargs
[
"ignore_keys_at_rope_validation"
]
=
[
"mrope_section"
,
"mrope_interleaved"
,
]
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
head_dim
=
head_dim
self
.
rope_parameters
=
rope_parameters
kwargs
.
setdefault
(
"partial_rotary_factor"
,
0.25
)
self
.
layer_types
=
layer_types
if
self
.
layer_types
is
None
:
interval_pattern
=
kwargs
.
get
(
"full_attention_interval"
,
4
)
self
.
layer_types
=
[
"linear_attention"
if
bool
((
i
+
1
)
%
interval_pattern
)
else
"full_attention"
for
i
in
range
(
self
.
num_hidden_layers
)
]
layer_type_validation
(
self
.
layer_types
,
self
.
num_hidden_layers
)
# linear attention part
self
.
linear_conv_kernel_dim
=
linear_conv_kernel_dim
self
.
linear_key_head_dim
=
linear_key_head_dim
self
.
linear_value_head_dim
=
linear_value_head_dim
self
.
linear_num_key_heads
=
linear_num_key_heads
self
.
linear_num_value_heads
=
linear_num_value_heads
super
().
__init__
(
**
kwargs
)
# Set these AFTER super().__init__() because transformers v4's
# PretrainedConfig.__init__ has these as explicit params with different
# defaults (e.g. tie_word_embeddings=True) that would overwrite our values.
self
.
pad_token_id
=
pad_token_id
self
.
bos_token_id
=
bos_token_id
self
.
eos_token_id
=
eos_token_id
self
.
tie_word_embeddings
=
tie_word_embeddings
class
Qwen3_5VisionConfig
(
PretrainedConfig
):
model_type
=
"qwen3_5"
base_config_key
=
"vision_config"
def
__init__
(
self
,
depth
=
27
,
hidden_size
=
1152
,
hidden_act
=
"gelu_pytorch_tanh"
,
intermediate_size
=
4304
,
num_heads
=
16
,
in_channels
=
3
,
patch_size
=
16
,
spatial_merge_size
=
2
,
temporal_patch_size
=
2
,
out_hidden_size
=
3584
,
num_position_embeddings
=
2304
,
initializer_range
=
0.02
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
depth
=
depth
self
.
hidden_size
=
hidden_size
self
.
hidden_act
=
hidden_act
self
.
intermediate_size
=
intermediate_size
self
.
num_heads
=
num_heads
self
.
in_channels
=
in_channels
self
.
patch_size
=
patch_size
self
.
spatial_merge_size
=
spatial_merge_size
self
.
temporal_patch_size
=
temporal_patch_size
self
.
out_hidden_size
=
out_hidden_size
self
.
num_position_embeddings
=
num_position_embeddings
self
.
initializer_range
=
initializer_range
class
Qwen3_5Config
(
PretrainedConfig
):
model_type
=
"qwen3_5"
sub_configs
=
{
"vision_config"
:
Qwen3_5VisionConfig
,
"text_config"
:
Qwen3_5TextConfig
,
}
keys_to_ignore_at_inference
=
[
"past_key_values"
]
def
__init__
(
self
,
text_config
=
None
,
vision_config
=
None
,
image_token_id
=
248056
,
video_token_id
=
248057
,
vision_start_token_id
=
248053
,
vision_end_token_id
=
248054
,
tie_word_embeddings
=
False
,
**
kwargs
,
):
if
isinstance
(
vision_config
,
dict
):
self
.
vision_config
=
self
.
sub_configs
[
"vision_config"
](
**
vision_config
)
elif
vision_config
is
None
:
self
.
vision_config
=
self
.
sub_configs
[
"vision_config"
]()
if
isinstance
(
text_config
,
dict
):
self
.
text_config
=
self
.
sub_configs
[
"text_config"
](
**
text_config
)
elif
text_config
is
None
:
self
.
text_config
=
self
.
sub_configs
[
"text_config"
]()
self
.
image_token_id
=
image_token_id
self
.
video_token_id
=
video_token_id
self
.
vision_start_token_id
=
vision_start_token_id
self
.
vision_end_token_id
=
vision_end_token_id
super
().
__init__
(
**
kwargs
)
# Set after super().__init__() to avoid v4 PretrainedConfig overwrite
self
.
tie_word_embeddings
=
tie_word_embeddings
__all__
=
[
"Qwen3_5Config"
,
"Qwen3_5TextConfig"
]
vllm/transformers_utils/configs/qwen3_5_moe.py
0 → 100644
View file @
35006c0f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The Qwen Team and The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen3.5-MoE model configuration"""
from
transformers.configuration_utils
import
PretrainedConfig
,
layer_type_validation
class
Qwen3_5MoeTextConfig
(
PretrainedConfig
):
model_type
=
"qwen3_5_moe_text"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
base_model_tp_plan
=
{
"layers.*.self_attn.q_proj"
:
"colwise"
,
"layers.*.self_attn.k_proj"
:
"colwise"
,
"layers.*.self_attn.v_proj"
:
"colwise"
,
"layers.*.self_attn.o_proj"
:
"rowwise"
,
"layers.*.mlp.experts.gate_up_proj"
:
"packed_colwise"
,
"layers.*.mlp.experts.down_proj"
:
"rowwise"
,
"layers.*.mlp.shared_expert.gate_proj"
:
"colwise"
,
"layers.*.mlp.shared_expert.up_proj"
:
"colwise"
,
"layers.*.mlp.shared_expert.down_proj"
:
"rowwise"
,
}
base_model_pp_plan
=
{
"embed_tokens"
:
([
"input_ids"
],
[
"inputs_embeds"
]),
"layers"
:
([
"hidden_states"
,
"attention_mask"
],
[
"hidden_states"
]),
"norm"
:
([
"hidden_states"
],
[
"hidden_states"
]),
}
base_config_key
=
"text_config"
def
__init__
(
self
,
vocab_size
=
248320
,
hidden_size
=
2048
,
num_hidden_layers
=
40
,
num_attention_heads
=
16
,
num_key_value_heads
=
2
,
hidden_act
=
"silu"
,
max_position_embeddings
=
32768
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
use_cache
=
True
,
tie_word_embeddings
=
False
,
rope_parameters
=
None
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
head_dim
=
256
,
linear_conv_kernel_dim
=
4
,
linear_key_head_dim
=
128
,
linear_value_head_dim
=
128
,
linear_num_key_heads
=
16
,
linear_num_value_heads
=
32
,
moe_intermediate_size
=
512
,
shared_expert_intermediate_size
=
512
,
num_experts_per_tok
=
8
,
num_experts
=
256
,
output_router_logits
=
False
,
router_aux_loss_coef
=
0.001
,
layer_types
=
None
,
pad_token_id
=
None
,
bos_token_id
=
None
,
eos_token_id
=
None
,
**
kwargs
,
):
kwargs
[
"ignore_keys_at_rope_validation"
]
=
[
"mrope_section"
,
"mrope_interleaved"
,
]
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
head_dim
=
head_dim
self
.
rope_parameters
=
rope_parameters
kwargs
.
setdefault
(
"partial_rotary_factor"
,
0.25
)
self
.
layer_types
=
layer_types
if
self
.
layer_types
is
None
:
interval_pattern
=
kwargs
.
get
(
"full_attention_interval"
,
4
)
self
.
layer_types
=
[
"linear_attention"
if
bool
((
i
+
1
)
%
interval_pattern
)
else
"full_attention"
for
i
in
range
(
self
.
num_hidden_layers
)
]
layer_type_validation
(
self
.
layer_types
,
self
.
num_hidden_layers
)
# linear attention part
self
.
linear_conv_kernel_dim
=
linear_conv_kernel_dim
self
.
linear_key_head_dim
=
linear_key_head_dim
self
.
linear_value_head_dim
=
linear_value_head_dim
self
.
linear_num_key_heads
=
linear_num_key_heads
self
.
linear_num_value_heads
=
linear_num_value_heads
self
.
moe_intermediate_size
=
moe_intermediate_size
self
.
shared_expert_intermediate_size
=
shared_expert_intermediate_size
self
.
num_experts_per_tok
=
num_experts_per_tok
self
.
num_experts
=
num_experts
self
.
output_router_logits
=
output_router_logits
self
.
router_aux_loss_coef
=
router_aux_loss_coef
super
().
__init__
(
**
kwargs
)
# Set these AFTER super().__init__() because transformers v4's
# PretrainedConfig.__init__ has these as explicit params with different
# defaults (e.g. tie_word_embeddings=True) that would overwrite our values.
self
.
pad_token_id
=
pad_token_id
self
.
bos_token_id
=
bos_token_id
self
.
eos_token_id
=
eos_token_id
self
.
tie_word_embeddings
=
tie_word_embeddings
class
Qwen3_5MoeVisionConfig
(
PretrainedConfig
):
model_type
=
"qwen3_5_moe"
base_config_key
=
"vision_config"
def
__init__
(
self
,
depth
=
27
,
hidden_size
=
1152
,
hidden_act
=
"gelu_pytorch_tanh"
,
intermediate_size
=
4304
,
num_heads
=
16
,
in_channels
=
3
,
patch_size
=
16
,
spatial_merge_size
=
2
,
temporal_patch_size
=
2
,
out_hidden_size
=
3584
,
num_position_embeddings
=
2304
,
initializer_range
=
0.02
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
depth
=
depth
self
.
hidden_size
=
hidden_size
self
.
hidden_act
=
hidden_act
self
.
intermediate_size
=
intermediate_size
self
.
num_heads
=
num_heads
self
.
in_channels
=
in_channels
self
.
patch_size
=
patch_size
self
.
spatial_merge_size
=
spatial_merge_size
self
.
temporal_patch_size
=
temporal_patch_size
self
.
out_hidden_size
=
out_hidden_size
self
.
num_position_embeddings
=
num_position_embeddings
self
.
initializer_range
=
initializer_range
class
Qwen3_5MoeConfig
(
PretrainedConfig
):
model_type
=
"qwen3_5_moe"
sub_configs
=
{
"vision_config"
:
Qwen3_5MoeVisionConfig
,
"text_config"
:
Qwen3_5MoeTextConfig
,
}
keys_to_ignore_at_inference
=
[
"past_key_values"
]
def
__init__
(
self
,
text_config
=
None
,
vision_config
=
None
,
image_token_id
=
248056
,
video_token_id
=
248057
,
vision_start_token_id
=
248053
,
vision_end_token_id
=
248054
,
tie_word_embeddings
=
False
,
**
kwargs
,
):
if
isinstance
(
vision_config
,
dict
):
self
.
vision_config
=
self
.
sub_configs
[
"vision_config"
](
**
vision_config
)
elif
vision_config
is
None
:
self
.
vision_config
=
self
.
sub_configs
[
"vision_config"
]()
if
isinstance
(
text_config
,
dict
):
self
.
text_config
=
self
.
sub_configs
[
"text_config"
](
**
text_config
)
elif
text_config
is
None
:
self
.
text_config
=
self
.
sub_configs
[
"text_config"
]()
self
.
image_token_id
=
image_token_id
self
.
video_token_id
=
video_token_id
self
.
vision_start_token_id
=
vision_start_token_id
self
.
vision_end_token_id
=
vision_end_token_id
super
().
__init__
(
**
kwargs
)
# Set after super().__init__() to avoid v4 PretrainedConfig overwrite
self
.
tie_word_embeddings
=
tie_word_embeddings
__all__
=
[
"Qwen3_5MoeConfig"
,
"Qwen3_5MoeTextConfig"
]
vllm/transformers_utils/model_arch_config_convertor.py
View file @
35006c0f
...
@@ -371,6 +371,11 @@ class Qwen3NextMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
...
@@ -371,6 +371,11 @@ class Qwen3NextMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
return
getattr
(
self
.
hf_text_config
,
"num_nextn_predict_layers"
,
0
)
return
getattr
(
self
.
hf_text_config
,
"num_nextn_predict_layers"
,
0
)
class
Qwen3_5MTPModelArchConfigConvertor
(
ModelArchConfigConvertorBase
):
def
get_num_hidden_layers
(
self
)
->
int
:
return
getattr
(
self
.
hf_text_config
,
"mtp_num_hidden_layers"
,
0
)
class
PanguUltraMoeMTPModelArchConfigConvertor
(
ModelArchConfigConvertorBase
):
class
PanguUltraMoeMTPModelArchConfigConvertor
(
ModelArchConfigConvertorBase
):
def
get_num_hidden_layers
(
self
)
->
int
:
def
get_num_hidden_layers
(
self
)
->
int
:
return
getattr
(
self
.
hf_text_config
,
"num_nextn_predict_layers"
,
0
)
return
getattr
(
self
.
hf_text_config
,
"num_nextn_predict_layers"
,
0
)
...
@@ -396,6 +401,7 @@ MODEL_ARCH_CONFIG_CONVERTORS = {
...
@@ -396,6 +401,7 @@ MODEL_ARCH_CONFIG_CONVERTORS = {
"nemotron-nas"
:
NemotronNasModelArchConfigConvertor
,
"nemotron-nas"
:
NemotronNasModelArchConfigConvertor
,
"deepseek_mtp"
:
DeepSeekMTPModelArchConfigConvertor
,
"deepseek_mtp"
:
DeepSeekMTPModelArchConfigConvertor
,
"qwen3_next_mtp"
:
Qwen3NextMTPModelArchConfigConvertor
,
"qwen3_next_mtp"
:
Qwen3NextMTPModelArchConfigConvertor
,
"qwen3_5_mtp"
:
Qwen3_5MTPModelArchConfigConvertor
,
"mimo_mtp"
:
MimoMTPModelArchConfigConvertor
,
"mimo_mtp"
:
MimoMTPModelArchConfigConvertor
,
"glm4_moe_mtp"
:
GLM4MoeMTPModelArchConfigConvertor
,
"glm4_moe_mtp"
:
GLM4MoeMTPModelArchConfigConvertor
,
"ernie_mtp"
:
ErnieMTPModelArchConfigConvertor
,
"ernie_mtp"
:
ErnieMTPModelArchConfigConvertor
,
...
...
vllm/v1/spec_decode/eagle.py
View file @
35006c0f
...
@@ -1147,6 +1147,8 @@ class SpecDecodeBaseProposer:
...
@@ -1147,6 +1147,8 @@ class SpecDecodeBaseProposer:
"Qwen3VLForConditionalGeneration"
,
"Qwen3VLForConditionalGeneration"
,
"Qwen3VLMoeForConditionalGeneration"
,
"Qwen3VLMoeForConditionalGeneration"
,
"GlmOcrForConditionalGeneration"
,
"GlmOcrForConditionalGeneration"
,
"Qwen3_5ForConditionalGeneration"
,
"Qwen3_5MoeForConditionalGeneration"
,
]:
]:
self
.
model
.
config
.
image_token_index
=
target_model
.
config
.
image_token_id
self
.
model
.
config
.
image_token_index
=
target_model
.
config
.
image_token_id
elif
self
.
get_model_name
(
target_model
)
==
"PixtralForConditionalGeneration"
:
elif
self
.
get_model_name
(
target_model
)
==
"PixtralForConditionalGeneration"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment