Unverified Commit 84cf78ac authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Model] Pooling models default to using chunked prefill & prefix caching if supported. (#20930)


Signed-off-by: default avatarwang.yuqi <noooop@126.com>
parent 16fb668b
......@@ -65,3 +65,9 @@ def test_pooling_params(llm: LLM):
assert torch.allclose(
softmax(wo_activation), w_activation, atol=1e-2
), "w_activation should be close to activation(wo_activation)."
def test_encode_api(llm: LLM):
err_msg = "pooling_task must be one of.+"
with pytest.raises(ValueError, match=err_msg):
llm.encode(prompts, use_tqdm=False)
......@@ -211,3 +211,18 @@ async def test_activation(server: RemoteOpenAIServer, model_name: str):
assert torch.allclose(
F.softmax(wo_activation, dim=-1), w_activation, atol=1e-2
), "w_activation should be close to activation(wo_activation)."
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_pooling(server: RemoteOpenAIServer, model_name: str):
# pooling api uses ALL pooling, which does not support chunked prefill.
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": "test",
"encoding_format": "float"
},
)
assert response.json()["error"]["type"] == "BadRequestError"
......@@ -177,9 +177,12 @@ def mteb_test_embed_models(hf_runner,
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:
model_config = vllm_model.llm.llm_engine.model_config
if model_info.architecture:
assert (model_info.architecture
in vllm_model.llm.llm_engine.model_config.architectures)
assert model_info.architecture in model_config.architectures
assert (model_config._model_info.default_pooling_type ==
model_info.default_pooling_type)
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
MTEB_EMBED_TASKS)
......@@ -286,7 +289,12 @@ def mteb_test_rerank_models(hf_runner,
**vllm_extra_kwargs) as vllm_model:
model_config = vllm_model.llm.llm_engine.model_config
if model_info.architecture:
assert (model_info.architecture in model_config.architectures)
assert model_config.hf_config.num_labels == 1
assert (model_config._model_info.default_pooling_type ==
model_info.default_pooling_type)
vllm_main_score = run_mteb_rerank(vllm_mteb_encoder(vllm_model),
tasks=MTEB_RERANK_TASKS,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from transformers import AutoModelForSequenceClassification
from tests.models.language.pooling.embed_utils import (
run_embedding_correctness_test)
@pytest.mark.parametrize(
"model",
["jason9693/Qwen2.5-1.5B-apeach"],
)
@pytest.mark.parametrize("dtype", ["half"])
def test_classify_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
example_prompts = example_prompts * 2
with vllm_runner(model,
max_model_len=512,
dtype=dtype,
enable_prefix_caching=True) as vllm_model:
cache_config = vllm_model.llm.llm_engine.cache_config
assert cache_config.enable_prefix_caching
vllm_outputs = vllm_model.classify(example_prompts)
with hf_runner(model,
dtype=dtype,
auto_cls=AutoModelForSequenceClassification) as hf_model:
hf_outputs = hf_model.classify(example_prompts)
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output)
vllm_output = torch.tensor(vllm_output)
assert torch.allclose(hf_output, vllm_output,
1e-3 if dtype == "float" else 1e-2)
@pytest.mark.parametrize(
"model",
["Qwen/Qwen3-Embedding-0.6B"],
)
@pytest.mark.parametrize("dtype", ["half"])
def test_embed_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
):
example_prompts = [str(s).strip() for s in example_prompts] * 2
with vllm_runner(
model,
runner="pooling",
max_model_len=None,
enable_prefix_caching=True,
) as vllm_model:
cache_config = vllm_model.llm.llm_engine.cache_config
assert cache_config.enable_prefix_caching
vllm_outputs = vllm_model.embed(example_prompts)
with hf_runner(
model,
is_sentence_transformer=True,
) as hf_model:
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)
@pytest.mark.parametrize(
"model",
[
"intfloat/e5-small",
"Alibaba-NLP/gte-Qwen2-1.5B-instruct", # is_causal == False
"papluca/xlm-roberta-base-language-detection",
])
@pytest.mark.parametrize("dtype", ["half"])
def test_non_causal_models(hf_runner, vllm_runner, example_prompts, model: str,
dtype: str) -> None:
with vllm_runner(model,
max_model_len=512,
dtype=dtype,
enable_prefix_caching=True) as vllm_model:
cache_config = vllm_model.llm.llm_engine.cache_config
assert not cache_config.enable_prefix_caching
......@@ -2,57 +2,59 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from ...utils import EmbedModelInfo, RerankModelInfo
from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo,
EmbedModelInfo, LASTPoolingEmbedModelInfo,
RerankModelInfo)
from .embed_utils import correctness_test_embed_models
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
MODELS = [
########## BertModel
EmbedModelInfo("BAAI/bge-base-en",
CLSPoolingEmbedModelInfo("BAAI/bge-base-en",
architecture="BertModel",
enable_test=True),
EmbedModelInfo("BAAI/bge-base-zh",
CLSPoolingEmbedModelInfo("BAAI/bge-base-zh",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-small-en",
CLSPoolingEmbedModelInfo("BAAI/bge-small-en",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-small-zh",
CLSPoolingEmbedModelInfo("BAAI/bge-small-zh",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-large-en",
CLSPoolingEmbedModelInfo("BAAI/bge-large-en",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-large-zh",
CLSPoolingEmbedModelInfo("BAAI/bge-large-zh",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-large-zh-noinstruct",
CLSPoolingEmbedModelInfo("BAAI/bge-large-zh-noinstruct",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-base-en-v1.5",
CLSPoolingEmbedModelInfo("BAAI/bge-base-en-v1.5",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-base-zh-v1.5",
CLSPoolingEmbedModelInfo("BAAI/bge-base-zh-v1.5",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-small-en-v1.5",
CLSPoolingEmbedModelInfo("BAAI/bge-small-en-v1.5",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-small-zh-v1.5",
CLSPoolingEmbedModelInfo("BAAI/bge-small-zh-v1.5",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-large-en-v1.5",
CLSPoolingEmbedModelInfo("BAAI/bge-large-en-v1.5",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-large-zh-v1.5",
CLSPoolingEmbedModelInfo("BAAI/bge-large-zh-v1.5",
architecture="BertModel",
enable_test=False),
########## XLMRobertaModel
EmbedModelInfo("BAAI/bge-m3",
CLSPoolingEmbedModelInfo("BAAI/bge-m3",
architecture="XLMRobertaModel",
enable_test=True),
########## Qwen2Model
EmbedModelInfo("BAAI/bge-code-v1",
LASTPoolingEmbedModelInfo("BAAI/bge-code-v1",
architecture="Qwen2Model",
dtype="float32",
enable_test=True),
......@@ -60,13 +62,16 @@ MODELS = [
RERANK_MODELS = [
########## XLMRobertaForSequenceClassification
RerankModelInfo("BAAI/bge-reranker-base",
CLSPoolingRerankModelInfo(
"BAAI/bge-reranker-base",
architecture="XLMRobertaForSequenceClassification",
enable_test=True),
RerankModelInfo("BAAI/bge-reranker-large",
CLSPoolingRerankModelInfo(
"BAAI/bge-reranker-large",
architecture="XLMRobertaForSequenceClassification",
enable_test=False),
RerankModelInfo("BAAI/bge-reranker-v2-m3",
CLSPoolingRerankModelInfo(
"BAAI/bge-reranker-v2-m3",
architecture="XLMRobertaForSequenceClassification",
enable_test=False)
]
......
......@@ -8,11 +8,11 @@ import torch
from tests.conftest import HfRunner
from .mteb_utils import (RerankModelInfo, VllmMtebEncoder,
mteb_test_rerank_models)
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models
RERANK_MODELS = [
RerankModelInfo("BAAI/bge-reranker-v2-gemma",
LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma",
architecture="GemmaForSequenceClassification"),
]
......
......@@ -2,12 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
from ...utils import (CLSPoolingRerankModelInfo, LASTPoolingRerankModelInfo,
RerankModelInfo)
from .mteb_utils import mteb_test_rerank_models
RERANK_MODELS = [
RerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2",
CLSPoolingRerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2",
architecture="BertForSequenceClassification"),
RerankModelInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls",
LASTPoolingRerankModelInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls",
architecture="Qwen3ForSequenceClassification")
]
......
......@@ -4,54 +4,55 @@ from typing import Any
import pytest
from ...utils import check_transformers_version
from .embed_utils import EmbedModelInfo, correctness_test_embed_models
from ...utils import (CLSPoolingEmbedModelInfo, EmbedModelInfo,
LASTPoolingEmbedModelInfo, check_transformers_version)
from .embed_utils import correctness_test_embed_models
from .mteb_utils import mteb_test_embed_models
MODELS = [
########## BertModel
EmbedModelInfo("thenlper/gte-large",
CLSPoolingEmbedModelInfo("thenlper/gte-large",
architecture="BertModel",
enable_test=True),
EmbedModelInfo("thenlper/gte-base",
CLSPoolingEmbedModelInfo("thenlper/gte-base",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("thenlper/gte-small",
CLSPoolingEmbedModelInfo("thenlper/gte-small",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("thenlper/gte-large-zh",
CLSPoolingEmbedModelInfo("thenlper/gte-large-zh",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("thenlper/gte-base-zh",
CLSPoolingEmbedModelInfo("thenlper/gte-base-zh",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("thenlper/gte-small-zh",
CLSPoolingEmbedModelInfo("thenlper/gte-small-zh",
architecture="BertModel",
enable_test=False),
########### NewModel
EmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
architecture="GteNewModel",
enable_test=True),
EmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
architecture="GteNewModel",
enable_test=True),
EmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
architecture="GteNewModel",
enable_test=True),
########### Qwen2ForCausalLM
EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
architecture="Qwen2ForCausalLM",
enable_test=True),
########## ModernBertModel
EmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
architecture="ModernBertModel",
enable_test=True),
########## Qwen3ForCausalLM
EmbedModelInfo("Qwen/Qwen3-Embedding-0.6B",
LASTPoolingEmbedModelInfo("Qwen/Qwen3-Embedding-0.6B",
architecture="Qwen3ForCausalLM",
dtype="float32",
enable_test=True),
EmbedModelInfo("Qwen/Qwen3-Embedding-4B",
LASTPoolingEmbedModelInfo("Qwen/Qwen3-Embedding-4B",
architecture="Qwen3ForCausalLM",
dtype="float32",
enable_test=False),
......
......@@ -2,32 +2,32 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from ...utils import EmbedModelInfo
from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo
from .embed_utils import correctness_test_embed_models
from .mteb_utils import mteb_test_embed_models
MODELS = [
########## BertModel
EmbedModelInfo("intfloat/e5-small",
CLSPoolingEmbedModelInfo("intfloat/e5-small",
architecture="BertModel",
enable_test=True),
EmbedModelInfo("intfloat/e5-base",
CLSPoolingEmbedModelInfo("intfloat/e5-base",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("intfloat/e5-large",
CLSPoolingEmbedModelInfo("intfloat/e5-large",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("intfloat/multilingual-e5-small",
CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-small",
architecture="BertModel",
enable_test=False),
########## XLMRobertaModel
EmbedModelInfo("intfloat/multilingual-e5-base",
CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-base",
architecture="XLMRobertaModel",
enable_test=True),
EmbedModelInfo("intfloat/multilingual-e5-large",
CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-large",
architecture="XLMRobertaModel",
enable_test=False),
EmbedModelInfo("intfloat/multilingual-e5-large-instruct",
CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-large-instruct",
architecture="XLMRobertaModel",
enable_test=False),
]
......
......@@ -6,19 +6,21 @@ import pytest
from vllm import PoolingParams
from ...utils import EmbedModelInfo, RerankModelInfo
from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo,
EmbedModelInfo, RerankModelInfo)
from .embed_utils import (check_embeddings_close,
correctness_test_embed_models, matryoshka_fy)
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
EMBEDDING_MODELS = [
EmbedModelInfo("jinaai/jina-embeddings-v3",
CLSPoolingEmbedModelInfo("jinaai/jina-embeddings-v3",
architecture="XLMRobertaModel",
is_matryoshka=True)
]
RERANK_MODELS = [
RerankModelInfo("jinaai/jina-reranker-v2-base-multilingual",
CLSPoolingRerankModelInfo(
"jinaai/jina-reranker-v2-base-multilingual",
architecture="XLMRobertaForSequenceClassification")
]
......
......@@ -7,13 +7,14 @@ import torch
from tests.conftest import HfRunner
from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
from .mteb_utils import mteb_test_rerank_models
RERANK_MODELS = [
RerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2",
LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2",
architecture="Qwen2ForSequenceClassification",
enable_test=True),
RerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2",
LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2",
architecture="Qwen2ForSequenceClassification",
enable_test=False)
]
......
......@@ -3,20 +3,21 @@
import pytest
from .embed_utils import EmbedModelInfo, correctness_test_embed_models
from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo
from .embed_utils import correctness_test_embed_models
from .mteb_utils import mteb_test_embed_models
MODELS = [
EmbedModelInfo("nomic-ai/nomic-embed-text-v1",
CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v1",
architecture="NomicBertModel",
enable_test=True),
EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5",
CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v1.5",
architecture="NomicBertModel",
enable_test=False),
EmbedModelInfo("nomic-ai/CodeRankEmbed",
CLSPoolingEmbedModelInfo("nomic-ai/CodeRankEmbed",
architecture="NomicBertModel",
enable_test=False),
EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe",
CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe",
architecture="NomicBertModel",
enable_test=True)
]
......
......@@ -8,13 +8,14 @@ import torch
from tests.conftest import HfRunner
from tests.utils import multi_gpu_test
from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
from .mteb_utils import mteb_test_rerank_models
RERANK_MODELS = [
RerankModelInfo("Qwen/Qwen3-Reranker-0.6B",
LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-0.6B",
architecture="Qwen3ForSequenceClassification",
enable_test=True),
RerankModelInfo("Qwen/Qwen3-Reranker-4B",
LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-4B",
architecture="Qwen3ForSequenceClassification",
enable_test=False)
]
......
......@@ -3,39 +3,40 @@
import pytest
from .embed_utils import EmbedModelInfo, correctness_test_embed_models
from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo
from .embed_utils import correctness_test_embed_models
from .mteb_utils import mteb_test_embed_models
MODELS = [
EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs",
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-xs",
is_matryoshka=False,
architecture="BertModel",
enable_test=True),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-s",
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-s",
is_matryoshka=False,
architecture="BertModel",
enable_test=False),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m",
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m",
is_matryoshka=False,
architecture="BertModel",
enable_test=False),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long",
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long",
is_matryoshka=False,
architecture="NomicBertModel",
enable_test=True),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-l",
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-l",
is_matryoshka=False,
architecture="BertModel",
enable_test=False),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5",
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5",
is_matryoshka=True,
architecture="BertModel",
enable_test=True),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0",
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0",
is_matryoshka=True,
architecture="XLMRobertaModel",
enable_test=True),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
is_matryoshka=True,
architecture="GteModel",
enable_test=True),
......
......@@ -345,16 +345,34 @@ class EmbedModelInfo(NamedTuple):
matryoshka_dimensions: Optional[list[int]] = None
architecture: str = ""
dtype: str = "auto"
default_pooling_type: str = ""
enable_test: bool = True
class CLSPoolingEmbedModelInfo(EmbedModelInfo):
default_pooling_type: str = "CLS"
class LASTPoolingEmbedModelInfo(EmbedModelInfo):
default_pooling_type: str = "LAST"
class RerankModelInfo(NamedTuple):
name: str
architecture: str = ""
dtype: str = "auto"
default_pooling_type: str = ""
enable_test: bool = True
class CLSPoolingRerankModelInfo(RerankModelInfo):
default_pooling_type: str = "CLS"
class LASTPoolingRerankModelInfo(RerankModelInfo):
default_pooling_type: str = "LAST"
def dummy_hf_overrides(
hf_config: PretrainedConfig,
*,
......
......@@ -227,6 +227,20 @@ def test_get_pooling_config_from_args():
assert asdict(pooling_config) == asdict(override_pooler_config)
@pytest.mark.parametrize(
("model_id", "default_pooling_type", "pooling_type"),
[
("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "LAST", "LAST"), # LLM
("intfloat/e5-small", "CLS", "MEAN"), # BertModel
("Qwen/Qwen2.5-Math-RM-72B", "ALL", "ALL"), # reward
("Qwen/Qwen2.5-Math-PRM-7B", "STEP", "STEP") # step reward
])
def test_default_pooling_type(model_id, default_pooling_type, pooling_type):
model_config = ModelConfig(model_id)
assert model_config._model_info.default_pooling_type == default_pooling_type
assert model_config.pooler_config.pooling_type == pooling_type
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_get_bert_tokenization_sentence_transformer_config():
......
......@@ -871,6 +871,10 @@ class ModelConfig:
if getattr(pooler_config, k) is None:
setattr(pooler_config, k, v)
default_pooling_type = self._model_info.default_pooling_type
if pooler_config.pooling_type is None:
pooler_config.pooling_type = default_pooling_type
return pooler_config
return None
......@@ -3844,6 +3848,10 @@ class VllmConfig:
disable_chunked_prefill_reasons.append(
"Only \"last\" pooling supports chunked "
"prefill and prefix caching; disabling both.")
elif not getattr(self.model_config.hf_config, "is_causal", True):
disable_chunked_prefill_reasons.append(
"Only models using causal attention supports chunked "
"prefill and prefix caching; disabling both.")
if disable_chunked_prefill_reasons:
for reason in disable_chunked_prefill_reasons:
......
......@@ -1600,11 +1600,10 @@ class EngineArgs:
else:
pooling_type = model_config.pooler_config.pooling_type
# TODO: when encoder models are supported we'll have to
# check for causal attention here.
incremental_prefill_supported = (pooling_type is not None and
pooling_type.lower() == "last")
is_causal = getattr(model_config.hf_config, "is_causal", True)
incremental_prefill_supported = (pooling_type is not None
and pooling_type.lower() == "last"
and is_causal)
action = "Enabling" if \
incremental_prefill_supported else "Disabling"
......
......@@ -1100,6 +1100,10 @@ class LLM:
"Try passing `--runner pooling` to use the model as a "
"pooling model.")
if pooling_task not in self.supported_tasks:
raise ValueError(
f"pooling_task must be one of {self.supported_tasks}.")
if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, list[str]]], prompts),
......
......@@ -44,15 +44,14 @@ class ResolvedPoolingConfig:
task: PoolingTask
@classmethod
def from_config_with_defaults(
def from_config(
cls,
task: PoolingTask,
pooler_config: PoolerConfig,
pooling_type: PoolingType,
) -> "ResolvedPoolingConfig":
assert pooler_config.pooling_type is not None
return cls(task=task,
pooling_type=PoolingType[pooler_config.pooling_type]
if pooler_config.pooling_type is not None else pooling_type)
pooling_type=PoolingType[pooler_config.pooling_type])
@dataclass(frozen=True)
......@@ -68,32 +67,20 @@ class Pooler(nn.Module, ABC):
"""The interface required for all poolers used in pooling models in vLLM."""
@staticmethod
def for_encode(
pooler_config: PoolerConfig,
*,
default_pooling_type: PoolingType = PoolingType.ALL,
):
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
task="encode",
pooler_config=pooler_config,
pooling_type=default_pooling_type,
)
if resolved_config.pooling_type == PoolingType.STEP:
def for_encode(pooler_config: PoolerConfig):
if pooler_config.pooling_type == "STEP":
return StepPooler()
resolved_config = ResolvedPoolingConfig(task="encode",
pooling_type=PoolingType.ALL)
return SimplePooler.from_config(resolved_config)
@staticmethod
def for_embed(
pooler_config: PoolerConfig,
*,
default_pooling_type: PoolingType = PoolingType.LAST,
):
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
def for_embed(pooler_config: PoolerConfig):
resolved_config = ResolvedPoolingConfig.from_config(
task="embed",
pooler_config=pooler_config,
pooling_type=default_pooling_type,
)
return SimplePooler.from_config(resolved_config)
......@@ -102,13 +89,10 @@ class Pooler(nn.Module, ABC):
def for_classify(
pooler_config: PoolerConfig,
classifier: Optional[ClassifierFn],
*,
default_pooling_type: PoolingType = PoolingType.LAST,
):
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
resolved_config = ResolvedPoolingConfig.from_config(
task="classify",
pooler_config=pooler_config,
pooling_type=default_pooling_type,
)
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment