Unverified Commit 1c3c9757 authored by Gabriel Marinho's avatar Gabriel Marinho Committed by GitHub
Browse files

[FEATURE] Enables /score endpoint for embedding models (#12846)

parent 1cdc8861
...@@ -108,8 +108,7 @@ A code example can be found here: <gh-file:examples/offline_inference/basic/clas ...@@ -108,8 +108,7 @@ A code example can be found here: <gh-file:examples/offline_inference/basic/clas
### `LLM.score` ### `LLM.score`
The {class}`~vllm.LLM.score` method outputs similarity scores between sentence pairs. The {class}`~vllm.LLM.score` method outputs similarity scores between sentence pairs.
It is primarily designed for [cross-encoder models](https://www.sbert.net/examples/applications/cross-encoder/README.html). It is designed for embedding models and cross encoder models. Embedding models use cosine similarity, and [cross-encoder models](https://www.sbert.net/examples/applications/cross-encoder/README.html) serve as rerankers between candidate query-document pairs in RAG systems.
These types of models serve as rerankers between candidate query-document pairs in RAG systems.
:::{note} :::{note}
vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG. vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG.
......
...@@ -51,7 +51,7 @@ In addition, we have the following custom APIs: ...@@ -51,7 +51,7 @@ In addition, we have the following custom APIs:
- [Pooling API](#pooling-api) (`/pooling`) - [Pooling API](#pooling-api) (`/pooling`)
- Applicable to all [pooling models](../models/pooling_models.md). - Applicable to all [pooling models](../models/pooling_models.md).
- [Score API](#score-api) (`/score`) - [Score API](#score-api) (`/score`)
- Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`). - Applicable to embedding models and [cross-encoder models](../models/pooling_models.md) (`--task score`).
- [Re-rank API](#rerank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`) - [Re-rank API](#rerank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`)
- Implements [Jina AI's v1 re-rank API](https://jina.ai/reranker/) - Implements [Jina AI's v1 re-rank API](https://jina.ai/reranker/)
- Also compatible with [Cohere's v1 & v2 re-rank APIs](https://docs.cohere.com/v2/reference/rerank) - Also compatible with [Cohere's v1 & v2 re-rank APIs](https://docs.cohere.com/v2/reference/rerank)
...@@ -333,10 +333,10 @@ Code example: <gh-file:examples/online_serving/openai_pooling_client.py> ...@@ -333,10 +333,10 @@ Code example: <gh-file:examples/online_serving/openai_pooling_client.py>
### Score API ### Score API
Our Score API applies a cross-encoder model to predict scores for sentence pairs. Our Score API can apply a cross-encoder model or an embedding model to predict scores for sentence pairs. When using an embedding model the score corresponds to the cosine similarity between each embedding pair.
Usually, the score for a sentence pair refers to the similarity between two sentences, on a scale of 0 to 1. Usually, the score for a sentence pair refers to the similarity between two sentences, on a scale of 0 to 1.
You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
Code example: <gh-file:examples/online_serving/openai_cross_encoder_score.py> Code example: <gh-file:examples/online_serving/openai_cross_encoder_score.py>
...@@ -496,11 +496,11 @@ The following extra parameters are supported: ...@@ -496,11 +496,11 @@ The following extra parameters are supported:
### Re-rank API ### Re-rank API
Our Re-rank API applies a cross-encoder model to predict relevant scores between a single query, and Our Re-rank API can apply an embedding model or a cross-encoder model to predict relevant scores between a single query, and
each of a list of documents. Usually, the score for a sentence pair refers to the similarity between two sentences, on each of a list of documents. Usually, the score for a sentence pair refers to the similarity between two sentences, on
a scale of 0 to 1. a scale of 0 to 1.
You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
The rerank endpoints support popular re-rank models such as `BAAI/bge-reranker-base` and other models supporting the The rerank endpoints support popular re-rank models such as `BAAI/bge-reranker-base` and other models supporting the
`score` task. Additionally, `/rerank`, `/v1/rerank`, and `/v2/rerank` `score` task. Additionally, `/rerank`, `/v1/rerank`, and `/v2/rerank`
......
...@@ -8,17 +8,17 @@ from vllm.entrypoints.openai.protocol import RerankResponse ...@@ -8,17 +8,17 @@ from vllm.entrypoints.openai.protocol import RerankResponse
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODEL_NAME = "BAAI/bge-reranker-base" MODEL_NAME = "BAAI/bge-reranker-base"
DTYPE = "bfloat16"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
args = ["--enforce-eager", "--max-model-len", "100"] args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server yield remote_server
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_rerank_texts(server: RemoteOpenAIServer, model_name: str): def test_rerank_texts(server: RemoteOpenAIServer, model_name: str):
query = "What is the capital of France?" query = "What is the capital of France?"
...@@ -42,7 +42,6 @@ def test_rerank_texts(server: RemoteOpenAIServer, model_name: str): ...@@ -42,7 +42,6 @@ def test_rerank_texts(server: RemoteOpenAIServer, model_name: str):
assert rerank.results[1].relevance_score <= 0.01 assert rerank.results[1].relevance_score <= 0.01
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_top_n(server: RemoteOpenAIServer, model_name: str): def test_top_n(server: RemoteOpenAIServer, model_name: str):
query = "What is the capital of France?" query = "What is the capital of France?"
...@@ -68,7 +67,6 @@ def test_top_n(server: RemoteOpenAIServer, model_name: str): ...@@ -68,7 +67,6 @@ def test_top_n(server: RemoteOpenAIServer, model_name: str):
assert rerank.results[1].relevance_score <= 0.01 assert rerank.results[1].relevance_score <= 0.01
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str): def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import math
from typing import Any
import pytest import pytest
import requests import requests
import torch.nn.functional as F
from torch import tensor
from vllm.entrypoints.openai.protocol import ScoreResponse from vllm.entrypoints.openai.protocol import ScoreResponse
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODEL_NAME = "BAAI/bge-reranker-v2-m3" MODELS = [
{
"name": "BAAI/bge-reranker-v2-m3",
"is_cross_encoder": True
},
{
"name": "BAAI/bge-base-en-v1.5",
"is_cross_encoder": False
},
]
DTYPE = "half"
def run_transformers(hf_model, model, text_pairs):
if model["is_cross_encoder"]:
return hf_model.predict(text_pairs).tolist()
else:
hf_embeddings = [
hf_model.encode(text_pair) for text_pair in text_pairs
]
return [
F.cosine_similarity(tensor(pair[0]), tensor(pair[1]), dim=0)
for pair in hf_embeddings
]
@pytest.fixture(scope="class", params=MODELS)
def model(request):
yield request.param
@pytest.fixture(scope="module")
def server():
args = ["--enforce-eager", "--max-model-len", "100"]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @pytest.fixture(scope="class")
def server(model: dict[str, Any]):
args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE]
with RemoteOpenAIServer(model["name"], args) as remote_server:
yield remote_server yield remote_server
@pytest.mark.asyncio @pytest.fixture(scope="class")
@pytest.mark.parametrize("model_name", [MODEL_NAME]) def runner(model: dict[str, Any], hf_runner):
def test_text_1_str_text_2_list(server: RemoteOpenAIServer, model_name: str): kwargs = {
"dtype": DTYPE,
"is_cross_encoder" if model["is_cross_encoder"]\
else "is_sentence_transformer": True
}
with hf_runner(model["name"], **kwargs) as hf_model:
yield hf_model
class TestModel:
def test_text_1_str_text_2_list(self, server: RemoteOpenAIServer,
model: dict[str, Any], runner):
text_1 = "What is the capital of France?" text_1 = "What is the capital of France?"
text_2 = [ text_2 = [
"The capital of Brazil is Brasilia.", "The capital of France is Paris." "The capital of Brazil is Brasilia.",
"The capital of France is Paris."
] ]
score_response = requests.post(server.url_for("score"), score_response = requests.post(server.url_for("score"),
json={ json={
"model": model_name, "model": model["name"],
"text_1": text_1, "text_1": text_1,
"text_2": text_2, "text_2": text_2,
}) })
...@@ -38,24 +85,29 @@ def test_text_1_str_text_2_list(server: RemoteOpenAIServer, model_name: str): ...@@ -38,24 +85,29 @@ def test_text_1_str_text_2_list(server: RemoteOpenAIServer, model_name: str):
assert score.id is not None assert score.id is not None
assert score.data is not None assert score.data is not None
assert len(score.data) == 2 assert len(score.data) == 2
assert score.data[0].score <= 0.01
assert score.data[1].score >= 0.9
vllm_outputs = [d.score for d in score.data]
text_pairs = [[text_1, text_2[0]], [text_1, text_2[1]]]
hf_outputs = run_transformers(runner, model, text_pairs)
@pytest.mark.asyncio for i in range(len(vllm_outputs)):
@pytest.mark.parametrize("model_name", [MODEL_NAME]) assert math.isclose(hf_outputs[i], vllm_outputs[i], rel_tol=0.01)
def test_text_1_list_text_2_list(server: RemoteOpenAIServer, model_name: str):
def test_text_1_list_text_2_list(self, server: RemoteOpenAIServer,
model: dict[str, Any], runner):
text_1 = [ text_1 = [
"What is the capital of the United States?", "What is the capital of the United States?",
"What is the capital of France?" "What is the capital of France?"
] ]
text_2 = [ text_2 = [
"The capital of Brazil is Brasilia.", "The capital of France is Paris." "The capital of Brazil is Brasilia.",
"The capital of France is Paris."
] ]
score_response = requests.post(server.url_for("score"), score_response = requests.post(server.url_for("score"),
json={ json={
"model": model_name, "model": model["name"],
"text_1": text_1, "text_1": text_1,
"text_2": text_2, "text_2": text_2,
}) })
...@@ -65,19 +117,23 @@ def test_text_1_list_text_2_list(server: RemoteOpenAIServer, model_name: str): ...@@ -65,19 +117,23 @@ def test_text_1_list_text_2_list(server: RemoteOpenAIServer, model_name: str):
assert score.id is not None assert score.id is not None
assert score.data is not None assert score.data is not None
assert len(score.data) == 2 assert len(score.data) == 2
assert score.data[0].score <= 0.01
assert score.data[1].score >= 0.9
vllm_outputs = [d.score for d in score.data]
text_pairs = [[text_1[0], text_2[0]], [text_1[1], text_2[1]]]
hf_outputs = run_transformers(runner, model, text_pairs)
@pytest.mark.asyncio for i in range(len(vllm_outputs)):
@pytest.mark.parametrize("model_name", [MODEL_NAME]) assert math.isclose(hf_outputs[i], vllm_outputs[i], rel_tol=0.01)
def test_text_1_str_text_2_str(server: RemoteOpenAIServer, model_name: str):
def test_text_1_str_text_2_str(self, server: RemoteOpenAIServer,
model: dict[str, Any], runner):
text_1 = "What is the capital of France?" text_1 = "What is the capital of France?"
text_2 = "The capital of France is Paris." text_2 = "The capital of France is Paris."
score_response = requests.post(server.url_for("score"), score_response = requests.post(server.url_for("score"),
json={ json={
"model": model_name, "model": model["name"],
"text_1": text_1, "text_1": text_1,
"text_2": text_2, "text_2": text_2,
}) })
...@@ -87,21 +143,27 @@ def test_text_1_str_text_2_str(server: RemoteOpenAIServer, model_name: str): ...@@ -87,21 +143,27 @@ def test_text_1_str_text_2_str(server: RemoteOpenAIServer, model_name: str):
assert score.id is not None assert score.id is not None
assert score.data is not None assert score.data is not None
assert len(score.data) == 1 assert len(score.data) == 1
assert score.data[0].score >= 0.9
vllm_outputs = [d.score for d in score.data]
text_pairs = [[text_1, text_2]]
hf_outputs = run_transformers(runner, model, text_pairs)
for i in range(len(vllm_outputs)):
assert math.isclose(hf_outputs[i], vllm_outputs[i], rel_tol=0.01)
@pytest.mark.asyncio def test_score_max_model_len(self, server: RemoteOpenAIServer,
@pytest.mark.parametrize("model_name", [MODEL_NAME]) model: dict[str, Any]):
def test_score_max_model_len(server: RemoteOpenAIServer, model_name: str):
text_1 = "What is the capital of France?" * 20 text_1 = "What is the capital of France?" * 20
text_2 = [ text_2 = [
"The capital of Brazil is Brasilia.", "The capital of France is Paris." "The capital of Brazil is Brasilia.",
"The capital of France is Paris."
] ]
score_response = requests.post(server.url_for("score"), score_response = requests.post(server.url_for("score"),
json={ json={
"model": model_name, "model": model["name"],
"text_1": text_1, "text_1": text_1,
"text_2": text_2, "text_2": text_2,
}) })
...@@ -113,7 +175,7 @@ def test_score_max_model_len(server: RemoteOpenAIServer, model_name: str): ...@@ -113,7 +175,7 @@ def test_score_max_model_len(server: RemoteOpenAIServer, model_name: str):
# Test truncation # Test truncation
score_response = requests.post(server.url_for("score"), score_response = requests.post(server.url_for("score"),
json={ json={
"model": model_name, "model": model["name"],
"text_1": text_1, "text_1": text_1,
"text_2": text_2, "text_2": text_2,
"truncate_prompt_tokens": 101 "truncate_prompt_tokens": 101
......
...@@ -7,7 +7,6 @@ from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence, ...@@ -7,7 +7,6 @@ from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
Tuple, Type, Union, cast, overload) Tuple, Type, Union, cast, overload)
import cloudpickle import cloudpickle
import torch
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from typing_extensions import TypeVar, deprecated from typing_extensions import TypeVar, deprecated
...@@ -25,6 +24,8 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, ...@@ -25,6 +24,8 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_mistral_chat_template, apply_mistral_chat_template,
parse_chat_messages, parse_chat_messages,
resolve_chat_template_content_format) resolve_chat_template_content_format)
from vllm.entrypoints.score_utils import (_cosine_similarity,
_validate_score_input_lens)
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -1010,40 +1011,25 @@ class LLM: ...@@ -1010,40 +1011,25 @@ class LLM:
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[ScoringRequestOutput]: ) -> List[ScoringRequestOutput]:
encoded_output = self.encode( encoded_output: List[PoolingRequestOutput] = self.encode(
text_1 + text_2, text_1 + text_2,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request)
encoded_output_1 = encoded_output[0:len(text_1)]
encoded_output_2 = encoded_output[len(text_1):] encoded_output_1: List[PoolingRequestOutput] = encoded_output[
0:len(text_1)]
encoded_output_2: List[PoolingRequestOutput] = encoded_output[
len(text_1):]
if len(encoded_output_1) == 1: if len(encoded_output_1) == 1:
encoded_output_1 = encoded_output_1 * len(encoded_output_2) encoded_output_1 = encoded_output_1 * len(encoded_output_2)
output_pairs = [(t1, t2) scores: List[PoolingRequestOutput] = []
for t1, t2 in zip(encoded_output_1, encoded_output_2)]
scores = []
scorer = torch.nn.CosineSimilarity(0)
for embed_1, embed_2 in output_pairs:
pair_score = scorer(embed_1.outputs.data, embed_2.outputs.data)
if (pad_token_id := getattr(tokenizer, "pad_token_id", scores = _cosine_similarity(tokenizer=tokenizer,
None)) is not None: embed_1=encoded_output_1,
tokens = embed_1.prompt_token_ids + [ embed_2=encoded_output_2)
pad_token_id
] + embed_2.prompt_token_ids
else:
tokens = embed_1.prompt_token_ids + embed_2.prompt_token_ids
scores.append(
PoolingRequestOutput(
request_id=f"{embed_1.request_id}_{embed_2.request_id}",
outputs=pair_score,
prompt_token_ids=tokens,
finished=True))
items = self.engine_class.validate_outputs(scores, items = self.engine_class.validate_outputs(scores,
PoolingRequestOutput) PoolingRequestOutput)
...@@ -1183,12 +1169,7 @@ class LLM: ...@@ -1183,12 +1169,7 @@ class LLM:
text_2 = [text_2] text_2 = [text_2]
input_text_2: List[str] = [ensure_str(t) for t in text_2] input_text_2: List[str] = [ensure_str(t) for t in text_2]
if len(input_text_1) > 1 and len(input_text_1) != len(input_text_2): _validate_score_input_lens(input_text_1, input_text_2)
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
if len(input_text_1) == 0:
raise ValueError("At least one text element must be given")
if len(input_text_2) == 0:
raise ValueError("At least one text_pair element must be given")
if self.llm_engine.model_config.is_cross_encoder: if self.llm_engine.model_config.is_cross_encoder:
return self._cross_encoding_score(tokenizer, input_text_1, return self._cross_encoding_score(tokenizer, input_text_1,
...@@ -1197,7 +1178,6 @@ class LLM: ...@@ -1197,7 +1178,6 @@ class LLM:
lora_request, lora_request,
prompt_adapter_request) prompt_adapter_request)
else: else:
return self._embedding_score( return self._embedding_score(
tokenizer, tokenizer,
input_text_1, # type: ignore[arg-type] input_text_1, # type: ignore[arg-type]
......
...@@ -73,8 +73,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing ...@@ -73,8 +73,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import (BaseModelPath, from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels) OpenAIServingModels)
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank from vllm.entrypoints.openai.serving_score import ServingScores
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
from vllm.entrypoints.openai.serving_tokenization import ( from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization) OpenAIServingTokenization)
from vllm.entrypoints.openai.serving_transcription import ( from vllm.entrypoints.openai.serving_transcription import (
...@@ -320,12 +319,12 @@ def embedding(request: Request) -> Optional[OpenAIServingEmbedding]: ...@@ -320,12 +319,12 @@ def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
return request.app.state.openai_serving_embedding return request.app.state.openai_serving_embedding
def score(request: Request) -> Optional[OpenAIServingScores]: def score(request: Request) -> Optional[ServingScores]:
return request.app.state.openai_serving_scores return request.app.state.openai_serving_scores
def rerank(request: Request) -> Optional[JinaAIServingRerank]: def rerank(request: Request) -> Optional[ServingScores]:
return request.app.state.jinaai_serving_reranking return request.app.state.openai_serving_scores
def tokenization(request: Request) -> OpenAIServingTokenization: def tokenization(request: Request) -> OpenAIServingTokenization:
...@@ -866,13 +865,13 @@ async def init_app_state( ...@@ -866,13 +865,13 @@ async def init_app_state(
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
) if model_config.task == "embed" else None ) if model_config.task == "embed" else None
state.openai_serving_scores = OpenAIServingScores( state.openai_serving_scores = ServingScores(
engine_client, engine_client,
model_config, model_config,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger request_logger=request_logger) if model_config.task in (
) if model_config.task == "score" else None "score", "embed", "pooling") else None
state.jinaai_serving_reranking = JinaAIServingRerank( state.jinaai_serving_reranking = ServingScores(
engine_client, engine_client,
model_config, model_config,
state.openai_serving_models, state.openai_serving_models,
......
...@@ -26,7 +26,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat ...@@ -26,7 +26,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_models import (BaseModelPath, from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels) OpenAIServingModels)
from vllm.entrypoints.openai.serving_score import OpenAIServingScores from vllm.entrypoints.openai.serving_score import ServingScores
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
...@@ -342,7 +342,7 @@ async def main(args): ...@@ -342,7 +342,7 @@ async def main(args):
chat_template=None, chat_template=None,
chat_template_content_format="auto", chat_template_content_format="auto",
) if model_config.task == "embed" else None ) if model_config.task == "embed" else None
openai_serving_scores = (OpenAIServingScores( openai_serving_scores = (ServingScores(
engine, engine,
model_config, model_config,
openai_serving_models, openai_serving_models,
...@@ -364,9 +364,9 @@ async def main(args): ...@@ -364,9 +364,9 @@ async def main(args):
# Determine the type of request and run it. # Determine the type of request and run it.
if request.url == "/v1/chat/completions": if request.url == "/v1/chat/completions":
handler_fn = (None if openai_serving_chat is None else chat_handler_fn = (None if openai_serving_chat is None else
openai_serving_chat.create_chat_completion) openai_serving_chat.create_chat_completion)
if handler_fn is None: if chat_handler_fn is None:
response_futures.append( response_futures.append(
make_async_error_request_output( make_async_error_request_output(
request, request,
...@@ -375,12 +375,13 @@ async def main(args): ...@@ -375,12 +375,13 @@ async def main(args):
)) ))
continue continue
response_futures.append(run_request(handler_fn, request, tracker)) response_futures.append(
run_request(chat_handler_fn, request, tracker))
tracker.submitted() tracker.submitted()
elif request.url == "/v1/embeddings": elif request.url == "/v1/embeddings":
handler_fn = (None if openai_serving_embedding is None else embed_handler_fn = (None if openai_serving_embedding is None else
openai_serving_embedding.create_embedding) openai_serving_embedding.create_embedding)
if handler_fn is None: if embed_handler_fn is None:
response_futures.append( response_futures.append(
make_async_error_request_output( make_async_error_request_output(
request, request,
...@@ -388,12 +389,13 @@ async def main(args): ...@@ -388,12 +389,13 @@ async def main(args):
)) ))
continue continue
response_futures.append(run_request(handler_fn, request, tracker)) response_futures.append(
run_request(embed_handler_fn, request, tracker))
tracker.submitted() tracker.submitted()
elif request.url == "/v1/score": elif request.url == "/v1/score":
handler_fn = (None if openai_serving_scores is None else score_handler_fn = (None if openai_serving_scores is None else
openai_serving_scores.create_score) openai_serving_scores.create_score)
if handler_fn is None: if score_handler_fn is None:
response_futures.append( response_futures.append(
make_async_error_request_output( make_async_error_request_output(
request, request,
...@@ -401,7 +403,8 @@ async def main(args): ...@@ -401,7 +403,8 @@ async def main(args):
)) ))
continue continue
response_futures.append(run_request(handler_fn, request, tracker)) response_futures.append(
run_request(score_handler_fn, request, tracker))
tracker.submitted() tracker.submitted()
else: else:
response_futures.append( response_futures.append(
......
...@@ -52,8 +52,8 @@ from vllm.utils import is_list_of, make_async, random_uuid ...@@ -52,8 +52,8 @@ from vllm.utils import is_list_of, make_async, random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest, CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
EmbeddingCompletionRequest, ScoreRequest, EmbeddingCompletionRequest, RerankRequest,
TokenizeCompletionRequest] ScoreRequest, TokenizeCompletionRequest]
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
TokenizeChatRequest] TokenizeChatRequest]
......
# SPDX-License-Identifier: Apache-2.0
import asyncio
from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument,
RerankRequest, RerankResponse,
RerankResult, RerankUsage)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import make_async, merge_async_iterators
logger = init_logger(__name__)
class JinaAIServingRerank(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger)
async def do_rerank(
self,
request: RerankRequest,
raw_request: Optional[Request] = None
) -> Union[RerankResponse, ErrorResponse]:
"""
Rerank API based on JinaAI's rerank API; implements the same
API interface. Designed for compatibility with off-the-shelf
tooling, since this is a common standard for reranking APIs
See example client implementations at
https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
numerous clients use this standard.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
model_name = request.model
request_id = f"rerank-{self._base_request_id(raw_request)}"
truncate_prompt_tokens = request.truncate_prompt_tokens
query = request.query
documents = request.documents
request_prompts = []
engine_prompts = []
top_n = request.top_n if request.top_n > 0 else len(documents)
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for scoring models")
if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"MistralTokenizer not supported for cross-encoding")
if not self.model_config.is_cross_encoder:
raise ValueError("Model is not cross encoder.")
if truncate_prompt_tokens is not None and \
truncate_prompt_tokens > self.max_model_len:
raise ValueError(
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
f"is greater than max_model_len ({self.max_model_len})."
f" Please, select a smaller truncation size.")
for doc in documents:
request_prompt = f"{query}{tokenizer.sep_token}{doc}"
tokenization_kwargs: Dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor)
prompt_inputs = await tokenize_async(text=query,
text_pair=doc,
**tokenization_kwargs)
input_ids = prompt_inputs["input_ids"]
text_token_prompt = \
self._validate_input(request, input_ids, request_prompt)
engine_prompt = TokensPrompt(
prompt_token_ids=text_token_prompt["prompt_token_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
request_prompts.append(request_prompt)
engine_prompts.append(engine_prompt)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
pooling_params = request.to_pooling_params()
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
request_prompts[i],
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator = merge_async_iterators(*generators)
num_prompts = len(engine_prompts)
# Non-streaming response
final_res_batch: List[Optional[PoolingRequestOutput]]
final_res_batch = [None] * num_prompts
try:
async for i, res in result_generator:
final_res_batch[i] = res
assert all(final_res is not None for final_res in final_res_batch)
final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch)
response = self.request_output_to_rerank_response(
final_res_batch_checked, request_id, model_name, documents,
top_n)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return response
def request_output_to_rerank_response(
self, final_res_batch: List[PoolingRequestOutput], request_id: str,
model_name: str, documents: List[str],
top_n: int) -> RerankResponse:
"""
Convert the output of do_rank to a RerankResponse
"""
results: List[RerankResult] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
classify_res = ScoringRequestOutput.from_base(final_res)
result = RerankResult(
index=idx,
document=RerankDocument(text=documents[idx]),
relevance_score=classify_res.outputs.score,
)
results.append(result)
prompt_token_ids = final_res.prompt_token_ids
num_prompt_tokens += len(prompt_token_ids)
# sort by relevance, then return the top n if set
results.sort(key=lambda x: x.relevance_score, reverse=True)
if top_n < len(documents):
results = results[:top_n]
return RerankResponse(
id=request_id,
model=model_name,
results=results,
usage=RerankUsage(total_tokens=num_prompt_tokens))
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import time import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Union
from fastapi import Request from fastapi import Request
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest, from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument,
ScoreResponse, ScoreResponseData, RerankRequest, RerankResponse,
UsageInfo) RerankResult, RerankUsage,
ScoreRequest, ScoreResponse,
ScoreResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.score_utils import (_cosine_similarity,
_validate_score_input_lens)
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
PreTrainedTokenizer,
PreTrainedTokenizerFast)
from vllm.utils import make_async, merge_async_iterators from vllm.utils import make_async, merge_async_iterators
logger = init_logger(__name__) logger = init_logger(__name__)
def make_pairs(text_1: Union[List[str], str], text_2: Union[List[str], class ServingScores(OpenAIServing):
str]) -> List:
if isinstance(text_1, (str, dict)):
# Convert a single prompt to a list.
text_1 = [text_1]
text_1 = [t for t in text_1]
if isinstance(text_2, (str, dict)):
# Convert a single prompt to a list.
text_2 = [text_2]
text_2 = [t for t in text_2]
if len(text_1) > 1 and len(text_1) != len(text_2):
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
if len(text_1) == 0:
raise ValueError("At least one text element must be given")
if len(text_2) == 0:
raise ValueError("At least one text_pair element must be given")
if len(text_1) == 1:
text_1 = text_1 * len(text_2)
return [(t1, t2) for t1, t2 in zip(text_1, text_2)]
class OpenAIServingScores(OpenAIServing):
def __init__( def __init__(
self, self,
...@@ -62,68 +45,134 @@ class OpenAIServingScores(OpenAIServing): ...@@ -62,68 +45,134 @@ class OpenAIServingScores(OpenAIServing):
models=models, models=models,
request_logger=request_logger) request_logger=request_logger)
async def create_score( async def _embedding_score(
self, self,
request: ScoreRequest, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
raw_request: Optional[Request] = None, texts_1: List[str],
) -> Union[ScoreResponse, ErrorResponse]: texts_2: List[str],
""" request: Union[RerankRequest, ScoreRequest],
Score API similar to Sentence Transformers cross encoder request_id=str,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[Union[LoRARequest, None]] = None,
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
None]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> List[PoolingRequestOutput]:
input_texts = texts_1 + texts_2
engine_prompts: List[TokensPrompt] = []
tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor)
See https://sbert.net/docs/package_reference/cross_encoder tokenization_kwargs = tokenization_kwargs or {}
""" tokenized_prompts = await asyncio.gather(
error_check_ret = await self._check_model(request) *(tokenize_async(t, **tokenization_kwargs) for t in input_texts))
if error_check_ret is not None:
return error_check_ret
model_name = request.model for tok_result, input_text in zip(tokenized_prompts, input_texts):
request_id = f"score-{self._base_request_id(raw_request)}"
created_time = int(time.time())
truncate_prompt_tokens = request.truncate_prompt_tokens
request_prompts = [] text_token_prompt = \
engine_prompts = [] self._validate_input(
request,
tok_result["input_ids"],
input_text)
try: engine_prompts.append(
( TokensPrompt(
lora_request, prompt_token_ids=text_token_prompt["prompt_token_ids"]))
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) # Schedule the request and get the result generator.
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
pooling_params = request.to_pooling_params()
if prompt_adapter_request is not None: for i, engine_prompt in enumerate(engine_prompts):
raise NotImplementedError("Prompt adapter is not supported "
"for scoring models")
if isinstance(tokenizer, MistralTokenizer): request_id_item = f"{request_id}-{i}"
raise ValueError(
"MistralTokenizer not supported for cross-encoding")
if not self.model_config.is_cross_encoder: self._log_inputs(request_id_item,
raise ValueError("Model is not cross encoder.") input_texts[i],
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
if truncate_prompt_tokens is not None and \ generators.append(
truncate_prompt_tokens > self.max_model_len: self.engine_client.encode(
raise ValueError( engine_prompt,
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) " pooling_params,
f"is greater than max_model_len ({self.max_model_len})." request_id_item,
f" Please, select a smaller truncation size.") lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
))
input_pairs = make_pairs(request.text_1, request.text_2) result_generator = merge_async_iterators(*generators)
for q, t in input_pairs:
request_prompt = f"{q}{tokenizer.sep_token}{t}"
tokenization_kwargs: Dict[str, Any] = {} # Non-streaming response
if truncate_prompt_tokens is not None: final_res_batch: List[PoolingRequestOutput] = []
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens embeddings: List[Optional[PoolingRequestOutput]] =\
[None] * len(engine_prompts)
async for i, res in result_generator:
embeddings[i] = res
emb_texts_1: List[PoolingRequestOutput] = []
emb_texts_2: List[PoolingRequestOutput] = []
for i in range(0, len(texts_1)):
assert (emb := embeddings[i]) is not None
emb_texts_1.append(emb)
for i in range(len(texts_1), len(embeddings)):
assert (emb := embeddings[i]) is not None
emb_texts_2.append(emb)
if len(emb_texts_1) == 1:
emb_texts_1 = emb_texts_1 * len(emb_texts_2)
final_res_batch = _cosine_similarity(tokenizer=tokenizer,
embed_1=emb_texts_1,
embed_2=emb_texts_2)
return final_res_batch
async def _cross_encoding_score(
self,
tokenizer: Union[AnyTokenizer],
texts_1: List[str],
texts_2: List[str],
request: Union[RerankRequest, ScoreRequest],
request_id=str,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[Union[LoRARequest, None]] = None,
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
None]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> List[PoolingRequestOutput]:
request_prompts: List[str] = []
engine_prompts: List[TokensPrompt] = []
if len(texts_1) == 1:
texts_1 = texts_1 * len(texts_2)
input_pairs = [(t1, t2) for t1, t2 in zip(texts_1, texts_2)]
if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"MistralTokenizer not supported for cross-encoding")
tokenize_async = make_async(tokenizer.__call__, tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor) executor=self._tokenizer_executor)
prompt_inputs = await tokenize_async(q,
text_pair=t, tokenization_kwargs = tokenization_kwargs or {}
**tokenization_kwargs) tokenized_prompts = await asyncio.gather(
*(tokenize_async(text=t1, text_pair=t2, **tokenization_kwargs)
for t1, t2 in input_pairs))
for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs):
request_prompt = f"{t1}{tokenizer.sep_token}{t2}"
input_ids = prompt_inputs["input_ids"] input_ids = prompt_inputs["input_ids"]
text_token_prompt = \ text_token_prompt = \
...@@ -135,14 +184,9 @@ class OpenAIServingScores(OpenAIServing): ...@@ -135,14 +184,9 @@ class OpenAIServingScores(OpenAIServing):
request_prompts.append(request_prompt) request_prompts.append(request_prompt)
engine_prompts.append(engine_prompt) engine_prompts.append(engine_prompt)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
pooling_params = request.to_pooling_params() pooling_params = request.to_pooling_params()
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
...@@ -154,9 +198,6 @@ class OpenAIServingScores(OpenAIServing): ...@@ -154,9 +198,6 @@ class OpenAIServingScores(OpenAIServing):
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
generator = self.engine_client.encode( generator = self.engine_client.encode(
engine_prompt, engine_prompt,
pooling_params, pooling_params,
...@@ -167,32 +208,117 @@ class OpenAIServingScores(OpenAIServing): ...@@ -167,32 +208,117 @@ class OpenAIServingScores(OpenAIServing):
) )
generators.append(generator) generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator = merge_async_iterators(*generators) result_generator = merge_async_iterators(*generators)
num_prompts = len(engine_prompts)
# Non-streaming response # Non-streaming response
final_res_batch: List[Optional[PoolingRequestOutput]] final_res_batch: List[
final_res_batch = [None] * num_prompts Optional[PoolingRequestOutput]] = [None] * len(engine_prompts)
try:
async for i, res in result_generator: async for i, res in result_generator:
final_res_batch[i] = res final_res_batch[i] = res
assert all(final_res is not None for final_res in final_res_batch) return [out for out in final_res_batch if out is not None]
async def _run_scoring(
self,
texts_1: Union[str, list[str]],
texts_2: Union[str, list[str]],
request: Union[ScoreRequest, RerankRequest],
request_id: str,
raw_request: Optional[Request] = None,
truncate_prompt_tokens: Optional[int] = None,
) -> List[PoolingRequestOutput]:
tokenization_kwargs: Dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for scoring models")
tokenizer = await self.engine_client.get_tokenizer(lora_request)
if truncate_prompt_tokens is not None and \
truncate_prompt_tokens > self.max_model_len:
raise ValueError(
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
f"is greater than max_model_len ({self.max_model_len})."
f" Please, select a smaller truncation size.")
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
if isinstance(texts_1, str):
texts_1 = [texts_1]
if isinstance(texts_2, str):
texts_2 = [texts_2]
_validate_score_input_lens(texts_1, texts_2)
if self.model_config.is_cross_encoder:
return await self._cross_encoding_score(
tokenizer=tokenizer,
texts_1=texts_1,
texts_2=texts_2,
request=request,
request_id=request_id,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers)
else:
return await self._embedding_score(
tokenizer=tokenizer,
texts_1=texts_1,
texts_2=texts_2,
request=request,
request_id=request_id,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers)
async def create_score(
self,
request: ScoreRequest,
raw_request: Optional[Request] = None,
) -> Union[ScoreResponse, ErrorResponse]:
"""
Score API similar to Sentence Transformers cross encoder
See https://sbert.net/docs/package_reference/cross_encoder
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
final_res_batch_checked = cast(List[PoolingRequestOutput], request_id = f"score-{self._base_request_id(raw_request)}"
final_res_batch) created_time = int(time.time())
try:
final_res_batch = await self._run_scoring(
request.text_1,
request.text_2,
request,
request_id,
raw_request,
request.truncate_prompt_tokens,
)
response = self.request_output_to_score_response( return self.request_output_to_score_response(
final_res_batch_checked, final_res_batch,
request_id, request_id,
created_time, created_time,
model_name, request.model,
) )
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
...@@ -200,7 +326,44 @@ class OpenAIServingScores(OpenAIServing): ...@@ -200,7 +326,44 @@ class OpenAIServingScores(OpenAIServing):
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
return response async def do_rerank(
self,
request: RerankRequest,
raw_request: Optional[Request] = None
) -> Union[RerankResponse, ErrorResponse]:
"""
Rerank API based on JinaAI's rerank API; implements the same
API interface. Designed for compatibility with off-the-shelf
tooling, since this is a common standard for reranking APIs
See example client implementations at
https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
numerous clients use this standard.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
request_id = f"rerank-{self._base_request_id(raw_request)}"
documents = request.documents
top_n = request.top_n if request.top_n > 0 else len(documents)
try:
final_res_batch = await self._run_scoring(
request.query,
documents,
request,
request_id,
raw_request,
request.truncate_prompt_tokens,
)
return self.request_output_to_rerank_response(
final_res_batch, request_id, request.model, documents, top_n)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
def request_output_to_score_response( def request_output_to_score_response(
self, self,
...@@ -236,3 +399,35 @@ class OpenAIServingScores(OpenAIServing): ...@@ -236,3 +399,35 @@ class OpenAIServingScores(OpenAIServing):
data=items, data=items,
usage=usage, usage=usage,
) )
def request_output_to_rerank_response(
self, final_res_batch: List[PoolingRequestOutput], request_id: str,
model_name: str, documents: List[str],
top_n: int) -> RerankResponse:
"""
Convert the output of do_rank to a RerankResponse
"""
results: List[RerankResult] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
classify_res = ScoringRequestOutput.from_base(final_res)
result = RerankResult(
index=idx,
document=RerankDocument(text=documents[idx]),
relevance_score=classify_res.outputs.score,
)
results.append(result)
prompt_token_ids = final_res.prompt_token_ids
num_prompt_tokens += len(prompt_token_ids)
# sort by relevance, then return the top n if set
results.sort(key=lambda x: x.relevance_score, reverse=True)
if top_n < len(documents):
results = results[:top_n]
return RerankResponse(
id=request_id,
model=model_name,
results=results,
usage=RerankUsage(total_tokens=num_prompt_tokens))
# SPDX-License-Identifier: Apache-2.0
from typing import List, Union
from torch.nn import CosineSimilarity
from vllm.outputs import PoolingRequestOutput
from vllm.transformers_utils.tokenizer import (PreTrainedTokenizer,
PreTrainedTokenizerFast)
def _cosine_similarity(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
embed_1: List[PoolingRequestOutput],
embed_2: List[PoolingRequestOutput],
) -> List[PoolingRequestOutput]:
scorer = CosineSimilarity(0)
scores: Union[List[PoolingRequestOutput]] = []
for emb_1, emb_2 in zip(embed_1, embed_2):
pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data)
padding = []
if (pad_token_id := getattr(tokenizer, "pad_token_id",
None)) is not None:
padding = [pad_token_id]
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
scores.append(
PoolingRequestOutput(
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
outputs=pair_score,
prompt_token_ids=tokens,
finished=True))
return scores
def _validate_score_input_lens(
texts_1: Union[List[str], List[dict]],
texts_2: Union[List[str], List[dict]],
):
if len(texts_1) > 1 and len(texts_1) != len(texts_2):
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
if len(texts_1) == 0:
raise ValueError("At least one text element must be given")
if len(texts_2) == 0:
raise ValueError("At least one text_pair element must be given")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment