Unverified Commit 4ae77dfd authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Frontend][1/n] Make pooling entrypoints request schema consensus | CompletionRequest (#32395)


Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent 73f635a7
...@@ -559,7 +559,7 @@ Our Classification API directly supports Hugging Face sequence-classification mo ...@@ -559,7 +559,7 @@ Our Classification API directly supports Hugging Face sequence-classification mo
We automatically wrap any other transformer via `as_seq_cls_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities. We automatically wrap any other transformer via `as_seq_cls_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities.
Code example: [examples/pooling/classify/openai_classification_client.py](../../examples/pooling/classify/openai_classification_client.py) Code example: [examples/pooling/classify/classification_online.py](../../examples/pooling/classify/classification_online.py)
#### Example Requests #### Example Requests
......
...@@ -11,27 +11,26 @@ import pprint ...@@ -11,27 +11,26 @@ import pprint
import requests import requests
headers = {"accept": "application/json", "Content-Type": "application/json"}
def post_http_request(payload: dict, api_url: str) -> requests.Response:
headers = {"User-Agent": "Test Client"}
response = requests.post(api_url, headers=headers, json=payload)
return response
def parse_args(): def parse_args():
parse = argparse.ArgumentParser() parse = argparse.ArgumentParser()
parse.add_argument("--host", type=str, default="localhost") parse.add_argument("--host", type=str, default="localhost")
parse.add_argument("--port", type=int, default=8000) parse.add_argument("--port", type=int, default=8000)
parse.add_argument("--model", type=str, default="jason9693/Qwen2.5-1.5B-apeach")
return parse.parse_args() return parse.parse_args()
def main(args): def main(args):
host = args.host base_url = f"http://{args.host}:{args.port}"
port = args.port models_url = base_url + "/v1/models"
model_name = args.model classify_url = base_url + "/classify"
tokenize_url = base_url + "/tokenize"
response = requests.get(models_url, headers=headers)
model = response.json()["data"][0]["id"]
api_url = f"http://{host}:{port}/classify" # /classify can accept str as input
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",
...@@ -40,12 +39,27 @@ def main(args): ...@@ -40,12 +39,27 @@ def main(args):
] ]
payload = { payload = {
"model": model_name, "model": model,
"input": prompts, "input": prompts,
} }
response = requests.post(classify_url, headers=headers, json=payload)
pprint.pprint(response.json())
# /classify can accept token ids as input
token_ids = []
for prompt in prompts:
response = requests.post(
tokenize_url,
json={"model": model, "prompt": prompt},
)
token_ids.append(response.json()["tokens"])
classify_response = post_http_request(payload=payload, api_url=api_url) payload = {
pprint.pprint(classify_response.json()) "model": model,
"input": token_ids,
}
response = requests.post(classify_url, headers=headers, json=payload)
pprint.pprint(response.json())
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -16,7 +16,7 @@ from typing import Literal, NamedTuple, TypeAlias, TypedDict, get_args ...@@ -16,7 +16,7 @@ from typing import Literal, NamedTuple, TypeAlias, TypedDict, get_args
from PIL.Image import Image from PIL.Image import Image
from vllm import LLM, EngineArgs from vllm import LLM, EngineArgs
from vllm.entrypoints.score_utils import ScoreMultiModalParam from vllm.entrypoints.pooling.score.utils import ScoreMultiModalParam
from vllm.multimodal.utils import fetch_image from vllm.multimodal.utils import fetch_image
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
......
...@@ -15,7 +15,7 @@ from pathlib import Path ...@@ -15,7 +15,7 @@ from pathlib import Path
from typing import NamedTuple from typing import NamedTuple
from vllm import LLM, EngineArgs from vllm import LLM, EngineArgs
from vllm.entrypoints.score_utils import ScoreMultiModalParam from vllm.entrypoints.pooling.score.utils import ScoreMultiModalParam
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
TEMPLATE_HOME = Path(__file__).parent / "template" TEMPLATE_HOME = Path(__file__).parent / "template"
......
...@@ -12,6 +12,8 @@ from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse ...@@ -12,6 +12,8 @@ from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
DTYPE = "float32" # Use float32 to avoid NaN issue DTYPE = "float32" # Use float32 to avoid NaN issue
input_text = "This product was excellent and exceeded my expectations"
input_tokens = [1986, 1985, 572, 9073, 323, 33808, 847, 16665]
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
...@@ -29,9 +31,23 @@ def server(): ...@@ -29,9 +31,23 @@ def server():
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_single_input_classification(server: RemoteOpenAIServer, model_name: str): def test_basic(server: RemoteOpenAIServer, model_name: str):
input_text = "This product was excellent and exceeded my expectations" # test /v1/models
response = requests.get(server.url_for("/v1/models"))
served_model = response.json()["data"][0]["id"]
assert served_model == MODEL_NAME
# test /tokenize
response = requests.post(
server.url_for("/tokenize"),
json={"model": model_name, "prompt": input_text},
)
assert response.json()["tokens"] == input_tokens
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_completion_request(server: RemoteOpenAIServer, model_name: str):
# test input: str
classification_response = requests.post( classification_response = requests.post(
server.url_for("classify"), server.url_for("classify"),
json={"model": model_name, "input": input_text}, json={"model": model_name, "input": input_text},
...@@ -46,35 +62,34 @@ def test_single_input_classification(server: RemoteOpenAIServer, model_name: str ...@@ -46,35 +62,34 @@ def test_single_input_classification(server: RemoteOpenAIServer, model_name: str
assert hasattr(output.data[0], "label") assert hasattr(output.data[0], "label")
assert hasattr(output.data[0], "probs") assert hasattr(output.data[0], "probs")
# test input: list[int]
@pytest.mark.parametrize("model_name", [MODEL_NAME]) classification_response = requests.post(
def test_add_special_tokens_false(server: RemoteOpenAIServer, model_name: str):
response = requests.post(
server.url_for("classify"), server.url_for("classify"),
json={"model": model_name, "input": "hello", "add_special_tokens": False}, json={"model": model_name, "input": input_tokens},
) )
response.raise_for_status()
ClassificationResponse.model_validate(response.json()) classification_response.raise_for_status()
output = ClassificationResponse.model_validate(classification_response.json())
assert output.object == "list"
assert output.model == MODEL_NAME
assert len(output.data) == 1
assert hasattr(output.data[0], "label")
assert hasattr(output.data[0], "probs")
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_multiple_inputs_classification(server: RemoteOpenAIServer, model_name: str): def test_completion_request_batched(server: RemoteOpenAIServer, model_name: str):
input_texts = [ N = 10
"The product arrived on time and works perfectly",
"I'm very satisfied with my purchase, would buy again",
"The customer service was helpful and resolved my issue quickly",
"This product broke after one week, terrible quality",
"I'm very disappointed with this purchase, complete waste of money",
"The customer service was rude and unhelpful",
]
# test input: list[str]
classification_response = requests.post( classification_response = requests.post(
server.url_for("classify"), server.url_for("classify"),
json={"model": model_name, "input": input_texts}, json={"model": model_name, "input": [input_text] * N},
) )
output = ClassificationResponse.model_validate(classification_response.json()) output = ClassificationResponse.model_validate(classification_response.json())
assert len(output.data) == len(input_texts) assert len(output.data) == N
for i, item in enumerate(output.data): for i, item in enumerate(output.data):
assert item.index == i assert item.index == i
assert hasattr(item, "label") assert hasattr(item, "label")
...@@ -82,6 +97,44 @@ def test_multiple_inputs_classification(server: RemoteOpenAIServer, model_name: ...@@ -82,6 +97,44 @@ def test_multiple_inputs_classification(server: RemoteOpenAIServer, model_name:
assert len(item.probs) == item.num_classes assert len(item.probs) == item.num_classes
assert item.label in ["Default", "Spoiled"] assert item.label in ["Default", "Spoiled"]
# test input: list[list[int]]
classification_response = requests.post(
server.url_for("classify"),
json={"model": model_name, "input": [input_tokens] * N},
)
output = ClassificationResponse.model_validate(classification_response.json())
assert len(output.data) == N
for i, item in enumerate(output.data):
assert item.index == i
assert hasattr(item, "label")
assert hasattr(item, "probs")
assert len(item.probs) == item.num_classes
assert item.label in ["Default", "Spoiled"]
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_empty_input_error(server: RemoteOpenAIServer, model_name: str):
classification_response = requests.post(
server.url_for("classify"),
json={"model": model_name, "input": ""},
)
error = classification_response.json()
assert classification_response.status_code == 400
assert "error" in error
classification_response = requests.post(
server.url_for("classify"),
json={"model": model_name, "input": []},
)
classification_response.raise_for_status()
output = ClassificationResponse.model_validate(classification_response.json())
assert output.object == "list"
assert isinstance(output.data, list)
assert len(output.data) == 0
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_truncate_prompt_tokens(server: RemoteOpenAIServer, model_name: str): def test_truncate_prompt_tokens(server: RemoteOpenAIServer, model_name: str):
...@@ -101,11 +154,7 @@ def test_truncate_prompt_tokens(server: RemoteOpenAIServer, model_name: str): ...@@ -101,11 +154,7 @@ def test_truncate_prompt_tokens(server: RemoteOpenAIServer, model_name: str):
assert output.usage.prompt_tokens == 5 assert output.usage.prompt_tokens == 5
assert output.usage.total_tokens == 5 assert output.usage.total_tokens == 5
# invalid_truncate_prompt_tokens
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_invalid_truncate_prompt_tokens_error(
server: RemoteOpenAIServer, model_name: str
):
classification_response = requests.post( classification_response = requests.post(
server.url_for("classify"), server.url_for("classify"),
json={"model": model_name, "input": "test", "truncate_prompt_tokens": 513}, json={"model": model_name, "input": "test", "truncate_prompt_tokens": 513},
...@@ -117,36 +166,28 @@ def test_invalid_truncate_prompt_tokens_error( ...@@ -117,36 +166,28 @@ def test_invalid_truncate_prompt_tokens_error(
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_empty_input_error(server: RemoteOpenAIServer, model_name: str): def test_add_special_tokens(server: RemoteOpenAIServer, model_name: str):
classification_response = requests.post( # FIXME: The add_special_tokens parameter doesn't seem to be working.
response = requests.post(
server.url_for("classify"), server.url_for("classify"),
json={"model": model_name, "input": ""}, json={"model": model_name, "input": input_text, "add_special_tokens": False},
) )
response.raise_for_status()
ClassificationResponse.model_validate(response.json())
error = classification_response.json() response = requests.post(
assert classification_response.status_code == 400
assert "error" in error
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_batch_classification_empty_list(server: RemoteOpenAIServer, model_name: str):
classification_response = requests.post(
server.url_for("classify"), server.url_for("classify"),
json={"model": model_name, "input": []}, json={"model": model_name, "input": input_text, "add_special_tokens": True},
) )
classification_response.raise_for_status() response.raise_for_status()
output = ClassificationResponse.model_validate(classification_response.json()) ClassificationResponse.model_validate(response.json())
assert output.object == "list"
assert isinstance(output.data, list)
assert len(output.data) == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer): async def test_invocations(server: RemoteOpenAIServer):
request_args = { request_args = {
"model": MODEL_NAME, "model": MODEL_NAME,
"input": "This product was excellent and exceeded my expectations", "input": input_text,
} }
classification_response = requests.post( classification_response = requests.post(
...@@ -175,8 +216,6 @@ async def test_invocations(server: RemoteOpenAIServer): ...@@ -175,8 +216,6 @@ async def test_invocations(server: RemoteOpenAIServer):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_use_activation(server: RemoteOpenAIServer, model_name: str): async def test_use_activation(server: RemoteOpenAIServer, model_name: str):
input_text = ["This product was excellent and exceeded my expectations"]
async def get_outputs(use_activation): async def get_outputs(use_activation):
response = requests.post( response = requests.post(
server.url_for("classify"), server.url_for("classify"),
...@@ -237,7 +276,6 @@ async def test_rerank(server: RemoteOpenAIServer, model_name: str): ...@@ -237,7 +276,6 @@ async def test_rerank(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str): async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
input_text = "This product was excellent and exceeded my expectations"
response = requests.post( response = requests.post(
server.url_for("pooling"), server.url_for("pooling"),
json={ json={
...@@ -256,7 +294,6 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str): ...@@ -256,7 +294,6 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str): async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
task = "token_classify" task = "token_classify"
input_text = ["This product was excellent and exceeded my expectations"]
response = requests.post( response = requests.post(
server.url_for("pooling"), server.url_for("pooling"),
json={ json={
...@@ -282,7 +319,7 @@ async def test_pooling_not_supported( ...@@ -282,7 +319,7 @@ async def test_pooling_not_supported(
server.url_for("pooling"), server.url_for("pooling"),
json={ json={
"model": model_name, "model": model_name,
"input": "test", "input": input_text,
"encoding_format": "float", "encoding_format": "float",
"task": task, "task": task,
}, },
......
...@@ -31,7 +31,26 @@ from vllm.utils.serial_utils import ( ...@@ -31,7 +31,26 @@ from vllm.utils.serial_utils import (
MODEL_NAME = "intfloat/multilingual-e5-small" MODEL_NAME = "intfloat/multilingual-e5-small"
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
DTYPE = "bfloat16" DTYPE = "bfloat16"
input_text = "The best thing about vLLM is that it supports many different models"
input_tokens = [
0,
581,
2965,
13580,
1672,
81,
23708,
594,
83,
450,
442,
8060,
7,
5941,
12921,
115774,
2,
]
if current_platform.is_rocm(): if current_platform.is_rocm():
# Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers # Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
...@@ -79,15 +98,36 @@ def hf_model(hf_runner): ...@@ -79,15 +98,36 @@ def hf_model(hf_runner):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, model_name: str): async def test_basic(
input_texts = [ server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
"The chef prepared a delicious meal.", ):
] # test /v1/models
response = requests.get(server.url_for("/v1/models"))
model = response.json()["data"][0]["id"]
assert model == MODEL_NAME
# test single embedding models = await client.models.list()
models = models.data
served_model = models[0]
assert served_model.id == MODEL_NAME
# test /tokenize
response = requests.post(
server.url_for("/tokenize"),
json={"model": model_name, "prompt": input_text},
)
assert response.json()["tokens"] == input_tokens
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_completion_request(
client: openai.AsyncOpenAI, model_name: str, hf_model
):
# test input: str
embedding_response = await client.embeddings.create( embedding_response = await client.embeddings.create(
model=model_name, model=model_name,
input=input_texts, input=input_text,
encoding_format="float", encoding_format="float",
) )
embeddings = EmbeddingResponse.model_validate( embeddings = EmbeddingResponse.model_validate(
...@@ -98,14 +138,13 @@ async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, model_name ...@@ -98,14 +138,13 @@ async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, model_name
assert len(embeddings.data) == 1 assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 384 assert len(embeddings.data[0].embedding) == 384
assert embeddings.usage.completion_tokens == 0 assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 11 assert embeddings.usage.prompt_tokens == len(input_tokens)
assert embeddings.usage.total_tokens == 11 assert embeddings.usage.total_tokens == len(input_tokens)
vllm_outputs = [d.embedding for d in embeddings.data] vllm_outputs = [d.embedding for d in embeddings.data]
run_embedding_correctness_test(hf_model, input_texts, vllm_outputs) run_embedding_correctness_test(hf_model, [input_text], vllm_outputs)
# test using token IDs # test input: list[int]
input_tokens = [1, 1, 1, 1, 1]
embedding_response = await client.embeddings.create( embedding_response = await client.embeddings.create(
model=model_name, model=model_name,
input=input_tokens, input=input_tokens,
...@@ -119,19 +158,22 @@ async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, model_name ...@@ -119,19 +158,22 @@ async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, model_name
assert len(embeddings.data) == 1 assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 384 assert len(embeddings.data[0].embedding) == 384
assert embeddings.usage.completion_tokens == 0 assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 5 assert embeddings.usage.prompt_tokens == len(input_tokens)
assert embeddings.usage.total_tokens == 5 assert embeddings.usage.total_tokens == len(input_tokens)
vllm_outputs = [d.embedding for d in embeddings.data]
run_embedding_correctness_test(hf_model, [input_text], vllm_outputs)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, model_name: str): async def test_completion_request_batched(
# test list[str] client: openai.AsyncOpenAI, model_name: str, hf_model
input_texts = [ ):
"The cat sat on the mat.", N = 10
"A feline was resting on a rug.", input_texts = [input_text] * N
"Stars twinkle brightly in the night sky.",
] # test input: list[str]
embedding_response = await client.embeddings.create( embedding_response = await client.embeddings.create(
model=model_name, model=model_name,
input=input_texts, input=input_texts,
...@@ -142,25 +184,19 @@ async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, model_name: ...@@ -142,25 +184,19 @@ async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, model_name:
) )
assert embeddings.id is not None assert embeddings.id is not None
assert len(embeddings.data) == 3 assert len(embeddings.data) == N
assert len(embeddings.data[0].embedding) == 384 assert len(embeddings.data[0].embedding) == 384
assert embeddings.usage.completion_tokens == 0 assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 33 assert embeddings.usage.prompt_tokens == len(input_tokens) * N
assert embeddings.usage.total_tokens == 33 assert embeddings.usage.total_tokens == len(input_tokens) * N
vllm_outputs = [d.embedding for d in embeddings.data] vllm_outputs = [d.embedding for d in embeddings.data]
run_embedding_correctness_test(hf_model, input_texts, vllm_outputs) run_embedding_correctness_test(hf_model, input_texts, vllm_outputs)
# test list[list[int]] # test list[list[int]]
input_tokens = [
[4, 5, 7, 9, 20],
[15, 29, 499],
[24, 24, 24, 24, 24],
[25, 32, 64, 77],
]
embedding_response = await client.embeddings.create( embedding_response = await client.embeddings.create(
model=model_name, model=model_name,
input=input_tokens, input=[input_tokens] * N,
encoding_format="float", encoding_format="float",
) )
embeddings = EmbeddingResponse.model_validate( embeddings = EmbeddingResponse.model_validate(
...@@ -168,11 +204,14 @@ async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, model_name: ...@@ -168,11 +204,14 @@ async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, model_name:
) )
assert embeddings.id is not None assert embeddings.id is not None
assert len(embeddings.data) == 4 assert len(embeddings.data) == N
assert len(embeddings.data[0].embedding) == 384 assert len(embeddings.data[0].embedding) == 384
assert embeddings.usage.completion_tokens == 0 assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 17 assert embeddings.usage.prompt_tokens == len(input_tokens) * N
assert embeddings.usage.total_tokens == 17 assert embeddings.usage.total_tokens == len(input_tokens) * N
vllm_outputs = [d.embedding for d in embeddings.data]
run_embedding_correctness_test(hf_model, input_texts, vllm_outputs)
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -235,9 +274,162 @@ async def test_conversation_embedding( ...@@ -235,9 +274,162 @@ async def test_conversation_embedding(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_batch_base64_embedding( async def test_truncate_prompt_tokens(client: openai.AsyncOpenAI, model_name: str):
hf_model, client: openai.AsyncOpenAI, model_name: str input_texts = [
): "Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
]
# test single embedding
embedding_response = await client.embeddings.create(
model=model_name, input=input_texts, extra_body={"truncate_prompt_tokens": 10}
)
embeddings = EmbeddingResponse.model_validate(
embedding_response.model_dump(mode="json")
)
assert embeddings.id is not None
assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 384
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 10
assert embeddings.usage.total_tokens == 10
input_tokens = [
1,
24428,
289,
18341,
26165,
285,
19323,
283,
289,
26789,
3871,
28728,
9901,
340,
2229,
385,
340,
315,
28741,
28804,
2,
]
embedding_response = await client.embeddings.create(
model=model_name, input=input_tokens, extra_body={"truncate_prompt_tokens": 10}
)
embeddings = EmbeddingResponse.model_validate(
embedding_response.model_dump(mode="json")
)
assert embeddings.id is not None
assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 384
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 10
assert embeddings.usage.total_tokens == 10
# invalid_truncate_prompt_tokens
input_texts = [
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
]
with pytest.raises(openai.BadRequestError):
response = await client.embeddings.create(
model=model_name,
input=input_texts,
extra_body={"truncate_prompt_tokens": 8193},
)
assert "error" in response.object
assert (
"truncate_prompt_tokens value is greater than max_model_len. "
"Please, select a smaller truncation size." in response.message
)
@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenAI):
request_args = {
"model": MODEL_NAME,
"input": input_text,
"encoding_format": "float",
}
completion_response = await client.embeddings.create(**request_args)
invocation_response = requests.post(
server.url_for("invocations"), json=request_args
)
invocation_response.raise_for_status()
completion_output = completion_response.model_dump()
invocation_output = invocation_response.json()
assert completion_output.keys() == invocation_output.keys()
for completion_data, invocation_data in zip(
completion_output["data"], invocation_output["data"]
):
assert completion_data.keys() == invocation_data.keys()
check_embeddings_close(
embeddings_0_lst=[completion_data["embedding"]],
embeddings_1_lst=[invocation_data["embedding"]],
name_0="completion",
name_1="invocation",
)
@pytest.mark.asyncio
async def test_invocations_conversation(server: RemoteOpenAIServer):
messages = [
{
"role": "user",
"content": "The cat sat on the mat.",
},
{
"role": "assistant",
"content": "A feline was resting on a rug.",
},
{
"role": "user",
"content": "Stars twinkle brightly in the night sky.",
},
]
request_args = {
"model": MODEL_NAME,
"messages": messages,
"encoding_format": "float",
}
chat_response = requests.post(server.url_for("v1/embeddings"), json=request_args)
chat_response.raise_for_status()
invocation_response = requests.post(
server.url_for("invocations"), json=request_args
)
invocation_response.raise_for_status()
chat_output = chat_response.json()
invocation_output = invocation_response.json()
assert chat_output.keys() == invocation_output.keys()
for chat_data, invocation_data in zip(
chat_output["data"], invocation_output["data"]
):
assert chat_data.keys() == invocation_data.keys()
check_embeddings_close(
embeddings_0_lst=[chat_data["embedding"]],
embeddings_1_lst=[invocation_data["embedding"]],
name_0="chat",
name_1="invocation",
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_base64_embedding(hf_model, client: openai.AsyncOpenAI, model_name: str):
input_texts = [ input_texts = [
"Hello my name is", "Hello my name is",
"The best thing about vLLM is that it supports many different models", "The best thing about vLLM is that it supports many different models",
...@@ -273,10 +465,7 @@ async def test_batch_base64_embedding( ...@@ -273,10 +465,7 @@ async def test_batch_base64_embedding(
async def test_base64_embed_dtype_and_endianness( async def test_base64_embed_dtype_and_endianness(
server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
): ):
input_texts = [ input_texts = [input_text] * 3
"The best thing about vLLM is that it supports many different models",
]
responses_float = await client.embeddings.create( responses_float = await client.embeddings.create(
input=input_texts, model=model_name, encoding_format="float" input=input_texts, model=model_name, encoding_format="float"
) )
...@@ -315,10 +504,7 @@ async def test_base64_embed_dtype_and_endianness( ...@@ -315,10 +504,7 @@ async def test_base64_embed_dtype_and_endianness(
async def test_bytes_embed_dtype_and_endianness( async def test_bytes_embed_dtype_and_endianness(
server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
): ):
input_texts = [ input_texts = [input_text] * 3
"The best thing about vLLM is that it supports many different models",
]
responses_float = await client.embeddings.create( responses_float = await client.embeddings.create(
input=input_texts, model=model_name, encoding_format="float" input=input_texts, model=model_name, encoding_format="float"
) )
...@@ -408,15 +594,11 @@ async def test_bytes_only_embed_dtype_and_endianness( ...@@ -408,15 +594,11 @@ async def test_bytes_only_embed_dtype_and_endianness(
async def test_params_not_supported( async def test_params_not_supported(
server: RemoteOpenAIServer, model_name: str, param_name: str server: RemoteOpenAIServer, model_name: str, param_name: str
): ):
input_texts = [
"The best thing about vLLM is that it supports many different models",
]
responses_base64 = requests.post( responses_base64 = requests.post(
server.url_for("/v1/embeddings"), server.url_for("/v1/embeddings"),
json={ json={
"model": model_name, "model": model_name,
"input": input_texts, "input": input_text,
"encoding_format": "base64", "encoding_format": "base64",
param_name: f"bad_{param_name}", param_name: f"bad_{param_name}",
}, },
...@@ -427,175 +609,9 @@ async def test_params_not_supported( ...@@ -427,175 +609,9 @@ async def test_params_not_supported(
assert f"bad_{param_name}" in responses_base64.json()["error"]["message"] assert f"bad_{param_name}" in responses_base64.json()["error"]["message"]
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_embedding_truncation(client: openai.AsyncOpenAI, model_name: str):
input_texts = [
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
]
# test single embedding
embedding_response = await client.embeddings.create(
model=model_name, input=input_texts, extra_body={"truncate_prompt_tokens": 10}
)
embeddings = EmbeddingResponse.model_validate(
embedding_response.model_dump(mode="json")
)
assert embeddings.id is not None
assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 384
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 10
assert embeddings.usage.total_tokens == 10
input_tokens = [
1,
24428,
289,
18341,
26165,
285,
19323,
283,
289,
26789,
3871,
28728,
9901,
340,
2229,
385,
340,
315,
28741,
28804,
2,
]
embedding_response = await client.embeddings.create(
model=model_name, input=input_tokens, extra_body={"truncate_prompt_tokens": 10}
)
embeddings = EmbeddingResponse.model_validate(
embedding_response.model_dump(mode="json")
)
assert embeddings.id is not None
assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 384
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 10
assert embeddings.usage.total_tokens == 10
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_embedding_truncation_invalid(
client: openai.AsyncOpenAI, model_name: str
):
input_texts = [
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
]
with pytest.raises(openai.BadRequestError):
response = await client.embeddings.create(
model=model_name,
input=input_texts,
extra_body={"truncate_prompt_tokens": 8193},
)
assert "error" in response.object
assert (
"truncate_prompt_tokens value is greater than max_model_len. "
"Please, select a smaller truncation size." in response.message
)
@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenAI):
input_texts = [
"The chef prepared a delicious meal.",
]
request_args = {
"model": MODEL_NAME,
"input": input_texts,
"encoding_format": "float",
}
completion_response = await client.embeddings.create(**request_args)
invocation_response = requests.post(
server.url_for("invocations"), json=request_args
)
invocation_response.raise_for_status()
completion_output = completion_response.model_dump()
invocation_output = invocation_response.json()
assert completion_output.keys() == invocation_output.keys()
for completion_data, invocation_data in zip(
completion_output["data"], invocation_output["data"]
):
assert completion_data.keys() == invocation_data.keys()
check_embeddings_close(
embeddings_0_lst=[completion_data["embedding"]],
embeddings_1_lst=[invocation_data["embedding"]],
name_0="completion",
name_1="invocation",
)
@pytest.mark.asyncio
async def test_invocations_conversation(server: RemoteOpenAIServer):
messages = [
{
"role": "user",
"content": "The cat sat on the mat.",
},
{
"role": "assistant",
"content": "A feline was resting on a rug.",
},
{
"role": "user",
"content": "Stars twinkle brightly in the night sky.",
},
]
request_args = {
"model": MODEL_NAME,
"messages": messages,
"encoding_format": "float",
}
chat_response = requests.post(server.url_for("v1/embeddings"), json=request_args)
chat_response.raise_for_status()
invocation_response = requests.post(
server.url_for("invocations"), json=request_args
)
invocation_response.raise_for_status()
chat_output = chat_response.json()
invocation_output = invocation_response.json()
assert chat_output.keys() == invocation_output.keys()
for chat_data, invocation_data in zip(
chat_output["data"], invocation_output["data"]
):
assert chat_data.keys() == invocation_data.keys()
check_embeddings_close(
embeddings_0_lst=[chat_data["embedding"]],
embeddings_1_lst=[invocation_data["embedding"]],
name_0="chat",
name_1="invocation",
)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_normalize(server: RemoteOpenAIServer, model_name: str): async def test_normalize(server: RemoteOpenAIServer, model_name: str):
input_text = ["The chef prepared a delicious meal."]
async def get_outputs(normalize): async def get_outputs(normalize):
request_args = { request_args = {
"model": MODEL_NAME, "model": MODEL_NAME,
...@@ -626,8 +642,6 @@ async def test_normalize(server: RemoteOpenAIServer, model_name: str): ...@@ -626,8 +642,6 @@ async def test_normalize(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling_embed(server: RemoteOpenAIServer, model_name: str): async def test_pooling_embed(server: RemoteOpenAIServer, model_name: str):
task = "embed" task = "embed"
input_text = ["The chef prepared a delicious meal."]
response = requests.post( response = requests.post(
server.url_for("pooling"), server.url_for("pooling"),
json={ json={
...@@ -648,8 +662,6 @@ async def test_pooling_embed(server: RemoteOpenAIServer, model_name: str): ...@@ -648,8 +662,6 @@ async def test_pooling_embed(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str): async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str):
task = "token_embed" task = "token_embed"
input_text = ["The chef prepared a delicious meal."]
response = requests.post( response = requests.post(
server.url_for("pooling"), server.url_for("pooling"),
json={ json={
...@@ -663,7 +675,7 @@ async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str): ...@@ -663,7 +675,7 @@ async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str):
poolings = PoolingResponse.model_validate(response.json()) poolings = PoolingResponse.model_validate(response.json())
assert len(poolings.data) == 1 assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 11 assert len(poolings.data[0].data) == len(input_tokens)
assert len(poolings.data[0].data[0]) == 384 assert len(poolings.data[0].data[0]) == 384
......
...@@ -24,6 +24,8 @@ from vllm.utils.serial_utils import ( ...@@ -24,6 +24,8 @@ from vllm.utils.serial_utils import (
MODEL_NAME = "internlm/internlm2-1_8b-reward" MODEL_NAME = "internlm/internlm2-1_8b-reward"
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
input_text = "The chef prepared a delicious meal."
input_tokens = [1, 918, 29981, 10166, 395, 18067, 15265, 281]
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
...@@ -46,30 +48,40 @@ def server(): ...@@ -46,30 +48,40 @@ def server():
yield remote_server yield remote_server
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_pooling(server: RemoteOpenAIServer, model_name: str): def test_basic(server: RemoteOpenAIServer, model_name: str):
input_texts = [ # test /v1/models
"The chef prepared a delicious meal.", response = requests.get(server.url_for("/v1/models"))
] served_model = response.json()["data"][0]["id"]
assert served_model == MODEL_NAME
# test single pooling # test /tokenize
response = requests.post(
server.url_for("/tokenize"),
json={"model": model_name, "prompt": input_text},
)
assert response.json()["tokens"] == input_tokens
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_completion_request(server: RemoteOpenAIServer, model_name: str):
# test input: str
response = requests.post( response = requests.post(
server.url_for("pooling"), server.url_for("pooling"),
json={"model": model_name, "input": input_texts, "encoding_format": "float"}, json={"model": model_name, "input": input_text, "encoding_format": "float"},
) )
response.raise_for_status() response.raise_for_status()
poolings = PoolingResponse.model_validate(response.json()) poolings = PoolingResponse.model_validate(response.json())
assert poolings.id is not None assert poolings.id is not None
assert len(poolings.data) == 1 assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 8 assert len(poolings.data[0].data) == len(input_tokens)
assert poolings.usage.completion_tokens == 0 assert poolings.usage.completion_tokens == 0
assert poolings.usage.prompt_tokens == 8 assert poolings.usage.prompt_tokens == len(input_tokens)
assert poolings.usage.total_tokens == 8 assert poolings.usage.total_tokens == len(input_tokens)
# test using token IDs # test input: list[int]
input_tokens = [1, 1, 1, 1, 1]
response = requests.post( response = requests.post(
server.url_for("pooling"), server.url_for("pooling"),
json={"model": model_name, "input": input_tokens, "encoding_format": "float"}, json={"model": model_name, "input": input_tokens, "encoding_format": "float"},
...@@ -79,21 +91,17 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str): ...@@ -79,21 +91,17 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str):
assert poolings.id is not None assert poolings.id is not None
assert len(poolings.data) == 1 assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 5 assert len(poolings.data[0].data) == len(input_tokens)
assert poolings.usage.completion_tokens == 0 assert poolings.usage.completion_tokens == 0
assert poolings.usage.prompt_tokens == 5 assert poolings.usage.prompt_tokens == len(input_tokens)
assert poolings.usage.total_tokens == 5 assert poolings.usage.total_tokens == len(input_tokens)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str): def test_completion_request_batched(server: RemoteOpenAIServer, model_name: str):
# test list[str] N = 10
input_texts = [ input_texts = [input_text] * N
"The cat sat on the mat.",
"A feline was resting on a rug.",
"Stars twinkle brightly in the night sky.",
]
response = requests.post( response = requests.post(
server.url_for("pooling"), server.url_for("pooling"),
json={"model": model_name, "input": input_texts, "encoding_format": "float"}, json={"model": model_name, "input": input_texts, "encoding_format": "float"},
...@@ -102,32 +110,30 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str): ...@@ -102,32 +110,30 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
poolings = PoolingResponse.model_validate(response.json()) poolings = PoolingResponse.model_validate(response.json())
assert poolings.id is not None assert poolings.id is not None
assert len(poolings.data) == 3 assert len(poolings.data) == N
assert len(poolings.data[0].data) == 8 assert len(poolings.data[0].data) == len(input_tokens)
assert poolings.usage.completion_tokens == 0 assert poolings.usage.completion_tokens == 0
assert poolings.usage.prompt_tokens == 29 assert poolings.usage.prompt_tokens == len(input_tokens) * N
assert poolings.usage.total_tokens == 29 assert poolings.usage.total_tokens == len(input_tokens) * N
# test list[list[int]] # test list[list[int]]
input_tokens = [
[4, 5, 7, 9, 20],
[15, 29, 499],
[24, 24, 24, 24, 24],
[25, 32, 64, 77],
]
response = requests.post( response = requests.post(
server.url_for("pooling"), server.url_for("pooling"),
json={"model": model_name, "input": input_tokens, "encoding_format": "float"}, json={
"model": model_name,
"input": [input_tokens] * N,
"encoding_format": "float",
},
) )
response.raise_for_status() response.raise_for_status()
poolings = PoolingResponse.model_validate(response.json()) poolings = PoolingResponse.model_validate(response.json())
assert poolings.id is not None assert poolings.id is not None
assert len(poolings.data) == 4 assert len(poolings.data) == N
assert len(poolings.data[0].data) == 5 assert len(poolings.data[0].data) == len(input_tokens)
assert poolings.usage.completion_tokens == 0 assert poolings.usage.completion_tokens == 0
assert poolings.usage.prompt_tokens == 17 assert poolings.usage.prompt_tokens == len(input_tokens) * N
assert poolings.usage.total_tokens == 17 assert poolings.usage.total_tokens == len(input_tokens) * N
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -259,9 +265,7 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, model_name: str) ...@@ -259,9 +265,7 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, model_name: str)
async def test_base64_embed_dtype_and_endianness( async def test_base64_embed_dtype_and_endianness(
server: RemoteOpenAIServer, model_name: str server: RemoteOpenAIServer, model_name: str
): ):
input_texts = [ input_texts = [input_text] * 3
"The best thing about vLLM is that it supports many different models",
]
url = server.url_for("pooling") url = server.url_for("pooling")
float_response = requests.post( float_response = requests.post(
...@@ -308,9 +312,7 @@ async def test_base64_embed_dtype_and_endianness( ...@@ -308,9 +312,7 @@ async def test_base64_embed_dtype_and_endianness(
async def test_bytes_embed_dtype_and_endianness( async def test_bytes_embed_dtype_and_endianness(
server: RemoteOpenAIServer, model_name: str server: RemoteOpenAIServer, model_name: str
): ):
input_texts = [ input_texts = [input_text] * 3
"The best thing about vLLM is that it supports many different models",
]
url = server.url_for("pooling") url = server.url_for("pooling")
float_response = requests.post( float_response = requests.post(
...@@ -358,9 +360,7 @@ async def test_bytes_embed_dtype_and_endianness( ...@@ -358,9 +360,7 @@ async def test_bytes_embed_dtype_and_endianness(
async def test_bytes_only_embed_dtype_and_endianness( async def test_bytes_only_embed_dtype_and_endianness(
server: RemoteOpenAIServer, model_name: str server: RemoteOpenAIServer, model_name: str
): ):
input_texts = [ input_texts = [input_text] * 3
"The best thing about vLLM is that it supports many different models",
] * 2
url = server.url_for("pooling") url = server.url_for("pooling")
float_response = requests.post( float_response = requests.post(
...@@ -414,15 +414,11 @@ async def test_bytes_only_embed_dtype_and_endianness( ...@@ -414,15 +414,11 @@ async def test_bytes_only_embed_dtype_and_endianness(
async def test_params_not_supported( async def test_params_not_supported(
server: RemoteOpenAIServer, model_name: str, param_name: str server: RemoteOpenAIServer, model_name: str, param_name: str
): ):
input_texts = [
"The best thing about vLLM is that it supports many different models",
]
responses_base64 = requests.post( responses_base64 = requests.post(
server.url_for("pooling"), server.url_for("pooling"),
json={ json={
"model": model_name, "model": model_name,
"input": input_texts, "input": input_text,
"encoding_format": "base64", "encoding_format": "base64",
param_name: f"bad_{param_name}", param_name: f"bad_{param_name}",
}, },
...@@ -435,13 +431,9 @@ async def test_params_not_supported( ...@@ -435,13 +431,9 @@ async def test_params_not_supported(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer): async def test_invocations(server: RemoteOpenAIServer):
input_texts = [
"The chef prepared a delicious meal.",
]
request_args = { request_args = {
"model": MODEL_NAME, "model": MODEL_NAME,
"input": input_texts, "input": input_text,
"encoding_format": "float", "encoding_format": "float",
} }
......
...@@ -13,6 +13,8 @@ from vllm.platforms import current_platform ...@@ -13,6 +13,8 @@ from vllm.platforms import current_platform
MODEL_NAME = "BAAI/bge-reranker-base" MODEL_NAME = "BAAI/bge-reranker-base"
DTYPE = "bfloat16" DTYPE = "bfloat16"
input_text = "This product was excellent and exceeded my expectations"
input_tokens = [0, 3293, 12996, 509, 40881, 136, 204839, 297, 759, 202702, 2]
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
...@@ -27,6 +29,21 @@ def server(): ...@@ -27,6 +29,21 @@ def server():
yield remote_server yield remote_server
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_basic(server: RemoteOpenAIServer, model_name: str):
# test /v1/models
response = requests.get(server.url_for("/v1/models"))
served_model = response.json()["data"][0]["id"]
assert served_model == MODEL_NAME
# test /tokenize
response = requests.post(
server.url_for("/tokenize"),
json={"model": model_name, "prompt": input_text},
)
assert response.json()["tokens"] == input_tokens
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_rerank_texts(server: RemoteOpenAIServer, model_name: str): def test_rerank_texts(server: RemoteOpenAIServer, model_name: str):
query = "What is the capital of France?" query = "What is the capital of France?"
...@@ -170,7 +187,6 @@ async def test_use_activation(server: RemoteOpenAIServer, model_name: str): ...@@ -170,7 +187,6 @@ async def test_use_activation(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str): async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
input_text = "This product was excellent and exceeded my expectations"
response = requests.post( response = requests.post(
server.url_for("pooling"), server.url_for("pooling"),
json={ json={
...@@ -188,8 +204,6 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str): ...@@ -188,8 +204,6 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str): async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
input_text = ["The chef prepared a delicious meal."]
response = requests.post( response = requests.post(
server.url_for("pooling"), server.url_for("pooling"),
json={"model": model_name, "input": input_text, "encoding_format": "float"}, json={"model": model_name, "input": input_text, "encoding_format": "float"},
...@@ -198,7 +212,7 @@ async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: st ...@@ -198,7 +212,7 @@ async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: st
poolings = PoolingResponse.model_validate(response.json()) poolings = PoolingResponse.model_validate(response.json())
assert len(poolings.data) == 1 assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 11 assert len(poolings.data[0].data) == len(input_tokens)
assert len(poolings.data[0].data[0]) == 1 assert len(poolings.data[0].data[0]) == 1
...@@ -212,7 +226,7 @@ async def test_pooling_not_supported( ...@@ -212,7 +226,7 @@ async def test_pooling_not_supported(
server.url_for("pooling"), server.url_for("pooling"),
json={ json={
"model": model_name, "model": model_name,
"input": "test", "input": input_text,
"encoding_format": "float", "encoding_format": "float",
"task": task, "task": task,
}, },
......
...@@ -7,7 +7,7 @@ import pytest ...@@ -7,7 +7,7 @@ import pytest
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateResolutionError from vllm.entrypoints.chat_utils import ChatTemplateResolutionError
from vllm.entrypoints.score_utils import get_score_prompt from vllm.entrypoints.pooling.score.utils import get_score_prompt
from vllm.inputs import TokensPrompt from vllm.inputs import TokensPrompt
from vllm.tokenizers import get_tokenizer from vllm.tokenizers import get_tokenizer
...@@ -212,7 +212,7 @@ class TestGetScorePrompt: ...@@ -212,7 +212,7 @@ class TestGetScorePrompt:
return_value=mock_model_no_score_template, return_value=mock_model_no_score_template,
), ),
patch( patch(
"vllm.entrypoints.score_utils.apply_hf_chat_template", "vllm.entrypoints.pooling.score.utils.apply_hf_chat_template",
return_value="test querytest doc", return_value="test querytest doc",
), ),
): ):
...@@ -245,7 +245,7 @@ class TestGetScorePrompt: ...@@ -245,7 +245,7 @@ class TestGetScorePrompt:
return_value=mock_model_no_score_template, return_value=mock_model_no_score_template,
), ),
patch( patch(
"vllm.entrypoints.score_utils.apply_hf_chat_template", "vllm.entrypoints.pooling.score.utils.apply_hf_chat_template",
side_effect=ChatTemplateResolutionError("No template"), side_effect=ChatTemplateResolutionError("No template"),
), ),
): ):
...@@ -296,7 +296,7 @@ class TestGetScorePrompt: ...@@ -296,7 +296,7 @@ class TestGetScorePrompt:
return_value=mock_model_no_score_template, return_value=mock_model_no_score_template,
), ),
patch( patch(
"vllm.entrypoints.score_utils.apply_hf_chat_template", "vllm.entrypoints.pooling.score.utils.apply_hf_chat_template",
side_effect=ChatTemplateResolutionError("No template"), side_effect=ChatTemplateResolutionError("No template"),
), ),
): ):
...@@ -331,7 +331,7 @@ class TestGetScorePrompt: ...@@ -331,7 +331,7 @@ class TestGetScorePrompt:
return_value=mock_model_with_score_template, return_value=mock_model_with_score_template,
), ),
patch( patch(
"vllm.entrypoints.score_utils.apply_hf_chat_template", "vllm.entrypoints.pooling.score.utils.apply_hf_chat_template",
side_effect=ChatTemplateResolutionError("No template"), side_effect=ChatTemplateResolutionError("No template"),
), ),
): ):
......
...@@ -10,7 +10,7 @@ from vllm.entrypoints.chat_utils import ( ...@@ -10,7 +10,7 @@ from vllm.entrypoints.chat_utils import (
ChatCompletionContentPartImageParam, ChatCompletionContentPartImageParam,
ChatCompletionContentPartTextParam, ChatCompletionContentPartTextParam,
) )
from vllm.entrypoints.score_utils import ScoreMultiModalParam from vllm.entrypoints.pooling.score.utils import ScoreMultiModalParam
from ....conftest import HfRunner, VllmRunner from ....conftest import HfRunner, VllmRunner
......
...@@ -42,7 +42,7 @@ from vllm.entrypoints.chat_utils import ( ...@@ -42,7 +42,7 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages, parse_chat_messages,
resolve_chat_template_content_format, resolve_chat_template_content_format,
) )
from vllm.entrypoints.score_utils import ( from vllm.entrypoints.pooling.score.utils import (
ScoreContentPartParam, ScoreContentPartParam,
ScoreMultiModalParam, ScoreMultiModalParam,
_cosine_similarity, _cosine_similarity,
......
...@@ -54,10 +54,6 @@ from vllm.entrypoints.openai.translations.serving import ( ...@@ -54,10 +54,6 @@ from vllm.entrypoints.openai.translations.serving import (
OpenAIServingTranscription, OpenAIServingTranscription,
OpenAIServingTranslation, OpenAIServingTranslation,
) )
from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
from vllm.entrypoints.pooling.score.serving import ServingScores
from vllm.entrypoints.serve.disagg.serving import ServingTokens from vllm.entrypoints.serve.disagg.serving import ServingTokens
from vllm.entrypoints.serve.elastic_ep.middleware import ( from vllm.entrypoints.serve.elastic_ep.middleware import (
ScalingMiddleware, ScalingMiddleware,
...@@ -73,7 +69,6 @@ from vllm.entrypoints.utils import ( ...@@ -73,7 +69,6 @@ from vllm.entrypoints.utils import (
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager from vllm.reasoning import ReasoningParserManager
from vllm.tasks import POOLING_TASKS
from vllm.tool_parsers import ToolParserManager from vllm.tool_parsers import ToolParserManager
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
...@@ -761,59 +756,6 @@ async def init_app_state( ...@@ -761,59 +756,6 @@ async def init_app_state(
if "generate" in supported_tasks if "generate" in supported_tasks
else None else None
) )
state.openai_serving_pooling = (
(
OpenAIServingPooling(
engine_client,
state.openai_serving_models,
supported_tasks=supported_tasks,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
)
)
if any(task in POOLING_TASKS for task in supported_tasks)
else None
)
state.openai_serving_embedding = (
OpenAIServingEmbedding(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
)
if "embed" in supported_tasks
else None
)
state.openai_serving_classification = (
ServingClassification(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
)
if "classify" in supported_tasks
else None
)
state.openai_serving_scores = (
ServingScores(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
score_template=resolved_chat_template,
log_error_stack=args.log_error_stack,
)
if ("embed" in supported_tasks or "score" in supported_tasks)
else None
)
state.openai_serving_tokenization = OpenAIServingTokenization( state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client, engine_client,
state.openai_serving_models, state.openai_serving_models,
...@@ -878,6 +820,10 @@ async def init_app_state( ...@@ -878,6 +820,10 @@ async def init_app_state(
else None else None
) )
from vllm.entrypoints.pooling import init_pooling_state
await init_pooling_state(engine_client, state, args)
state.enable_server_load_tracking = args.enable_server_load_tracking state.enable_server_load_tracking = args.enable_server_load_tracking
state.server_load_metrics = 0 state.server_load_metrics = 0
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
from fastapi import FastAPI from fastapi import FastAPI
if TYPE_CHECKING:
from argparse import Namespace
from starlette.datastructures import State
from vllm.engine.protocol import EngineClient
def register_pooling_api_routers(app: FastAPI): def register_pooling_api_routers(app: FastAPI):
from vllm.entrypoints.pooling.classify.api_router import router as classify_router from vllm.entrypoints.pooling.classify.api_router import router as classify_router
...@@ -14,3 +23,82 @@ def register_pooling_api_routers(app: FastAPI): ...@@ -14,3 +23,82 @@ def register_pooling_api_routers(app: FastAPI):
app.include_router(embed_router) app.include_router(embed_router)
app.include_router(score_router) app.include_router(score_router)
app.include_router(pooling_router) app.include_router(pooling_router)
async def init_pooling_state(
engine_client: "EngineClient", state: "State", args: "Namespace"
):
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
from vllm.entrypoints.pooling.score.serving import ServingScores
from vllm.entrypoints.utils import process_chat_template
from vllm.tasks import POOLING_TASKS
supported_tasks = await engine_client.get_supported_tasks()
vllm_config = engine_client.vllm_config
resolved_chat_template = await process_chat_template(
args.chat_template, engine_client, vllm_config.model_config
)
if args.enable_log_requests:
request_logger = RequestLogger(max_log_len=args.max_log_len)
else:
request_logger = None
state.openai_serving_pooling = (
(
OpenAIServingPooling(
engine_client,
state.openai_serving_models,
supported_tasks=supported_tasks,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
)
)
if any(task in POOLING_TASKS for task in supported_tasks)
else None
)
state.openai_serving_embedding = (
OpenAIServingEmbedding(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
)
if "embed" in supported_tasks
else None
)
state.openai_serving_classification = (
ServingClassification(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
)
if "classify" in supported_tasks
else None
)
state.openai_serving_scores = (
ServingScores(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
score_template=resolved_chat_template,
log_error_stack=args.log_error_stack,
)
if ("embed" in supported_tasks or "score" in supported_tasks)
else None
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Annotated
from pydantic import Field
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
from vllm.utils import random_uuid
class PoolingBasicRequestMixin(OpenAIBaseModel):
model: str | None = None
user: str | None = None
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
request_id: str = Field(
default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
class CompletionRequestMixin(OpenAIBaseModel):
input: list[int] | list[list[int]] | str | list[str]
add_special_tokens: bool = Field(
default=True,
description=(
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."
),
)
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time import time
from typing import Annotated, Any, TypeAlias from typing import Any, TypeAlias
from pydantic import ( from pydantic import (
Field, Field,
...@@ -12,39 +12,15 @@ from vllm import PoolingParams ...@@ -12,39 +12,15 @@ from vllm import PoolingParams
from vllm.config.pooler import get_use_activation from vllm.config.pooler import get_use_activation
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
CompletionRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.utils import random_uuid from vllm.utils import random_uuid
class ClassificationCompletionRequest(OpenAIBaseModel): class ClassificationCompletionRequest(PoolingBasicRequestMixin, CompletionRequestMixin):
model: str | None = None
input: list[str] | str
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
user: str | None = None
# --8<-- [start:classification-extra-params] # --8<-- [start:classification-extra-params]
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
add_special_tokens: bool = Field(
default=True,
description=(
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."
),
)
request_id: str = Field(
default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
softmax: bool | None = Field( softmax: bool | None = Field(
default=None, default=None,
description="softmax will be deprecated, please use use_activation instead.", description="softmax will be deprecated, please use use_activation instead.",
...@@ -69,11 +45,8 @@ class ClassificationCompletionRequest(OpenAIBaseModel): ...@@ -69,11 +45,8 @@ class ClassificationCompletionRequest(OpenAIBaseModel):
) )
class ClassificationChatRequest(OpenAIBaseModel): class ClassificationChatRequest(PoolingBasicRequestMixin):
model: str | None = None
messages: list[ChatCompletionMessageParam] messages: list[ChatCompletionMessageParam]
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
user: str | None = None
# --8<-- [start:chat-classification-extra-params] # --8<-- [start:chat-classification-extra-params]
add_generation_prompt: bool = Field( add_generation_prompt: bool = Field(
...@@ -119,23 +92,6 @@ class ClassificationChatRequest(OpenAIBaseModel): ...@@ -119,23 +92,6 @@ class ClassificationChatRequest(OpenAIBaseModel):
description=("Additional kwargs to pass to the HF processor."), description=("Additional kwargs to pass to the HF processor."),
) )
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
request_id: str = Field(
default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
softmax: bool | None = Field( softmax: bool | None = Field(
default=None, default=None,
description="softmax will be deprecated, please use use_activation instead.", description="softmax will be deprecated, please use use_activation instead.",
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time import time
from typing import Annotated, Any, TypeAlias from typing import Any, TypeAlias
from pydantic import ( from pydantic import (
Field, Field,
...@@ -11,44 +11,22 @@ from pydantic import ( ...@@ -11,44 +11,22 @@ from pydantic import (
from vllm import PoolingParams from vllm import PoolingParams
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
CompletionRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
class EmbeddingCompletionRequest(OpenAIBaseModel): class EmbeddingCompletionRequest(PoolingBasicRequestMixin, CompletionRequestMixin):
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings # https://platform.openai.com/docs/api-reference/embeddings
model: str | None = None
input: list[int] | list[list[int]] | str | list[str]
encoding_format: EncodingFormat = "float" encoding_format: EncodingFormat = "float"
dimensions: int | None = None dimensions: int | None = None
user: str | None = None
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
# --8<-- [start:embedding-extra-params] # --8<-- [start:embedding-extra-params]
add_special_tokens: bool = Field(
default=True,
description=(
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."
),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
request_id: str = Field(
default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
normalize: bool | None = Field( normalize: bool | None = Field(
default=None, default=None,
description="Whether to normalize the embeddings outputs. Default is True.", description="Whether to normalize the embeddings outputs. Default is True.",
...@@ -73,20 +51,17 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): ...@@ -73,20 +51,17 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams( return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
dimensions=self.dimensions, dimensions=self.dimensions,
use_activation=self.normalize, use_activation=self.normalize,
truncate_prompt_tokens=self.truncate_prompt_tokens,
) )
class EmbeddingChatRequest(OpenAIBaseModel): class EmbeddingChatRequest(PoolingBasicRequestMixin):
model: str | None = None
messages: list[ChatCompletionMessageParam] messages: list[ChatCompletionMessageParam]
encoding_format: EncodingFormat = "float" encoding_format: EncodingFormat = "float"
dimensions: int | None = None dimensions: int | None = None
user: str | None = None
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
# --8<-- [start:chat-embedding-extra-params] # --8<-- [start:chat-embedding-extra-params]
add_generation_prompt: bool = Field( add_generation_prompt: bool = Field(
...@@ -137,22 +112,6 @@ class EmbeddingChatRequest(OpenAIBaseModel): ...@@ -137,22 +112,6 @@ class EmbeddingChatRequest(OpenAIBaseModel):
default=None, default=None,
description=("Additional kwargs to pass to the HF processor."), description=("Additional kwargs to pass to the HF processor."),
) )
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
request_id: str = Field(
default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
normalize: bool | None = Field( normalize: bool | None = Field(
default=None, default=None,
description="Whether to normalize the embeddings outputs. Default is True.", description="Whether to normalize the embeddings outputs. Default is True.",
......
...@@ -10,6 +10,7 @@ from pydantic import ( ...@@ -10,6 +10,7 @@ from pydantic import (
from vllm import PoolingParams from vllm import PoolingParams
from vllm.config.pooler import get_use_activation from vllm.config.pooler import get_use_activation
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import PoolingBasicRequestMixin
from vllm.entrypoints.pooling.embed.protocol import ( from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingChatRequest, EmbeddingChatRequest,
EmbeddingCompletionRequest, EmbeddingCompletionRequest,
...@@ -72,17 +73,8 @@ class PoolingChatRequest(EmbeddingChatRequest): ...@@ -72,17 +73,8 @@ class PoolingChatRequest(EmbeddingChatRequest):
T = TypeVar("T") T = TypeVar("T")
class IOProcessorRequest(OpenAIBaseModel, Generic[T]): class IOProcessorRequest(PoolingBasicRequestMixin, Generic[T]):
model: str | None = None
priority: int = Field(default=0)
"""
The priority of the request (lower means earlier handling;
default: 0). Any priority other than 0 will raise an error
if the served model does not use priority scheduling.
"""
data: T data: T
task: PoolingTask = "plugin" task: PoolingTask = "plugin"
encoding_format: EncodingFormat = "float" encoding_format: EncodingFormat = "float"
embed_dtype: EmbedDType = Field( embed_dtype: EmbedDType = Field(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time import time
from typing import Annotated, Any from typing import Any
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
...@@ -11,32 +11,24 @@ from pydantic import ( ...@@ -11,32 +11,24 @@ from pydantic import (
from vllm import PoolingParams from vllm import PoolingParams
from vllm.config.pooler import get_use_activation from vllm.config.pooler import get_use_activation
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.score_utils import ScoreContentPartParam, ScoreMultiModalParam from vllm.entrypoints.pooling.base.protocol import PoolingBasicRequestMixin
from vllm.entrypoints.pooling.score.utils import (
ScoreContentPartParam,
ScoreMultiModalParam,
)
from vllm.utils import random_uuid from vllm.utils import random_uuid
class ScoreRequest(OpenAIBaseModel): class ScoreRequest(PoolingBasicRequestMixin):
model: str | None = None
text_1: list[str] | str | ScoreMultiModalParam text_1: list[str] | str | ScoreMultiModalParam
text_2: list[str] | str | ScoreMultiModalParam text_2: list[str] | str | ScoreMultiModalParam
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
# --8<-- [start:score-extra-params] # --8<-- [start:score-extra-params]
mm_processor_kwargs: dict[str, Any] | None = Field( mm_processor_kwargs: dict[str, Any] | None = Field(
default=None, default=None,
description=("Additional kwargs to pass to the HF processor."), description=("Additional kwargs to pass to the HF processor."),
) )
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
softmax: bool | None = Field( softmax: bool | None = Field(
default=None, default=None,
description="softmax will be deprecated, please use use_activation instead.", description="softmax will be deprecated, please use use_activation instead.",
...@@ -61,29 +53,16 @@ class ScoreRequest(OpenAIBaseModel): ...@@ -61,29 +53,16 @@ class ScoreRequest(OpenAIBaseModel):
) )
class RerankRequest(OpenAIBaseModel): class RerankRequest(PoolingBasicRequestMixin):
model: str | None = None
query: str | ScoreMultiModalParam query: str | ScoreMultiModalParam
documents: list[str] | ScoreMultiModalParam documents: list[str] | ScoreMultiModalParam
top_n: int = Field(default_factory=lambda: 0) top_n: int = Field(default_factory=lambda: 0)
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
# --8<-- [start:rerank-extra-params] # --8<-- [start:rerank-extra-params]
mm_processor_kwargs: dict[str, Any] | None = Field( mm_processor_kwargs: dict[str, Any] | None = Field(
default=None, default=None,
description=("Additional kwargs to pass to the HF processor."), description=("Additional kwargs to pass to the HF processor."),
) )
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
softmax: bool | None = Field( softmax: bool | None = Field(
default=None, default=None,
description="softmax will be deprecated, please use use_activation instead.", description="softmax will be deprecated, please use use_activation instead.",
......
...@@ -25,7 +25,7 @@ from vllm.entrypoints.pooling.score.protocol import ( ...@@ -25,7 +25,7 @@ from vllm.entrypoints.pooling.score.protocol import (
ScoreResponse, ScoreResponse,
ScoreResponseData, ScoreResponseData,
) )
from vllm.entrypoints.score_utils import ( from vllm.entrypoints.pooling.score.utils import (
ScoreContentPartParam, ScoreContentPartParam,
ScoreMultiModalParam, ScoreMultiModalParam,
_cosine_similarity, _cosine_similarity,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment