Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
84cf78ac
Unverified
Commit
84cf78ac
authored
Aug 12, 2025
by
wang.yuqi
Committed by
GitHub
Aug 11, 2025
Browse files
[Model] Pooling models default to using chunked prefill & prefix caching if supported. (#20930)
Signed-off-by:
wang.yuqi
<
noooop@126.com
>
parent
16fb668b
Changes
31
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
393 additions
and
230 deletions
+393
-230
tests/entrypoints/llm/test_classify.py
tests/entrypoints/llm/test_classify.py
+6
-0
tests/entrypoints/openai/test_classification.py
tests/entrypoints/openai/test_classification.py
+15
-0
tests/models/language/pooling/mteb_utils.py
tests/models/language/pooling/mteb_utils.py
+10
-2
tests/models/language/pooling/test_auto_prefix_cache_support.py
...models/language/pooling/test_auto_prefix_cache_support.py
+93
-0
tests/models/language/pooling/test_baai.py
tests/models/language/pooling/test_baai.py
+61
-56
tests/models/language/pooling/test_bge_reranker_v2_gemma.py
tests/models/language/pooling/test_bge_reranker_v2_gemma.py
+4
-4
tests/models/language/pooling/test_cross_encoder.py
tests/models/language/pooling/test_cross_encoder.py
+7
-5
tests/models/language/pooling/test_gte.py
tests/models/language/pooling/test_gte.py
+44
-43
tests/models/language/pooling/test_intfloat.py
tests/models/language/pooling/test_intfloat.py
+22
-22
tests/models/language/pooling/test_jina.py
tests/models/language/pooling/test_jina.py
+8
-6
tests/models/language/pooling/test_mxbai_rerank.py
tests/models/language/pooling/test_mxbai_rerank.py
+8
-7
tests/models/language/pooling/test_nomic.py
tests/models/language/pooling/test_nomic.py
+14
-13
tests/models/language/pooling/test_qwen3_reranker.py
tests/models/language/pooling/test_qwen3_reranker.py
+8
-7
tests/models/language/pooling/test_snowflake_arctic_embed.py
tests/models/language/pooling/test_snowflake_arctic_embed.py
+34
-33
tests/models/utils.py
tests/models/utils.py
+18
-0
tests/test_config.py
tests/test_config.py
+14
-0
vllm/config/__init__.py
vllm/config/__init__.py
+8
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+4
-5
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+4
-0
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+11
-27
No files found.
tests/entrypoints/llm/test_classify.py
View file @
84cf78ac
...
...
@@ -65,3 +65,9 @@ def test_pooling_params(llm: LLM):
assert
torch
.
allclose
(
softmax
(
wo_activation
),
w_activation
,
atol
=
1e-2
),
"w_activation should be close to activation(wo_activation)."
def
test_encode_api
(
llm
:
LLM
):
err_msg
=
"pooling_task must be one of.+"
with
pytest
.
raises
(
ValueError
,
match
=
err_msg
):
llm
.
encode
(
prompts
,
use_tqdm
=
False
)
tests/entrypoints/openai/test_classification.py
View file @
84cf78ac
...
...
@@ -211,3 +211,18 @@ async def test_activation(server: RemoteOpenAIServer, model_name: str):
assert
torch
.
allclose
(
F
.
softmax
(
wo_activation
,
dim
=-
1
),
w_activation
,
atol
=
1e-2
),
"w_activation should be close to activation(wo_activation)."
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
def
test_pooling
(
server
:
RemoteOpenAIServer
,
model_name
:
str
):
# pooling api uses ALL pooling, which does not support chunked prefill.
response
=
requests
.
post
(
server
.
url_for
(
"pooling"
),
json
=
{
"model"
:
model_name
,
"input"
:
"test"
,
"encoding_format"
:
"float"
},
)
assert
response
.
json
()[
"error"
][
"type"
]
==
"BadRequestError"
tests/models/language/pooling/mteb_utils.py
View file @
84cf78ac
...
...
@@ -177,9 +177,12 @@ def mteb_test_embed_models(hf_runner,
max_model_len
=
None
,
**
vllm_extra_kwargs
)
as
vllm_model
:
model_config
=
vllm_model
.
llm
.
llm_engine
.
model_config
if
model_info
.
architecture
:
assert
(
model_info
.
architecture
in
vllm_model
.
llm
.
llm_engine
.
model_config
.
architectures
)
assert
model_info
.
architecture
in
model_config
.
architectures
assert
(
model_config
.
_model_info
.
default_pooling_type
==
model_info
.
default_pooling_type
)
vllm_main_score
=
run_mteb_embed_task
(
VllmMtebEncoder
(
vllm_model
),
MTEB_EMBED_TASKS
)
...
...
@@ -286,7 +289,12 @@ def mteb_test_rerank_models(hf_runner,
**
vllm_extra_kwargs
)
as
vllm_model
:
model_config
=
vllm_model
.
llm
.
llm_engine
.
model_config
if
model_info
.
architecture
:
assert
(
model_info
.
architecture
in
model_config
.
architectures
)
assert
model_config
.
hf_config
.
num_labels
==
1
assert
(
model_config
.
_model_info
.
default_pooling_type
==
model_info
.
default_pooling_type
)
vllm_main_score
=
run_mteb_rerank
(
vllm_mteb_encoder
(
vllm_model
),
tasks
=
MTEB_RERANK_TASKS
,
...
...
tests/models/language/pooling/test_auto_prefix_cache_support.py
0 → 100644
View file @
84cf78ac
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
from
transformers
import
AutoModelForSequenceClassification
from
tests.models.language.pooling.embed_utils
import
(
run_embedding_correctness_test
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"jason9693/Qwen2.5-1.5B-apeach"
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_classify_models
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
)
->
None
:
example_prompts
=
example_prompts
*
2
with
vllm_runner
(
model
,
max_model_len
=
512
,
dtype
=
dtype
,
enable_prefix_caching
=
True
)
as
vllm_model
:
cache_config
=
vllm_model
.
llm
.
llm_engine
.
cache_config
assert
cache_config
.
enable_prefix_caching
vllm_outputs
=
vllm_model
.
classify
(
example_prompts
)
with
hf_runner
(
model
,
dtype
=
dtype
,
auto_cls
=
AutoModelForSequenceClassification
)
as
hf_model
:
hf_outputs
=
hf_model
.
classify
(
example_prompts
)
for
hf_output
,
vllm_output
in
zip
(
hf_outputs
,
vllm_outputs
):
hf_output
=
torch
.
tensor
(
hf_output
)
vllm_output
=
torch
.
tensor
(
vllm_output
)
assert
torch
.
allclose
(
hf_output
,
vllm_output
,
1e-3
if
dtype
==
"float"
else
1e-2
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"Qwen/Qwen3-Embedding-0.6B"
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_embed_models
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
):
example_prompts
=
[
str
(
s
).
strip
()
for
s
in
example_prompts
]
*
2
with
vllm_runner
(
model
,
runner
=
"pooling"
,
max_model_len
=
None
,
enable_prefix_caching
=
True
,
)
as
vllm_model
:
cache_config
=
vllm_model
.
llm
.
llm_engine
.
cache_config
assert
cache_config
.
enable_prefix_caching
vllm_outputs
=
vllm_model
.
embed
(
example_prompts
)
with
hf_runner
(
model
,
is_sentence_transformer
=
True
,
)
as
hf_model
:
run_embedding_correctness_test
(
hf_model
,
example_prompts
,
vllm_outputs
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"intfloat/e5-small"
,
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
,
# is_causal == False
"papluca/xlm-roberta-base-language-detection"
,
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_non_causal_models
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
)
->
None
:
with
vllm_runner
(
model
,
max_model_len
=
512
,
dtype
=
dtype
,
enable_prefix_caching
=
True
)
as
vllm_model
:
cache_config
=
vllm_model
.
llm
.
llm_engine
.
cache_config
assert
not
cache_config
.
enable_prefix_caching
tests/models/language/pooling/test_baai.py
View file @
84cf78ac
...
...
@@ -2,57 +2,59 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
from
...utils
import
EmbedModelInfo
,
RerankModelInfo
from
...utils
import
(
CLSPoolingEmbedModelInfo
,
CLSPoolingRerankModelInfo
,
EmbedModelInfo
,
LASTPoolingEmbedModelInfo
,
RerankModelInfo
)
from
.embed_utils
import
correctness_test_embed_models
from
.mteb_utils
import
mteb_test_embed_models
,
mteb_test_rerank_models
MODELS
=
[
########## BertModel
EmbedModelInfo
(
"BAAI/bge-base-en"
,
CLSPooling
EmbedModelInfo
(
"BAAI/bge-base-en"
,
architecture
=
"BertModel"
,
enable_test
=
True
),
EmbedModelInfo
(
"BAAI/bge-base-zh"
,
CLSPooling
EmbedModelInfo
(
"BAAI/bge-base-zh"
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"BAAI/bge-small-en"
,
CLSPooling
EmbedModelInfo
(
"BAAI/bge-small-en"
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"BAAI/bge-small-zh"
,
CLSPooling
EmbedModelInfo
(
"BAAI/bge-small-zh"
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"BAAI/bge-large-en"
,
CLSPooling
EmbedModelInfo
(
"BAAI/bge-large-en"
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"BAAI/bge-large-zh"
,
CLSPooling
EmbedModelInfo
(
"BAAI/bge-large-zh"
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"BAAI/bge-large-zh-noinstruct"
,
CLSPooling
EmbedModelInfo
(
"BAAI/bge-large-zh-noinstruct"
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"BAAI/bge-base-en-v1.5"
,
CLSPooling
EmbedModelInfo
(
"BAAI/bge-base-en-v1.5"
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"BAAI/bge-base-zh-v1.5"
,
CLSPooling
EmbedModelInfo
(
"BAAI/bge-base-zh-v1.5"
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"BAAI/bge-small-en-v1.5"
,
CLSPooling
EmbedModelInfo
(
"BAAI/bge-small-en-v1.5"
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"BAAI/bge-small-zh-v1.5"
,
CLSPooling
EmbedModelInfo
(
"BAAI/bge-small-zh-v1.5"
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"BAAI/bge-large-en-v1.5"
,
CLSPooling
EmbedModelInfo
(
"BAAI/bge-large-en-v1.5"
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"BAAI/bge-large-zh-v1.5"
,
CLSPooling
EmbedModelInfo
(
"BAAI/bge-large-zh-v1.5"
,
architecture
=
"BertModel"
,
enable_test
=
False
),
########## XLMRobertaModel
EmbedModelInfo
(
"BAAI/bge-m3"
,
CLSPooling
EmbedModelInfo
(
"BAAI/bge-m3"
,
architecture
=
"XLMRobertaModel"
,
enable_test
=
True
),
########## Qwen2Model
EmbedModelInfo
(
"BAAI/bge-code-v1"
,
LASTPooling
EmbedModelInfo
(
"BAAI/bge-code-v1"
,
architecture
=
"Qwen2Model"
,
dtype
=
"float32"
,
enable_test
=
True
),
...
...
@@ -60,13 +62,16 @@ MODELS = [
RERANK_MODELS
=
[
########## XLMRobertaForSequenceClassification
RerankModelInfo
(
"BAAI/bge-reranker-base"
,
CLSPoolingRerankModelInfo
(
"BAAI/bge-reranker-base"
,
architecture
=
"XLMRobertaForSequenceClassification"
,
enable_test
=
True
),
RerankModelInfo
(
"BAAI/bge-reranker-large"
,
CLSPoolingRerankModelInfo
(
"BAAI/bge-reranker-large"
,
architecture
=
"XLMRobertaForSequenceClassification"
,
enable_test
=
False
),
RerankModelInfo
(
"BAAI/bge-reranker-v2-m3"
,
CLSPoolingRerankModelInfo
(
"BAAI/bge-reranker-v2-m3"
,
architecture
=
"XLMRobertaForSequenceClassification"
,
enable_test
=
False
)
]
...
...
tests/models/language/pooling/test_bge_reranker_v2_gemma.py
View file @
84cf78ac
...
...
@@ -8,11 +8,11 @@ import torch
from
tests.conftest
import
HfRunner
from
.
mteb_
utils
import
(
RerankModelInfo
,
VllmMtebEncoder
,
mteb_test_rerank_models
)
from
.
..
utils
import
LASTPooling
RerankModelInfo
,
RerankModelInfo
from
.mteb_utils
import
VllmMtebEncoder
,
mteb_test_rerank_models
RERANK_MODELS
=
[
RerankModelInfo
(
"BAAI/bge-reranker-v2-gemma"
,
LASTPooling
RerankModelInfo
(
"BAAI/bge-reranker-v2-gemma"
,
architecture
=
"GemmaForSequenceClassification"
),
]
...
...
tests/models/language/pooling/test_cross_encoder.py
View file @
84cf78ac
...
...
@@ -2,12 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
from
.mteb_utils
import
RerankModelInfo
,
mteb_test_rerank_models
from
...utils
import
(
CLSPoolingRerankModelInfo
,
LASTPoolingRerankModelInfo
,
RerankModelInfo
)
from
.mteb_utils
import
mteb_test_rerank_models
RERANK_MODELS
=
[
RerankModelInfo
(
"cross-encoder/ms-marco-TinyBERT-L-2-v2"
,
CLSPooling
RerankModelInfo
(
"cross-encoder/ms-marco-TinyBERT-L-2-v2"
,
architecture
=
"BertForSequenceClassification"
),
RerankModelInfo
(
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
,
LASTPooling
RerankModelInfo
(
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
,
architecture
=
"Qwen3ForSequenceClassification"
)
]
...
...
tests/models/language/pooling/test_gte.py
View file @
84cf78ac
...
...
@@ -4,54 +4,55 @@ from typing import Any
import
pytest
from
...utils
import
check_transformers_version
from
.embed_utils
import
EmbedModelInfo
,
correctness_test_embed_models
from
...utils
import
(
CLSPoolingEmbedModelInfo
,
EmbedModelInfo
,
LASTPoolingEmbedModelInfo
,
check_transformers_version
)
from
.embed_utils
import
correctness_test_embed_models
from
.mteb_utils
import
mteb_test_embed_models
MODELS
=
[
########## BertModel
EmbedModelInfo
(
"thenlper/gte-large"
,
CLSPooling
EmbedModelInfo
(
"thenlper/gte-large"
,
architecture
=
"BertModel"
,
enable_test
=
True
),
EmbedModelInfo
(
"thenlper/gte-base"
,
CLSPooling
EmbedModelInfo
(
"thenlper/gte-base"
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"thenlper/gte-small"
,
CLSPooling
EmbedModelInfo
(
"thenlper/gte-small"
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"thenlper/gte-large-zh"
,
CLSPooling
EmbedModelInfo
(
"thenlper/gte-large-zh"
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"thenlper/gte-base-zh"
,
CLSPooling
EmbedModelInfo
(
"thenlper/gte-base-zh"
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"thenlper/gte-small-zh"
,
CLSPooling
EmbedModelInfo
(
"thenlper/gte-small-zh"
,
architecture
=
"BertModel"
,
enable_test
=
False
),
########### NewModel
EmbedModelInfo
(
"Alibaba-NLP/gte-multilingual-base"
,
CLSPooling
EmbedModelInfo
(
"Alibaba-NLP/gte-multilingual-base"
,
architecture
=
"GteNewModel"
,
enable_test
=
True
),
EmbedModelInfo
(
"Alibaba-NLP/gte-base-en-v1.5"
,
CLSPooling
EmbedModelInfo
(
"Alibaba-NLP/gte-base-en-v1.5"
,
architecture
=
"GteNewModel"
,
enable_test
=
True
),
EmbedModelInfo
(
"Alibaba-NLP/gte-large-en-v1.5"
,
CLSPooling
EmbedModelInfo
(
"Alibaba-NLP/gte-large-en-v1.5"
,
architecture
=
"GteNewModel"
,
enable_test
=
True
),
########### Qwen2ForCausalLM
EmbedModelInfo
(
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
,
LASTPooling
EmbedModelInfo
(
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
,
architecture
=
"Qwen2ForCausalLM"
,
enable_test
=
True
),
########## ModernBertModel
EmbedModelInfo
(
"Alibaba-NLP/gte-modernbert-base"
,
CLSPooling
EmbedModelInfo
(
"Alibaba-NLP/gte-modernbert-base"
,
architecture
=
"ModernBertModel"
,
enable_test
=
True
),
########## Qwen3ForCausalLM
EmbedModelInfo
(
"Qwen/Qwen3-Embedding-0.6B"
,
LASTPooling
EmbedModelInfo
(
"Qwen/Qwen3-Embedding-0.6B"
,
architecture
=
"Qwen3ForCausalLM"
,
dtype
=
"float32"
,
enable_test
=
True
),
EmbedModelInfo
(
"Qwen/Qwen3-Embedding-4B"
,
LASTPooling
EmbedModelInfo
(
"Qwen/Qwen3-Embedding-4B"
,
architecture
=
"Qwen3ForCausalLM"
,
dtype
=
"float32"
,
enable_test
=
False
),
...
...
tests/models/language/pooling/test_intfloat.py
View file @
84cf78ac
...
...
@@ -2,32 +2,32 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
from
...utils
import
EmbedModelInfo
from
...utils
import
CLSPoolingEmbedModelInfo
,
EmbedModelInfo
from
.embed_utils
import
correctness_test_embed_models
from
.mteb_utils
import
mteb_test_embed_models
MODELS
=
[
########## BertModel
EmbedModelInfo
(
"intfloat/e5-small"
,
CLSPooling
EmbedModelInfo
(
"intfloat/e5-small"
,
architecture
=
"BertModel"
,
enable_test
=
True
),
EmbedModelInfo
(
"intfloat/e5-base"
,
CLSPooling
EmbedModelInfo
(
"intfloat/e5-base"
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"intfloat/e5-large"
,
CLSPooling
EmbedModelInfo
(
"intfloat/e5-large"
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"intfloat/multilingual-e5-small"
,
CLSPooling
EmbedModelInfo
(
"intfloat/multilingual-e5-small"
,
architecture
=
"BertModel"
,
enable_test
=
False
),
########## XLMRobertaModel
EmbedModelInfo
(
"intfloat/multilingual-e5-base"
,
CLSPooling
EmbedModelInfo
(
"intfloat/multilingual-e5-base"
,
architecture
=
"XLMRobertaModel"
,
enable_test
=
True
),
EmbedModelInfo
(
"intfloat/multilingual-e5-large"
,
CLSPooling
EmbedModelInfo
(
"intfloat/multilingual-e5-large"
,
architecture
=
"XLMRobertaModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"intfloat/multilingual-e5-large-instruct"
,
CLSPooling
EmbedModelInfo
(
"intfloat/multilingual-e5-large-instruct"
,
architecture
=
"XLMRobertaModel"
,
enable_test
=
False
),
]
...
...
tests/models/language/pooling/test_jina.py
View file @
84cf78ac
...
...
@@ -6,19 +6,21 @@ import pytest
from
vllm
import
PoolingParams
from
...utils
import
EmbedModelInfo
,
RerankModelInfo
from
...utils
import
(
CLSPoolingEmbedModelInfo
,
CLSPoolingRerankModelInfo
,
EmbedModelInfo
,
RerankModelInfo
)
from
.embed_utils
import
(
check_embeddings_close
,
correctness_test_embed_models
,
matryoshka_fy
)
from
.mteb_utils
import
mteb_test_embed_models
,
mteb_test_rerank_models
EMBEDDING_MODELS
=
[
EmbedModelInfo
(
"jinaai/jina-embeddings-v3"
,
CLSPooling
EmbedModelInfo
(
"jinaai/jina-embeddings-v3"
,
architecture
=
"XLMRobertaModel"
,
is_matryoshka
=
True
)
]
RERANK_MODELS
=
[
RerankModelInfo
(
"jinaai/jina-reranker-v2-base-multilingual"
,
CLSPoolingRerankModelInfo
(
"jinaai/jina-reranker-v2-base-multilingual"
,
architecture
=
"XLMRobertaForSequenceClassification"
)
]
...
...
tests/models/language/pooling/test_mxbai_rerank.py
View file @
84cf78ac
...
...
@@ -7,13 +7,14 @@ import torch
from
tests.conftest
import
HfRunner
from
.mteb_utils
import
RerankModelInfo
,
mteb_test_rerank_models
from
...utils
import
LASTPoolingRerankModelInfo
,
RerankModelInfo
from
.mteb_utils
import
mteb_test_rerank_models
RERANK_MODELS
=
[
RerankModelInfo
(
"mixedbread-ai/mxbai-rerank-base-v2"
,
LASTPooling
RerankModelInfo
(
"mixedbread-ai/mxbai-rerank-base-v2"
,
architecture
=
"Qwen2ForSequenceClassification"
,
enable_test
=
True
),
RerankModelInfo
(
"mixedbread-ai/mxbai-rerank-large-v2"
,
LASTPooling
RerankModelInfo
(
"mixedbread-ai/mxbai-rerank-large-v2"
,
architecture
=
"Qwen2ForSequenceClassification"
,
enable_test
=
False
)
]
...
...
tests/models/language/pooling/test_nomic.py
View file @
84cf78ac
...
...
@@ -3,20 +3,21 @@
import
pytest
from
.embed_utils
import
EmbedModelInfo
,
correctness_test_embed_models
from
...utils
import
CLSPoolingEmbedModelInfo
,
EmbedModelInfo
from
.embed_utils
import
correctness_test_embed_models
from
.mteb_utils
import
mteb_test_embed_models
MODELS
=
[
EmbedModelInfo
(
"nomic-ai/nomic-embed-text-v1"
,
CLSPooling
EmbedModelInfo
(
"nomic-ai/nomic-embed-text-v1"
,
architecture
=
"NomicBertModel"
,
enable_test
=
True
),
EmbedModelInfo
(
"nomic-ai/nomic-embed-text-v1.5"
,
CLSPooling
EmbedModelInfo
(
"nomic-ai/nomic-embed-text-v1.5"
,
architecture
=
"NomicBertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"nomic-ai/CodeRankEmbed"
,
CLSPooling
EmbedModelInfo
(
"nomic-ai/CodeRankEmbed"
,
architecture
=
"NomicBertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"nomic-ai/nomic-embed-text-v2-moe"
,
CLSPooling
EmbedModelInfo
(
"nomic-ai/nomic-embed-text-v2-moe"
,
architecture
=
"NomicBertModel"
,
enable_test
=
True
)
]
...
...
tests/models/language/pooling/test_qwen3_reranker.py
View file @
84cf78ac
...
...
@@ -8,13 +8,14 @@ import torch
from
tests.conftest
import
HfRunner
from
tests.utils
import
multi_gpu_test
from
.mteb_utils
import
RerankModelInfo
,
mteb_test_rerank_models
from
...utils
import
LASTPoolingRerankModelInfo
,
RerankModelInfo
from
.mteb_utils
import
mteb_test_rerank_models
RERANK_MODELS
=
[
RerankModelInfo
(
"Qwen/Qwen3-Reranker-0.6B"
,
LASTPooling
RerankModelInfo
(
"Qwen/Qwen3-Reranker-0.6B"
,
architecture
=
"Qwen3ForSequenceClassification"
,
enable_test
=
True
),
RerankModelInfo
(
"Qwen/Qwen3-Reranker-4B"
,
LASTPooling
RerankModelInfo
(
"Qwen/Qwen3-Reranker-4B"
,
architecture
=
"Qwen3ForSequenceClassification"
,
enable_test
=
False
)
]
...
...
tests/models/language/pooling/test_snowflake_arctic_embed.py
View file @
84cf78ac
...
...
@@ -3,39 +3,40 @@
import
pytest
from
.embed_utils
import
EmbedModelInfo
,
correctness_test_embed_models
from
...utils
import
CLSPoolingEmbedModelInfo
,
EmbedModelInfo
from
.embed_utils
import
correctness_test_embed_models
from
.mteb_utils
import
mteb_test_embed_models
MODELS
=
[
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-xs"
,
CLSPooling
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-xs"
,
is_matryoshka
=
False
,
architecture
=
"BertModel"
,
enable_test
=
True
),
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-s"
,
CLSPooling
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-s"
,
is_matryoshka
=
False
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-m"
,
CLSPooling
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-m"
,
is_matryoshka
=
False
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-m-long"
,
CLSPooling
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-m-long"
,
is_matryoshka
=
False
,
architecture
=
"NomicBertModel"
,
enable_test
=
True
),
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-l"
,
CLSPooling
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-l"
,
is_matryoshka
=
False
,
architecture
=
"BertModel"
,
enable_test
=
False
),
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-m-v1.5"
,
CLSPooling
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-m-v1.5"
,
is_matryoshka
=
True
,
architecture
=
"BertModel"
,
enable_test
=
True
),
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-l-v2.0"
,
CLSPooling
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-l-v2.0"
,
is_matryoshka
=
True
,
architecture
=
"XLMRobertaModel"
,
enable_test
=
True
),
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-m-v2.0"
,
CLSPooling
EmbedModelInfo
(
"Snowflake/snowflake-arctic-embed-m-v2.0"
,
is_matryoshka
=
True
,
architecture
=
"GteModel"
,
enable_test
=
True
),
...
...
tests/models/utils.py
View file @
84cf78ac
...
...
@@ -345,16 +345,34 @@ class EmbedModelInfo(NamedTuple):
matryoshka_dimensions
:
Optional
[
list
[
int
]]
=
None
architecture
:
str
=
""
dtype
:
str
=
"auto"
default_pooling_type
:
str
=
""
enable_test
:
bool
=
True
class
CLSPoolingEmbedModelInfo
(
EmbedModelInfo
):
default_pooling_type
:
str
=
"CLS"
class
LASTPoolingEmbedModelInfo
(
EmbedModelInfo
):
default_pooling_type
:
str
=
"LAST"
class
RerankModelInfo
(
NamedTuple
):
name
:
str
architecture
:
str
=
""
dtype
:
str
=
"auto"
default_pooling_type
:
str
=
""
enable_test
:
bool
=
True
class
CLSPoolingRerankModelInfo
(
RerankModelInfo
):
default_pooling_type
:
str
=
"CLS"
class
LASTPoolingRerankModelInfo
(
RerankModelInfo
):
default_pooling_type
:
str
=
"LAST"
def
dummy_hf_overrides
(
hf_config
:
PretrainedConfig
,
*
,
...
...
tests/test_config.py
View file @
84cf78ac
...
...
@@ -227,6 +227,20 @@ def test_get_pooling_config_from_args():
assert
asdict
(
pooling_config
)
==
asdict
(
override_pooler_config
)
@
pytest
.
mark
.
parametrize
(
(
"model_id"
,
"default_pooling_type"
,
"pooling_type"
),
[
(
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
,
"LAST"
,
"LAST"
),
# LLM
(
"intfloat/e5-small"
,
"CLS"
,
"MEAN"
),
# BertModel
(
"Qwen/Qwen2.5-Math-RM-72B"
,
"ALL"
,
"ALL"
),
# reward
(
"Qwen/Qwen2.5-Math-PRM-7B"
,
"STEP"
,
"STEP"
)
# step reward
])
def
test_default_pooling_type
(
model_id
,
default_pooling_type
,
pooling_type
):
model_config
=
ModelConfig
(
model_id
)
assert
model_config
.
_model_info
.
default_pooling_type
==
default_pooling_type
assert
model_config
.
pooler_config
.
pooling_type
==
pooling_type
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Xformers backend is not supported on ROCm."
)
def
test_get_bert_tokenization_sentence_transformer_config
():
...
...
vllm/config/__init__.py
View file @
84cf78ac
...
...
@@ -871,6 +871,10 @@ class ModelConfig:
if
getattr
(
pooler_config
,
k
)
is
None
:
setattr
(
pooler_config
,
k
,
v
)
default_pooling_type
=
self
.
_model_info
.
default_pooling_type
if
pooler_config
.
pooling_type
is
None
:
pooler_config
.
pooling_type
=
default_pooling_type
return
pooler_config
return
None
...
...
@@ -3844,6 +3848,10 @@ class VllmConfig:
disable_chunked_prefill_reasons
.
append
(
"Only
\"
last
\"
pooling supports chunked "
"prefill and prefix caching; disabling both."
)
elif
not
getattr
(
self
.
model_config
.
hf_config
,
"is_causal"
,
True
):
disable_chunked_prefill_reasons
.
append
(
"Only models using causal attention supports chunked "
"prefill and prefix caching; disabling both."
)
if
disable_chunked_prefill_reasons
:
for
reason
in
disable_chunked_prefill_reasons
:
...
...
vllm/engine/arg_utils.py
View file @
84cf78ac
...
...
@@ -1600,11 +1600,10 @@ class EngineArgs:
else
:
pooling_type
=
model_config
.
pooler_config
.
pooling_type
# TODO: when encoder models are supported we'll have to
# check for causal attention here.
incremental_prefill_supported
=
(
pooling_type
is
not
None
and
pooling_type
.
lower
()
==
"last"
)
is_causal
=
getattr
(
model_config
.
hf_config
,
"is_causal"
,
True
)
incremental_prefill_supported
=
(
pooling_type
is
not
None
and
pooling_type
.
lower
()
==
"last"
and
is_causal
)
action
=
"Enabling"
if
\
incremental_prefill_supported
else
"Disabling"
...
...
vllm/entrypoints/llm.py
View file @
84cf78ac
...
...
@@ -1100,6 +1100,10 @@ class LLM:
"Try passing `--runner pooling` to use the model as a "
"pooling model."
)
if
pooling_task
not
in
self
.
supported_tasks
:
raise
ValueError
(
f
"pooling_task must be one of
{
self
.
supported_tasks
}
."
)
if
prompt_token_ids
is
not
None
:
parsed_prompts
=
self
.
_convert_v1_inputs
(
prompts
=
cast
(
Optional
[
Union
[
str
,
list
[
str
]]],
prompts
),
...
...
vllm/model_executor/layers/pooler.py
View file @
84cf78ac
...
...
@@ -44,15 +44,14 @@ class ResolvedPoolingConfig:
task
:
PoolingTask
@
classmethod
def
from_config
_with_defaults
(
def
from_config
(
cls
,
task
:
PoolingTask
,
pooler_config
:
PoolerConfig
,
pooling_type
:
PoolingType
,
)
->
"ResolvedPoolingConfig"
:
assert
pooler_config
.
pooling_type
is
not
None
return
cls
(
task
=
task
,
pooling_type
=
PoolingType
[
pooler_config
.
pooling_type
]
if
pooler_config
.
pooling_type
is
not
None
else
pooling_type
)
pooling_type
=
PoolingType
[
pooler_config
.
pooling_type
])
@
dataclass
(
frozen
=
True
)
...
...
@@ -68,32 +67,20 @@ class Pooler(nn.Module, ABC):
"""The interface required for all poolers used in pooling models in vLLM."""
@
staticmethod
def
for_encode
(
pooler_config
:
PoolerConfig
,
*
,
default_pooling_type
:
PoolingType
=
PoolingType
.
ALL
,
):
resolved_config
=
ResolvedPoolingConfig
.
from_config_with_defaults
(
task
=
"encode"
,
pooler_config
=
pooler_config
,
pooling_type
=
default_pooling_type
,
)
if
resolved_config
.
pooling_type
==
PoolingType
.
STEP
:
def
for_encode
(
pooler_config
:
PoolerConfig
):
if
pooler_config
.
pooling_type
==
"STEP"
:
return
StepPooler
()
resolved_config
=
ResolvedPoolingConfig
(
task
=
"encode"
,
pooling_type
=
PoolingType
.
ALL
)
return
SimplePooler
.
from_config
(
resolved_config
)
@
staticmethod
def
for_embed
(
pooler_config
:
PoolerConfig
,
*
,
default_pooling_type
:
PoolingType
=
PoolingType
.
LAST
,
):
resolved_config
=
ResolvedPoolingConfig
.
from_config_with_defaults
(
def
for_embed
(
pooler_config
:
PoolerConfig
):
resolved_config
=
ResolvedPoolingConfig
.
from_config
(
task
=
"embed"
,
pooler_config
=
pooler_config
,
pooling_type
=
default_pooling_type
,
)
return
SimplePooler
.
from_config
(
resolved_config
)
...
...
@@ -102,13 +89,10 @@ class Pooler(nn.Module, ABC):
def
for_classify
(
pooler_config
:
PoolerConfig
,
classifier
:
Optional
[
ClassifierFn
],
*
,
default_pooling_type
:
PoolingType
=
PoolingType
.
LAST
,
):
resolved_config
=
ResolvedPoolingConfig
.
from_config
_with_defaults
(
resolved_config
=
ResolvedPoolingConfig
.
from_config
(
task
=
"classify"
,
pooler_config
=
pooler_config
,
pooling_type
=
default_pooling_type
,
)
pooling
=
PoolingMethod
.
from_pooling_type
(
resolved_config
.
pooling_type
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment