Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0034b09c
Unverified
Commit
0034b09c
authored
Jan 26, 2025
by
Kyle Mistele
Committed by
GitHub
Jan 26, 2025
Browse files
[Frontend] Rerank API (Jina- and Cohere-compatible API) (#12376)
Signed-off-by:
Kyle Mistele
<
kyle@mistele.com
>
parent
72bac730
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
552 additions
and
11 deletions
+552
-11
docs/source/serving/openai_compatible_server.md
docs/source/serving/openai_compatible_server.md
+92
-0
examples/online_serving/cohere_rerank_client.py
examples/online_serving/cohere_rerank_client.py
+32
-0
examples/online_serving/jinaai_rerank_client.py
examples/online_serving/jinaai_rerank_client.py
+33
-0
tests/entrypoints/openai/test_rerank.py
tests/entrypoints/openai/test_rerank.py
+87
-0
tests/entrypoints/openai/test_score.py
tests/entrypoints/openai/test_score.py
+1
-6
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+50
-1
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+46
-0
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+5
-4
vllm/entrypoints/openai/serving_rerank.py
vllm/entrypoints/openai/serving_rerank.py
+206
-0
No files found.
docs/source/serving/openai_compatible_server.md
View file @
0034b09c
...
...
@@ -50,6 +50,11 @@ In addition, we have the following custom APIs:
- Applicable to all [pooling models](../models/pooling_models.md).
- [Score API](#score-api) (`
/score
`)
- Only applicable to [cross-encoder models](../models/pooling_models.md) (`
--task score
`).
- [Re-rank API](#rerank-api) (`
/rerank
`, `
/v1/rerank
`, `
/v2/rerank
`)
- Implements [Jina AI's v1 re-rank API](https://jina.ai/reranker/)
- Also compatible with [Cohere's v1 & v2 re-rank APIs](https://docs.cohere.com/v2/reference/rerank)
- Jina and Cohere's APIs are very similar; Jina's includes extra information in the rerank endpoint's response.
- Only applicable to [cross-encoder models](../models/pooling_models.md) (`
--task score
`).
(chat-template)=
...
...
@@ -473,3 +478,90 @@ The following extra parameters are supported:
:start-after: begin-score-extra-params
:end-before: end-score-extra-params
```
(rerank-api)=
### Re-rank API
Our Re-rank API applies a cross-encoder model to predict relevant scores between a single query, and
each of a list of documents. Usually, the score for a sentence pair refers to the similarity between two sentences, on
a scale of 0 to 1.
You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
The rerank endpoints support popular re-rank models such as `BAAI/bge-reranker-base` and other models supporting the
`score` task. Additionally, `/rerank`, `/v1/rerank`, and `/v2/rerank`
endpoints are compatible with both [Jina AI's re-rank API interface](https://jina.ai/reranker/) and
[Cohere's re-rank API interface](https://docs.cohere.com/v2/reference/rerank) to ensure compatibility with
popular open-source tools.
Code example: <gh-file:examples/online_serving/jinaai_rerank_client.py>
#### Example Request
Note that the `top_n` request parameter is optional and will default to the length of the `documents` field.
Result documents will be sorted by relevance, and the `index` property can be used to determine original order.
Request:
```
bash
curl -X 'POST'
\
'http://127.0.0.1:8000/v1/rerank'
\
-H 'accept: application/json'
\
-H 'Content-Type: application/json'
\
-d '{
"model": "BAAI/bge-reranker-base",
"query": "What is the capital of France?",
"documents": [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
"Horses and cows are both animals"
]
}'
```
Response:
```
bash
{
"id": "rerank-fae51b2b664d4ed38f5969b612edff77",
"model": "BAAI/bge-reranker-base",
"usage": {
"total_tokens": 56
},
"results": [
{
"index": 1,
"document": {
"text": "The capital of France is Paris."
},
"relevance_score": 0.99853515625
},
{
"index": 0,
"document": {
"text": "The capital of Brazil is Brasilia."
},
"relevance_score": 0.0005860328674316406
}
]
}
```
#### Extra parameters
The following [pooling parameters](#pooling-params) are supported.
```
{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-rerank-pooling-params
:end-before: end-rerank-pooling-params
```
The following extra parameters are supported:
```
{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-rerank-extra-params
:end-before: end-rerank-extra-params
```
examples/online_serving/cohere_rerank_client.py
0 → 100644
View file @
0034b09c
"""
Example of using the OpenAI entrypoint's rerank API which is compatible with
the Cohere SDK: https://github.com/cohere-ai/cohere-python
run: vllm serve BAAI/bge-reranker-base
"""
import
cohere
# cohere v1 client
co
=
cohere
.
Client
(
base_url
=
"http://localhost:8000"
,
api_key
=
"sk-fake-key"
)
rerank_v1_result
=
co
.
rerank
(
model
=
"BAAI/bge-reranker-base"
,
query
=
"What is the capital of France?"
,
documents
=
[
"The capital of France is Paris"
,
"Reranking is fun!"
,
"vLLM is an open-source framework for fast AI serving"
])
print
(
rerank_v1_result
)
# or the v2
co2
=
cohere
.
ClientV2
(
"sk-fake-key"
,
base_url
=
"http://localhost:8000"
)
v2_rerank_result
=
co2
.
rerank
(
model
=
"BAAI/bge-reranker-base"
,
query
=
"What is the capital of France?"
,
documents
=
[
"The capital of France is Paris"
,
"Reranking is fun!"
,
"vLLM is an open-source framework for fast AI serving"
])
print
(
v2_rerank_result
)
examples/online_serving/jinaai_rerank_client.py
0 → 100644
View file @
0034b09c
"""
Example of using the OpenAI entrypoint's rerank API which is compatible with
Jina and Cohere https://jina.ai/reranker
run: vllm serve BAAI/bge-reranker-base
"""
import
json
import
requests
url
=
"http://127.0.0.1:8000/rerank"
headers
=
{
"accept"
:
"application/json"
,
"Content-Type"
:
"application/json"
}
data
=
{
"model"
:
"BAAI/bge-reranker-base"
,
"query"
:
"What is the capital of France?"
,
"documents"
:
[
"The capital of Brazil is Brasilia."
,
"The capital of France is Paris."
,
"Horses and cows are both animals"
]
}
response
=
requests
.
post
(
url
,
headers
=
headers
,
json
=
data
)
# Check the response
if
response
.
status_code
==
200
:
print
(
"Request successful!"
)
print
(
json
.
dumps
(
response
.
json
(),
indent
=
2
))
else
:
print
(
f
"Request failed with status code:
{
response
.
status_code
}
"
)
print
(
response
.
text
)
tests/entrypoints/openai/test_rerank.py
0 → 100644
View file @
0034b09c
import
pytest
import
requests
from
vllm.entrypoints.openai.protocol
import
RerankResponse
from
...utils
import
RemoteOpenAIServer
MODEL_NAME
=
"BAAI/bge-reranker-base"
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
():
args
=
[
"--enforce-eager"
,
"--max-model-len"
,
"100"
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
yield
remote_server
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
def
test_rerank_texts
(
server
:
RemoteOpenAIServer
,
model_name
:
str
):
query
=
"What is the capital of France?"
documents
=
[
"The capital of Brazil is Brasilia."
,
"The capital of France is Paris."
]
rerank_response
=
requests
.
post
(
server
.
url_for
(
"rerank"
),
json
=
{
"model"
:
model_name
,
"query"
:
query
,
"documents"
:
documents
,
})
rerank_response
.
raise_for_status
()
rerank
=
RerankResponse
.
model_validate
(
rerank_response
.
json
())
assert
rerank
.
id
is
not
None
assert
rerank
.
results
is
not
None
assert
len
(
rerank
.
results
)
==
2
assert
rerank
.
results
[
0
].
relevance_score
>=
0.9
assert
rerank
.
results
[
1
].
relevance_score
<=
0.01
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
def
test_top_n
(
server
:
RemoteOpenAIServer
,
model_name
:
str
):
query
=
"What is the capital of France?"
documents
=
[
"The capital of Brazil is Brasilia."
,
"The capital of France is Paris."
,
"Cross-encoder models are neat"
]
rerank_response
=
requests
.
post
(
server
.
url_for
(
"rerank"
),
json
=
{
"model"
:
model_name
,
"query"
:
query
,
"documents"
:
documents
,
"top_n"
:
2
})
rerank_response
.
raise_for_status
()
rerank
=
RerankResponse
.
model_validate
(
rerank_response
.
json
())
assert
rerank
.
id
is
not
None
assert
rerank
.
results
is
not
None
assert
len
(
rerank
.
results
)
==
2
assert
rerank
.
results
[
0
].
relevance_score
>=
0.9
assert
rerank
.
results
[
1
].
relevance_score
<=
0.01
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
def
test_rerank_max_model_len
(
server
:
RemoteOpenAIServer
,
model_name
:
str
):
query
=
"What is the capital of France?"
*
100
documents
=
[
"The capital of Brazil is Brasilia."
,
"The capital of France is Paris."
]
rerank_response
=
requests
.
post
(
server
.
url_for
(
"rerank"
),
json
=
{
"model"
:
model_name
,
"query"
:
query
,
"documents"
:
documents
})
assert
rerank_response
.
status_code
==
400
# Assert just a small fragments of the response
assert
"Please reduce the length of the input."
in
\
rerank_response
.
text
\ No newline at end of file
tests/entrypoints/openai/test_score.py
View file @
0034b09c
...
...
@@ -10,12 +10,7 @@ MODEL_NAME = "BAAI/bge-reranker-v2-m3"
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
():
args
=
[
"--enforce-eager"
,
# Will be used on tests to compare prompt input length
"--max-model-len"
,
"100"
]
args
=
[
"--enforce-eager"
,
"--max-model-len"
,
"100"
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
yield
remote_server
...
...
vllm/entrypoints/openai/api_server.py
View file @
0034b09c
...
...
@@ -56,6 +56,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
PoolingChatRequest
,
PoolingCompletionRequest
,
PoolingRequest
,
PoolingResponse
,
RerankRequest
,
RerankResponse
,
ScoreRequest
,
ScoreResponse
,
TokenizeRequest
,
TokenizeResponse
,
...
...
@@ -68,6 +69,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
from
vllm.entrypoints.openai.serving_models
import
(
BaseModelPath
,
OpenAIServingModels
)
from
vllm.entrypoints.openai.serving_pooling
import
OpenAIServingPooling
from
vllm.entrypoints.openai.serving_rerank
import
JinaAIServingRerank
from
vllm.entrypoints.openai.serving_score
import
OpenAIServingScores
from
vllm.entrypoints.openai.serving_tokenization
import
(
OpenAIServingTokenization
)
...
...
@@ -306,6 +308,10 @@ def score(request: Request) -> Optional[OpenAIServingScores]:
return
request
.
app
.
state
.
openai_serving_scores
def
rerank
(
request
:
Request
)
->
Optional
[
JinaAIServingRerank
]:
return
request
.
app
.
state
.
jinaai_serving_reranking
def
tokenization
(
request
:
Request
)
->
OpenAIServingTokenization
:
return
request
.
app
.
state
.
openai_serving_tokenization
...
...
@@ -502,6 +508,40 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
return
await
create_score
(
request
,
raw_request
)
@
router
.
post
(
"/rerank"
)
@
with_cancellation
async
def
do_rerank
(
request
:
RerankRequest
,
raw_request
:
Request
):
handler
=
rerank
(
raw_request
)
if
handler
is
None
:
return
base
(
raw_request
).
create_error_response
(
message
=
"The model does not support Rerank (Score) API"
)
generator
=
await
handler
.
do_rerank
(
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
elif
isinstance
(
generator
,
RerankResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
())
assert_never
(
generator
)
@
router
.
post
(
"/v1/rerank"
)
@
with_cancellation
async
def
do_rerank_v1
(
request
:
RerankRequest
,
raw_request
:
Request
):
logger
.
warning
(
"To indicate that the rerank API is not part of the standard OpenAI"
" API, we have located it at `/rerank`. Please update your client"
"accordingly. (Note: Conforms to JinaAI rerank API)"
)
return
await
do_rerank
(
request
,
raw_request
)
@
router
.
post
(
"/v2/rerank"
)
@
with_cancellation
async
def
do_rerank_v2
(
request
:
RerankRequest
,
raw_request
:
Request
):
return
await
do_rerank
(
request
,
raw_request
)
TASK_HANDLERS
:
Dict
[
str
,
Dict
[
str
,
tuple
]]
=
{
"generate"
:
{
"messages"
:
(
ChatCompletionRequest
,
create_chat_completion
),
...
...
@@ -512,7 +552,10 @@ TASK_HANDLERS: Dict[str, Dict[str, tuple]] = {
"default"
:
(
EmbeddingCompletionRequest
,
create_embedding
),
},
"score"
:
{
"default"
:
(
ScoreRequest
,
create_score
),
"default"
:
(
RerankRequest
,
do_rerank
)
},
"rerank"
:
{
"default"
:
(
RerankRequest
,
do_rerank
)
},
"reward"
:
{
"messages"
:
(
PoolingChatRequest
,
create_pooling
),
...
...
@@ -759,6 +802,12 @@ async def init_app_state(
state
.
openai_serving_models
,
request_logger
=
request_logger
)
if
model_config
.
task
==
"score"
else
None
state
.
jinaai_serving_reranking
=
JinaAIServingRerank
(
engine_client
,
model_config
,
state
.
openai_serving_models
,
request_logger
=
request_logger
)
if
model_config
.
task
==
"score"
else
None
state
.
openai_serving_tokenization
=
OpenAIServingTokenization
(
engine_client
,
model_config
,
...
...
vllm/entrypoints/openai/protocol.py
View file @
0034b09c
...
...
@@ -1018,6 +1018,52 @@ class ScoreRequest(OpenAIBaseModel):
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
class
RerankRequest
(
OpenAIBaseModel
):
model
:
str
query
:
str
documents
:
List
[
str
]
top_n
:
int
=
Field
(
default_factory
=
lambda
:
0
)
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
# doc: begin-rerank-pooling-params
additional_data
:
Optional
[
Any
]
=
None
# doc: end-rerank-pooling-params
# doc: begin-rerank-extra-params
priority
:
int
=
Field
(
default
=
0
,
description
=
(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
))
# doc: end-rerank-extra-params
def
to_pooling_params
(
self
):
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
class
RerankDocument
(
BaseModel
):
text
:
str
class
RerankResult
(
BaseModel
):
index
:
int
document
:
RerankDocument
relevance_score
:
float
class
RerankUsage
(
BaseModel
):
total_tokens
:
int
class
RerankResponse
(
OpenAIBaseModel
):
id
:
str
model
:
str
usage
:
RerankUsage
results
:
List
[
RerankResult
]
class
CompletionLogProbs
(
OpenAIBaseModel
):
text_offset
:
List
[
int
]
=
Field
(
default_factory
=
list
)
token_logprobs
:
List
[
Optional
[
float
]]
=
Field
(
default_factory
=
list
)
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
0034b09c
...
...
@@ -26,7 +26,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DetokenizeRequest
,
EmbeddingChatRequest
,
EmbeddingCompletionRequest
,
ErrorResponse
,
ScoreRequest
,
ErrorResponse
,
RerankRequest
,
ScoreRequest
,
TokenizeChatRequest
,
TokenizeCompletionRequest
)
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
...
...
@@ -204,9 +205,9 @@ class OpenAIServing:
token_num
=
len
(
input_ids
)
# Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens
if
isinstance
(
r
equest
,
(
EmbeddingChatRequest
,
EmbeddingCompletionRequest
,
Score
Request
)):
if
isinstance
(
request
,
(
EmbeddingChatRequest
,
EmbeddingCompletionR
equest
,
ScoreRequest
,
Rerank
Request
)):
operation
=
"score"
if
isinstance
(
request
,
ScoreRequest
)
\
else
"embedding generation"
...
...
vllm/entrypoints/openai/serving_rerank.py
0 → 100644
View file @
0034b09c
import
asyncio
from
typing
import
Any
,
AsyncGenerator
,
Dict
,
List
,
Optional
,
Union
,
cast
from
fastapi
import
Request
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.protocol
import
(
ErrorResponse
,
RerankDocument
,
RerankRequest
,
RerankResponse
,
RerankResult
,
RerankUsage
)
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.inputs.data
import
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.outputs
import
PoolingRequestOutput
,
ScoringRequestOutput
from
vllm.transformers_utils.tokenizers.mistral
import
MistralTokenizer
from
vllm.utils
import
make_async
,
merge_async_iterators
logger
=
init_logger
(
__name__
)
class
JinaAIServingRerank
(
OpenAIServing
):
def
__init__
(
self
,
engine_client
:
EngineClient
,
model_config
:
ModelConfig
,
models
:
OpenAIServingModels
,
*
,
request_logger
:
Optional
[
RequestLogger
],
)
->
None
:
super
().
__init__
(
engine_client
=
engine_client
,
model_config
=
model_config
,
models
=
models
,
request_logger
=
request_logger
)
async
def
do_rerank
(
self
,
request
:
RerankRequest
,
raw_request
:
Optional
[
Request
]
=
None
)
->
Union
[
RerankResponse
,
ErrorResponse
]:
"""
Rerank API based on JinaAI's rerank API; implements the same
API interface. Designed for compatibility with off-the-shelf
tooling, since this is a common standard for reranking APIs
See example client implementations at
https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
numerous clients use this standard.
"""
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
model_name
=
request
.
model
request_id
=
f
"rerank-
{
self
.
_base_request_id
(
raw_request
)
}
"
truncate_prompt_tokens
=
request
.
truncate_prompt_tokens
query
=
request
.
query
documents
=
request
.
documents
request_prompts
=
[]
engine_prompts
=
[]
top_n
=
request
.
top_n
if
request
.
top_n
>
0
else
len
(
documents
)
try
:
(
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
if
prompt_adapter_request
is
not
None
:
raise
NotImplementedError
(
"Prompt adapter is not supported "
"for scoring models"
)
if
isinstance
(
tokenizer
,
MistralTokenizer
):
raise
ValueError
(
"MistralTokenizer not supported for cross-encoding"
)
if
not
self
.
model_config
.
is_cross_encoder
:
raise
ValueError
(
"Model is not cross encoder."
)
if
truncate_prompt_tokens
is
not
None
and
\
truncate_prompt_tokens
>
self
.
max_model_len
:
raise
ValueError
(
f
"truncate_prompt_tokens value (
{
truncate_prompt_tokens
}
) "
f
"is greater than max_model_len (
{
self
.
max_model_len
}
)."
f
" Please, select a smaller truncation size."
)
for
doc
in
documents
:
request_prompt
=
f
"
{
query
}{
tokenizer
.
sep_token
}{
doc
}
"
tokenization_kwargs
:
Dict
[
str
,
Any
]
=
{}
if
truncate_prompt_tokens
is
not
None
:
tokenization_kwargs
[
"truncation"
]
=
True
tokenization_kwargs
[
"max_length"
]
=
truncate_prompt_tokens
tokenize_async
=
make_async
(
tokenizer
.
__call__
,
executor
=
self
.
_tokenizer_executor
)
prompt_inputs
=
await
tokenize_async
(
text
=
query
,
text_pair
=
doc
,
**
tokenization_kwargs
)
input_ids
=
prompt_inputs
[
"input_ids"
]
text_token_prompt
=
\
self
.
_validate_input
(
request
,
input_ids
,
request_prompt
)
engine_prompt
=
TokensPrompt
(
prompt_token_ids
=
text_token_prompt
[
"prompt_token_ids"
],
token_type_ids
=
prompt_inputs
.
get
(
"token_type_ids"
))
request_prompts
.
append
(
request_prompt
)
engine_prompts
.
append
(
engine_prompt
)
except
ValueError
as
e
:
logger
.
exception
(
"Error in preprocessing prompt inputs"
)
return
self
.
create_error_response
(
str
(
e
))
# Schedule the request and get the result generator.
generators
:
List
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
try
:
pooling_params
=
request
.
to_pooling_params
()
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
self
.
_log_inputs
(
request_id_item
,
request_prompts
[
i
],
params
=
pooling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
trace_headers
=
(
None
if
raw_request
is
None
else
await
self
.
_get_trace_headers
(
raw_request
.
headers
))
generator
=
self
.
engine_client
.
encode
(
engine_prompt
,
pooling_params
,
request_id_item
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
)
generators
.
append
(
generator
)
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
result_generator
=
merge_async_iterators
(
*
generators
)
num_prompts
=
len
(
engine_prompts
)
# Non-streaming response
final_res_batch
:
List
[
Optional
[
PoolingRequestOutput
]]
final_res_batch
=
[
None
]
*
num_prompts
try
:
async
for
i
,
res
in
result_generator
:
final_res_batch
[
i
]
=
res
assert
all
(
final_res
is
not
None
for
final_res
in
final_res_batch
)
final_res_batch_checked
=
cast
(
List
[
PoolingRequestOutput
],
final_res_batch
)
response
=
self
.
request_output_to_rerank_response
(
final_res_batch_checked
,
request_id
,
model_name
,
documents
,
top_n
)
except
asyncio
.
CancelledError
:
return
self
.
create_error_response
(
"Client disconnected"
)
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
return
response
def
request_output_to_rerank_response
(
self
,
final_res_batch
:
List
[
PoolingRequestOutput
],
request_id
:
str
,
model_name
:
str
,
documents
:
List
[
str
],
top_n
:
int
)
->
RerankResponse
:
"""
Convert the output of do_rank to a RerankResponse
"""
results
:
List
[
RerankResult
]
=
[]
num_prompt_tokens
=
0
for
idx
,
final_res
in
enumerate
(
final_res_batch
):
classify_res
=
ScoringRequestOutput
.
from_base
(
final_res
)
result
=
RerankResult
(
index
=
idx
,
document
=
RerankDocument
(
text
=
documents
[
idx
]),
relevance_score
=
classify_res
.
outputs
.
score
,
)
results
.
append
(
result
)
prompt_token_ids
=
final_res
.
prompt_token_ids
num_prompt_tokens
+=
len
(
prompt_token_ids
)
# sort by relevance, then return the top n if set
results
.
sort
(
key
=
lambda
x
:
x
.
relevance_score
,
reverse
=
True
)
if
top_n
<
len
(
documents
):
results
=
results
[:
top_n
]
return
RerankResponse
(
id
=
request_id
,
model
=
model_name
,
results
=
results
,
usage
=
RerankUsage
(
total_tokens
=
num_prompt_tokens
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment