Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ca4eb82b
Unverified
Commit
ca4eb82b
authored
Jul 18, 2025
by
wang.yuqi
Committed by
GitHub
Jul 18, 2025
Browse files
[Model] Re-add the implicit conversion feature for as_seq_cls_model (#21103)
Signed-off-by:
wang.yuqi
<
noooop@126.com
>
parent
ba2dfbb0
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
165 additions
and
75 deletions
+165
-75
tests/models/registry.py
tests/models/registry.py
+20
-12
tests/models/test_initialization.py
tests/models/test_initialization.py
+21
-8
tests/models/test_transformers.py
tests/models/test_transformers.py
+35
-0
vllm/config.py
vllm/config.py
+25
-21
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+26
-4
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+9
-6
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+0
-4
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+0
-4
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+0
-4
vllm/model_executor/models/qwen3.py
vllm/model_executor/models/qwen3.py
+0
-4
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+29
-8
No files found.
tests/models/registry.py
View file @
ca4eb82b
...
@@ -265,7 +265,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -265,7 +265,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Qwen2MoeForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen1.5-MoE-A2.7B-Chat"
),
"Qwen2MoeForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen1.5-MoE-A2.7B-Chat"
),
"Qwen3ForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen3-8B"
),
"Qwen3ForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen3-8B"
),
"Qwen3MoeForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen3-30B-A3B"
),
"Qwen3MoeForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen3-30B-A3B"
),
"Qwen3ForSequenceClassification"
:
_HfExamplesInfo
(
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
),
# noqa: E501
"RWForCausalLM"
:
_HfExamplesInfo
(
"tiiuae/falcon-40b"
),
"RWForCausalLM"
:
_HfExamplesInfo
(
"tiiuae/falcon-40b"
),
"StableLMEpochForCausalLM"
:
_HfExamplesInfo
(
"stabilityai/stablelm-zephyr-3b"
),
# noqa: E501
"StableLMEpochForCausalLM"
:
_HfExamplesInfo
(
"stabilityai/stablelm-zephyr-3b"
),
# noqa: E501
"StableLmForCausalLM"
:
_HfExamplesInfo
(
"stabilityai/stablelm-3b-4e1t"
),
"StableLmForCausalLM"
:
_HfExamplesInfo
(
"stabilityai/stablelm-3b-4e1t"
),
...
@@ -292,7 +291,6 @@ _EMBEDDING_EXAMPLE_MODELS = {
...
@@ -292,7 +291,6 @@ _EMBEDDING_EXAMPLE_MODELS = {
# [Text-only]
# [Text-only]
"BertModel"
:
_HfExamplesInfo
(
"BAAI/bge-base-en-v1.5"
,
v0_only
=
True
),
"BertModel"
:
_HfExamplesInfo
(
"BAAI/bge-base-en-v1.5"
,
v0_only
=
True
),
"Gemma2Model"
:
_HfExamplesInfo
(
"BAAI/bge-multilingual-gemma2"
,
v0_only
=
True
),
# noqa: E501
"Gemma2Model"
:
_HfExamplesInfo
(
"BAAI/bge-multilingual-gemma2"
,
v0_only
=
True
),
# noqa: E501
"GPT2ForSequenceClassification"
:
_HfExamplesInfo
(
"nie3e/sentiment-polish-gpt2-small"
),
# noqa: E501
"GritLM"
:
_HfExamplesInfo
(
"parasail-ai/GritLM-7B-vllm"
),
"GritLM"
:
_HfExamplesInfo
(
"parasail-ai/GritLM-7B-vllm"
),
"GteModel"
:
_HfExamplesInfo
(
"Snowflake/snowflake-arctic-embed-m-v2.0"
,
"GteModel"
:
_HfExamplesInfo
(
"Snowflake/snowflake-arctic-embed-m-v2.0"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
...
@@ -311,7 +309,6 @@ _EMBEDDING_EXAMPLE_MODELS = {
...
@@ -311,7 +309,6 @@ _EMBEDDING_EXAMPLE_MODELS = {
"Qwen2Model"
:
_HfExamplesInfo
(
"ssmits/Qwen2-7B-Instruct-embed-base"
),
"Qwen2Model"
:
_HfExamplesInfo
(
"ssmits/Qwen2-7B-Instruct-embed-base"
),
"Qwen2ForRewardModel"
:
_HfExamplesInfo
(
"Qwen/Qwen2.5-Math-RM-72B"
),
"Qwen2ForRewardModel"
:
_HfExamplesInfo
(
"Qwen/Qwen2.5-Math-RM-72B"
),
"Qwen2ForProcessRewardModel"
:
_HfExamplesInfo
(
"Qwen/Qwen2.5-Math-PRM-7B"
),
"Qwen2ForProcessRewardModel"
:
_HfExamplesInfo
(
"Qwen/Qwen2.5-Math-PRM-7B"
),
"Qwen2ForSequenceClassification"
:
_HfExamplesInfo
(
"jason9693/Qwen2.5-1.5B-apeach"
),
# noqa: E501
"RobertaModel"
:
_HfExamplesInfo
(
"sentence-transformers/stsb-roberta-base-v2"
,
v0_only
=
True
),
# noqa: E501
"RobertaModel"
:
_HfExamplesInfo
(
"sentence-transformers/stsb-roberta-base-v2"
,
v0_only
=
True
),
# noqa: E501
"RobertaForMaskedLM"
:
_HfExamplesInfo
(
"sentence-transformers/all-roberta-large-v1"
,
v0_only
=
True
),
# noqa: E501
"RobertaForMaskedLM"
:
_HfExamplesInfo
(
"sentence-transformers/all-roberta-large-v1"
,
v0_only
=
True
),
# noqa: E501
"XLMRobertaModel"
:
_HfExamplesInfo
(
"intfloat/multilingual-e5-small"
,
v0_only
=
True
),
# noqa: E501
"XLMRobertaModel"
:
_HfExamplesInfo
(
"intfloat/multilingual-e5-small"
,
v0_only
=
True
),
# noqa: E501
...
@@ -324,20 +321,29 @@ _EMBEDDING_EXAMPLE_MODELS = {
...
@@ -324,20 +321,29 @@ _EMBEDDING_EXAMPLE_MODELS = {
is_available_online
=
False
),
# noqa: E501
is_available_online
=
False
),
# noqa: E501
}
}
_CROSS_ENCODER_EXAMPLE_MODELS
=
{
_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS
=
{
# [Text-only]
# [Decoder-only]
"GPT2ForSequenceClassification"
:
_HfExamplesInfo
(
"nie3e/sentiment-polish-gpt2-small"
),
# noqa: E501
# [Cross-encoder]
"BertForSequenceClassification"
:
_HfExamplesInfo
(
"cross-encoder/ms-marco-MiniLM-L-6-v2"
,
v0_only
=
True
),
# noqa: E501
"BertForSequenceClassification"
:
_HfExamplesInfo
(
"cross-encoder/ms-marco-MiniLM-L-6-v2"
,
v0_only
=
True
),
# noqa: E501
"GemmaForSequenceClassification"
:
_HfExamplesInfo
(
"BAAI/bge-reranker-v2-gemma"
,
# noqa: E501
v0_only
=
True
,
hf_overrides
=
{
"architectures"
:
[
"GemmaForSequenceClassification"
],
# noqa: E501
"classifier_from_token"
:
[
"Yes"
],
# noqa: E501
"method"
:
"no_post_processing"
}),
# noqa: E501
"LlamaForSequenceClassification"
:
_HfExamplesInfo
(
"Skywork/Skywork-Reward-V2-Llama-3.2-1B"
),
# noqa: E501
"ModernBertForSequenceClassification"
:
_HfExamplesInfo
(
"Alibaba-NLP/gte-reranker-modernbert-base"
,
v0_only
=
True
),
# noqa: E501
"ModernBertForSequenceClassification"
:
_HfExamplesInfo
(
"Alibaba-NLP/gte-reranker-modernbert-base"
,
v0_only
=
True
),
# noqa: E501
"RobertaForSequenceClassification"
:
_HfExamplesInfo
(
"cross-encoder/quora-roberta-base"
,
v0_only
=
True
),
# noqa: E501
"RobertaForSequenceClassification"
:
_HfExamplesInfo
(
"cross-encoder/quora-roberta-base"
,
v0_only
=
True
),
# noqa: E501
"XLMRobertaForSequenceClassification"
:
_HfExamplesInfo
(
"BAAI/bge-reranker-v2-m3"
,
v0_only
=
True
),
# noqa: E501
"XLMRobertaForSequenceClassification"
:
_HfExamplesInfo
(
"BAAI/bge-reranker-v2-m3"
,
v0_only
=
True
),
# noqa: E501
}
}
_AUTOMATIC_CONVERTED_MODELS
=
{
# Use as_seq_cls_model for automatic conversion
"GemmaForSequenceClassification"
:
_HfExamplesInfo
(
"BAAI/bge-reranker-v2-gemma"
,
# noqa: E501
v0_only
=
True
,
hf_overrides
=
{
"architectures"
:
[
"GemmaForSequenceClassification"
],
# noqa: E501
"classifier_from_token"
:
[
"Yes"
],
# noqa: E501
"method"
:
"no_post_processing"
}),
# noqa: E501
"LlamaForSequenceClassification"
:
_HfExamplesInfo
(
"Skywork/Skywork-Reward-V2-Llama-3.2-1B"
),
# noqa: E501
"Qwen2ForSequenceClassification"
:
_HfExamplesInfo
(
"jason9693/Qwen2.5-1.5B-apeach"
),
# noqa: E501
"Qwen3ForSequenceClassification"
:
_HfExamplesInfo
(
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
),
# noqa: E501
}
_MULTIMODAL_EXAMPLE_MODELS
=
{
_MULTIMODAL_EXAMPLE_MODELS
=
{
# [Decoder-only]
# [Decoder-only]
"AriaForConditionalGeneration"
:
_HfExamplesInfo
(
"rhymes-ai/Aria"
),
"AriaForConditionalGeneration"
:
_HfExamplesInfo
(
"rhymes-ai/Aria"
),
...
@@ -449,6 +455,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -449,6 +455,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"JinaVLForRanking"
:
_HfExamplesInfo
(
"jinaai/jina-reranker-m0"
),
# noqa: E501
"JinaVLForRanking"
:
_HfExamplesInfo
(
"jinaai/jina-reranker-m0"
),
# noqa: E501
}
}
_SPECULATIVE_DECODING_EXAMPLE_MODELS
=
{
_SPECULATIVE_DECODING_EXAMPLE_MODELS
=
{
"EAGLEModel"
:
_HfExamplesInfo
(
"JackFram/llama-68m"
,
"EAGLEModel"
:
_HfExamplesInfo
(
"JackFram/llama-68m"
,
speculative_model
=
"abhigoyal/vllm-eagle-llama-68m-random"
),
# noqa: E501
speculative_model
=
"abhigoyal/vllm-eagle-llama-68m-random"
),
# noqa: E501
...
@@ -489,7 +496,7 @@ _TRANSFORMERS_MODELS = {
...
@@ -489,7 +496,7 @@ _TRANSFORMERS_MODELS = {
_EXAMPLE_MODELS
=
{
_EXAMPLE_MODELS
=
{
**
_TEXT_GENERATION_EXAMPLE_MODELS
,
**
_TEXT_GENERATION_EXAMPLE_MODELS
,
**
_EMBEDDING_EXAMPLE_MODELS
,
**
_EMBEDDING_EXAMPLE_MODELS
,
**
_
CROSS_ENCODER
_EXAMPLE_MODELS
,
**
_
SEQUENCE_CLASSIFICATION
_EXAMPLE_MODELS
,
**
_MULTIMODAL_EXAMPLE_MODELS
,
**
_MULTIMODAL_EXAMPLE_MODELS
,
**
_SPECULATIVE_DECODING_EXAMPLE_MODELS
,
**
_SPECULATIVE_DECODING_EXAMPLE_MODELS
,
**
_TRANSFORMERS_MODELS
,
**
_TRANSFORMERS_MODELS
,
...
@@ -522,3 +529,4 @@ class HfExampleModels:
...
@@ -522,3 +529,4 @@ class HfExampleModels:
HF_EXAMPLE_MODELS
=
HfExampleModels
(
_EXAMPLE_MODELS
)
HF_EXAMPLE_MODELS
=
HfExampleModels
(
_EXAMPLE_MODELS
)
AUTO_EXAMPLE_MODELS
=
HfExampleModels
(
_AUTOMATIC_CONVERTED_MODELS
)
tests/models/test_initialization.py
View file @
ca4eb82b
...
@@ -13,20 +13,21 @@ from vllm.v1.core.kv_cache_utils import get_kv_cache_config
...
@@ -13,20 +13,21 @@ from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from
vllm.v1.engine.core
import
EngineCore
as
V1EngineCore
from
vllm.v1.engine.core
import
EngineCore
as
V1EngineCore
from
..utils
import
create_new_process_for_each_test
from
..utils
import
create_new_process_for_each_test
from
.registry
import
HF_EXAMPLE_MODELS
from
.registry
import
AUTO_EXAMPLE_MODELS
,
HF_EXAMPLE_MODELS
,
HfExampleModels
@
pytest
.
mark
.
parametrize
(
"model_arch"
,
HF_EXAMPLE_MODELS
.
get_supported_archs
())
@
create_new_process_for_each_test
()
@
create_new_process_for_each_test
()
def
test_can_initialize
(
model_arch
:
str
,
monkeypatch
:
pytest
.
MonkeyPatch
):
def
can_initialize
(
model_arch
:
str
,
monkeypatch
:
pytest
.
MonkeyPatch
,
"""The reason for using create_new_process_for_each_test is to avoid
EXAMPLE_MODELS
:
HfExampleModels
):
the WARNING:
"""The reason for using create_new_process_for_each_test is to avoid
"We must use the 'spawn' multiprocessing start method. Overriding
the WARNING:
"We must use the 'spawn' multiprocessing start method. Overriding
VLLM_WORKER_MULTIPROC_METHOD to 'spawn'."
VLLM_WORKER_MULTIPROC_METHOD to 'spawn'."
The spawn process causes the _initialize_kv_caches_v1 function below to
The spawn process causes the _initialize_kv_caches_v1 function below to
become ineffective.
become ineffective.
"""
"""
model_info
=
HF_EXAMPLE_MODELS
.
get_hf_info
(
model_arch
)
model_info
=
EXAMPLE_MODELS
.
get_hf_info
(
model_arch
)
model_info
.
check_available_online
(
on_fail
=
"skip"
)
model_info
.
check_available_online
(
on_fail
=
"skip"
)
model_info
.
check_transformers_version
(
on_fail
=
"skip"
)
model_info
.
check_transformers_version
(
on_fail
=
"skip"
)
...
@@ -127,3 +128,15 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
...
@@ -127,3 +128,15 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
load_format
=
"dummy"
,
load_format
=
"dummy"
,
hf_overrides
=
hf_overrides
,
hf_overrides
=
hf_overrides
,
)
)
@
pytest
.
mark
.
parametrize
(
"model_arch"
,
HF_EXAMPLE_MODELS
.
get_supported_archs
())
def
test_can_initialize
(
model_arch
:
str
,
monkeypatch
:
pytest
.
MonkeyPatch
):
can_initialize
(
model_arch
,
monkeypatch
,
HF_EXAMPLE_MODELS
)
@
pytest
.
mark
.
parametrize
(
"model_arch"
,
AUTO_EXAMPLE_MODELS
.
get_supported_archs
())
def
test_implicit_converted_models
(
model_arch
:
str
,
monkeypatch
:
pytest
.
MonkeyPatch
):
can_initialize
(
model_arch
,
monkeypatch
,
AUTO_EXAMPLE_MODELS
)
tests/models/test_transformers.py
View file @
ca4eb82b
...
@@ -138,3 +138,38 @@ def test_quantization(
...
@@ -138,3 +138,38 @@ def test_quantization(
name_0
=
"transformers"
,
name_0
=
"transformers"
,
name_1
=
"vllm"
,
name_1
=
"vllm"
,
)
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"jason9693/Qwen2.5-1.5B-apeach"
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_classify
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
monkeypatch
,
)
->
None
:
import
torch
from
transformers
import
AutoModelForSequenceClassification
with
vllm_runner
(
model
,
max_model_len
=
512
,
dtype
=
dtype
,
model_impl
=
"transformers"
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
classify
(
example_prompts
)
with
hf_runner
(
model
,
dtype
=
dtype
,
auto_cls
=
AutoModelForSequenceClassification
)
as
hf_model
:
hf_outputs
=
hf_model
.
classify
(
example_prompts
)
for
hf_output
,
vllm_output
in
zip
(
hf_outputs
,
vllm_outputs
):
hf_output
=
torch
.
tensor
(
hf_output
)
vllm_output
=
torch
.
tensor
(
vllm_output
)
assert
torch
.
allclose
(
hf_output
,
vllm_output
,
1e-3
if
dtype
==
"float"
else
1e-2
)
vllm/config.py
View file @
ca4eb82b
...
@@ -551,7 +551,7 @@ class ModelConfig:
...
@@ -551,7 +551,7 @@ class ModelConfig:
# For pooling models, self.task is used to indicate the
# For pooling models, self.task is used to indicate the
# user-selected task
# user-selected task
if
self
.
task
==
"score"
:
if
self
.
task
==
"score"
:
if
self
.
registry
.
is_cross_encoder_model
(
self
.
architectures
):
if
self
.
_is_classify_task
(
self
.
architectures
):
self
.
task
=
"classify"
self
.
task
=
"classify"
else
:
else
:
self
.
task
=
"embed"
self
.
task
=
"embed"
...
@@ -806,6 +806,12 @@ class ModelConfig:
...
@@ -806,6 +806,12 @@ class ModelConfig:
f
"one of
{
get_args
(
TokenizerMode
)
}
."
)
f
"one of
{
get_args
(
TokenizerMode
)
}
."
)
self
.
tokenizer_mode
=
tokenizer_mode
self
.
tokenizer_mode
=
tokenizer_mode
def
_is_classify_task
(
self
,
architectures
:
list
[
str
]):
for
arch
in
architectures
:
if
arch
.
endswith
(
"ForSequenceClassification"
):
return
True
return
self
.
registry
.
is_cross_encoder_model
(
architectures
)
def
_get_preferred_pooling_task
(
def
_get_preferred_pooling_task
(
self
,
self
,
architectures
:
list
[
str
],
architectures
:
list
[
str
],
...
@@ -813,14 +819,11 @@ class ModelConfig:
...
@@ -813,14 +819,11 @@ class ModelConfig:
model_id
=
self
.
model
model_id
=
self
.
model
if
get_pooling_config
(
model_id
,
self
.
revision
):
if
get_pooling_config
(
model_id
,
self
.
revision
):
return
"embed"
return
"embed"
if
self
.
registry
.
is_cross_encoder_model
(
architectures
):
return
"classify"
if
self
.
registry
.
is_transcription_model
(
architectures
):
if
self
.
registry
.
is_transcription_model
(
architectures
):
return
"transcription"
return
"transcription"
suffix_to_preferred_task
:
list
[
tuple
[
str
,
_ResolvedTask
]]
=
[
suffix_to_preferred_task
:
list
[
tuple
[
str
,
_ResolvedTask
]]
=
[
# Other models follow this pattern
# Other models follow this pattern
(
"ForSequenceClassification"
,
"classify"
),
(
"EmbeddingModel"
,
"embed"
),
(
"EmbeddingModel"
,
"embed"
),
(
"RewardModel"
,
"reward"
),
(
"RewardModel"
,
"reward"
),
]
]
...
@@ -878,11 +881,14 @@ class ModelConfig:
...
@@ -878,11 +881,14 @@ class ModelConfig:
self
,
self
,
task_option
:
TaskOption
,
task_option
:
TaskOption
,
)
->
dict
[
RunnerType
,
list
[
_ResolvedTask
]]:
)
->
dict
[
RunnerType
,
list
[
_ResolvedTask
]]:
return
{
if
self
.
_is_classify_task
(
self
.
architectures
):
"generate"
:
self
.
_get_supported_generation_tasks
(
task_option
),
return
{
"generate"
:
[],
"pooling"
:
[
"classify"
],
"draft"
:
[]}
"pooling"
:
self
.
_get_supported_pooling_tasks
(
task_option
),
else
:
"draft"
:
[
"draft"
]
return
{
}
"generate"
:
self
.
_get_supported_generation_tasks
(
task_option
),
"pooling"
:
self
.
_get_supported_pooling_tasks
(
task_option
),
"draft"
:
[
"draft"
]
}
def
_get_supported_runner_types
(
def
_get_supported_runner_types
(
self
,
self
,
...
@@ -925,12 +931,16 @@ class ModelConfig:
...
@@ -925,12 +931,16 @@ class ModelConfig:
f
"Available tasks for runner=
{
task_runner
!
r
}
: "
f
"Available tasks for runner=
{
task_runner
!
r
}
: "
f
"
{
supported_tasks
[
task_runner
]
}
"
)
f
"
{
supported_tasks
[
task_runner
]
}
"
)
if
"classify"
in
supported_tasks
.
get
(
"pooling"
,
[]):
# When multiple pooling tasks are present, default to
# pooling (eg cross-encoder) for non-standard architectures.
return
"pooling"
suffix_to_preferred_runner
:
list
[
tuple
[
str
,
RunnerType
]]
=
[
suffix_to_preferred_runner
:
list
[
tuple
[
str
,
RunnerType
]]
=
[
(
"ForCausalLM"
,
"generate"
),
(
"ForCausalLM"
,
"generate"
),
(
"ForConditionalGeneration"
,
"generate"
),
(
"ForConditionalGeneration"
,
"generate"
),
(
"ChatModel"
,
"generate"
),
(
"ChatModel"
,
"generate"
),
(
"LMHeadModel"
,
"generate"
),
(
"LMHeadModel"
,
"generate"
),
(
"ForSequenceClassification"
,
"pooling"
),
(
"EmbeddingModel"
,
"pooling"
),
(
"EmbeddingModel"
,
"pooling"
),
(
"RewardModel"
,
"pooling"
),
(
"RewardModel"
,
"pooling"
),
]
]
...
@@ -940,10 +950,6 @@ class ModelConfig:
...
@@ -940,10 +950,6 @@ class ModelConfig:
if
arch
.
endswith
(
suffix
)
and
pref_runner
in
supported_runner_types
:
if
arch
.
endswith
(
suffix
)
and
pref_runner
in
supported_runner_types
:
return
pref_runner
return
pref_runner
if
"classify"
in
supported_tasks
.
get
(
"pooling"
,
[]):
# When multiple pooling tasks are present, default to
# pooling (eg cross-encoder) for non-standard architectures.
return
"pooling"
if
"generate"
in
supported_runner_types
:
if
"generate"
in
supported_runner_types
:
return
"generate"
return
"generate"
if
"pooling"
in
supported_runner_types
:
if
"pooling"
in
supported_runner_types
:
...
@@ -1525,7 +1531,7 @@ class ModelConfig:
...
@@ -1525,7 +1531,7 @@ class ModelConfig:
@
property
@
property
def
is_matryoshka
(
self
)
->
bool
:
def
is_matryoshka
(
self
)
->
bool
:
return
(
has
attr
(
self
.
hf_config
,
"matryoshka_dimensions"
)
return
(
bool
(
get
attr
(
self
.
hf_config
,
"matryoshka_dimensions"
,
None
)
)
or
getattr
(
self
.
hf_config
,
"is_matryoshka"
,
False
))
or
getattr
(
self
.
hf_config
,
"is_matryoshka"
,
False
))
@
property
@
property
...
@@ -1539,13 +1545,11 @@ class ModelConfig:
...
@@ -1539,13 +1545,11 @@ class ModelConfig:
return
getattr
(
self
.
hf_config
,
"use_pad_token"
,
True
)
return
getattr
(
self
.
hf_config
,
"use_pad_token"
,
True
)
def
get_and_verify_max_len
(
self
,
max_model_len
:
int
):
def
get_and_verify_max_len
(
self
,
max_model_len
:
int
):
# For pooling models, the tokenizer's `model_max_length` is often a
# Consider max_model_len in tokenizer_config only when
# reliable source for the maximum sequence length. However, for
# pooling models use absolute position_embedding.
# generative models, this can be incorrect and unduly limit the
# context window (e.g., DeepSeek-R1). Therefore, we only consider
# tokenizer_config for pooling models.
tokenizer_config
=
None
tokenizer_config
=
None
if
self
.
runner_type
==
"pooling"
:
if
(
self
.
runner_type
==
"pooling"
and
getattr
(
self
.
hf_config
,
"position_embedding_type"
,
""
)
==
"absolute"
):
tokenizer_config
=
try_get_tokenizer_config
(
tokenizer_config
=
try_get_tokenizer_config
(
self
.
tokenizer
,
self
.
tokenizer
,
trust_remote_code
=
self
.
trust_remote_code
,
trust_remote_code
=
self
.
trust_remote_code
,
...
...
vllm/model_executor/model_loader/utils.py
View file @
ca4eb82b
...
@@ -22,7 +22,8 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -22,7 +22,8 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models.adapters
import
(
as_embedding_model
,
from
vllm.model_executor.models.adapters
import
(
as_embedding_model
,
as_reward_model
)
as_reward_model
,
as_seq_cls_model
)
from
vllm.model_executor.models.interfaces
import
SupportsQuant
from
vllm.model_executor.models.interfaces
import
SupportsQuant
from
vllm.utils
import
is_pin_memory_available
from
vllm.utils
import
is_pin_memory_available
...
@@ -238,9 +239,29 @@ def get_model_architecture(
...
@@ -238,9 +239,29 @@ def get_model_architecture(
vllm_supported_archs
=
ModelRegistry
.
get_supported_archs
()
vllm_supported_archs
=
ModelRegistry
.
get_supported_archs
()
vllm_not_supported
=
not
any
(
arch
in
vllm_supported_archs
vllm_not_supported
=
not
any
(
arch
in
vllm_supported_archs
for
arch
in
architectures
)
for
arch
in
architectures
)
if
vllm_not_supported
:
# try automatic conversion in adapters.py
for
arch
in
architectures
:
if
not
arch
.
endswith
(
"ForSequenceClassification"
):
continue
assert
model_config
.
task
==
"classify"
causal_lm_arch
=
arch
.
replace
(
"ForSequenceClassification"
,
"ForCausalLM"
)
causal_lm_arch_vllm_supported
=
(
causal_lm_arch
in
vllm_supported_archs
)
if
not
causal_lm_arch_vllm_supported
:
continue
architectures
=
[
causal_lm_arch
]
vllm_not_supported
=
False
break
if
(
model_config
.
model_impl
==
ModelImpl
.
TRANSFORMERS
or
if
(
model_config
.
model_impl
==
ModelImpl
.
TRANSFORMERS
or
model_config
.
model_impl
!=
ModelImpl
.
VLLM
and
vllm_not_supported
):
model_config
.
model_impl
!=
ModelImpl
.
VLLM
and
vllm_not_supported
):
architectures
=
resolve_transformers_arch
(
model_config
,
architectures
)
architectures
=
resolve_transformers_arch
(
model_config
,
architectures
)
logger
.
debug_once
(
"Resolve transformers arch %s"
,
str
(
architectures
))
elif
(
model_config
.
quantization
is
not
None
elif
(
model_config
.
quantization
is
not
None
and
model_config
.
quantization
not
in
mixtral_supported
and
model_config
.
quantization
not
in
mixtral_supported
and
"MixtralForCausalLM"
in
architectures
):
and
"MixtralForCausalLM"
in
architectures
):
...
@@ -248,12 +269,13 @@ def get_model_architecture(
...
@@ -248,12 +269,13 @@ def get_model_architecture(
model_cls
,
arch
=
ModelRegistry
.
resolve_model_cls
(
architectures
)
model_cls
,
arch
=
ModelRegistry
.
resolve_model_cls
(
architectures
)
if
model_config
.
task
==
"embed"
:
if
model_config
.
task
==
"embed"
:
logger
.
debug_once
(
"Automatic conversion using `as_embedding_model`."
)
model_cls
=
as_embedding_model
(
model_cls
)
model_cls
=
as_embedding_model
(
model_cls
)
elif
model_config
.
task
==
"classify"
:
elif
model_config
.
task
==
"classify"
:
# Cannot automatically run as_seq_cls_model,
logger
.
debug_once
(
"Automatic conversion using `as_seq_cls_model`."
)
# otherwise it will cause a circular reference on is_cross_encoder_model
model_cls
=
as_seq_cls_model
(
model_cls
)
pass
elif
model_config
.
task
==
"reward"
:
elif
model_config
.
task
==
"reward"
:
logger
.
debug_once
(
"Automatic conversion using `as_reward_model`."
)
model_cls
=
as_reward_model
(
model_cls
)
model_cls
=
as_reward_model
(
model_cls
)
return
model_cls
,
arch
return
model_cls
,
arch
...
...
vllm/model_executor/models/adapters.py
View file @
ca4eb82b
...
@@ -331,13 +331,13 @@ def load_weights_using_from_2_way_softmax(
...
@@ -331,13 +331,13 @@ def load_weights_using_from_2_way_softmax(
false_id
=
tokenizer
.
convert_tokens_to_ids
(
tokens
[
0
])
false_id
=
tokenizer
.
convert_tokens_to_ids
(
tokens
[
0
])
true_id
=
tokenizer
.
convert_tokens_to_ids
(
tokens
[
1
])
true_id
=
tokenizer
.
convert_tokens_to_ids
(
tokens
[
1
])
weight
=
model
.
lm_head
.
weight
.
data
[[
true_id
]].
to
(
score_
weight
=
model
.
lm_head
.
weight
.
data
[[
true_id
]].
to
(
torch
.
float32
)
-
model
.
lm_head
.
weight
.
data
[[
false_id
]].
to
(
torch
.
float32
)
-
model
.
lm_head
.
weight
.
data
[[
false_id
]].
to
(
torch
.
float32
)
torch
.
float32
)
param
=
model
.
score
.
weight
param
=
model
.
score
.
weight
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
weight
)
weight_loader
(
param
,
score_
weight
)
del
model
.
lm_head
del
model
.
lm_head
loaded_weights
.
add
(
"score.weight"
)
loaded_weights
.
add
(
"score.weight"
)
...
@@ -350,6 +350,8 @@ def load_weights_no_post_processing(model,
...
@@ -350,6 +350,8 @@ def load_weights_no_post_processing(model,
torch
.
Tensor
]]):
torch
.
Tensor
]]):
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
)
ParallelLMHead
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
)
from
vllm.model_executor.models.utils
import
AutoWeightsLoader
from
vllm.model_executor.models.utils
import
AutoWeightsLoader
model_config
=
model
.
vllm_config
.
model_config
model_config
=
model
.
vllm_config
.
model_config
...
@@ -357,8 +359,6 @@ def load_weights_no_post_processing(model,
...
@@ -357,8 +359,6 @@ def load_weights_no_post_processing(model,
tokens
=
cast
(
list
[
int
],
tokens
)
tokens
=
cast
(
list
[
int
],
tokens
)
assert
len
(
tokens
)
>
0
assert
len
(
tokens
)
>
0
device
=
model
.
score
.
weight
.
device
if
model
.
config
.
tie_word_embeddings
:
if
model
.
config
.
tie_word_embeddings
:
model
.
lm_head
=
model
.
model
.
embed_tokens
model
.
lm_head
=
model
.
model
.
embed_tokens
else
:
else
:
...
@@ -376,8 +376,11 @@ def load_weights_no_post_processing(model,
...
@@ -376,8 +376,11 @@ def load_weights_no_post_processing(model,
trust_remote_code
=
model_config
.
trust_remote_code
)
trust_remote_code
=
model_config
.
trust_remote_code
)
token_ids
=
[
tokenizer
.
convert_tokens_to_ids
(
t
)
for
t
in
tokens
]
token_ids
=
[
tokenizer
.
convert_tokens_to_ids
(
t
)
for
t
in
tokens
]
score_weight
=
model
.
lm_head
.
weight
.
data
[
token_ids
].
to
(
device
)
score_weight
=
model
.
lm_head
.
weight
.
data
[
token_ids
]
model
.
score
.
weight
.
data
.
copy_
(
score_weight
)
param
=
model
.
score
.
weight
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
score_weight
)
del
model
.
lm_head
del
model
.
lm_head
loaded_weights
.
add
(
"score.weight"
)
loaded_weights
.
add
(
"score.weight"
)
...
...
vllm/model_executor/models/gemma.py
View file @
ca4eb82b
...
@@ -43,7 +43,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -43,7 +43,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.adapters
import
as_seq_cls_model
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
...
@@ -426,6 +425,3 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -426,6 +425,3 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if
self
.
config
.
tie_word_embeddings
else
None
),
if
self
.
config
.
tie_word_embeddings
else
None
),
)
)
return
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
GemmaForSequenceClassification
=
as_seq_cls_model
(
GemmaForCausalLM
)
vllm/model_executor/models/llama.py
View file @
ca4eb82b
...
@@ -49,7 +49,6 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -49,7 +49,6 @@ from vllm.model_executor.model_loader.weight_utils import (
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.adapters
import
as_seq_cls_model
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
extract_layer_index
,
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
extract_layer_index
,
is_pp_missing_parameter
,
is_pp_missing_parameter
,
...
@@ -646,6 +645,3 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -646,6 +645,3 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
name
=
name
.
replace
(
item
,
mapping
[
item
])
name
=
name
.
replace
(
item
,
mapping
[
item
])
return
name
,
loaded_weight
return
name
,
loaded_weight
LlamaForSequenceClassification
=
as_seq_cls_model
(
LlamaForCausalLM
)
vllm/model_executor/models/qwen2.py
View file @
ca4eb82b
...
@@ -50,7 +50,6 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -50,7 +50,6 @@ from vllm.model_executor.model_loader.weight_utils import (
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.adapters
import
as_seq_cls_model
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
extract_layer_index
,
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
extract_layer_index
,
is_pp_missing_parameter
,
is_pp_missing_parameter
,
...
@@ -496,6 +495,3 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -496,6 +495,3 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if
self
.
config
.
tie_word_embeddings
else
None
),
if
self
.
config
.
tie_word_embeddings
else
None
),
)
)
return
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
Qwen2ForSequenceClassification
=
as_seq_cls_model
(
Qwen2ForCausalLM
)
vllm/model_executor/models/qwen3.py
View file @
ca4eb82b
...
@@ -44,7 +44,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
...
@@ -44,7 +44,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.adapters
import
as_seq_cls_model
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.qwen2
import
Qwen2MLP
as
Qwen3MLP
from
.qwen2
import
Qwen2MLP
as
Qwen3MLP
from
.qwen2
import
Qwen2Model
from
.qwen2
import
Qwen2Model
...
@@ -320,6 +319,3 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -320,6 +319,3 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if
self
.
config
.
tie_word_embeddings
else
None
),
if
self
.
config
.
tie_word_embeddings
else
None
),
)
)
return
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
Qwen3ForSequenceClassification
=
as_seq_cls_model
(
Qwen3ForCausalLM
)
vllm/model_executor/models/registry.py
View file @
ca4eb82b
...
@@ -12,7 +12,7 @@ import sys
...
@@ -12,7 +12,7 @@ import sys
import
tempfile
import
tempfile
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Set
from
collections.abc
import
Set
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
asdict
,
dataclass
,
field
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Callable
,
Optional
,
TypeVar
,
Union
from
typing
import
Callable
,
Optional
,
TypeVar
,
Union
...
@@ -181,10 +181,6 @@ _CROSS_ENCODER_MODELS = {
...
@@ -181,10 +181,6 @@ _CROSS_ENCODER_MODELS = {
"ModernBertForSequenceClassification"
:
(
"modernbert"
,
"ModernBertForSequenceClassification"
:
(
"modernbert"
,
"ModernBertForSequenceClassification"
),
"ModernBertForSequenceClassification"
),
# [Auto-converted (see adapters.py)]
# [Auto-converted (see adapters.py)]
"GemmaForSequenceClassification"
:
(
"gemma"
,
"GemmaForSequenceClassification"
),
# noqa: E501
"Qwen2ForSequenceClassification"
:
(
"qwen2"
,
"Qwen2ForSequenceClassification"
),
# noqa: E501
"Qwen3ForSequenceClassification"
:
(
"qwen3"
,
"Qwen3ForSequenceClassification"
),
# noqa: E501
"LlamaForSequenceClassification"
:
(
"llama"
,
"LlamaForSequenceClassification"
),
# noqa: E501
"JinaVLForRanking"
:
(
"jina_vl"
,
"JinaVLForSequenceClassification"
),
# noqa: E501,
"JinaVLForRanking"
:
(
"jina_vl"
,
"JinaVLForSequenceClassification"
),
# noqa: E501,
}
}
...
@@ -462,10 +458,26 @@ class _ModelRegistry:
...
@@ -462,10 +458,26 @@ class _ModelRegistry:
return
_try_load_model_cls
(
model_arch
,
self
.
models
[
model_arch
])
return
_try_load_model_cls
(
model_arch
,
self
.
models
[
model_arch
])
def
_try_inspect_model_cls
(
self
,
model_arch
:
str
)
->
Optional
[
_ModelInfo
]:
def
_try_inspect_model_cls
(
self
,
model_arch
:
str
)
->
Optional
[
_ModelInfo
]:
if
model_arch
not
in
self
.
models
:
if
model_arch
in
self
.
models
:
return
None
return
_try_inspect_model_cls
(
model_arch
,
self
.
models
[
model_arch
])
if
model_arch
.
endswith
(
"ForSequenceClassification"
):
causal_lm_arch
=
model_arch
.
replace
(
"ForSequenceClassification"
,
"ForCausalLM"
)
if
causal_lm_arch
not
in
self
.
models
:
return
None
info
=
_try_inspect_model_cls
(
causal_lm_arch
,
self
.
models
[
causal_lm_arch
])
return
_try_inspect_model_cls
(
model_arch
,
self
.
models
[
model_arch
])
info
=
_ModelInfo
(
**
dict
(
asdict
(
info
),
**
{
"architecture"
:
model_arch
,
"supports_cross_encoding"
:
True
}))
return
info
return
None
def
_normalize_archs
(
def
_normalize_archs
(
self
,
self
,
...
@@ -480,6 +492,15 @@ class _ModelRegistry:
...
@@ -480,6 +492,15 @@ class _ModelRegistry:
normalized_arch
=
list
(
normalized_arch
=
list
(
filter
(
lambda
model
:
model
in
self
.
models
,
architectures
))
filter
(
lambda
model
:
model
in
self
.
models
,
architectures
))
# try automatic conversion in adapters.py
for
arch
in
architectures
:
if
not
arch
.
endswith
(
"ForSequenceClassification"
):
continue
causal_lm_arch
=
arch
.
replace
(
"ForSequenceClassification"
,
"ForCausalLM"
)
if
causal_lm_arch
in
self
.
models
:
normalized_arch
.
append
(
arch
)
# make sure Transformers backend is put at the last as a fallback
# make sure Transformers backend is put at the last as a fallback
if
len
(
normalized_arch
)
!=
len
(
architectures
):
if
len
(
normalized_arch
)
!=
len
(
architectures
):
normalized_arch
.
append
(
"TransformersForCausalLM"
)
normalized_arch
.
append
(
"TransformersForCausalLM"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment