Unverified Commit 1b6cb920 authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Deprecate] Deprecate pooling multi task support. (#37956)


Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: default avatarwang.yuqi <noooop@126.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 352b90c4
# Pooling Models # Pooling Models
!!! note !!! note
We currently support pooling models primarily for convenience. This is not guaranteed to provide any performance improvements over using Hugging Face Transformers or Sentence Transformers directly. We currently support pooling models primarily for convenience. This is not guaranteed to provide any performance
improvements over using Hugging Face Transformers or Sentence Transformers directly.
We plan to optimize pooling models in vLLM. Please comment on <https://github.com/vllm-project/vllm/issues/21796> if you have any suggestions! We plan to optimize pooling models in vLLM. Please comment on <https://github.com/vllm-project/vllm/issues/21796> if you have any suggestions!
...@@ -12,22 +13,38 @@ Natural Language Processing (NLP) can be primarily divided into the following tw ...@@ -12,22 +13,38 @@ Natural Language Processing (NLP) can be primarily divided into the following tw
- Natural Language Understanding (NLU) - Natural Language Understanding (NLU)
- Natural Language Generation (NLG) - Natural Language Generation (NLG)
The generative models supported by vLLM cover a variety of task types, such as the large language models (LLMs) we are familiar with, multimodal models (VLM) that handle multimodal inputs like images, videos, and audio, speech-to-text transcription models, and real-time models that support streaming input. Their common feature is the ability to generate text. Taking it a step further, vLLM-Omni supports the generation of multimodal content, including images, videos, and audio. The generative models supported by vLLM cover a variety of task types, such as the large language models (LLMs) we are
familiar with, multimodal models (VLM) that handle multimodal inputs like images, videos, and audio, speech-to-text
transcription models, and real-time models that support streaming input. Their common feature is the ability to generate
text. Taking it a step further, vLLM-Omni supports the generation of multimodal content, including images, videos, and audio.
As the capabilities of generative models continue to improve, the boundaries of these models are also constantly expanding. However, certain application scenarios still require specialized small language models to efficiently complete specific tasks. These models typically have the following characteristics: As the capabilities of generative models continue to improve, the boundaries of these models are also constantly expanding.
However, certain application scenarios still require specialized small language models to efficiently complete specific tasks.
These models typically have the following characteristics:
- They do not require content generation. - They do not require content generation.
- They only need to perform very limited functions, without requiring strong generalization, creativity, or high intelligence. - They only need to perform very limited functions, without requiring strong generalization, creativity, or high intelligence.
- They demand extremely low latency and may operate on cost-constrained hardware. - They demand extremely low latency and may operate on cost-constrained hardware.
- Text-only models typically have fewer than 1 billion parameters, while multimodal models generally have fewer than 10 billion parameters. - Text-only models typically have fewer than 1 billion parameters, while multimodal models generally have fewer than 10 billion parameters.
Although these models are relatively small in scale, they are still based on the Transformer architecture, similar or even identical to the most advanced large language models today. Many recently released pooling models are also fine-tuned from large language models, allowing them to benefit from the continuous improvements in large models. This architecture similarity enables them to reuse much of vLLM’s infrastructure. If compatible, we would be happy to help them leverage the latest features of vLLM as well. Although these models are relatively small in scale, they are still based on the Transformer architecture, similar or
even identical to the most advanced large language models today. Many recently released pooling models are also fine-tuned
from large language models, allowing them to benefit from the continuous improvements in large models. This architecture
similarity enables them to reuse much of vLLM’s infrastructure. If compatible, we would be happy to help them leverage
the latest features of vLLM as well.
### Sequence-wise Task and Token-wise Task ### Sequence-wise Task and Token-wise Task
The key distinction between sequence-wise task and token-wise task lies in their output granularity: sequence-wise task produces a single result for an entire input sequence, whereas token-wise task yields a result for each individual token within the sequence. The key distinction between sequence-wise task and token-wise task lies in their output granularity: sequence-wise task
produces a single result for an entire input sequence, whereas token-wise task yields a result for each individual token
within the sequence.
Of course, we also have "plugin" tasks that allow users to customize input and output processors. For more information, please refer to [IO Processor Plugins](../../design/io_processor_plugins.md). Many Pooling models support both (sequence) task and token task. When the default pooling task (e.g. a sequence-wise task)
is not what you want, you need to manually specify (e.g. a token-wise task) via `PoolerConfig(task=<task>)` offline or
`--pooler-config.task <task>` online.
Of course, we also have "plugin" tasks that allow users to customize input and output processors. For more information,
please refer to [IO Processor Plugins](../../design/io_processor_plugins.md).
### Pooling Tasks ### Pooling Tasks
...@@ -39,11 +56,13 @@ Of course, we also have "plugin" tasks that allow users to customize input and o ...@@ -39,11 +56,13 @@ Of course, we also have "plugin" tasks that allow users to customize input and o
| `token_embed` | Token-wise | vector representations for each token | | `token_embed` | Token-wise | vector representations for each token |
!!! note !!! note
Within classification tasks, there is a specialized subcategory: Cross-encoder (aka reranker) models. These models are a subset of classification models that accept two prompts as input and output num_labels equal to 1. Within classification tasks, there is a specialized subcategory: Cross-encoder (aka reranker) models. These models
are a subset of classification models that accept two prompts as input and output num_labels equal to 1.
### Score Types ### Score Types
The scoring models is designed to compute similarity scores between two input prompts. It supports three model types (aka `score_type`): `cross-encoder`, `late-interaction`, and `bi-encoder`. The scoring models is designed to compute similarity scores between two input prompts. It supports three model types
(aka `score_type`): `cross-encoder`, `late-interaction`, and `bi-encoder`.
| Pooling Tasks | Granularity | Outputs | Score Types | scoring function | | Pooling Tasks | Granularity | Outputs | Score Types | scoring function |
|-----------------------|---------------|----------------------------------------------|--------------------|--------------------------| |-----------------------|---------------|----------------------------------------------|--------------------|--------------------------|
...@@ -250,11 +269,17 @@ We have split the `encode` task into two more specific token-wise tasks: `token_ ...@@ -250,11 +269,17 @@ We have split the `encode` task into two more specific token-wise tasks: `token_
- `token_embed` is the same as `embed`, using normalization as the activation. - `token_embed` is the same as `embed`, using normalization as the activation.
- `token_classify` is the same as `classify`, by default using softmax as the activation. - `token_classify` is the same as `classify`, by default using softmax as the activation.
Pooling models now default support all pooling, you can use it without any settings. Pooling models now support token-wise task.
- Extracting hidden states prefers using `token_embed` task. - Extracting hidden states prefers using `token_embed` task.
- Named Entity Recognition (NER) and reward models prefers using `token_classify` task. - Named Entity Recognition (NER) and reward models prefers using `token_classify` task.
### Score task ### Score task
`score` task is deprecated and will be removed in v0.20. Please use `classify` instead. Only when a classification model outputs num_labels equal to 1 can it be used as a scoring model and have its scoring API enabled. `score` task is deprecated and will be removed in v0.20. Please use `classify` instead. Only when a
classification model outputs num_labels equal to 1 can it be used as a scoring model and have its scoring API enabled.
### Pooling multitask support
Pooling multitask support is deprecated and will be removed in v0.20. When the default pooling task is not what you want,
you need to manually specify it via `PoolerConfig(task=<task>)` offline or `--pooler-config.task <task>` online.
...@@ -13,6 +13,12 @@ The key distinction between (sequence) classification and token classification l ...@@ -13,6 +13,12 @@ The key distinction between (sequence) classification and token classification l
Many classification models support both (sequence) classification and token classification. For further details on (sequence) classification, please refer to [this page](classify.md). Many classification models support both (sequence) classification and token classification. For further details on (sequence) classification, please refer to [this page](classify.md).
!!! note
Pooling multitask support is deprecated and will be removed in v0.20. When the default pooling task (classify) is not
what you want, you need to manually specify it via `PoolerConfig(task="token_classify")` offline or
`--pooler-config.task token_classify` online.
## Typical Use Cases ## Typical Use Cases
### Named Entity Recognition (NER) ### Named Entity Recognition (NER)
......
...@@ -13,6 +13,12 @@ The difference between the (sequence) embedding task and the token embedding tas ...@@ -13,6 +13,12 @@ The difference between the (sequence) embedding task and the token embedding tas
Many embedding models support both (sequence) embedding and token embedding. For further details on (sequence) embedding, please refer to [this page](embed.md). Many embedding models support both (sequence) embedding and token embedding. For further details on (sequence) embedding, please refer to [this page](embed.md).
!!! note
Pooling multitask support is deprecated and will be removed in v0.20. When the default pooling task (embed) is not
what you want, you need to manually specify it via via `PoolerConfig(task="token_embed")` offline or
`--pooler-config.task token_embed` online.
## Typical Use Cases ## Typical Use Cases
### Multi-Vector Retrieval ### Multi-Vector Retrieval
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import weakref import weakref
import pytest import pytest
...@@ -67,8 +67,11 @@ def test_list_prompts(llm: LLM): ...@@ -67,8 +67,11 @@ def test_list_prompts(llm: LLM):
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
def test_token_classify(llm: LLM): def test_token_classify(llm: LLM, caplog_vllm):
with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"):
outputs = llm.encode(prompt, pooling_task="token_classify", use_tqdm=False) outputs = llm.encode(prompt, pooling_task="token_classify", use_tqdm=False)
assert "deprecated" in caplog_vllm.text
assert len(outputs) == 1 assert len(outputs) == 1
assert isinstance(outputs[0], PoolingRequestOutput) assert isinstance(outputs[0], PoolingRequestOutput)
assert outputs[0].prompt_token_ids == prompt_token_ids assert outputs[0].prompt_token_ids == prompt_token_ids
...@@ -107,8 +110,8 @@ def test_score_api(llm: LLM): ...@@ -107,8 +110,8 @@ def test_score_api(llm: LLM):
llm.score("ping", "pong", use_tqdm=False) llm.score("ping", "pong", use_tqdm=False)
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"]) @pytest.mark.parametrize("task", ["embed", "token_embed"])
def test_unsupported_tasks(llm: LLM, task: PoolingTask): def test_unsupported_tasks(llm: LLM, task: PoolingTask):
err_msg = f"Unsupported task: '{task}' Supported tasks.+" err_msg = "Embedding API is not supported by this model.+"
with pytest.raises(ValueError, match=err_msg): with pytest.raises(ValueError, match=err_msg):
llm.encode(prompt, pooling_task=task, use_tqdm=False) llm.encode(prompt, pooling_task=task, use_tqdm=False)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import weakref import weakref
import pytest import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm import LLM, PoolingParams from vllm import LLM, EmbeddingRequestOutput, PoolingParams
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.tasks import PoolingTask
MODEL_NAME = "intfloat/multilingual-e5-small" MODEL_NAME = "intfloat/multilingual-e5-small"
prompts = ["The chef prepared a delicious meal."] prompt = "The chef prepared a delicious meal."
prompt_token_ids = [0, 581, 21861, 133888, 10, 8, 150, 60744, 109911, 5, 2]
embedding_size = 384
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
...@@ -44,16 +47,48 @@ def llm(): ...@@ -44,16 +47,48 @@ def llm():
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
def test_token_embed(llm: LLM): def test_str_prompts(llm: LLM):
outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False) outputs = llm.embed(prompt, use_tqdm=False)
assert len(outputs) == 1
assert isinstance(outputs[0], EmbeddingRequestOutput)
assert outputs[0].prompt_token_ids == prompt_token_ids
assert len(outputs[0].outputs.embedding) == embedding_size
@pytest.mark.skip_global_cleanup
def test_token_ids_prompts(llm: LLM):
outputs = llm.embed([prompt_token_ids], use_tqdm=False)
assert len(outputs) == 1
assert isinstance(outputs[0], EmbeddingRequestOutput)
assert outputs[0].prompt_token_ids == prompt_token_ids
assert len(outputs[0].outputs.embedding) == embedding_size
@pytest.mark.skip_global_cleanup
def test_list_prompts(llm: LLM):
outputs = llm.embed([prompt, prompt_token_ids], use_tqdm=False)
assert len(outputs) == 2
for i in range(len(outputs)):
assert isinstance(outputs[i], EmbeddingRequestOutput)
assert outputs[i].prompt_token_ids == prompt_token_ids
assert len(outputs[i].outputs.embedding) == embedding_size
@pytest.mark.skip_global_cleanup
def test_token_embed(llm: LLM, caplog_vllm):
with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"):
outputs = llm.encode(prompt, pooling_task="token_embed", use_tqdm=False)
assert "deprecated" in caplog_vllm.text
multi_vector = outputs[0].outputs.data multi_vector = outputs[0].outputs.data
assert multi_vector.shape == (11, 384) assert multi_vector.shape == (11, 384)
@pytest.mark.skip_global_cleanup
def test_pooling_params(llm: LLM): def test_pooling_params(llm: LLM):
def get_outputs(normalize): def get_outputs(normalize):
outputs = llm.embed( outputs = llm.embed(
prompts, [prompt],
pooling_params=PoolingParams(use_activation=normalize), pooling_params=PoolingParams(use_activation=normalize),
use_tqdm=False, use_tqdm=False,
) )
...@@ -70,3 +105,10 @@ def test_pooling_params(llm: LLM): ...@@ -70,3 +105,10 @@ def test_pooling_params(llm: LLM):
assert torch.allclose(w_normal, F.normalize(wo_normal, p=2, dim=-1), atol=1e-2), ( assert torch.allclose(w_normal, F.normalize(wo_normal, p=2, dim=-1), atol=1e-2), (
"w_normal should be close to normal(wo_normal)." "w_normal should be close to normal(wo_normal)."
) )
@pytest.mark.parametrize("task", ["token_classify", "classify"])
def test_unsupported_tasks(llm: LLM, task: PoolingTask):
err_msg = "Classification API is not supported by this model.+"
with pytest.raises(ValueError, match=err_msg):
llm.encode(prompt, pooling_task=task, use_tqdm=False)
...@@ -206,7 +206,12 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str): ...@@ -206,7 +206,12 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str): async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
response = requests.post( response = requests.post(
server.url_for("pooling"), server.url_for("pooling"),
json={"model": model_name, "input": input_text, "encoding_format": "float"}, json={
"model": model_name,
"task": "token_classify",
"input": input_text,
"encoding_format": "float",
},
) )
poolings = PoolingResponse.model_validate(response.json()) poolings = PoolingResponse.model_validate(response.json())
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import weakref
import pytest
from vllm import LLM, PoolingRequestOutput
from vllm.config import PoolerConfig
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.tasks import PoolingTask
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
prompt = "The chef prepared a delicious meal."
prompt_token_ids = [785, 29706, 10030, 264, 17923, 15145, 13]
num_labels = 2
@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(
model=MODEL_NAME,
pooler_config=PoolerConfig(task="token_classify"),
max_num_batched_tokens=32768,
tensor_parallel_size=1,
gpu_memory_utilization=0.75,
enforce_eager=True,
seed=0,
)
yield weakref.proxy(llm)
del llm
cleanup_dist_env_and_memory()
@pytest.mark.skip_global_cleanup
def test_str_prompts(llm: LLM):
outputs = llm.encode(prompt, pooling_task="token_classify", use_tqdm=False)
assert len(outputs) == 1
assert isinstance(outputs[0], PoolingRequestOutput)
assert outputs[0].prompt_token_ids == prompt_token_ids
assert outputs[0].outputs.data.shape == (len(prompt_token_ids), num_labels)
@pytest.mark.skip_global_cleanup
def test_token_ids_prompts(llm: LLM):
outputs = llm.encode(
[prompt_token_ids], pooling_task="token_classify", use_tqdm=False
)
assert len(outputs) == 1
assert isinstance(outputs[0], PoolingRequestOutput)
assert outputs[0].prompt_token_ids == prompt_token_ids
assert outputs[0].outputs.data.shape == (len(prompt_token_ids), num_labels)
@pytest.mark.skip_global_cleanup
def test_score_api(llm: LLM):
err_msg = "Score API is only enabled for num_labels == 1."
with pytest.raises(ValueError, match=err_msg):
llm.score("ping", "pong", use_tqdm=False)
@pytest.mark.parametrize("task", ["classify", "embed", "token_embed"])
def test_unsupported_tasks(llm: LLM, task: PoolingTask, caplog_vllm):
if task == "classify":
with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"):
llm.encode(prompt, pooling_task=task, use_tqdm=False)
assert "deprecated" in caplog_vllm.text
else:
err_msg = "Embedding API is not supported by this model.+"
with pytest.raises(ValueError, match=err_msg):
llm.encode(prompt, pooling_task=task, use_tqdm=False)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import requests
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
DTYPE = "float32" # Use float32 to avoid NaN issue
input_text = "This product was excellent and exceeded my expectations"
input_tokens = [1986, 1985, 572, 9073, 323, 33808, 847, 16665]
@pytest.fixture(scope="module")
def server():
args = [
"--enforce-eager",
"--max-model-len",
"512",
"--dtype",
DTYPE,
"--pooler-config.task",
"token_classify",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
task = "token_classify"
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": input_text,
"encoding_format": "float",
"task": task,
},
)
poolings = PoolingResponse.model_validate(response.json())
assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 8
assert len(poolings.data[0].data[0]) == 2
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("task", ["classify", "embed", "token_embed", "plugin"])
async def test_pooling_not_supported(
server: RemoteOpenAIServer, model_name: str, task: str
):
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": input_text,
"encoding_format": "float",
"task": task,
},
)
if task != "classify":
assert response.json()["error"]["type"] == "BadRequestError"
err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import weakref
import pytest
from vllm import LLM, PoolingRequestOutput
from vllm.config import PoolerConfig
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.platforms import current_platform
from vllm.tasks import PoolingTask
MODEL_NAME = "intfloat/multilingual-e5-small"
prompt = "The chef prepared a delicious meal."
prompt_token_ids = [0, 581, 21861, 133888, 10, 8, 150, 60744, 109911, 5, 2]
embedding_size = 384
@pytest.fixture(scope="module")
def llm():
# ROCm: Use FLEX_ATTENTION backend as it's the only attention backend
# that supports encoder-only models on ROCm.
attention_config = None
if current_platform.is_rocm():
attention_config = {"backend": "FLEX_ATTENTION"}
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(
model=MODEL_NAME,
pooler_config=PoolerConfig(task="token_embed"),
max_num_batched_tokens=32768,
tensor_parallel_size=1,
gpu_memory_utilization=0.75,
enforce_eager=True,
seed=0,
attention_config=attention_config,
)
assert embedding_size == llm.model_config.embedding_size
yield weakref.proxy(llm)
del llm
cleanup_dist_env_and_memory()
@pytest.mark.skip_global_cleanup
def test_str_prompts(llm: LLM):
outputs = llm.encode(prompt, pooling_task="token_embed", use_tqdm=False)
assert len(outputs) == 1
assert isinstance(outputs[0], PoolingRequestOutput)
assert outputs[0].outputs.data.shape == (11, 384)
@pytest.mark.skip_global_cleanup
def test_token_ids_prompts(llm: LLM):
outputs = llm.encode([prompt_token_ids], pooling_task="token_embed", use_tqdm=False)
assert len(outputs) == 1
assert isinstance(outputs[0], PoolingRequestOutput)
assert outputs[0].outputs.data.shape == (11, 384)
@pytest.mark.parametrize("task", ["embed", "classify", "token_classify"])
def test_unsupported_tasks(llm: LLM, task: PoolingTask, caplog_vllm):
if task == "embed":
with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"):
llm.encode(prompt, pooling_task=task, use_tqdm=False)
assert "deprecated" in caplog_vllm.text
else:
err_msg = "Classification API is not supported by this model.+"
with pytest.raises(ValueError, match=err_msg):
llm.encode(prompt, pooling_task=task, use_tqdm=False)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import requests
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
MODEL_NAME = "intfloat/multilingual-e5-small"
DTYPE = "bfloat16"
input_text = "The best thing about vLLM is that it supports many different models"
input_tokens = [
0,
581,
2965,
13580,
1672,
81,
23708,
594,
83,
450,
442,
8060,
7,
5941,
12921,
115774,
2,
]
@pytest.fixture(scope="module")
def server():
args = [
"--runner",
"pooling",
"--dtype",
DTYPE,
"--enforce-eager",
"--max-model-len",
"512",
"--pooler-config.task",
"token_embed",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str):
task = "token_embed"
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": input_text,
"encoding_format": "float",
"task": task,
},
)
poolings = PoolingResponse.model_validate(response.json())
assert len(poolings.data) == 1
assert len(poolings.data[0].data) == len(input_tokens)
assert len(poolings.data[0].data[0]) == 384
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("task", ["embed", "classify", "token_classify", "plugin"])
async def test_pooling_not_supported(
server: RemoteOpenAIServer, model_name: str, task: str
):
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": "test",
"encoding_format": "float",
"task": task,
},
)
if task != "embed":
assert response.json()["error"]["type"] == "BadRequestError"
err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg)
...@@ -102,7 +102,7 @@ async def test_bge_m3_sparse_plugin_online( ...@@ -102,7 +102,7 @@ async def test_bge_m3_sparse_plugin_online(
"""Test BGE-M3 sparse plugin in online mode via API.""" """Test BGE-M3 sparse plugin in online mode via API."""
request_payload = { request_payload = {
"model": model_config["model_name"], "model": model_config["model_name"],
"task": "token_classify", "task": "plugin",
"data": {"input": model_config["test_input"], "return_tokens": return_tokens}, "data": {"input": model_config["test_input"], "return_tokens": return_tokens},
} }
...@@ -166,7 +166,7 @@ def test_bge_m3_sparse_plugin_offline(vllm_runner, return_tokens: bool): ...@@ -166,7 +166,7 @@ def test_bge_m3_sparse_plugin_offline(vllm_runner, return_tokens: bool):
default_torch_num_threads=1, default_torch_num_threads=1,
) as llm_runner: ) as llm_runner:
llm = llm_runner.get_llm() llm = llm_runner.get_llm()
pooler_output = llm.encode(prompt, pooling_task="token_classify") pooler_output = llm.encode(prompt, pooling_task="plugin")
outputs = pooler_output[0] outputs = pooler_output[0]
...@@ -213,7 +213,7 @@ def test_bge_m3_sparse_plugin_offline_multiple_inputs(vllm_runner): ...@@ -213,7 +213,7 @@ def test_bge_m3_sparse_plugin_offline_multiple_inputs(vllm_runner):
default_torch_num_threads=1, default_torch_num_threads=1,
) as llm_runner: ) as llm_runner:
llm = llm_runner.get_llm() llm = llm_runner.get_llm()
pooler_output = llm.encode(prompts, pooling_task="token_classify") pooler_output = llm.encode(prompts, pooling_task="plugin")
outputs = pooler_output[0] outputs = pooler_output[0]
......
...@@ -25,7 +25,7 @@ from vllm.config.scheduler import RunnerType ...@@ -25,7 +25,7 @@ from vllm.config.scheduler import RunnerType
from vllm.config.utils import config, getattr_iter from vllm.config.utils import config, getattr_iter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.tasks import ScoreType from vllm.tasks import PoolingTask, ScoreType, SupportedTask
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
ConfigFormat, ConfigFormat,
get_config, get_config,
...@@ -1409,6 +1409,41 @@ class ModelConfig: # type: ignore[misc] ...@@ -1409,6 +1409,41 @@ class ModelConfig: # type: ignore[misc]
return diff_sampling_param return diff_sampling_param
def get_pooling_task(
self, supported_tasks: tuple[SupportedTask, ...]
) -> PoolingTask | None:
if self.pooler_config is None:
return None
pooling_task = self.pooler_config.task
if pooling_task is not None:
if self.pooler_config.task in supported_tasks:
return self.pooler_config.task
else:
raise RuntimeError(
f"Unsupported task: {pooling_task!r} "
f"Supported tasks: {supported_tasks}"
)
if "token_classify" in supported_tasks:
for architecture in self.architectures:
if "ForTokenClassification" in architecture:
return "token_classify"
priority: list[PoolingTask] = [
"embed&token_classify",
"embed",
"classify",
"token_embed",
"token_classify",
"plugin",
]
for task in priority:
if task in supported_tasks:
return task
return None
@cached_property @cached_property
def is_encoder_decoder(self) -> bool: def is_encoder_decoder(self) -> bool:
"""Extract the HF encoder/decoder model flag.""" """Extract the HF encoder/decoder model flag."""
......
...@@ -5,6 +5,7 @@ from typing import Any, Literal, get_args ...@@ -5,6 +5,7 @@ from typing import Any, Literal, get_args
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tasks import PoolingTask
from vllm.utils.hashing import safe_hash from vllm.utils.hashing import safe_hash
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -20,6 +21,11 @@ TOK_POOLING_TYPES: tuple[TokenPoolingType, ...] = get_args(TokenPoolingType) ...@@ -20,6 +21,11 @@ TOK_POOLING_TYPES: tuple[TokenPoolingType, ...] = get_args(TokenPoolingType)
class PoolerConfig: class PoolerConfig:
"""Controls the behavior of output pooling in pooling models.""" """Controls the behavior of output pooling in pooling models."""
task: PoolingTask | None = None
"""
The task used for pooling.
"""
pooling_type: SequencePoolingType | TokenPoolingType | None = None pooling_type: SequencePoolingType | TokenPoolingType | None = None
""" """
The pooling method used for pooling. The pooling method used for pooling.
......
...@@ -382,16 +382,19 @@ class LLM: ...@@ -382,16 +382,19 @@ class LLM:
self.llm_engine = LLMEngine.from_engine_args( self.llm_engine = LLMEngine.from_engine_args(
engine_args=engine_args, usage_context=UsageContext.LLM_CLASS engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
) )
self.model_config = self.llm_engine.model_config
self.engine_class = type(self.llm_engine) self.engine_class = type(self.llm_engine)
self.request_counter = Counter() self.request_counter = Counter()
self.default_sampling_params: dict[str, Any] | None = None self.default_sampling_params: dict[str, Any] | None = None
supported_tasks = self.llm_engine.get_supported_tasks() supported_tasks = self.llm_engine.get_supported_tasks()
logger.info("Supported tasks: %s", supported_tasks)
self.supported_tasks = supported_tasks self.supported_tasks = supported_tasks
self.pooling_task = self.model_config.get_pooling_task(supported_tasks)
if self.pooling_task is not None:
logger.info("Supported pooling task: %s", self.pooling_task)
self.model_config = self.llm_engine.model_config self.runner_type = self.model_config.runner_type
self.renderer = self.llm_engine.renderer self.renderer = self.llm_engine.renderer
self.chat_template = load_chat_template(chat_template) self.chat_template = load_chat_template(chat_template)
self.io_processor = self.llm_engine.io_processor self.io_processor = self.llm_engine.io_processor
...@@ -1072,31 +1075,7 @@ class LLM: ...@@ -1072,31 +1075,7 @@ class LLM:
pooled hidden states in the same order as the input prompts. pooled hidden states in the same order as the input prompts.
""" """
if pooling_task is None: self._verify_pooling_task(pooling_task)
raise ValueError(
"pooling_task required for `LLM.encode`\n"
"Please use one of the more specific methods or set the "
"pooling_task when using `LLM.encode`:\n"
" - For embeddings, use `LLM.embed(...)` "
'or `pooling_task="embed"`.\n'
" - For classification logits, use `LLM.classify(...)` "
'or `pooling_task="classify"`.\n'
" - For similarity scores, use `LLM.score(...)`.\n"
" - For rewards, use `LLM.reward(...)` "
'or `pooling_task="token_classify"`\n'
" - For token classification, "
'use `pooling_task="token_classify"`\n'
' - For multi-vector retrieval, use `pooling_task="token_embed"`'
)
model_config = self.model_config
runner_type = model_config.runner_type
if runner_type != "pooling":
raise ValueError(
"LLM.encode() is only supported for pooling models. "
"Try passing `--runner pooling` to use the model as a "
"pooling model."
)
if isinstance(prompts, dict) and "data" in prompts: if isinstance(prompts, dict) and "data" in prompts:
if self.io_processor is None: if self.io_processor is None:
...@@ -1206,6 +1185,65 @@ class LLM: ...@@ -1206,6 +1185,65 @@ class LLM:
) )
return outputs return outputs
def _verify_pooling_task(self, pooling_task: PoolingTask | None):
if self.runner_type != "pooling":
raise ValueError(
"LLM.encode() is only supported for pooling models. "
"Try passing `--runner pooling` to use the model as a "
"pooling model."
)
if pooling_task is None:
raise ValueError(
"pooling_task required for `LLM.encode`\n"
"Please use one of the more specific methods or set the "
"pooling_task when using `LLM.encode`:\n"
" - For embeddings, use `LLM.embed(...)` "
'or `pooling_task="embed"`.\n'
" - For classification logits, use `LLM.classify(...)` "
'or `pooling_task="classify"`.\n'
" - For similarity scores, use `LLM.score(...)`.\n"
" - For rewards, use `LLM.reward(...)` "
'or `pooling_task="token_classify"`\n'
" - For token classification, "
'use `pooling_task="token_classify"`\n'
' - For multi-vector retrieval, use `pooling_task="token_embed"`'
)
if (
pooling_task in ("embed", "token_embed")
and pooling_task not in self.supported_tasks
):
raise ValueError(
"Embedding API is not supported by this model. "
"Try converting the model using `--convert embed`."
)
if (
pooling_task in ("classify", "token_classify")
and pooling_task not in self.supported_tasks
):
raise ValueError(
"Classification API is not supported by this model. "
"Try converting the model using `--convert classify`."
)
# plugin task uses io_processor.parse_request to verify inputs
if pooling_task != "plugin" and pooling_task != self.pooling_task:
if pooling_task not in self.supported_tasks:
raise ValueError(
f"Unsupported task: {pooling_task!r} "
f"Supported tasks: {self.supported_tasks}"
)
else:
logger.warning_once(
"Pooling multitask support is deprecated and will "
"be removed in v0.20. When the default pooling task is "
"not what you want, you need to manually specify it "
'via PoolerConfig(task="%s"). ',
pooling_task,
)
def embed( def embed(
self, self,
prompts: PromptType | Sequence[PromptType], prompts: PromptType | Sequence[PromptType],
...@@ -1239,11 +1277,6 @@ class LLM: ...@@ -1239,11 +1277,6 @@ class LLM:
A list of `EmbeddingRequestOutput` objects containing the A list of `EmbeddingRequestOutput` objects containing the
embedding vectors in the same order as the input prompts. embedding vectors in the same order as the input prompts.
""" """
if "embed" not in self.supported_tasks:
raise ValueError(
"Embedding API is not supported by this model. "
"Try converting the model using `--convert embed`."
)
items = self.encode( items = self.encode(
prompts, prompts,
...@@ -1289,11 +1322,6 @@ class LLM: ...@@ -1289,11 +1322,6 @@ class LLM:
A list of `ClassificationRequestOutput` objects containing the A list of `ClassificationRequestOutput` objects containing the
embedding vectors in the same order as the input prompts. embedding vectors in the same order as the input prompts.
""" """
if "classify" not in self.supported_tasks:
raise ValueError(
"Classification API is not supported by this model. "
"Try converting the model using `--convert classify`."
)
items = self.encode( items = self.encode(
prompts, prompts,
......
...@@ -45,6 +45,12 @@ def register_pooling_api_routers( ...@@ -45,6 +45,12 @@ def register_pooling_api_routers(
supported_tasks: tuple["SupportedTask", ...], supported_tasks: tuple["SupportedTask", ...],
model_config: ModelConfig | None = None, model_config: ModelConfig | None = None,
): ):
if model_config is None:
return
pooling_task = model_config.get_pooling_task(supported_tasks)
if pooling_task is not None:
from vllm.entrypoints.pooling.pooling.api_router import router as pooling_router from vllm.entrypoints.pooling.pooling.api_router import router as pooling_router
app.include_router(pooling_router) app.include_router(pooling_router)
...@@ -91,6 +97,7 @@ def init_pooling_state( ...@@ -91,6 +97,7 @@ def init_pooling_state(
engine_client, engine_client,
state.openai_serving_models, state.openai_serving_models,
state.openai_serving_render, state.openai_serving_render,
supported_tasks=supported_tasks,
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
......
...@@ -37,6 +37,7 @@ from vllm.inputs import ProcessorInputs ...@@ -37,6 +37,7 @@ from vllm.inputs import ProcessorInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.renderers.inputs.preprocess import prompt_to_seq from vllm.renderers.inputs.preprocess import prompt_to_seq
from vllm.tasks import SupportedTask
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
...@@ -49,6 +50,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -49,6 +50,7 @@ class OpenAIServingPooling(OpenAIServing):
engine_client: EngineClient, engine_client: EngineClient,
models: OpenAIServingModels, models: OpenAIServingModels,
openai_serving_render: OpenAIServingRender, openai_serving_render: OpenAIServingRender,
supported_tasks: tuple[SupportedTask, ...],
*, *,
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
chat_template: str | None, chat_template: str | None,
...@@ -60,7 +62,8 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -60,7 +62,8 @@ class OpenAIServingPooling(OpenAIServing):
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
) )
self.supported_tasks = supported_tasks
self.pooling_task = self.model_config.get_pooling_task(supported_tasks)
self.openai_serving_render = openai_serving_render self.openai_serving_render = openai_serving_render
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format self.chat_template_content_format: Final = chat_template_content_format
...@@ -86,9 +89,27 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -86,9 +89,27 @@ class OpenAIServingPooling(OpenAIServing):
lora_request = self._maybe_get_adapters(request) lora_request = self._maybe_get_adapters(request)
if request.task is None:
request.task = self.pooling_task
if getattr(request, "dimensions", None) is not None: if getattr(request, "dimensions", None) is not None:
return self.create_error_response("dimensions is currently not supported") return self.create_error_response("dimensions is currently not supported")
# plugin task uses io_processor.parse_request to verify inputs
if request.task != "plugin" and request.task != self.pooling_task:
if request.task not in self.supported_tasks:
raise ValueError(
f"Unsupported task: {request.task!r} "
f"Supported tasks: {self.supported_tasks}"
)
else:
logger.warning_once(
"Pooling multitask support is deprecated and will be removed "
"in v0.20. When the default pooling task is not what you want, you "
'need to manually specify it via --pooler-config.task "%s". ',
request.task,
)
engine_prompts: Sequence[ProcessorInputs] engine_prompts: Sequence[ProcessorInputs]
if use_io_processor := isinstance(request, IOProcessorRequest): if use_io_processor := isinstance(request, IOProcessorRequest):
if self.io_processor is None: if self.io_processor is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment