Unverified Commit 1c3c9757 authored by Gabriel Marinho's avatar Gabriel Marinho Committed by GitHub
Browse files

[FEATURE] Enables /score endpoint for embedding models (#12846)

parent 1cdc8861
...@@ -108,8 +108,7 @@ A code example can be found here: <gh-file:examples/offline_inference/basic/clas ...@@ -108,8 +108,7 @@ A code example can be found here: <gh-file:examples/offline_inference/basic/clas
### `LLM.score` ### `LLM.score`
The {class}`~vllm.LLM.score` method outputs similarity scores between sentence pairs. The {class}`~vllm.LLM.score` method outputs similarity scores between sentence pairs.
It is primarily designed for [cross-encoder models](https://www.sbert.net/examples/applications/cross-encoder/README.html). It is designed for embedding models and cross encoder models. Embedding models use cosine similarity, and [cross-encoder models](https://www.sbert.net/examples/applications/cross-encoder/README.html) serve as rerankers between candidate query-document pairs in RAG systems.
These types of models serve as rerankers between candidate query-document pairs in RAG systems.
:::{note} :::{note}
vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG. vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG.
......
...@@ -51,7 +51,7 @@ In addition, we have the following custom APIs: ...@@ -51,7 +51,7 @@ In addition, we have the following custom APIs:
- [Pooling API](#pooling-api) (`/pooling`) - [Pooling API](#pooling-api) (`/pooling`)
- Applicable to all [pooling models](../models/pooling_models.md). - Applicable to all [pooling models](../models/pooling_models.md).
- [Score API](#score-api) (`/score`) - [Score API](#score-api) (`/score`)
- Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`). - Applicable to embedding models and [cross-encoder models](../models/pooling_models.md) (`--task score`).
- [Re-rank API](#rerank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`) - [Re-rank API](#rerank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`)
- Implements [Jina AI's v1 re-rank API](https://jina.ai/reranker/) - Implements [Jina AI's v1 re-rank API](https://jina.ai/reranker/)
- Also compatible with [Cohere's v1 & v2 re-rank APIs](https://docs.cohere.com/v2/reference/rerank) - Also compatible with [Cohere's v1 & v2 re-rank APIs](https://docs.cohere.com/v2/reference/rerank)
...@@ -333,10 +333,10 @@ Code example: <gh-file:examples/online_serving/openai_pooling_client.py> ...@@ -333,10 +333,10 @@ Code example: <gh-file:examples/online_serving/openai_pooling_client.py>
### Score API ### Score API
Our Score API applies a cross-encoder model to predict scores for sentence pairs. Our Score API can apply a cross-encoder model or an embedding model to predict scores for sentence pairs. When using an embedding model the score corresponds to the cosine similarity between each embedding pair.
Usually, the score for a sentence pair refers to the similarity between two sentences, on a scale of 0 to 1. Usually, the score for a sentence pair refers to the similarity between two sentences, on a scale of 0 to 1.
You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
Code example: <gh-file:examples/online_serving/openai_cross_encoder_score.py> Code example: <gh-file:examples/online_serving/openai_cross_encoder_score.py>
...@@ -496,11 +496,11 @@ The following extra parameters are supported: ...@@ -496,11 +496,11 @@ The following extra parameters are supported:
### Re-rank API ### Re-rank API
Our Re-rank API applies a cross-encoder model to predict relevant scores between a single query, and Our Re-rank API can apply an embedding model or a cross-encoder model to predict relevant scores between a single query, and
each of a list of documents. Usually, the score for a sentence pair refers to the similarity between two sentences, on each of a list of documents. Usually, the score for a sentence pair refers to the similarity between two sentences, on
a scale of 0 to 1. a scale of 0 to 1.
You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
The rerank endpoints support popular re-rank models such as `BAAI/bge-reranker-base` and other models supporting the The rerank endpoints support popular re-rank models such as `BAAI/bge-reranker-base` and other models supporting the
`score` task. Additionally, `/rerank`, `/v1/rerank`, and `/v2/rerank` `score` task. Additionally, `/rerank`, `/v1/rerank`, and `/v2/rerank`
......
...@@ -8,17 +8,17 @@ from vllm.entrypoints.openai.protocol import RerankResponse ...@@ -8,17 +8,17 @@ from vllm.entrypoints.openai.protocol import RerankResponse
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODEL_NAME = "BAAI/bge-reranker-base" MODEL_NAME = "BAAI/bge-reranker-base"
DTYPE = "bfloat16"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
args = ["--enforce-eager", "--max-model-len", "100"] args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server yield remote_server
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_rerank_texts(server: RemoteOpenAIServer, model_name: str): def test_rerank_texts(server: RemoteOpenAIServer, model_name: str):
query = "What is the capital of France?" query = "What is the capital of France?"
...@@ -42,7 +42,6 @@ def test_rerank_texts(server: RemoteOpenAIServer, model_name: str): ...@@ -42,7 +42,6 @@ def test_rerank_texts(server: RemoteOpenAIServer, model_name: str):
assert rerank.results[1].relevance_score <= 0.01 assert rerank.results[1].relevance_score <= 0.01
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_top_n(server: RemoteOpenAIServer, model_name: str): def test_top_n(server: RemoteOpenAIServer, model_name: str):
query = "What is the capital of France?" query = "What is the capital of France?"
...@@ -68,7 +67,6 @@ def test_top_n(server: RemoteOpenAIServer, model_name: str): ...@@ -68,7 +67,6 @@ def test_top_n(server: RemoteOpenAIServer, model_name: str):
assert rerank.results[1].relevance_score <= 0.01 assert rerank.results[1].relevance_score <= 0.01
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str): def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import math
from typing import Any
import pytest import pytest
import requests import requests
import torch.nn.functional as F
from torch import tensor
from vllm.entrypoints.openai.protocol import ScoreResponse from vllm.entrypoints.openai.protocol import ScoreResponse
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODEL_NAME = "BAAI/bge-reranker-v2-m3" MODELS = [
{
"name": "BAAI/bge-reranker-v2-m3",
@pytest.fixture(scope="module") "is_cross_encoder": True
def server(): },
args = ["--enforce-eager", "--max-model-len", "100"] {
"name": "BAAI/bge-base-en-v1.5",
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: "is_cross_encoder": False
},
]
DTYPE = "half"
def run_transformers(hf_model, model, text_pairs):
if model["is_cross_encoder"]:
return hf_model.predict(text_pairs).tolist()
else:
hf_embeddings = [
hf_model.encode(text_pair) for text_pair in text_pairs
]
return [
F.cosine_similarity(tensor(pair[0]), tensor(pair[1]), dim=0)
for pair in hf_embeddings
]
@pytest.fixture(scope="class", params=MODELS)
def model(request):
yield request.param
@pytest.fixture(scope="class")
def server(model: dict[str, Any]):
args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE]
with RemoteOpenAIServer(model["name"], args) as remote_server:
yield remote_server yield remote_server
@pytest.mark.asyncio @pytest.fixture(scope="class")
@pytest.mark.parametrize("model_name", [MODEL_NAME]) def runner(model: dict[str, Any], hf_runner):
def test_text_1_str_text_2_list(server: RemoteOpenAIServer, model_name: str): kwargs = {
text_1 = "What is the capital of France?" "dtype": DTYPE,
text_2 = [ "is_cross_encoder" if model["is_cross_encoder"]\
"The capital of Brazil is Brasilia.", "The capital of France is Paris." else "is_sentence_transformer": True
] }
score_response = requests.post(server.url_for("score"), with hf_runner(model["name"], **kwargs) as hf_model:
json={ yield hf_model
"model": model_name,
"text_1": text_1,
"text_2": text_2, class TestModel:
})
score_response.raise_for_status() def test_text_1_str_text_2_list(self, server: RemoteOpenAIServer,
score = ScoreResponse.model_validate(score_response.json()) model: dict[str, Any], runner):
text_1 = "What is the capital of France?"
assert score.id is not None text_2 = [
assert score.data is not None "The capital of Brazil is Brasilia.",
assert len(score.data) == 2 "The capital of France is Paris."
assert score.data[0].score <= 0.01 ]
assert score.data[1].score >= 0.9
score_response = requests.post(server.url_for("score"),
json={
@pytest.mark.asyncio "model": model["name"],
@pytest.mark.parametrize("model_name", [MODEL_NAME]) "text_1": text_1,
def test_text_1_list_text_2_list(server: RemoteOpenAIServer, model_name: str): "text_2": text_2,
text_1 = [ })
"What is the capital of the United States?", score_response.raise_for_status()
"What is the capital of France?" score = ScoreResponse.model_validate(score_response.json())
]
text_2 = [ assert score.id is not None
"The capital of Brazil is Brasilia.", "The capital of France is Paris." assert score.data is not None
] assert len(score.data) == 2
score_response = requests.post(server.url_for("score"), vllm_outputs = [d.score for d in score.data]
json={
"model": model_name, text_pairs = [[text_1, text_2[0]], [text_1, text_2[1]]]
"text_1": text_1, hf_outputs = run_transformers(runner, model, text_pairs)
"text_2": text_2,
}) for i in range(len(vllm_outputs)):
score_response.raise_for_status() assert math.isclose(hf_outputs[i], vllm_outputs[i], rel_tol=0.01)
score = ScoreResponse.model_validate(score_response.json())
def test_text_1_list_text_2_list(self, server: RemoteOpenAIServer,
assert score.id is not None model: dict[str, Any], runner):
assert score.data is not None text_1 = [
assert len(score.data) == 2 "What is the capital of the United States?",
assert score.data[0].score <= 0.01 "What is the capital of France?"
assert score.data[1].score >= 0.9 ]
text_2 = [
"The capital of Brazil is Brasilia.",
@pytest.mark.asyncio "The capital of France is Paris."
@pytest.mark.parametrize("model_name", [MODEL_NAME]) ]
def test_text_1_str_text_2_str(server: RemoteOpenAIServer, model_name: str):
text_1 = "What is the capital of France?" score_response = requests.post(server.url_for("score"),
text_2 = "The capital of France is Paris." json={
"model": model["name"],
score_response = requests.post(server.url_for("score"), "text_1": text_1,
json={ "text_2": text_2,
"model": model_name, })
"text_1": text_1, score_response.raise_for_status()
"text_2": text_2, score = ScoreResponse.model_validate(score_response.json())
})
score_response.raise_for_status() assert score.id is not None
score = ScoreResponse.model_validate(score_response.json()) assert score.data is not None
assert len(score.data) == 2
assert score.id is not None
assert score.data is not None vllm_outputs = [d.score for d in score.data]
assert len(score.data) == 1
assert score.data[0].score >= 0.9 text_pairs = [[text_1[0], text_2[0]], [text_1[1], text_2[1]]]
hf_outputs = run_transformers(runner, model, text_pairs)
@pytest.mark.asyncio for i in range(len(vllm_outputs)):
@pytest.mark.parametrize("model_name", [MODEL_NAME]) assert math.isclose(hf_outputs[i], vllm_outputs[i], rel_tol=0.01)
def test_score_max_model_len(server: RemoteOpenAIServer, model_name: str):
def test_text_1_str_text_2_str(self, server: RemoteOpenAIServer,
text_1 = "What is the capital of France?" * 20 model: dict[str, Any], runner):
text_2 = [ text_1 = "What is the capital of France?"
"The capital of Brazil is Brasilia.", "The capital of France is Paris." text_2 = "The capital of France is Paris."
]
score_response = requests.post(server.url_for("score"),
score_response = requests.post(server.url_for("score"), json={
json={ "model": model["name"],
"model": model_name, "text_1": text_1,
"text_1": text_1, "text_2": text_2,
"text_2": text_2, })
}) score_response.raise_for_status()
assert score_response.status_code == 400 score = ScoreResponse.model_validate(score_response.json())
# Assert just a small fragments of the response
assert "Please reduce the length of the input." in \ assert score.id is not None
score_response.text assert score.data is not None
assert len(score.data) == 1
# Test truncation
score_response = requests.post(server.url_for("score"), vllm_outputs = [d.score for d in score.data]
json={
"model": model_name, text_pairs = [[text_1, text_2]]
"text_1": text_1, hf_outputs = run_transformers(runner, model, text_pairs)
"text_2": text_2,
"truncate_prompt_tokens": 101 for i in range(len(vllm_outputs)):
}) assert math.isclose(hf_outputs[i], vllm_outputs[i], rel_tol=0.01)
assert score_response.status_code == 400
assert "Please, select a smaller truncation size." in \ def test_score_max_model_len(self, server: RemoteOpenAIServer,
score_response.text model: dict[str, Any]):
text_1 = "What is the capital of France?" * 20
text_2 = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris."
]
score_response = requests.post(server.url_for("score"),
json={
"model": model["name"],
"text_1": text_1,
"text_2": text_2,
})
assert score_response.status_code == 400
# Assert just a small fragments of the response
assert "Please reduce the length of the input." in \
score_response.text
# Test truncation
score_response = requests.post(server.url_for("score"),
json={
"model": model["name"],
"text_1": text_1,
"text_2": text_2,
"truncate_prompt_tokens": 101
})
assert score_response.status_code == 400
assert "Please, select a smaller truncation size." in \
score_response.text
...@@ -7,7 +7,6 @@ from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence, ...@@ -7,7 +7,6 @@ from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
Tuple, Type, Union, cast, overload) Tuple, Type, Union, cast, overload)
import cloudpickle import cloudpickle
import torch
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from typing_extensions import TypeVar, deprecated from typing_extensions import TypeVar, deprecated
...@@ -25,6 +24,8 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, ...@@ -25,6 +24,8 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_mistral_chat_template, apply_mistral_chat_template,
parse_chat_messages, parse_chat_messages,
resolve_chat_template_content_format) resolve_chat_template_content_format)
from vllm.entrypoints.score_utils import (_cosine_similarity,
_validate_score_input_lens)
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -1010,40 +1011,25 @@ class LLM: ...@@ -1010,40 +1011,25 @@ class LLM:
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[ScoringRequestOutput]: ) -> List[ScoringRequestOutput]:
encoded_output = self.encode( encoded_output: List[PoolingRequestOutput] = self.encode(
text_1 + text_2, text_1 + text_2,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request)
encoded_output_1 = encoded_output[0:len(text_1)]
encoded_output_2 = encoded_output[len(text_1):] encoded_output_1: List[PoolingRequestOutput] = encoded_output[
0:len(text_1)]
encoded_output_2: List[PoolingRequestOutput] = encoded_output[
len(text_1):]
if len(encoded_output_1) == 1: if len(encoded_output_1) == 1:
encoded_output_1 = encoded_output_1 * len(encoded_output_2) encoded_output_1 = encoded_output_1 * len(encoded_output_2)
output_pairs = [(t1, t2) scores: List[PoolingRequestOutput] = []
for t1, t2 in zip(encoded_output_1, encoded_output_2)]
scores = []
scorer = torch.nn.CosineSimilarity(0)
for embed_1, embed_2 in output_pairs:
pair_score = scorer(embed_1.outputs.data, embed_2.outputs.data)
if (pad_token_id := getattr(tokenizer, "pad_token_id", scores = _cosine_similarity(tokenizer=tokenizer,
None)) is not None: embed_1=encoded_output_1,
tokens = embed_1.prompt_token_ids + [ embed_2=encoded_output_2)
pad_token_id
] + embed_2.prompt_token_ids
else:
tokens = embed_1.prompt_token_ids + embed_2.prompt_token_ids
scores.append(
PoolingRequestOutput(
request_id=f"{embed_1.request_id}_{embed_2.request_id}",
outputs=pair_score,
prompt_token_ids=tokens,
finished=True))
items = self.engine_class.validate_outputs(scores, items = self.engine_class.validate_outputs(scores,
PoolingRequestOutput) PoolingRequestOutput)
...@@ -1183,12 +1169,7 @@ class LLM: ...@@ -1183,12 +1169,7 @@ class LLM:
text_2 = [text_2] text_2 = [text_2]
input_text_2: List[str] = [ensure_str(t) for t in text_2] input_text_2: List[str] = [ensure_str(t) for t in text_2]
if len(input_text_1) > 1 and len(input_text_1) != len(input_text_2): _validate_score_input_lens(input_text_1, input_text_2)
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
if len(input_text_1) == 0:
raise ValueError("At least one text element must be given")
if len(input_text_2) == 0:
raise ValueError("At least one text_pair element must be given")
if self.llm_engine.model_config.is_cross_encoder: if self.llm_engine.model_config.is_cross_encoder:
return self._cross_encoding_score(tokenizer, input_text_1, return self._cross_encoding_score(tokenizer, input_text_1,
...@@ -1197,7 +1178,6 @@ class LLM: ...@@ -1197,7 +1178,6 @@ class LLM:
lora_request, lora_request,
prompt_adapter_request) prompt_adapter_request)
else: else:
return self._embedding_score( return self._embedding_score(
tokenizer, tokenizer,
input_text_1, # type: ignore[arg-type] input_text_1, # type: ignore[arg-type]
......
...@@ -73,8 +73,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing ...@@ -73,8 +73,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import (BaseModelPath, from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels) OpenAIServingModels)
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank from vllm.entrypoints.openai.serving_score import ServingScores
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
from vllm.entrypoints.openai.serving_tokenization import ( from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization) OpenAIServingTokenization)
from vllm.entrypoints.openai.serving_transcription import ( from vllm.entrypoints.openai.serving_transcription import (
...@@ -320,12 +319,12 @@ def embedding(request: Request) -> Optional[OpenAIServingEmbedding]: ...@@ -320,12 +319,12 @@ def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
return request.app.state.openai_serving_embedding return request.app.state.openai_serving_embedding
def score(request: Request) -> Optional[OpenAIServingScores]: def score(request: Request) -> Optional[ServingScores]:
return request.app.state.openai_serving_scores return request.app.state.openai_serving_scores
def rerank(request: Request) -> Optional[JinaAIServingRerank]: def rerank(request: Request) -> Optional[ServingScores]:
return request.app.state.jinaai_serving_reranking return request.app.state.openai_serving_scores
def tokenization(request: Request) -> OpenAIServingTokenization: def tokenization(request: Request) -> OpenAIServingTokenization:
...@@ -866,13 +865,13 @@ async def init_app_state( ...@@ -866,13 +865,13 @@ async def init_app_state(
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
) if model_config.task == "embed" else None ) if model_config.task == "embed" else None
state.openai_serving_scores = OpenAIServingScores( state.openai_serving_scores = ServingScores(
engine_client, engine_client,
model_config, model_config,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger request_logger=request_logger) if model_config.task in (
) if model_config.task == "score" else None "score", "embed", "pooling") else None
state.jinaai_serving_reranking = JinaAIServingRerank( state.jinaai_serving_reranking = ServingScores(
engine_client, engine_client,
model_config, model_config,
state.openai_serving_models, state.openai_serving_models,
......
...@@ -26,7 +26,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat ...@@ -26,7 +26,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_models import (BaseModelPath, from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels) OpenAIServingModels)
from vllm.entrypoints.openai.serving_score import OpenAIServingScores from vllm.entrypoints.openai.serving_score import ServingScores
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
...@@ -342,7 +342,7 @@ async def main(args): ...@@ -342,7 +342,7 @@ async def main(args):
chat_template=None, chat_template=None,
chat_template_content_format="auto", chat_template_content_format="auto",
) if model_config.task == "embed" else None ) if model_config.task == "embed" else None
openai_serving_scores = (OpenAIServingScores( openai_serving_scores = (ServingScores(
engine, engine,
model_config, model_config,
openai_serving_models, openai_serving_models,
...@@ -364,9 +364,9 @@ async def main(args): ...@@ -364,9 +364,9 @@ async def main(args):
# Determine the type of request and run it. # Determine the type of request and run it.
if request.url == "/v1/chat/completions": if request.url == "/v1/chat/completions":
handler_fn = (None if openai_serving_chat is None else chat_handler_fn = (None if openai_serving_chat is None else
openai_serving_chat.create_chat_completion) openai_serving_chat.create_chat_completion)
if handler_fn is None: if chat_handler_fn is None:
response_futures.append( response_futures.append(
make_async_error_request_output( make_async_error_request_output(
request, request,
...@@ -375,12 +375,13 @@ async def main(args): ...@@ -375,12 +375,13 @@ async def main(args):
)) ))
continue continue
response_futures.append(run_request(handler_fn, request, tracker)) response_futures.append(
run_request(chat_handler_fn, request, tracker))
tracker.submitted() tracker.submitted()
elif request.url == "/v1/embeddings": elif request.url == "/v1/embeddings":
handler_fn = (None if openai_serving_embedding is None else embed_handler_fn = (None if openai_serving_embedding is None else
openai_serving_embedding.create_embedding) openai_serving_embedding.create_embedding)
if handler_fn is None: if embed_handler_fn is None:
response_futures.append( response_futures.append(
make_async_error_request_output( make_async_error_request_output(
request, request,
...@@ -388,12 +389,13 @@ async def main(args): ...@@ -388,12 +389,13 @@ async def main(args):
)) ))
continue continue
response_futures.append(run_request(handler_fn, request, tracker)) response_futures.append(
run_request(embed_handler_fn, request, tracker))
tracker.submitted() tracker.submitted()
elif request.url == "/v1/score": elif request.url == "/v1/score":
handler_fn = (None if openai_serving_scores is None else score_handler_fn = (None if openai_serving_scores is None else
openai_serving_scores.create_score) openai_serving_scores.create_score)
if handler_fn is None: if score_handler_fn is None:
response_futures.append( response_futures.append(
make_async_error_request_output( make_async_error_request_output(
request, request,
...@@ -401,7 +403,8 @@ async def main(args): ...@@ -401,7 +403,8 @@ async def main(args):
)) ))
continue continue
response_futures.append(run_request(handler_fn, request, tracker)) response_futures.append(
run_request(score_handler_fn, request, tracker))
tracker.submitted() tracker.submitted()
else: else:
response_futures.append( response_futures.append(
......
...@@ -52,8 +52,8 @@ from vllm.utils import is_list_of, make_async, random_uuid ...@@ -52,8 +52,8 @@ from vllm.utils import is_list_of, make_async, random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest, CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
EmbeddingCompletionRequest, ScoreRequest, EmbeddingCompletionRequest, RerankRequest,
TokenizeCompletionRequest] ScoreRequest, TokenizeCompletionRequest]
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
TokenizeChatRequest] TokenizeChatRequest]
......
# SPDX-License-Identifier: Apache-2.0
import asyncio
from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument,
RerankRequest, RerankResponse,
RerankResult, RerankUsage)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import make_async, merge_async_iterators
logger = init_logger(__name__)
class JinaAIServingRerank(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger)
async def do_rerank(
self,
request: RerankRequest,
raw_request: Optional[Request] = None
) -> Union[RerankResponse, ErrorResponse]:
"""
Rerank API based on JinaAI's rerank API; implements the same
API interface. Designed for compatibility with off-the-shelf
tooling, since this is a common standard for reranking APIs
See example client implementations at
https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
numerous clients use this standard.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
model_name = request.model
request_id = f"rerank-{self._base_request_id(raw_request)}"
truncate_prompt_tokens = request.truncate_prompt_tokens
query = request.query
documents = request.documents
request_prompts = []
engine_prompts = []
top_n = request.top_n if request.top_n > 0 else len(documents)
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for scoring models")
if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"MistralTokenizer not supported for cross-encoding")
if not self.model_config.is_cross_encoder:
raise ValueError("Model is not cross encoder.")
if truncate_prompt_tokens is not None and \
truncate_prompt_tokens > self.max_model_len:
raise ValueError(
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
f"is greater than max_model_len ({self.max_model_len})."
f" Please, select a smaller truncation size.")
for doc in documents:
request_prompt = f"{query}{tokenizer.sep_token}{doc}"
tokenization_kwargs: Dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor)
prompt_inputs = await tokenize_async(text=query,
text_pair=doc,
**tokenization_kwargs)
input_ids = prompt_inputs["input_ids"]
text_token_prompt = \
self._validate_input(request, input_ids, request_prompt)
engine_prompt = TokensPrompt(
prompt_token_ids=text_token_prompt["prompt_token_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
request_prompts.append(request_prompt)
engine_prompts.append(engine_prompt)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
pooling_params = request.to_pooling_params()
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
request_prompts[i],
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator = merge_async_iterators(*generators)
num_prompts = len(engine_prompts)
# Non-streaming response
final_res_batch: List[Optional[PoolingRequestOutput]]
final_res_batch = [None] * num_prompts
try:
async for i, res in result_generator:
final_res_batch[i] = res
assert all(final_res is not None for final_res in final_res_batch)
final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch)
response = self.request_output_to_rerank_response(
final_res_batch_checked, request_id, model_name, documents,
top_n)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return response
def request_output_to_rerank_response(
self, final_res_batch: List[PoolingRequestOutput], request_id: str,
model_name: str, documents: List[str],
top_n: int) -> RerankResponse:
"""
Convert the output of do_rank to a RerankResponse
"""
results: List[RerankResult] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
classify_res = ScoringRequestOutput.from_base(final_res)
result = RerankResult(
index=idx,
document=RerankDocument(text=documents[idx]),
relevance_score=classify_res.outputs.score,
)
results.append(result)
prompt_token_ids = final_res.prompt_token_ids
num_prompt_tokens += len(prompt_token_ids)
# sort by relevance, then return the top n if set
results.sort(key=lambda x: x.relevance_score, reverse=True)
if top_n < len(documents):
results = results[:top_n]
return RerankResponse(
id=request_id,
model=model_name,
results=results,
usage=RerankUsage(total_tokens=num_prompt_tokens))
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
from typing import List, Union
from torch.nn import CosineSimilarity
from vllm.outputs import PoolingRequestOutput
from vllm.transformers_utils.tokenizer import (PreTrainedTokenizer,
PreTrainedTokenizerFast)
def _cosine_similarity(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
embed_1: List[PoolingRequestOutput],
embed_2: List[PoolingRequestOutput],
) -> List[PoolingRequestOutput]:
scorer = CosineSimilarity(0)
scores: Union[List[PoolingRequestOutput]] = []
for emb_1, emb_2 in zip(embed_1, embed_2):
pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data)
padding = []
if (pad_token_id := getattr(tokenizer, "pad_token_id",
None)) is not None:
padding = [pad_token_id]
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
scores.append(
PoolingRequestOutput(
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
outputs=pair_score,
prompt_token_ids=tokens,
finished=True))
return scores
def _validate_score_input_lens(
texts_1: Union[List[str], List[dict]],
texts_2: Union[List[str], List[dict]],
):
if len(texts_1) > 1 and len(texts_1) != len(texts_2):
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
if len(texts_1) == 0:
raise ValueError("At least one text element must be given")
if len(texts_2) == 0:
raise ValueError("At least one text_pair element must be given")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment