Commit 5eb36575 authored by khluu's avatar khluu
Browse files

Revert "[Frontend] Remove frontend pooling multi task support. (#37861)"

This reverts commit d2e2e856.
parent 4d51588e
......@@ -292,10 +292,10 @@ Pooling models now support token-wise task.
### Score task
`score` task have has been removed in v0.21, use `classify` instead. Only when a classification model outputs num_labels
equal to 1 can it be used as a scoring model and have its scoring API enabled.
`score` task is deprecated and will be removed in v0.20. Please use `classify` instead. Only when a
classification model outputs num_labels equal to 1 can it be used as a scoring model and have its scoring API enabled.
### Pooling multitask support
Pooling multitask support has been removed in v0.21. When the default pooling task is not what you want,
Pooling multitask support is deprecated and will be removed in v0.20. When the default pooling task is not what you want,
you need to manually specify it via `PoolerConfig(task=<task>)` offline or `--pooler-config.task <task>` online.
......@@ -4,74 +4,68 @@
import torch
from vllm import LLM
from vllm.config import PoolerConfig
from vllm.inputs import TextPrompt
from vllm.multimodal.utils import fetch_image
# Initialize model
model = LLM(
model="jinaai/jina-embeddings-v4-vllm-text-matching",
runner="pooling",
max_model_len=1024,
gpu_memory_utilization=0.8,
)
def main():
# Initialize model
model = LLM(
model="jinaai/jina-embeddings-v4-vllm-text-matching",
pooler_config=PoolerConfig(task="token_embed"),
runner="pooling",
max_model_len=1024,
gpu_memory_utilization=0.8,
)
# Create text prompts
text1 = "Ein wunderschöner Sonnenuntergang am Strand"
text1_prompt = TextPrompt(prompt=f"Query: {text1}")
# Create text prompts
text1 = "Ein wunderschöner Sonnenuntergang am Strand"
text1_prompt = TextPrompt(prompt=f"Query: {text1}")
text2 = "浜辺に沈む美しい夕日"
text2_prompt = TextPrompt(prompt=f"Query: {text2}")
text2 = "浜辺に沈む美しい夕日"
text2_prompt = TextPrompt(prompt=f"Query: {text2}")
# Create image prompt
image = fetch_image(
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/eskimo.jpg" # noqa: E501
)
image_prompt = TextPrompt(
prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", # noqa: E501
multi_modal_data={"image": image},
)
# Create image prompt
image = fetch_image(
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/eskimo.jpg" # noqa: E501
)
image_prompt = TextPrompt(
prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", # noqa: E501
multi_modal_data={"image": image},
)
# Encode all prompts
prompts = [text1_prompt, text2_prompt, image_prompt]
outputs = model.encode(prompts, pooling_task="token_embed")
# Encode all prompts
prompts = [text1_prompt, text2_prompt, image_prompt]
outputs = model.encode(prompts, pooling_task="token_embed")
def get_embeddings(outputs):
VISION_START_TOKEN_ID, VISION_END_TOKEN_ID = 151652, 151653
def get_embeddings(outputs):
VISION_START_TOKEN_ID, VISION_END_TOKEN_ID = 151652, 151653
embeddings = []
for output in outputs:
if VISION_START_TOKEN_ID in output.prompt_token_ids:
# Gather only vision tokens
img_start_pos = torch.where(
torch.tensor(output.prompt_token_ids) == VISION_START_TOKEN_ID
)[0][0]
img_end_pos = torch.where(
torch.tensor(output.prompt_token_ids) == VISION_END_TOKEN_ID
)[0][0]
embeddings_tensor = output.outputs.data.detach().clone()[
img_start_pos : img_end_pos + 1
]
else:
# Use all tokens for text-only prompts
embeddings_tensor = output.outputs.data.detach().clone()
embeddings = []
for output in outputs:
if VISION_START_TOKEN_ID in output.prompt_token_ids:
# Gather only vision tokens
img_start_pos = torch.where(
torch.tensor(output.prompt_token_ids) == VISION_START_TOKEN_ID
)[0][0]
img_end_pos = torch.where(
torch.tensor(output.prompt_token_ids) == VISION_END_TOKEN_ID
)[0][0]
embeddings_tensor = output.outputs.data.detach().clone()[
img_start_pos : img_end_pos + 1
]
else:
# Use all tokens for text-only prompts
embeddings_tensor = output.outputs.data.detach().clone()
# Pool and normalize embeddings
pooled_output = (
embeddings_tensor.sum(dim=0, dtype=torch.float32)
/ embeddings_tensor.shape[0]
)
embeddings.append(torch.nn.functional.normalize(pooled_output, dim=-1))
return embeddings
# Pool and normalize embeddings
pooled_output = (
embeddings_tensor.sum(dim=0, dtype=torch.float32)
/ embeddings_tensor.shape[0]
)
embeddings.append(torch.nn.functional.normalize(pooled_output, dim=-1))
return embeddings
embeddings = get_embeddings(outputs)
for embedding in embeddings:
print(embedding.shape)
embeddings = get_embeddings(outputs)
if __name__ == "__main__":
main()
for embedding in embeddings:
print(embedding.shape)
......@@ -4,7 +4,6 @@
from argparse import Namespace
from vllm import LLM, EngineArgs
from vllm.config import PoolerConfig
from vllm.utils.argparse_utils import FlexibleArgumentParser
......@@ -14,7 +13,6 @@ def parse_args():
# Set example specific arguments
parser.set_defaults(
model="BAAI/bge-m3",
pooler_config=PoolerConfig(task="token_embed"),
runner="pooling",
enforce_eager=True,
)
......@@ -34,6 +32,15 @@ def main(args: Namespace):
# You should pass runner="pooling" for embedding models
llm = LLM(**vars(args))
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = llm.embed(prompts)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding
print(len(embeds))
# Generate embedding for each token. The output is a list of PoolingRequestOutput.
outputs = llm.encode(prompts, pooling_task="token_embed")
......@@ -43,20 +50,6 @@ def main(args: Namespace):
multi_vector = output.outputs.data
print(multi_vector.shape)
query = "What is the capital of France?"
documents = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
]
# Generate scores.
outputs = llm.score(query, documents)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for document, output in zip(documents, outputs):
score = output.outputs.score
print(f"Pair: {[query, document]!r} \nScore: {score}")
print("-" * 60)
if __name__ == "__main__":
args = parse_args()
......
......@@ -7,11 +7,10 @@ Example online usage of Pooling API for multi vector retrieval.
Run `vllm serve <model> --runner pooling`
to start up the server in vLLM. e.g.
vllm serve BAAI/bge-m3 --pooler-config.task token_embed
vllm serve BAAI/bge-m3
"""
import argparse
import pprint
import requests
import torch
......@@ -33,8 +32,7 @@ def parse_args():
def main(args):
pooling_url = f"http://{args.host}:{args.port}/pooling"
score_url = f"http://{args.host}:{args.port}/score"
api_url = f"http://{args.host}:{args.port}/pooling"
model_name = args.model
prompts = [
......@@ -45,23 +43,11 @@ def main(args):
]
prompt = {"model": model_name, "input": prompts}
pooling_response = post_http_request(prompt=prompt, api_url=pooling_url)
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
for output in pooling_response.json()["data"]:
multi_vector = torch.tensor(output["data"])
print(multi_vector.shape)
queries = "What is the capital of France?"
documents = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
]
prompt = {"model": model_name, "queries": queries, "documents": documents}
score_response = post_http_request(prompt=prompt, api_url=score_url)
print("\nPrompt when queries is string and documents is a list:")
pprint.pprint(prompt)
print("\nScore Response:")
pprint.pprint(score_response.json())
if __name__ == "__main__":
args = parse_args()
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import weakref
import pytest
import torch
from tests.models.utils import softmax
from vllm import LLM, ClassificationRequestOutput, PoolingParams
from vllm import LLM, ClassificationRequestOutput, PoolingParams, PoolingRequestOutput
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.tasks import PoolingTask
......@@ -65,6 +66,18 @@ def test_list_prompts(llm: LLM):
assert len(outputs[i].outputs.probs) == num_labels
@pytest.mark.skip_global_cleanup
def test_token_classify(llm: LLM, caplog_vllm):
with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"):
outputs = llm.encode(prompt, pooling_task="token_classify", use_tqdm=False)
assert "deprecated" in caplog_vllm.text
assert len(outputs) == 1
assert isinstance(outputs[0], PoolingRequestOutput)
assert outputs[0].prompt_token_ids == prompt_token_ids
assert outputs[0].outputs.data.shape == (len(prompt_token_ids), num_labels)
@pytest.mark.skip_global_cleanup
def test_pooling_params(llm: LLM):
def get_outputs(use_activation):
......@@ -97,12 +110,10 @@ def test_score_api(llm: LLM):
llm.score("ping", "pong", use_tqdm=False)
@pytest.mark.parametrize("task", ["embed", "token_embed", "token_classify", "plugin"])
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
def test_unsupported_tasks(llm: LLM, task: PoolingTask):
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
elif task == "token_classify":
err_msg = "Try switching the model's pooling_task via.+"
else:
err_msg = "Embedding API is not supported by this model.+"
with pytest.raises(ValueError, match=err_msg):
......
......@@ -436,7 +436,26 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("task", ["embed", "token_embed", "token_classify", "plugin"])
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
task = "token_classify"
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": input_text,
"encoding_format": "float",
"task": task,
},
)
poolings = PoolingResponse.model_validate(response.json())
assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 8
assert len(poolings.data[0].data[0]) == 2
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
async def test_pooling_not_supported(
server: RemoteOpenAIServer, model_name: str, task: str
):
......@@ -450,11 +469,8 @@ async def test_pooling_not_supported(
},
)
assert response.json()["error"]["type"] == "BadRequestError"
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
elif task == "token_classify":
err_msg = "Try switching the model's pooling_task via"
else:
err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import weakref
import pytest
......@@ -37,11 +38,11 @@ def llm():
seed=0,
attention_config=attention_config,
)
assert embedding_size == llm.model_config.embedding_size
yield weakref.proxy(llm)
del llm
cleanup_dist_env_and_memory()
......@@ -73,6 +74,16 @@ def test_list_prompts(llm: LLM):
assert len(outputs[i].outputs.embedding) == embedding_size
@pytest.mark.skip_global_cleanup
def test_token_embed(llm: LLM, caplog_vllm):
with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"):
outputs = llm.encode(prompt, pooling_task="token_embed", use_tqdm=False)
assert "deprecated" in caplog_vllm.text
multi_vector = outputs[0].outputs.data
assert multi_vector.shape == (11, 384)
@pytest.mark.skip_global_cleanup
def test_pooling_params(llm: LLM):
def get_outputs(normalize):
......@@ -96,14 +107,10 @@ def test_pooling_params(llm: LLM):
)
@pytest.mark.parametrize(
"task", ["token_classify", "classify", "token_embed", "plugin"]
)
@pytest.mark.parametrize("task", ["token_classify", "classify", "plugin"])
def test_unsupported_tasks(llm: LLM, task: PoolingTask):
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
elif task == "token_embed":
err_msg = "Try switching the model's pooling_task via.+"
else:
err_msg = "Classification API is not supported by this model.+"
with pytest.raises(ValueError, match=err_msg):
......
......@@ -732,9 +732,28 @@ async def test_pooling_embed(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize(
"task", ["classify", "token_classify", "token_embed", "plugin"]
)
async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str):
task = "token_embed"
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": input_text,
"encoding_format": "float",
"task": task,
},
)
poolings = PoolingResponse.model_validate(response.json())
assert len(poolings.data) == 1
assert len(poolings.data[0].data) == len(input_tokens)
assert len(poolings.data[0].data[0]) == 384
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("task", ["classify", "token_classify", "plugin"])
async def test_pooling_not_supported(
server: RemoteOpenAIServer, model_name: str, task: str
):
......@@ -750,8 +769,6 @@ async def test_pooling_not_supported(
assert response.json()["error"]["type"] == "BadRequestError"
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
elif task == "token_embed":
err_msg = "Try switching the model's pooling_task via"
else:
err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg)
......@@ -452,6 +452,25 @@ async def test_pooling_classify(server: RemoteOpenAIServer):
assert len(poolings.data[0].data) == 1
@pytest.mark.asyncio
async def test_pooling_token_classify(server: RemoteOpenAIServer):
response = requests.post(
server.url_for("pooling"),
json={
"model": MODEL_NAME,
"task": "token_classify",
"input": input_text,
"encoding_format": "float",
},
)
poolings = PoolingResponse.model_validate(response.json())
assert len(poolings.data) == 1
assert len(poolings.data[0].data) == len(input_tokens)
assert len(poolings.data[0].data[0]) == 1
@pytest.mark.asyncio
async def test_rerank_max_tokens_per_doc(
server: RemoteOpenAIServer,
......@@ -525,7 +544,7 @@ async def test_rerank_max_tokens_per_doc_validation(
@pytest.mark.asyncio
@pytest.mark.parametrize("task", ["embed", "token_embed", "token_classify", "plugin"])
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
async def test_pooling_not_supported(server: RemoteOpenAIServer, task: str):
response = requests.post(
server.url_for("pooling"),
......@@ -539,8 +558,6 @@ async def test_pooling_not_supported(server: RemoteOpenAIServer, task: str):
assert response.json()["error"]["type"] == "BadRequestError"
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
elif task == "token_classify":
err_msg = "Try switching the model's pooling_task via"
else:
err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import weakref
import pytest
......@@ -59,19 +60,22 @@ def test_token_ids_prompts(llm: LLM):
@pytest.mark.skip_global_cleanup
def test_score_api(llm: LLM):
err_msg = "This model does not support the Scoring API."
err_msg = "Scoring API is only enabled for num_labels == 1."
with pytest.raises(ValueError, match=err_msg):
llm.score("ping", "pong", use_tqdm=False)
@pytest.mark.parametrize("task", ["classify", "embed", "token_embed", "plugin"])
def test_unsupported_tasks(llm: LLM, task: PoolingTask, caplog_vllm):
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
elif task == "classify":
err_msg = "Try switching the model's pooling_task via.+"
if task == "classify":
with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"):
llm.encode(prompt, pooling_task=task, use_tqdm=False)
assert "deprecated" in caplog_vllm.text
else:
err_msg = "Embedding API is not supported by this model.+"
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = "Embedding API is not supported by this model.+"
with pytest.raises(ValueError, match=err_msg):
llm.encode(prompt, pooling_task=task, use_tqdm=False)
with pytest.raises(ValueError, match=err_msg):
llm.encode(prompt, pooling_task=task, use_tqdm=False)
......@@ -50,7 +50,7 @@ async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: st
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("task", ["classify", "embed", "token_embed", "plugin"])
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
async def test_pooling_not_supported(
server: RemoteOpenAIServer, model_name: str, task: str
):
......@@ -63,12 +63,9 @@ async def test_pooling_not_supported(
"task": task,
},
)
assert response.json()["error"]["type"] == "BadRequestError"
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
elif task == "classify":
err_msg = "Try switching the model's pooling_task via"
else:
err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import weakref
import pytest
......@@ -63,12 +64,15 @@ def test_token_ids_prompts(llm: LLM):
@pytest.mark.parametrize("task", ["embed", "classify", "token_classify", "plugin"])
def test_unsupported_tasks(llm: LLM, task: PoolingTask, caplog_vllm):
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
elif task == "embed":
err_msg = "Try switching the model's pooling_task via.+"
if task == "embed":
with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"):
llm.encode(prompt, pooling_task=task, use_tqdm=False)
assert "deprecated" in caplog_vllm.text
else:
err_msg = "Classification API is not supported by this model.+"
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
else:
err_msg = "Classification API is not supported by this model.+"
with pytest.raises(ValueError, match=err_msg):
llm.encode(prompt, pooling_task=task, use_tqdm=False)
with pytest.raises(ValueError, match=err_msg):
llm.encode(prompt, pooling_task=task, use_tqdm=False)
......@@ -73,7 +73,7 @@ async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("task", ["embed", "classify", "token_classify", "plugin"])
@pytest.mark.parametrize("task", ["classify", "token_classify", "plugin"])
async def test_pooling_not_supported(
server: RemoteOpenAIServer, model_name: str, task: str
):
......@@ -86,12 +86,9 @@ async def test_pooling_not_supported(
"task": task,
},
)
assert response.json()["error"]["type"] == "BadRequestError"
if task == "plugin":
err_msg = "No IOProcessor plugin installed."
elif task == "embed":
err_msg = "Try switching the model's pooling_task via"
else:
err_msg = f"Unsupported task: {task!r}"
assert response.json()["error"]["message"].startswith(err_msg)
......@@ -6,7 +6,6 @@ from transformers import AutoModel
from tests.models.utils import check_embeddings_close
from vllm import TokensPrompt
from vllm.config import PoolerConfig
@pytest.mark.parametrize(
......@@ -22,7 +21,6 @@ def test_embed_models(hf_runner, vllm_runner, model: str):
with vllm_runner(
model,
runner="pooling",
pooler_config=PoolerConfig(task="token_embed"),
max_model_len=128,
max_num_batched_tokens=chunk_size,
enforce_eager=True,
......
......@@ -3,6 +3,7 @@
import httpx
import openai
import pytest
import pytest_asyncio
import torch
from ....utils import RemoteOpenAIServer
......@@ -24,42 +25,29 @@ sentences_2 = [
similarity_reference = [[0.6259, 0.3474], [0.3309, 0.6734]]
lexical_score_reference = [0.19554901123046875, 0.0]
colbert_score_reference = [0.7797, 0.4620]
SUPPORTED_TASKS = ["embed", "token_embed", "token_classify"]
@pytest.fixture(scope="module", params=SUPPORTED_TASKS)
def pooling_task(request):
yield request.param
@pytest.fixture(scope="module")
def server(pooling_task):
def server():
args = [
"--max-model-len",
str(MAX_MODEL_LEN),
"--hf-overrides",
'{"architectures": ["BgeM3EmbeddingModel"]}',
"--pooler-config.task",
pooling_task,
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.mark.asyncio
async def test_bge_m3_api_server_embedding(server, pooling_task):
client = server.get_async_client()
if pooling_task != "embed":
with pytest.raises(openai.InternalServerError):
await run_client_embeddings(
client,
MODEL_NAME,
sentences_1,
)
return
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
async def test_bge_m3_api_server_embedding(client: openai.AsyncOpenAI):
embeddings_list_1 = await run_client_embeddings(
client,
MODEL_NAME,
......@@ -129,14 +117,7 @@ def compute_lexical_matching_score(
@pytest.mark.asyncio
async def test_bge_m3_api_server_sparse_embedding(server, pooling_task):
client = server.get_async_client()
if pooling_task != "token_classify":
with pytest.raises(openai.BadRequestError):
await sparse_embeddings(client, sentences_1)
return
async def test_bge_m3_api_server_sparse_embedding(client: openai.AsyncOpenAI):
embeddings_1 = await sparse_embeddings(client, sentences_1)
embeddings_2 = await sparse_embeddings(client, sentences_2)
......@@ -156,11 +137,9 @@ async def test_bge_m3_api_server_sparse_embedding(server, pooling_task):
@pytest.mark.asyncio
async def test_bge_m3_api_server_sparse_embedding_corner_case(server, pooling_task):
if pooling_task != "token_classify":
return
client = server.get_async_client()
async def test_bge_m3_api_server_sparse_embedding_corner_case(
client: openai.AsyncOpenAI,
):
embeddings = await sparse_embeddings(client, ["Hi"])
assert len(embeddings) == 1
assert 2673 in embeddings[0]
......@@ -176,18 +155,7 @@ def colbert_score(q_reps: torch.Tensor, p_reps: torch.Tensor) -> torch.Tensor:
@pytest.mark.asyncio
async def test_bge_m3_api_server_multi_vector(server, pooling_task):
client = server.get_async_client()
if pooling_task != "token_embed":
with pytest.raises(openai.BadRequestError):
await client.post(
"../pooling",
body={"model": MODEL_NAME, "input": sentences_1, "task": "token_embed"},
cast_to=httpx.Response,
)
return
async def test_bge_m3_api_server_multi_vector(client: openai.AsyncOpenAI):
result_1 = await client.post(
"../pooling",
body={"model": MODEL_NAME, "input": sentences_1, "task": "token_embed"},
......
......@@ -4,7 +4,6 @@ import pytest
import torch
from vllm import TokensPrompt
from vllm.config import PoolerConfig
@pytest.mark.parametrize(
......@@ -21,7 +20,6 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
max_model_len=128,
enforce_eager=True,
runner="pooling",
pooler_config=PoolerConfig(task="token_embed"),
enable_prefix_caching=True,
) as vllm_model:
pooling_outputs = vllm_model.llm.encode(
......@@ -46,3 +44,14 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
assert len(output.prompt_token_ids) == n
assert len(output.outputs.data) == n
assert output.num_cached_tokens == 0
# skip_reading_prefix_cache can still write to cache
# to accelerate following requests
pooling_outputs = vllm_model.llm.encode(
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
pooling_task="embed",
)
for n, output in zip(n_prompt_tokens, pooling_outputs):
assert len(output.prompt_token_ids) == n
assert output.num_cached_tokens > 0
......@@ -5,7 +5,6 @@ import torch
from transformers import AutoModel
from tests.models.utils import check_embeddings_close
from vllm.config import PoolerConfig
@pytest.mark.parametrize(
......@@ -18,7 +17,6 @@ def test_embed_models(hf_runner, vllm_runner, example_prompts, model: str, dtype
with vllm_runner(
model,
runner="pooling",
pooler_config=PoolerConfig(task="token_embed"),
max_model_len=None,
) as vllm_model:
vllm_outputs = vllm_model.token_embed(example_prompts)
......
......@@ -146,7 +146,7 @@ def test_multi_vector_retrieval_models_using_normalize(
model,
max_model_len=512,
dtype=dtype,
pooler_config=PoolerConfig(use_activation=False, task="token_embed"),
pooler_config=PoolerConfig(use_activation=False),
) as vllm_model:
wo_normalize = vllm_model.token_embed(example_prompts)
......@@ -154,7 +154,7 @@ def test_multi_vector_retrieval_models_using_normalize(
model,
max_model_len=512,
dtype=dtype,
pooler_config=PoolerConfig(use_activation=True, task="token_embed"),
pooler_config=PoolerConfig(use_activation=True),
) as vllm_model:
w_normalize = vllm_model.token_embed(example_prompts)
......
......@@ -79,7 +79,7 @@ from vllm.renderers.inputs.preprocess import (
prompt_to_seq,
)
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
from vllm.tasks import SCORE_TYPE_MAP, PoolingTask
from vllm.tasks import PoolingTask
from vllm.tokenizers import TokenizerLike
from vllm.usage.usage_lib import UsageContext
from vllm.utils.counter import Counter
......@@ -1204,9 +1204,12 @@ class LLM:
f"Supported tasks: {self.supported_tasks}"
)
else:
raise ValueError(
f"Try switching the model's pooling_task "
f'via `PoolerConfig(task="{pooling_task}")`'
logger.warning_once(
"Pooling multitask support is deprecated and will "
"be removed in v0.20. When the default pooling task is "
"not what you want, you need to manually specify it "
'via PoolerConfig(task="%s"). ',
pooling_task,
)
if pooling_task == "plugin" and "plugin" not in self.pooling_io_processors:
......@@ -1409,7 +1412,7 @@ class LLM:
"pooling model."
)
score_type: str | None = SCORE_TYPE_MAP.get(self.pooling_task, None) # type: ignore[arg-type]
score_type = self.model_config.score_type
if (
score_type == "cross-encoder"
and getattr(self.model_config.hf_config, "num_labels", 0) != 1
......
......@@ -15,7 +15,10 @@ from starlette.datastructures import Headers
from vllm import PoolingParams, PoolingRequestOutput, envs
from vllm.config import VllmConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.chat_utils import (
ChatTemplateConfig,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
......@@ -45,7 +48,9 @@ class PoolingServingBase(ABC):
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template_config: ChatTemplateConfig,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
):
......@@ -58,7 +63,11 @@ class PoolingServingBase(ABC):
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
self.log_error_stack = log_error_stack
self.chat_template_config = chat_template_config
self.chat_template_config = ChatTemplateConfig(
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
trust_request_chat_template=trust_request_chat_template,
)
# Shared thread pool executor for preprocessing and postprocessing.
self._executor: Executor = models.renderer._executor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment