Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d9d21eb8
Unverified
Commit
d9d21eb8
authored
Mar 31, 2026
by
wang.yuqi
Committed by
GitHub
Mar 31, 2026
Browse files
[Frontend][3/n] Improve pooling entrypoints | scoring. (#28631)
Signed-off-by:
wang.yuqi
<
yuqi.wang@daocloud.io
>
parent
f09daea2
Changes
37
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
854 additions
and
1154 deletions
+854
-1154
vllm/entrypoints/openai/engine/serving.py
vllm/entrypoints/openai/engine/serving.py
+4
-104
vllm/entrypoints/openai/run_batch.py
vllm/entrypoints/openai/run_batch.py
+27
-27
vllm/entrypoints/pooling/__init__.py
vllm/entrypoints/pooling/__init__.py
+6
-21
vllm/entrypoints/pooling/base/io_processor.py
vllm/entrypoints/pooling/base/io_processor.py
+21
-25
vllm/entrypoints/pooling/base/serving.py
vllm/entrypoints/pooling/base/serving.py
+18
-4
vllm/entrypoints/pooling/io_processor_factories.py
vllm/entrypoints/pooling/io_processor_factories.py
+7
-0
vllm/entrypoints/pooling/score/serving.py
vllm/entrypoints/pooling/score/serving.py
+0
-667
vllm/entrypoints/pooling/scoring/__init__.py
vllm/entrypoints/pooling/scoring/__init__.py
+0
-0
vllm/entrypoints/pooling/scoring/api_router.py
vllm/entrypoints/pooling/scoring/api_router.py
+5
-29
vllm/entrypoints/pooling/scoring/io_processor.py
vllm/entrypoints/pooling/scoring/io_processor.py
+419
-0
vllm/entrypoints/pooling/scoring/protocol.py
vllm/entrypoints/pooling/scoring/protocol.py
+17
-14
vllm/entrypoints/pooling/scoring/serving.py
vllm/entrypoints/pooling/scoring/serving.py
+160
-0
vllm/entrypoints/pooling/scoring/typing.py
vllm/entrypoints/pooling/scoring/typing.py
+46
-0
vllm/entrypoints/pooling/scoring/utils.py
vllm/entrypoints/pooling/scoring/utils.py
+71
-246
vllm/entrypoints/pooling/typing.py
vllm/entrypoints/pooling/typing.py
+29
-12
vllm/entrypoints/pooling/utils.py
vllm/entrypoints/pooling/utils.py
+19
-0
vllm/entrypoints/sagemaker/api_router.py
vllm/entrypoints/sagemaker/api_router.py
+5
-5
No files found.
vllm/entrypoints/openai/engine/serving.py
View file @
d9d21eb8
...
@@ -11,9 +11,7 @@ from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
...
@@ -11,9 +11,7 @@ from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
import
numpy
as
np
import
numpy
as
np
from
fastapi
import
Request
from
fastapi
import
Request
from
openai.types.responses
import
(
from
openai.types.responses
import
ToolChoiceFunction
ToolChoiceFunction
,
)
from
pydantic
import
ConfigDict
,
TypeAdapter
,
ValidationError
from
pydantic
import
ConfigDict
,
TypeAdapter
,
ValidationError
from
starlette.datastructures
import
Headers
from
starlette.datastructures
import
Headers
...
@@ -21,9 +19,7 @@ import vllm.envs as envs
...
@@ -21,9 +19,7 @@ import vllm.envs as envs
from
vllm.beam_search
import
BeamSearchSequence
,
create_sort_beams_key_function
from
vllm.beam_search
import
BeamSearchSequence
,
create_sort_beams_key_function
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.chat_utils
import
(
from
vllm.entrypoints.chat_utils
import
ChatTemplateContentFormatOption
ChatTemplateContentFormatOption
,
)
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.chat_completion.protocol
import
(
from
vllm.entrypoints.openai.chat_completion.protocol
import
(
BatchChatCompletionRequest
,
BatchChatCompletionRequest
,
...
@@ -42,9 +38,7 @@ from vllm.entrypoints.openai.engine.protocol import (
...
@@ -42,9 +38,7 @@ from vllm.entrypoints.openai.engine.protocol import (
GenerationError
,
GenerationError
,
)
)
from
vllm.entrypoints.openai.models.serving
import
OpenAIServingModels
from
vllm.entrypoints.openai.models.serving
import
OpenAIServingModels
from
vllm.entrypoints.openai.responses.protocol
import
(
from
vllm.entrypoints.openai.responses.protocol
import
ResponsesRequest
ResponsesRequest
,
)
from
vllm.entrypoints.openai.speech_to_text.protocol
import
(
from
vllm.entrypoints.openai.speech_to_text.protocol
import
(
TranscriptionRequest
,
TranscriptionRequest
,
TranscriptionResponse
,
TranscriptionResponse
,
...
@@ -56,14 +50,6 @@ from vllm.entrypoints.pooling.pooling.protocol import (
...
@@ -56,14 +50,6 @@ from vllm.entrypoints.pooling.pooling.protocol import (
PoolingCompletionRequest
,
PoolingCompletionRequest
,
PoolingResponse
,
PoolingResponse
,
)
)
from
vllm.entrypoints.pooling.score.protocol
import
(
RerankRequest
,
ScoreDataRequest
,
ScoreQueriesDocumentsRequest
,
ScoreRequest
,
ScoreResponse
,
ScoreTextRequest
,
)
from
vllm.entrypoints.serve.disagg.protocol
import
GenerateRequest
,
GenerateResponse
from
vllm.entrypoints.serve.disagg.protocol
import
GenerateRequest
,
GenerateResponse
from
vllm.entrypoints.serve.tokenize.protocol
import
(
from
vllm.entrypoints.serve.tokenize.protocol
import
(
DetokenizeRequest
,
DetokenizeRequest
,
...
@@ -72,8 +58,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
...
@@ -72,8 +58,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeResponse
,
TokenizeResponse
,
)
)
from
vllm.entrypoints.utils
import
create_error_response
from
vllm.entrypoints.utils
import
create_error_response
from
vllm.exceptions
import
VLLMValidationError
from
vllm.inputs
import
EngineInput
,
PromptType
from
vllm.inputs
import
EngineInput
,
PromptType
,
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.logprobs
import
Logprob
,
PromptLogprobs
from
vllm.logprobs
import
Logprob
,
PromptLogprobs
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -119,8 +104,6 @@ CompletionLikeRequest: TypeAlias = (
...
@@ -119,8 +104,6 @@ CompletionLikeRequest: TypeAlias = (
CompletionRequest
CompletionRequest
|
TokenizeCompletionRequest
|
TokenizeCompletionRequest
|
DetokenizeRequest
|
DetokenizeRequest
|
RerankRequest
|
ScoreRequest
|
PoolingCompletionRequest
|
PoolingCompletionRequest
)
)
...
@@ -148,7 +131,6 @@ AnyResponse: TypeAlias = (
...
@@ -148,7 +131,6 @@ AnyResponse: TypeAlias = (
|
TranscriptionResponse
|
TranscriptionResponse
|
TokenizeResponse
|
TokenizeResponse
|
PoolingResponse
|
PoolingResponse
|
ScoreResponse
|
GenerateResponse
|
GenerateResponse
)
)
...
@@ -692,88 +674,6 @@ class OpenAIServing:
...
@@ -692,88 +674,6 @@ class OpenAIServing:
message_types
.
add
(
content_dict
[
"type"
].
split
(
"_"
)[
0
])
message_types
.
add
(
content_dict
[
"type"
].
split
(
"_"
)[
0
])
return
message_types
return
message_types
def
_validate_input
(
self
,
request
:
object
,
input_ids
:
list
[
int
],
input_text
:
str
,
)
->
TokensPrompt
:
token_num
=
len
(
input_ids
)
max_model_len
=
self
.
model_config
.
max_model_len
# Note: ScoreRequest doesn't have max_tokens
if
isinstance
(
request
,
(
ScoreDataRequest
,
ScoreTextRequest
,
ScoreQueriesDocumentsRequest
,
RerankRequest
,
),
):
# Note: input length can be up to the entire model context length
# since these requests don't generate tokens.
if
token_num
>
max_model_len
:
operations
:
dict
[
type
[
AnyRequest
],
str
]
=
{
ScoreDataRequest
:
"score"
,
ScoreTextRequest
:
"score"
,
ScoreQueriesDocumentsRequest
:
"score"
,
}
operation
=
operations
.
get
(
type
(
request
),
"embedding generation"
)
raise
VLLMValidationError
(
f
"This model's maximum context length is "
f
"
{
max_model_len
}
tokens. However, you requested "
f
"
{
token_num
}
tokens in the input for
{
operation
}
. "
f
"Please reduce the length of the input prompt."
,
parameter
=
"input_tokens"
,
value
=
token_num
,
)
return
TokensPrompt
(
prompt
=
input_text
,
prompt_token_ids
=
input_ids
)
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation
if
isinstance
(
request
,
(
TokenizeCompletionRequest
,
TokenizeChatRequest
,
DetokenizeRequest
),
):
return
TokensPrompt
(
prompt
=
input_text
,
prompt_token_ids
=
input_ids
)
# chat completion endpoint supports max_completion_tokens
if
isinstance
(
request
,
ChatCompletionRequest
):
# TODO(#9845): remove max_tokens when field dropped from OpenAI API
max_tokens
=
request
.
max_completion_tokens
or
request
.
max_tokens
else
:
max_tokens
=
getattr
(
request
,
"max_tokens"
,
None
)
# Note: input length can be up to model context length - 1 for
# completion-like requests.
if
token_num
>=
max_model_len
:
raise
VLLMValidationError
(
f
"This model's maximum context length is "
f
"
{
max_model_len
}
tokens. However, your request has "
f
"
{
token_num
}
input tokens. Please reduce the length of "
"the input messages."
,
parameter
=
"input_tokens"
,
value
=
token_num
,
)
if
max_tokens
is
not
None
and
token_num
+
max_tokens
>
max_model_len
:
raise
VLLMValidationError
(
f
"This model's maximum context length is "
f
"
{
max_model_len
}
tokens. However, you requested "
f
"
{
max_tokens
}
output tokens and your prompt contains "
f
"
{
token_num
}
input tokens, for a total of "
f
"
{
token_num
+
max_tokens
}
tokens "
f
"(
{
token_num
}
+
{
max_tokens
}
= "
f
"
{
token_num
+
max_tokens
}
>
{
max_model_len
}
). "
f
"Please reduce the length of the input prompt or the "
f
"number of requested output tokens."
,
parameter
=
"max_tokens"
,
value
=
max_tokens
,
)
return
TokensPrompt
(
prompt
=
input_text
,
prompt_token_ids
=
input_ids
)
def
_validate_chat_template
(
def
_validate_chat_template
(
self
,
self
,
request_chat_template
:
str
|
None
,
request_chat_template
:
str
|
None
,
...
...
vllm/entrypoints/openai/run_batch.py
View file @
d9d21eb8
...
@@ -2,6 +2,8 @@
...
@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
asyncio
import
contextlib
import
json
import
sys
import
sys
import
tempfile
import
tempfile
from
argparse
import
Namespace
from
argparse
import
Namespace
...
@@ -13,12 +15,14 @@ from urllib.parse import urlparse
...
@@ -13,12 +15,14 @@ from urllib.parse import urlparse
import
aiohttp
import
aiohttp
import
pybase64
as
base64
import
pybase64
as
base64
import
pydantic
import
torch
import
torch
from
fastapi
import
UploadFile
from
fastapi
import
UploadFile
from
prometheus_client
import
start_http_server
from
prometheus_client
import
start_http_server
from
pydantic
import
Field
,
TypeAdapter
,
field_validator
,
model_validator
from
pydantic
import
Field
,
TypeAdapter
,
field_validator
,
model_validator
from
pydantic_core.core_schema
import
ValidationInfo
from
pydantic_core.core_schema
import
ValidationInfo
from
starlette.datastructures
import
State
from
starlette.datastructures
import
State
from
starlette.responses
import
JSONResponse
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
urllib3.util
import
parse_url
from
urllib3.util
import
parse_url
...
@@ -49,7 +53,7 @@ from vllm.entrypoints.pooling.embed.protocol import (
...
@@ -49,7 +53,7 @@ from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingRequest
,
EmbeddingRequest
,
EmbeddingResponse
,
EmbeddingResponse
,
)
)
from
vllm.entrypoints.pooling.scor
e
.protocol
import
(
from
vllm.entrypoints.pooling.scor
ing
.protocol
import
(
RerankRequest
,
RerankRequest
,
RerankResponse
,
RerankResponse
,
ScoreRequest
,
ScoreRequest
,
...
@@ -180,6 +184,18 @@ class BatchRequestInput(OpenAIBaseModel):
...
@@ -180,6 +184,18 @@ class BatchRequestInput(OpenAIBaseModel):
return
TypeAdapter
(
BatchRequestInputBody
).
validate_python
(
value
)
return
TypeAdapter
(
BatchRequestInputBody
).
validate_python
(
value
)
AllResponse
:
TypeAlias
=
(
ChatCompletionResponse
|
EmbeddingResponse
|
ScoreResponse
|
RerankResponse
|
TranscriptionResponse
|
TranscriptionResponseVerbose
|
TranslationResponse
|
TranslationResponseVerbose
)
class
BatchResponseData
(
OpenAIBaseModel
):
class
BatchResponseData
(
OpenAIBaseModel
):
# HTTP status code of the response.
# HTTP status code of the response.
status_code
:
int
=
200
status_code
:
int
=
200
...
@@ -188,17 +204,7 @@ class BatchResponseData(OpenAIBaseModel):
...
@@ -188,17 +204,7 @@ class BatchResponseData(OpenAIBaseModel):
request_id
:
str
request_id
:
str
# The body of the response.
# The body of the response.
body
:
(
body
:
AllResponse
|
None
=
None
ChatCompletionResponse
|
EmbeddingResponse
|
ScoreResponse
|
RerankResponse
|
TranscriptionResponse
|
TranscriptionResponseVerbose
|
TranslationResponse
|
TranslationResponseVerbose
|
None
)
=
None
class
BatchRequestOutput
(
OpenAIBaseModel
):
class
BatchRequestOutput
(
OpenAIBaseModel
):
...
@@ -536,19 +542,13 @@ async def run_request(
...
@@ -536,19 +542,13 @@ async def run_request(
except
Exception
as
e
:
except
Exception
as
e
:
response
=
create_error_response
(
e
)
response
=
create_error_response
(
e
)
if
isinstance
(
if
isinstance
(
response
,
JSONResponse
):
response
,
with
contextlib
.
suppress
(
pydantic
.
ValidationError
):
(
response
=
TypeAdapter
(
AllResponse
|
ErrorResponse
).
validate_python
(
ChatCompletionResponse
,
json
.
loads
(
response
.
body
)
EmbeddingResponse
,
)
ScoreResponse
,
RerankResponse
,
if
isinstance
(
response
,
AllResponse
):
TranscriptionResponse
,
TranscriptionResponseVerbose
,
TranslationResponse
,
TranslationResponseVerbose
,
),
):
batch_output
=
BatchRequestOutput
(
batch_output
=
BatchRequestOutput
(
id
=
f
"vllm-
{
random_uuid
()
}
"
,
id
=
f
"vllm-
{
random_uuid
()
}
"
,
custom_id
=
request
.
custom_id
,
custom_id
=
request
.
custom_id
,
...
@@ -745,14 +745,14 @@ async def build_endpoint_registry(
...
@@ -745,14 +745,14 @@ async def build_endpoint_registry(
"score"
:
{
"score"
:
{
"url_matcher"
:
lambda
url
:
url
.
endswith
(
"/score"
),
"url_matcher"
:
lambda
url
:
url
.
endswith
(
"/score"
),
"handler_getter"
:
lambda
:
(
"handler_getter"
:
lambda
:
(
serving_scores
.
create_score
if
serving_scores
is
not
None
else
None
serving_scores
if
serving_scores
is
not
None
else
None
),
),
"wrapper_fn"
:
None
,
"wrapper_fn"
:
None
,
},
},
"rerank"
:
{
"rerank"
:
{
"url_matcher"
:
lambda
url
:
url
.
endswith
(
"/rerank"
),
"url_matcher"
:
lambda
url
:
url
.
endswith
(
"/rerank"
),
"handler_getter"
:
lambda
:
(
"handler_getter"
:
lambda
:
(
serving_scores
.
do_rerank
if
serving_scores
is
not
None
else
None
serving_scores
if
serving_scores
is
not
None
else
None
),
),
"wrapper_fn"
:
None
,
"wrapper_fn"
:
None
,
},
},
...
...
vllm/entrypoints/pooling/__init__.py
View file @
d9d21eb8
...
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING
...
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING
from
fastapi
import
FastAPI
from
fastapi
import
FastAPI
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.entrypoints.pooling.utils
import
enable_scoring_api
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -23,23 +24,6 @@ else:
...
@@ -23,23 +24,6 @@ else:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
enable_scoring_api
(
supported_tasks
:
tuple
[
"SupportedTask"
,
...],
model_config
:
ModelConfig
|
None
=
None
,
)
->
bool
:
if
any
(
t
in
supported_tasks
for
t
in
(
"embed"
,
"token_embed"
)):
return
True
if
model_config
is
not
None
and
"classify"
in
supported_tasks
:
num_labels
=
getattr
(
model_config
.
hf_config
,
"num_labels"
,
0
)
if
num_labels
!=
1
:
logger
.
debug_once
(
"Score API is only enabled for num_labels == 1."
)
return
False
return
True
return
False
def
register_pooling_api_routers
(
def
register_pooling_api_routers
(
app
:
FastAPI
,
app
:
FastAPI
,
supported_tasks
:
tuple
[
"SupportedTask"
,
...],
supported_tasks
:
tuple
[
"SupportedTask"
,
...],
...
@@ -68,7 +52,7 @@ def register_pooling_api_routers(
...
@@ -68,7 +52,7 @@ def register_pooling_api_routers(
app
.
include_router
(
embed_router
)
app
.
include_router
(
embed_router
)
if
enable_scoring_api
(
supported_tasks
,
model_config
):
if
enable_scoring_api
(
supported_tasks
,
model_config
):
from
vllm.entrypoints.pooling.scor
e
.api_router
import
router
as
score_router
from
vllm.entrypoints.pooling.scor
ing
.api_router
import
router
as
score_router
app
.
include_router
(
score_router
)
app
.
include_router
(
score_router
)
...
@@ -84,7 +68,7 @@ def init_pooling_state(
...
@@ -84,7 +68,7 @@ def init_pooling_state(
from
vllm.entrypoints.pooling.classify.serving
import
ServingClassification
from
vllm.entrypoints.pooling.classify.serving
import
ServingClassification
from
vllm.entrypoints.pooling.embed.serving
import
ServingEmbedding
from
vllm.entrypoints.pooling.embed.serving
import
ServingEmbedding
from
vllm.entrypoints.pooling.pooling.serving
import
OpenAIServingPooling
from
vllm.entrypoints.pooling.pooling.serving
import
OpenAIServingPooling
from
vllm.entrypoints.pooling.scor
e
.serving
import
ServingScores
from
vllm.entrypoints.pooling.scor
ing
.serving
import
ServingScores
from
vllm.tasks
import
POOLING_TASKS
from
vllm.tasks
import
POOLING_TASKS
model_config
=
engine_client
.
model_config
model_config
=
engine_client
.
model_config
...
@@ -136,8 +120,9 @@ def init_pooling_state(
...
@@ -136,8 +120,9 @@ def init_pooling_state(
engine_client
,
engine_client
,
state
.
openai_serving_models
,
state
.
openai_serving_models
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
score_template
=
resolved_chat_template
,
chat_template
=
resolved_chat_template
,
log_error_stack
=
args
.
log_error_stack
,
chat_template_content_format
=
args
.
chat_template_content_format
,
trust_request_chat_template
=
args
.
trust_request_chat_template
,
)
)
if
enable_scoring_api
(
supported_tasks
,
model_config
)
if
enable_scoring_api
(
supported_tasks
,
model_config
)
else
None
else
None
...
...
vllm/entrypoints/pooling/base/io_processor.py
View file @
d9d21eb8
...
@@ -13,13 +13,16 @@ from vllm.entrypoints.chat_utils import (
...
@@ -13,13 +13,16 @@ from vllm.entrypoints.chat_utils import (
ConversationMessage
,
ConversationMessage
,
)
)
from
vllm.entrypoints.openai.engine.serving
import
RendererChatRequest
,
RendererRequest
from
vllm.entrypoints.openai.engine.serving
import
RendererChatRequest
,
RendererRequest
from
vllm.entrypoints.pooling.scoring.typing
import
ScoringData
from
vllm.entrypoints.pooling.typing
import
(
from
vllm.entrypoints.pooling.typing
import
(
OfflineInputsContext
,
OfflineOutputsContext
,
PoolingChatLikeRequest
,
PoolingChatLikeRequest
,
PoolingCompletionLikeRequest
,
PoolingCompletionLikeRequest
,
PoolingServeContext
,
PoolingServeContext
,
)
)
from
vllm.inputs
import
EngineInput
,
SingletonPrompt
from
vllm.inputs
import
EngineInput
,
SingletonPrompt
from
vllm.renderers
import
BaseRenderer
,
merge_kwargs
from
vllm.renderers
import
BaseRenderer
,
TokenizeParams
,
merge_kwargs
from
vllm.renderers.inputs.preprocess
import
parse_model_prompt
,
prompt_to_seq
from
vllm.renderers.inputs.preprocess
import
parse_model_prompt
,
prompt_to_seq
from
vllm.tool_parsers
import
ToolParser
from
vllm.tool_parsers
import
ToolParser
from
vllm.utils.mistral
import
is_mistral_tokenizer
from
vllm.utils.mistral
import
is_mistral_tokenizer
...
@@ -96,29 +99,29 @@ class PoolingIOProcessor:
...
@@ -96,29 +99,29 @@ class PoolingIOProcessor:
#######################################
#######################################
# offline APIs
# offline APIs
def
pre_process_offline
(
def
pre_process_offline
(
self
,
ctx
:
OfflineInputsContext
)
->
Sequence
[
EngineInput
]:
self
,
assert
not
isinstance
(
ctx
.
prompts
,
ScoringData
)
prompts
:
PromptType
|
Sequence
[
PromptType
],
tok_params
=
self
.
renderer
.
default_cmpl_tok_params
.
with_kwargs
(
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
**
(
ctx
.
tokenization_kwargs
or
{})
)
->
Sequence
[
EngineInput
]:
)
return
self
.
_preprocess_completion_offline
(
return
self
.
_preprocess_completion_offline
(
prompts
=
prompts
,
tok
enization_kwargs
=
tokenization_kwarg
s
prompts
=
ctx
.
prompts
,
tok
_params
=
tok_param
s
)
)
async
def
pre_process_offline_async
(
self
,
*
args
,
**
kwargs
):
async
def
pre_process_offline_async
(
self
,
ctx
:
OfflineInputsContext
):
return
self
.
pre_process_offline
(
*
args
,
**
kwargs
)
return
self
.
pre_process_offline
(
ctx
)
def
post_process_offline
(
def
post_process_offline
(
self
,
self
,
outputs
:
list
[
PoolingRequestOutput
]
,
ctx
:
OfflineOutputsContext
,
)
->
list
[
PoolingRequestOutput
]:
)
->
list
[
PoolingRequestOutput
]:
return
outputs
return
ctx
.
outputs
async
def
post_process_offline_async
(
async
def
post_process_offline_async
(
self
,
self
,
outputs
:
list
[
PoolingRequestOutput
]
,
ctx
:
OfflineOutputsContext
,
)
->
list
[
PoolingRequestOutput
]:
)
->
list
[
PoolingRequestOutput
]:
return
self
.
post_process_offline
(
outputs
)
return
self
.
post_process_offline
(
ctx
)
#######################################
#######################################
# helpers
# helpers
...
@@ -204,28 +207,21 @@ class PoolingIOProcessor:
...
@@ -204,28 +207,21 @@ class PoolingIOProcessor:
def
_preprocess_completion_offline
(
def
_preprocess_completion_offline
(
self
,
self
,
prompts
:
PromptType
|
Sequence
[
PromptType
],
prompts
:
PromptType
|
Sequence
[
PromptType
],
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
tok_params
:
TokenizeParams
,
prompt_extras
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
Sequence
[
EngineInput
]:
)
->
Sequence
[
EngineInput
]:
renderer
=
self
.
renderer
model_config
=
self
.
model_config
prompts
=
prompt_to_seq
(
prompts
)
prompts
=
prompt_to_seq
(
prompts
)
parsed_prompts
=
[
parsed_prompts
=
[
(
(
prompt
prompt
if
isinstance
(
prompt
,
bytes
)
if
isinstance
(
prompt
,
bytes
)
else
parse_model_prompt
(
model_config
,
prompt
)
else
parse_model_prompt
(
self
.
model_config
,
prompt
)
)
)
for
prompt
in
prompts
for
prompt
in
prompts
]
]
tok_params
=
renderer
.
default_cmpl_tok_params
.
with_kwargs
(
**
(
tokenization_kwargs
or
{})
)
return
renderer
.
render_cmpl
(
return
self
.
renderer
.
render_cmpl
(
parsed_prompts
,
parsed_prompts
,
tok_params
,
prompt_extras
=
prompt_extras
tok_params
,
)
)
def
_validate_chat_template
(
def
_validate_chat_template
(
...
...
vllm/entrypoints/pooling/base/serving.py
View file @
d9d21eb8
...
@@ -117,8 +117,16 @@ class PoolingServing:
...
@@ -117,8 +117,16 @@ class PoolingServing:
else
await
self
.
_get_trace_headers
(
ctx
.
raw_request
.
headers
)
else
await
self
.
_get_trace_headers
(
ctx
.
raw_request
.
headers
)
)
)
pooling_params
=
self
.
io_processor
.
create_pooling_params
(
ctx
.
request
)
if
ctx
.
pooling_params
is
None
:
pooling_params
.
verify
(
self
.
model_config
)
pooling_params
=
self
.
io_processor
.
create_pooling_params
(
ctx
.
request
)
else
:
pooling_params
=
ctx
.
pooling_params
if
isinstance
(
pooling_params
,
list
):
for
params
in
pooling_params
:
params
.
verify
(
self
.
model_config
)
else
:
pooling_params
.
verify
(
self
.
model_config
)
for
i
,
engine_input
in
enumerate
(
ctx
.
engine_inputs
):
for
i
,
engine_input
in
enumerate
(
ctx
.
engine_inputs
):
prompt_request_id
=
(
prompt_request_id
=
(
...
@@ -127,16 +135,22 @@ class PoolingServing:
...
@@ -127,16 +135,22 @@ class PoolingServing:
else
ctx
.
prompt_request_ids
[
i
]
else
ctx
.
prompt_request_ids
[
i
]
)
)
params
=
(
pooling_params
[
i
]
if
isinstance
(
pooling_params
,
list
)
else
pooling_params
)
self
.
_log_inputs
(
self
.
_log_inputs
(
prompt_request_id
,
prompt_request_id
,
engine_input
,
engine_input
,
params
=
pooling_
params
,
params
=
params
,
lora_request
=
ctx
.
lora_request
,
lora_request
=
ctx
.
lora_request
,
)
)
generator
=
self
.
engine_client
.
encode
(
generator
=
self
.
engine_client
.
encode
(
engine_input
,
engine_input
,
pooling_
params
,
params
,
prompt_request_id
,
prompt_request_id
,
lora_request
=
ctx
.
lora_request
,
lora_request
=
ctx
.
lora_request
,
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
...
...
vllm/entrypoints/pooling/io_processor_factories.py
View file @
d9d21eb8
...
@@ -5,6 +5,8 @@
...
@@ -5,6 +5,8 @@
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.entrypoints.chat_utils
import
ChatTemplateConfig
from
vllm.entrypoints.chat_utils
import
ChatTemplateConfig
from
vllm.entrypoints.pooling.base.io_processor
import
PoolingIOProcessor
from
vllm.entrypoints.pooling.base.io_processor
import
PoolingIOProcessor
from
vllm.entrypoints.pooling.scoring.io_processor
import
ScoringIOProcessors
from
vllm.entrypoints.pooling.utils
import
enable_scoring_api
from
vllm.renderers
import
BaseRenderer
from
vllm.renderers
import
BaseRenderer
from
vllm.tasks
import
SupportedTask
from
vllm.tasks
import
SupportedTask
...
@@ -25,6 +27,11 @@ def init_pooling_io_processors(
...
@@ -25,6 +27,11 @@ def init_pooling_io_processors(
processors
.
append
((
"embed"
,
EmbedIOProcessor
))
processors
.
append
((
"embed"
,
EmbedIOProcessor
))
if
enable_scoring_api
(
supported_tasks
,
model_config
):
score_type
=
model_config
.
score_type
if
score_type
is
not
None
and
score_type
in
ScoringIOProcessors
:
processors
.
append
((
score_type
,
ScoringIOProcessors
[
score_type
]))
return
{
return
{
task
:
processor_cls
(
task
:
processor_cls
(
model_config
=
model_config
,
model_config
=
model_config
,
...
...
vllm/entrypoints/pooling/score/serving.py
deleted
100644 → 0
View file @
f09daea2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
time
from
collections.abc
import
AsyncGenerator
,
Mapping
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
Any
from
fastapi
import
Request
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.engine.protocol
import
(
ErrorResponse
,
UsageInfo
,
)
from
vllm.entrypoints.openai.engine.serving
import
OpenAIServing
from
vllm.entrypoints.openai.models.serving
import
OpenAIServingModels
from
vllm.entrypoints.pooling.score.protocol
import
(
RerankDocument
,
RerankRequest
,
RerankResponse
,
RerankResult
,
RerankUsage
,
ScoreRequest
,
ScoreResponse
,
ScoreResponseData
,
)
from
vllm.entrypoints.pooling.score.utils
import
(
ScoreData
,
ScoreInputs
,
_cosine_similarity
,
compress_token_type_ids
,
get_score_prompt
,
parse_score_data_single
,
validate_score_input
,
)
from
vllm.inputs
import
EngineInput
,
TokensPrompt
,
tokens_input
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
PoolingRequestOutput
,
ScoringRequestOutput
from
vllm.tokenizers
import
TokenizerLike
from
vllm.utils.async_utils
import
make_async
,
merge_async_iterators
from
vllm.utils.mistral
import
is_mistral_tokenizer
from
vllm.v1.pool.late_interaction
import
(
build_late_interaction_doc_params
,
build_late_interaction_query_params
,
)
logger
=
init_logger
(
__name__
)
class
ServingScores
(
OpenAIServing
):
def
__init__
(
self
,
engine_client
:
EngineClient
,
models
:
OpenAIServingModels
,
*
,
request_logger
:
RequestLogger
|
None
,
score_template
:
str
|
None
=
None
,
log_error_stack
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
engine_client
=
engine_client
,
models
=
models
,
request_logger
=
request_logger
,
)
self
.
score_template
=
score_template
self
.
_tokenizer_executor
=
ThreadPoolExecutor
(
max_workers
=
1
)
self
.
score_type
=
self
.
model_config
.
score_type
self
.
architecture
=
self
.
model_config
.
architecture
self
.
is_multimodal_model
=
self
.
model_config
.
is_multimodal_model
if
self
.
score_type
==
"cross-encoder"
:
self
.
_score_func
=
self
.
_cross_encoding_score
elif
self
.
score_type
==
"late-interaction"
:
self
.
_score_func
=
self
.
_late_interaction_score
else
:
# "bi-encoder"
self
.
_score_func
=
self
.
_embedding_score
async
def
_embedding_score
(
self
,
data_1
:
list
[
ScoreData
],
data_2
:
list
[
ScoreData
],
request
:
RerankRequest
|
ScoreRequest
,
request_id
:
str
,
lora_request
:
LoRARequest
|
None
|
None
=
None
,
trace_headers
:
Mapping
[
str
,
str
]
|
None
=
None
,
)
->
list
[
PoolingRequestOutput
]
|
ErrorResponse
:
input_texts
:
list
[
str
]
=
[]
for
text
in
data_1
+
data_2
:
if
not
isinstance
(
text
,
str
):
raise
NotImplementedError
(
"Embedding scores currently do not support multimodal input."
)
input_texts
.
append
(
text
)
model_config
=
self
.
model_config
tokenizer
=
self
.
renderer
.
get_tokenizer
()
encode_async
=
make_async
(
tokenizer
.
encode
,
executor
=
self
.
_tokenizer_executor
,
)
tokenization_kwargs
=
request
.
build_tok_params
(
model_config
).
get_encode_kwargs
()
tokenized_prompts
=
await
asyncio
.
gather
(
*
(
encode_async
(
t
,
**
tokenization_kwargs
)
for
t
in
input_texts
)
)
engine_inputs
:
list
[
EngineInput
]
=
[]
for
tok_result
,
input_text
in
zip
(
tokenized_prompts
,
input_texts
):
text_token_prompt
=
self
.
_validate_input
(
request
,
tok_result
,
input_text
)
engine_inputs
.
append
(
tokens_input
(
text_token_prompt
[
"prompt_token_ids"
],
prompt
=
input_text
,
)
)
# Schedule the request and get the result generator.
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
pooling_params
=
request
.
to_pooling_params
(
"embed"
)
for
i
,
engine_input
in
enumerate
(
engine_inputs
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
self
.
_log_inputs
(
request_id_item
,
engine_input
,
params
=
pooling_params
,
lora_request
=
lora_request
,
)
generators
.
append
(
self
.
engine_client
.
encode
(
engine_input
,
pooling_params
,
request_id_item
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
)
)
result_generator
=
merge_async_iterators
(
*
generators
)
# Non-streaming response
final_res_batch
:
list
[
PoolingRequestOutput
]
=
[]
embeddings
:
list
[
PoolingRequestOutput
|
None
]
=
[
None
]
*
len
(
engine_inputs
)
async
for
i
,
res
in
result_generator
:
embeddings
[
i
]
=
res
emb_data_1
:
list
[
PoolingRequestOutput
]
=
[]
emb_data_2
:
list
[
PoolingRequestOutput
]
=
[]
for
i
in
range
(
0
,
len
(
data_1
)):
assert
(
emb
:
=
embeddings
[
i
])
is
not
None
emb_data_1
.
append
(
emb
)
for
i
in
range
(
len
(
data_1
),
len
(
embeddings
)):
assert
(
emb
:
=
embeddings
[
i
])
is
not
None
emb_data_2
.
append
(
emb
)
if
len
(
emb_data_1
)
==
1
:
emb_data_1
=
emb_data_1
*
len
(
emb_data_2
)
final_res_batch
=
_cosine_similarity
(
tokenizer
=
tokenizer
,
embed_1
=
emb_data_1
,
embed_2
=
emb_data_2
)
return
final_res_batch
def
_preprocess_late_interaction_item
(
self
,
data
:
ScoreData
,
role
:
str
,
request
:
RerankRequest
|
ScoreRequest
,
tokenizer
:
TokenizerLike
,
tokenization_kwargs
:
dict
[
str
,
Any
],
)
->
TokensPrompt
:
"""Parse a single ScoreData into a text + optional multimodal
TokensPrompt for late-interaction encoding.
For plain strings, tokenises directly.
For multimodal content parts, extracts text and multi_modal_data.
"""
model_config
=
self
.
model_config
if
isinstance
(
data
,
str
):
text
,
mm_data
,
mm_uuids
=
data
,
None
,
None
else
:
text
,
mm_data
,
mm_uuids
=
parse_score_data_single
(
data
,
role
,
model_config
)
prompt_ids
=
tokenizer
.
encode
(
text
,
**
tokenization_kwargs
)
self
.
_validate_input
(
request
,
prompt_ids
,
text
)
tok_prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_ids
,
prompt
=
text
,
)
if
mm_data
is
not
None
:
tok_prompt
[
"multi_modal_data"
]
=
mm_data
if
mm_uuids
is
not
None
:
tok_prompt
[
"multi_modal_uuids"
]
=
mm_uuids
if
request
.
mm_processor_kwargs
is
not
None
:
tok_prompt
[
"mm_processor_kwargs"
]
=
request
.
mm_processor_kwargs
return
tok_prompt
async
def
_late_interaction_score
(
self
,
data_1
:
list
[
ScoreData
],
data_2
:
list
[
ScoreData
],
request
:
RerankRequest
|
ScoreRequest
,
request_id
:
str
,
lora_request
:
LoRARequest
|
None
=
None
,
trace_headers
:
Mapping
[
str
,
str
]
|
None
=
None
,
)
->
list
[
PoolingRequestOutput
]
|
ErrorResponse
:
"""
Late interaction scoring (ColBERT MaxSim).
Encodes queries and documents into per-token embeddings, then computes
MaxSim: sum over query tokens of max similarity to any document token.
"""
model_config
=
self
.
model_config
tokenizer
=
self
.
renderer
.
get_tokenizer
()
tokenization_kwargs
=
request
.
build_tok_params
(
model_config
).
get_encode_kwargs
()
all_data
=
data_1
+
data_2
roles
=
[
"query"
]
*
len
(
data_1
)
+
[
"document"
]
*
len
(
data_2
)
preprocess_async
=
make_async
(
self
.
_preprocess_late_interaction_item
,
executor
=
self
.
_tokenizer_executor
,
)
tok_prompts
=
await
asyncio
.
gather
(
*
(
preprocess_async
(
data
=
d
,
role
=
r
,
request
=
request
,
tokenizer
=
tokenizer
,
tokenization_kwargs
=
tokenization_kwargs
,
)
for
d
,
r
in
zip
(
all_data
,
roles
)
)
)
query_prompts
=
tok_prompts
[:
len
(
data_1
)]
doc_prompts
=
tok_prompts
[
len
(
data_1
)
:]
default_pooling_params
=
request
.
to_pooling_params
(
"token_embed"
)
# stage 1: encode queries and cache token embeddings on workers.
query_keys
=
[
f
"
{
request_id
}
-query-
{
i
}
"
for
i
in
range
(
len
(
query_prompts
))]
query_uses
=
[
len
(
doc_prompts
)
if
len
(
query_prompts
)
==
1
else
1
]
*
len
(
query_prompts
)
query_generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
for
i
,
tok_prompt
in
enumerate
(
query_prompts
):
request_id_item
=
f
"
{
request_id
}
-query-
{
i
}
"
pooling_params
=
default_pooling_params
.
clone
()
pooling_params
.
late_interaction_params
=
(
build_late_interaction_query_params
(
query_key
=
query_keys
[
i
],
query_uses
=
query_uses
[
i
],
)
)
self
.
_log_inputs
(
request_id_item
,
tok_prompt
,
params
=
pooling_params
,
lora_request
=
lora_request
,
)
query_generators
.
append
(
self
.
engine_client
.
encode
(
tok_prompt
,
pooling_params
,
request_id_item
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
)
)
query_outputs
:
list
[
PoolingRequestOutput
|
None
]
=
[
None
]
*
len
(
query_prompts
)
if
query_generators
:
async
for
i
,
res
in
merge_async_iterators
(
*
query_generators
):
query_outputs
[
i
]
=
res
assert
all
(
res
is
not
None
for
res
in
query_outputs
)
query_results
=
[
res
for
res
in
query_outputs
if
res
is
not
None
]
# stage 2: encode docs and return scalar scores from workers.
doc_generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
for
i
,
tok_prompt
in
enumerate
(
doc_prompts
):
request_id_item
=
f
"
{
request_id
}
-doc-
{
i
}
"
query_idx
=
0
if
len
(
query_prompts
)
==
1
else
i
pooling_params
=
default_pooling_params
.
clone
()
pooling_params
.
late_interaction_params
=
build_late_interaction_doc_params
(
query_key
=
query_keys
[
query_idx
]
)
self
.
_log_inputs
(
request_id_item
,
tok_prompt
,
params
=
pooling_params
,
lora_request
=
lora_request
,
)
doc_generators
.
append
(
self
.
engine_client
.
encode
(
tok_prompt
,
pooling_params
,
request_id_item
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
)
)
doc_outputs
:
list
[
PoolingRequestOutput
|
None
]
=
[
None
]
*
len
(
doc_prompts
)
if
doc_generators
:
async
for
i
,
res
in
merge_async_iterators
(
*
doc_generators
):
doc_outputs
[
i
]
=
res
assert
all
(
res
is
not
None
for
res
in
doc_outputs
)
doc_results
=
[
res
for
res
in
doc_outputs
if
res
is
not
None
]
scores
:
list
[
PoolingRequestOutput
]
=
[]
padding
:
list
[
int
]
=
[]
if
(
pad_token_id
:
=
tokenizer
.
pad_token_id
)
is
not
None
:
padding
=
[
pad_token_id
]
if
len
(
query_results
)
==
1
:
query_results
=
query_results
*
len
(
doc_results
)
for
query_result
,
doc_result
in
zip
(
query_results
,
doc_results
):
tokens
=
(
query_result
.
prompt_token_ids
+
padding
+
doc_result
.
prompt_token_ids
)
scores
.
append
(
PoolingRequestOutput
(
request_id
=
f
"
{
query_result
.
request_id
}
_
{
doc_result
.
request_id
}
"
,
outputs
=
doc_result
.
outputs
,
prompt_token_ids
=
tokens
,
num_cached_tokens
=
(
query_result
.
num_cached_tokens
+
doc_result
.
num_cached_tokens
),
finished
=
True
,
)
)
return
scores
async
def
_cross_encoding_score
(
self
,
data_1
:
list
[
ScoreData
],
data_2
:
list
[
ScoreData
],
request
:
RerankRequest
|
ScoreRequest
,
request_id
:
str
,
lora_request
:
LoRARequest
|
None
|
None
=
None
,
trace_headers
:
Mapping
[
str
,
str
]
|
None
=
None
,
)
->
list
[
PoolingRequestOutput
]
|
ErrorResponse
:
tokenizer
=
self
.
renderer
.
get_tokenizer
()
if
is_mistral_tokenizer
(
tokenizer
):
raise
ValueError
(
"MistralTokenizer not supported for cross-encoding"
)
model_config
=
self
.
model_config
if
len
(
data_1
)
==
1
:
data_1
=
data_1
*
len
(
data_2
)
tok_kwargs
=
request
.
build_tok_params
(
model_config
).
get_encode_kwargs
()
input_pairs
=
[(
t1
,
t2
)
for
t1
,
t2
in
zip
(
data_1
,
data_2
)]
preprocess_async
=
make_async
(
self
.
_preprocess_score
,
executor
=
self
.
_tokenizer_executor
,
)
preprocessed_prompts
=
await
asyncio
.
gather
(
*
(
preprocess_async
(
request
=
request
,
tokenizer
=
tokenizer
,
tokenization_kwargs
=
tok_kwargs
,
data_1
=
t1
,
data_2
=
t2
,
)
for
t1
,
t2
in
input_pairs
)
)
# Schedule the request and get the result generator.
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
default_pooling_params
=
request
.
to_pooling_params
(
"classify"
)
for
i
,
(
full_prompt
,
tok_prompt
)
in
enumerate
(
preprocessed_prompts
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
self
.
_log_inputs
(
request_id_item
,
full_prompt
,
params
=
default_pooling_params
,
lora_request
=
lora_request
,
)
if
token_type_ids
:
=
tok_prompt
.
pop
(
"token_type_ids"
,
None
):
pooling_params
=
default_pooling_params
.
clone
()
compressed
=
compress_token_type_ids
(
token_type_ids
)
pooling_params
.
extra_kwargs
=
{
"compressed_token_type_ids"
:
compressed
}
else
:
pooling_params
=
default_pooling_params
generator
=
self
.
engine_client
.
encode
(
tok_prompt
,
pooling_params
,
request_id_item
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
)
generators
.
append
(
generator
)
result_generator
=
merge_async_iterators
(
*
generators
)
# Non-streaming response
final_res_batch
:
list
[
PoolingRequestOutput
|
None
]
=
[
None
]
*
len
(
preprocessed_prompts
)
async
for
i
,
res
in
result_generator
:
final_res_batch
[
i
]
=
res
return
[
out
for
out
in
final_res_batch
if
out
is
not
None
]
def
_preprocess_score
(
self
,
request
:
RerankRequest
|
ScoreRequest
,
tokenizer
:
TokenizerLike
,
tokenization_kwargs
:
dict
[
str
,
Any
],
data_1
:
ScoreData
,
data_2
:
ScoreData
,
)
->
tuple
[
str
,
TokensPrompt
]:
model_config
=
self
.
model_config
full_prompt
,
engine_input
=
get_score_prompt
(
model_config
=
model_config
,
data_1
=
data_1
,
data_2
=
data_2
,
tokenizer
=
tokenizer
,
tokenization_kwargs
=
tokenization_kwargs
,
score_template
=
self
.
score_template
,
)
self
.
_validate_input
(
request
,
engine_input
[
"prompt_token_ids"
],
full_prompt
)
if
request
.
mm_processor_kwargs
is
not
None
:
engine_input
[
"mm_processor_kwargs"
]
=
request
.
mm_processor_kwargs
return
full_prompt
,
engine_input
async
def
_run_scoring
(
self
,
data_1
:
ScoreInputs
,
data_2
:
ScoreInputs
,
request
:
ScoreRequest
|
RerankRequest
,
request_id
:
str
,
raw_request
:
Request
|
None
=
None
,
)
->
list
[
PoolingRequestOutput
]
|
ErrorResponse
:
lora_request
=
self
.
_maybe_get_adapters
(
request
)
trace_headers
=
(
None
if
raw_request
is
None
else
await
self
.
_get_trace_headers
(
raw_request
.
headers
)
)
score_data_1
,
score_data_2
=
validate_score_input
(
data_1
,
data_2
,
is_multimodal_model
=
self
.
is_multimodal_model
,
architecture
=
self
.
architecture
,
)
return
await
self
.
_score_func
(
data_1
=
score_data_1
,
data_2
=
score_data_2
,
request
=
request
,
request_id
=
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
)
async
def
create_score
(
self
,
request
:
ScoreRequest
,
raw_request
:
Request
|
None
=
None
,
)
->
ScoreResponse
|
ErrorResponse
:
"""
Score API similar to Sentence Transformers cross encoder
See https://sbert.net/docs/package_reference/cross_encoder
"""
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
request_id
=
f
"score-
{
self
.
_base_request_id
(
raw_request
)
}
"
created_time
=
int
(
time
.
time
())
try
:
final_res_batch
=
await
self
.
_run_scoring
(
request
.
data_1
,
request
.
data_2
,
request
,
request_id
,
raw_request
,
)
if
isinstance
(
final_res_batch
,
ErrorResponse
):
return
final_res_batch
return
self
.
request_output_to_score_response
(
final_res_batch
,
request_id
,
created_time
,
self
.
models
.
model_name
(),
)
except
asyncio
.
CancelledError
:
return
self
.
create_error_response
(
"Client disconnected"
)
async
def
do_rerank
(
self
,
request
:
RerankRequest
,
raw_request
:
Request
|
None
=
None
)
->
RerankResponse
|
ErrorResponse
:
"""
Rerank API based on JinaAI's rerank API; implements the same
API interface. Designed for compatibility with off-the-shelf
tooling, since this is a common standard for reranking APIs
See example client implementations at
https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
numerous clients use this standard.
"""
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
request_id
=
f
"rerank-
{
self
.
_base_request_id
(
raw_request
)
}
"
documents
=
request
.
documents
try
:
final_res_batch
=
await
self
.
_run_scoring
(
request
.
query
,
documents
,
request
,
request_id
,
raw_request
,
)
if
isinstance
(
final_res_batch
,
ErrorResponse
):
return
final_res_batch
top_n
=
request
.
top_n
if
request
.
top_n
>
0
else
len
(
final_res_batch
)
return
self
.
request_output_to_rerank_response
(
final_res_batch
,
request_id
,
self
.
models
.
model_name
(),
documents
,
top_n
,
)
except
asyncio
.
CancelledError
:
return
self
.
create_error_response
(
"Client disconnected"
)
def
request_output_to_score_response
(
self
,
final_res_batch
:
list
[
PoolingRequestOutput
],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
)
->
ScoreResponse
:
items
:
list
[
ScoreResponseData
]
=
[]
num_prompt_tokens
=
0
for
idx
,
final_res
in
enumerate
(
final_res_batch
):
classify_res
=
ScoringRequestOutput
.
from_base
(
final_res
)
item
=
ScoreResponseData
(
index
=
idx
,
score
=
classify_res
.
outputs
.
score
,
)
prompt_token_ids
=
final_res
.
prompt_token_ids
items
.
append
(
item
)
num_prompt_tokens
+=
len
(
prompt_token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
total_tokens
=
num_prompt_tokens
,
)
return
ScoreResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
data
=
items
,
usage
=
usage
,
)
def
request_output_to_rerank_response
(
self
,
final_res_batch
:
list
[
PoolingRequestOutput
],
request_id
:
str
,
model_name
:
str
,
documents
:
ScoreInputs
,
top_n
:
int
,
)
->
RerankResponse
:
"""
Convert the output of do_rank to a RerankResponse
"""
if
not
isinstance
(
documents
,
list
):
documents
=
[
documents
]
results
:
list
[
RerankResult
]
=
[]
num_prompt_tokens
=
0
for
idx
,
final_res
in
enumerate
(
final_res_batch
):
classify_res
=
ScoringRequestOutput
.
from_base
(
final_res
)
document
=
documents
[
idx
]
if
isinstance
(
document
,
str
):
rerank_document
=
RerankDocument
(
text
=
document
)
else
:
rerank_document
=
RerankDocument
(
multi_modal
=
document
.
get
(
"content"
,
[])
)
result
=
RerankResult
(
index
=
idx
,
document
=
rerank_document
,
relevance_score
=
classify_res
.
outputs
.
score
,
)
results
.
append
(
result
)
prompt_token_ids
=
final_res
.
prompt_token_ids
num_prompt_tokens
+=
len
(
prompt_token_ids
)
# sort by relevance, then return the top n if set
results
.
sort
(
key
=
lambda
x
:
x
.
relevance_score
,
reverse
=
True
)
if
top_n
<
len
(
documents
):
results
=
results
[:
top_n
]
return
RerankResponse
(
id
=
request_id
,
model
=
model_name
,
results
=
results
,
usage
=
RerankUsage
(
total_tokens
=
num_prompt_tokens
,
prompt_tokens
=
num_prompt_tokens
),
)
vllm/entrypoints/pooling/scor
e
/__init__.py
→
vllm/entrypoints/pooling/scor
ing
/__init__.py
View file @
d9d21eb8
File moved
vllm/entrypoints/pooling/scor
e
/api_router.py
→
vllm/entrypoints/pooling/scor
ing
/api_router.py
View file @
d9d21eb8
...
@@ -3,21 +3,15 @@
...
@@ -3,21 +3,15 @@
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
fastapi
import
APIRouter
,
Depends
,
Request
from
fastapi
import
APIRouter
,
Depends
,
Request
from
fastapi.responses
import
JSONResponse
from
typing_extensions
import
assert_never
from
vllm.entrypoints.openai.engine.protocol
import
ErrorResponse
from
vllm.entrypoints.openai.engine.protocol
import
ErrorResponse
from
vllm.entrypoints.openai.utils
import
validate_json_request
from
vllm.entrypoints.openai.utils
import
validate_json_request
from
vllm.entrypoints.pooling.score.protocol
import
(
RerankRequest
,
RerankResponse
,
ScoreRequest
,
ScoreResponse
,
)
from
vllm.entrypoints.pooling.score.serving
import
ServingScores
from
vllm.entrypoints.utils
import
load_aware_call
,
with_cancellation
from
vllm.entrypoints.utils
import
load_aware_call
,
with_cancellation
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
.protocol
import
RerankRequest
,
ScoreRequest
from
.serving
import
ServingScores
router
=
APIRouter
()
router
=
APIRouter
()
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -46,16 +40,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
...
@@ -46,16 +40,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
if
handler
is
None
:
if
handler
is
None
:
raise
NotImplementedError
(
"The model does not support Score API"
)
raise
NotImplementedError
(
"The model does not support Score API"
)
generator
=
await
handler
.
create_score
(
request
,
raw_request
)
return
await
handler
(
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
error
.
code
)
elif
isinstance
(
generator
,
ScoreResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
())
assert_never
(
generator
)
@
router
.
post
(
@
router
.
post
(
...
@@ -92,16 +77,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
...
@@ -92,16 +77,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
if
handler
is
None
:
if
handler
is
None
:
raise
NotImplementedError
(
"The model does not support Rerank (Score) API"
)
raise
NotImplementedError
(
"The model does not support Rerank (Score) API"
)
generator
=
await
handler
.
do_rerank
(
request
,
raw_request
)
return
await
handler
(
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
error
.
code
)
elif
isinstance
(
generator
,
RerankResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
())
assert_never
(
generator
)
@
router
.
post
(
@
router
.
post
(
...
...
vllm/entrypoints/pooling/scoring/io_processor.py
0 → 100644
View file @
d9d21eb8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
time
from
collections.abc
import
Sequence
from
typing
import
Any
,
TypeAlias
,
cast
import
torch.nn.functional
as
F
from
vllm
import
PoolingParams
,
PoolingRequestOutput
,
TokensPrompt
from
vllm.entrypoints.pooling.base.io_processor
import
PoolingIOProcessor
from
vllm.entrypoints.pooling.typing
import
(
OfflineInputsContext
,
OfflineOutputsContext
,
PoolingServeContext
,
)
from
vllm.inputs
import
EngineInput
from
vllm.renderers
import
TokenizeParams
from
vllm.renderers.hf
import
safe_apply_chat_template
from
vllm.tasks
import
PoolingTask
,
ScoreType
from
vllm.utils.mistral
import
is_mistral_tokenizer
from
...chat_utils
import
ChatTemplateResolutionError
from
.protocol
import
RerankRequest
,
ScoreRequest
,
ScoringRequest
from
.typing
import
ScoreData
,
ScoreInput
,
ScoringData
from
.utils
import
(
compress_token_type_ids
,
compute_maxsim_score
,
parse_score_data
,
score_data_to_prompts
,
validate_score_input
,
)
ScoringServeContext
:
TypeAlias
=
PoolingServeContext
[
ScoringRequest
]
class
ScoringIOProcessor
(
PoolingIOProcessor
):
name
:
ScoreType
pooling_task
:
PoolingTask
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
tokenizer
=
self
.
renderer
.
get_tokenizer
()
self
.
architecture
=
self
.
model_config
.
architecture
self
.
is_multimodal_model
=
self
.
model_config
.
is_multimodal_model
self
.
pad_token_id
=
self
.
tokenizer
.
pad_token_id
def
create_pooling_params
(
self
,
request
):
return
request
.
to_pooling_params
(
self
.
pooling_task
)
def
valid_inputs
(
self
,
data_1
:
ScoreInput
|
list
[
ScoreInput
],
data_2
:
ScoreInput
|
list
[
ScoreInput
],
)
->
ScoringData
:
scoring_data
=
validate_score_input
(
data_1
,
data_2
,
is_multimodal_model
=
self
.
is_multimodal_model
,
architecture
=
self
.
architecture
,
)
return
scoring_data
class
BiEncoderIOProcessor
(
ScoringIOProcessor
):
name
:
ScoreType
=
"bi-encoder"
pooling_task
:
PoolingTask
=
"embed"
#######################################
# online APIs
def
pre_process_online
(
self
,
ctx
:
ScoringServeContext
):
request
=
ctx
.
request
if
isinstance
(
request
,
ScoreRequest
):
data_1
=
request
.
data_1
data_2
=
request
.
data_2
elif
isinstance
(
request
,
RerankRequest
):
data_1
=
request
.
query
data_2
=
request
.
documents
else
:
raise
ValueError
(
f
"Invalid
{
self
.
name
}
request type"
)
scoring_data
=
self
.
valid_inputs
(
data_1
,
data_2
)
tok_params
=
request
.
build_tok_params
(
self
.
model_config
)
engine_inputs
=
self
.
_pre_process
(
scoring_data
,
tok_params
,
prompt_extras
=
{
k
:
v
for
k
in
(
"mm_processor_kwargs"
,
"cache_salt"
)
if
(
v
:
=
getattr
(
request
,
k
,
None
))
is
not
None
},
)
ctx
.
engine_inputs
=
engine_inputs
ctx
.
intermediates
=
len
(
scoring_data
.
data_1
)
def
post_process_online
(
self
,
ctx
:
ScoringServeContext
,
):
if
ctx
.
final_res_batch
is
None
:
raise
ValueError
(
"Final response batch not available"
)
if
ctx
.
intermediates
is
None
:
raise
ValueError
(
"data_1 len not available"
)
ctx
.
final_res_batch
=
self
.
_post_process
(
outputs
=
ctx
.
final_res_batch
,
offset
=
cast
(
int
,
ctx
.
intermediates
)
)
#######################################
# offline APIs
def
pre_process_offline
(
self
,
ctx
:
OfflineInputsContext
)
->
Sequence
[
EngineInput
]:
assert
isinstance
(
ctx
.
prompts
,
ScoringData
)
tok_params
=
self
.
renderer
.
default_cmpl_tok_params
.
with_kwargs
(
**
(
ctx
.
tokenization_kwargs
or
{})
)
return
self
.
_pre_process
(
ctx
.
prompts
,
tok_params
)
def
post_process_offline
(
self
,
ctx
:
OfflineOutputsContext
,
)
->
list
[
PoolingRequestOutput
]:
assert
ctx
.
offset
is
not
None
return
self
.
_post_process
(
outputs
=
ctx
.
outputs
,
offset
=
ctx
.
offset
)
#######################################
# helpers
def
_pre_process
(
self
,
scoring_data
:
ScoringData
,
tok_params
:
TokenizeParams
,
prompt_extras
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
Sequence
[
EngineInput
]:
data_1
=
score_data_to_prompts
(
scoring_data
.
data_1
,
"query"
,
self
.
model_config
)
data_2
=
score_data_to_prompts
(
scoring_data
.
data_2
,
"document"
,
self
.
model_config
)
return
self
.
_preprocess_completion_offline
(
prompts
=
data_1
+
data_2
,
tok_params
=
tok_params
,
prompt_extras
=
prompt_extras
)
def
_post_process
(
self
,
outputs
:
list
[
PoolingRequestOutput
],
offset
:
int
):
emb_data_1
=
outputs
[:
offset
]
emb_data_2
=
outputs
[
offset
:]
if
len
(
emb_data_1
)
==
1
:
emb_data_1
=
emb_data_1
*
len
(
emb_data_2
)
final_res_batch
:
list
[
PoolingRequestOutput
]
=
[]
for
emb_1
,
emb_2
in
zip
(
emb_data_1
,
emb_data_2
):
pair_score
=
F
.
cosine_similarity
(
emb_1
.
outputs
.
data
.
float
(),
emb_2
.
outputs
.
data
.
float
(),
dim
=
0
)
padding
:
list
[
int
]
=
[]
if
self
.
pad_token_id
is
not
None
:
padding
=
[
self
.
pad_token_id
]
tokens
=
emb_1
.
prompt_token_ids
+
padding
+
emb_2
.
prompt_token_ids
final_res_batch
.
append
(
PoolingRequestOutput
(
request_id
=
f
"
{
emb_1
.
request_id
}
_
{
emb_2
.
request_id
}
"
,
outputs
=
pair_score
,
prompt_token_ids
=
tokens
,
num_cached_tokens
=
emb_1
.
num_cached_tokens
+
emb_2
.
num_cached_tokens
,
finished
=
True
,
)
)
return
final_res_batch
class
LateInteractionIOProcessor
(
BiEncoderIOProcessor
):
name
:
ScoreType
=
"late-interaction"
pooling_task
:
PoolingTask
=
"token_embed"
def
_post_process
(
self
,
outputs
:
list
[
PoolingRequestOutput
],
offset
:
int
):
# Split into query and document embeddings
emb_data_1
=
outputs
[:
offset
]
emb_data_2
=
outputs
[
offset
:]
# Expand queries if 1:N scoring
if
len
(
emb_data_1
)
==
1
:
emb_data_1
=
emb_data_1
*
len
(
emb_data_2
)
final_res_batch
:
list
[
PoolingRequestOutput
]
=
[]
padding
:
list
[
int
]
=
[]
if
(
pad_token_id
:
=
self
.
pad_token_id
)
is
not
None
:
padding
=
[
pad_token_id
]
# Compute MaxSim scores
for
emb_1
,
emb_2
in
zip
(
emb_data_1
,
emb_data_2
):
# emb_1.outputs.data: [query_len, dim]
# emb_2.outputs.data: [doc_len, dim]
q_emb
=
emb_1
.
outputs
.
data
d_emb
=
emb_2
.
outputs
.
data
maxsim_score
=
compute_maxsim_score
(
q_emb
,
d_emb
)
tokens
=
emb_1
.
prompt_token_ids
+
padding
+
emb_2
.
prompt_token_ids
final_res_batch
.
append
(
PoolingRequestOutput
(
request_id
=
f
"
{
emb_1
.
request_id
}
_
{
emb_2
.
request_id
}
"
,
outputs
=
maxsim_score
,
prompt_token_ids
=
tokens
,
num_cached_tokens
=
emb_1
.
num_cached_tokens
+
emb_2
.
num_cached_tokens
,
finished
=
True
,
)
)
return
final_res_batch
class
CrossEncoderIOProcessor
(
ScoringIOProcessor
):
name
:
ScoreType
=
"cross-encoder"
pooling_task
:
PoolingTask
=
"classify"
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
if
is_mistral_tokenizer
(
self
.
tokenizer
):
raise
ValueError
(
"MistralTokenizer not supported for cross-encoding"
)
from
vllm.model_executor.model_loader
import
get_model_cls
from
vllm.model_executor.models.interfaces
import
supports_score_template
model
=
get_model_cls
(
self
.
model_config
)
self
.
supports_score_template
=
supports_score_template
(
model
)
self
.
model
=
model
if
self
.
supports_score_template
else
None
self
.
use_sep_token
=
self
.
model_config
.
use_sep_token
#######################################
# online APIs
def
pre_process_online
(
self
,
ctx
:
ScoringServeContext
):
request
=
ctx
.
request
if
isinstance
(
request
,
ScoreRequest
):
data_1
=
request
.
data_1
data_2
=
request
.
data_2
elif
isinstance
(
request
,
RerankRequest
):
data_1
=
request
.
query
data_2
=
request
.
documents
else
:
raise
ValueError
(
f
"Invalid
{
self
.
name
}
request type"
)
scoring_data
=
self
.
valid_inputs
(
data_1
,
data_2
)
tok_params
=
request
.
build_tok_params
(
self
.
model_config
)
pooling_params
=
self
.
create_pooling_params
(
request
)
engine_inputs
,
pooling_params_list
=
self
.
_pre_process
(
scoring_data
,
tok_params
,
pooling_params
,
chat_template
=
self
.
chat_template
,
prompt_extras
=
{
k
:
v
for
k
in
(
"mm_processor_kwargs"
,
"cache_salt"
)
if
(
v
:
=
getattr
(
request
,
k
,
None
))
is
not
None
},
)
ctx
.
engine_inputs
=
engine_inputs
ctx
.
pooling_params
=
pooling_params_list
#######################################
# offline APIs
def
pre_process_offline
(
self
,
ctx
:
OfflineInputsContext
)
->
Sequence
[
EngineInput
]:
assert
isinstance
(
ctx
.
prompts
,
ScoringData
)
assert
not
isinstance
(
ctx
.
pooling_params
,
list
)
tok_params
=
self
.
renderer
.
default_cmpl_tok_params
.
with_kwargs
(
**
(
ctx
.
tokenization_kwargs
or
{})
)
engine_inputs
,
pooling_params_list
=
self
.
_pre_process
(
ctx
.
prompts
,
tok_params
,
ctx
.
pooling_params
,
ctx
.
chat_template
)
ctx
.
pooling_params
=
pooling_params_list
return
engine_inputs
#######################################
# helpers
def
_pre_process
(
self
,
scoring_data
:
ScoringData
,
tok_params
:
TokenizeParams
,
pooling_params
:
PoolingParams
|
None
,
chat_template
:
str
|
None
=
None
,
prompt_extras
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
tuple
[
Sequence
[
EngineInput
],
list
[
PoolingParams
]]:
# todo: support prompt_extras
arrival_time
=
time
.
time
()
data_1
=
scoring_data
.
data_1
data_2
=
scoring_data
.
data_2
if
len
(
data_1
)
==
1
:
data_1
=
data_1
*
len
(
data_2
)
if
pooling_params
is
None
:
pooling_params
=
PoolingParams
(
task
=
"classify"
)
pooling_params_list
=
list
[
PoolingParams
]()
engine_inputs
=
list
[
EngineInput
]()
for
q
,
d
in
zip
(
data_1
,
data_2
):
_
,
engine_prompt
=
self
.
get_score_prompt
(
data_1
=
q
,
data_2
=
d
,
encode_kwargs
=
tok_params
.
get_encode_kwargs
(),
chat_template
=
chat_template
,
)
if
token_type_ids
:
=
engine_prompt
.
pop
(
"token_type_ids"
,
None
):
params
=
pooling_params
.
clone
()
compressed
=
compress_token_type_ids
(
token_type_ids
)
params
.
extra_kwargs
=
{
"compressed_token_type_ids"
:
compressed
}
pooling_params_list
.
append
(
params
)
else
:
pooling_params_list
.
append
(
pooling_params
)
tok_params
.
apply_post_tokenization
(
self
.
tokenizer
,
engine_prompt
)
engine_inputs
.
append
(
self
.
renderer
.
process_for_engine
(
engine_prompt
,
arrival_time
)
)
return
engine_inputs
,
pooling_params_list
def
get_score_prompt
(
self
,
data_1
:
ScoreData
,
data_2
:
ScoreData
,
encode_kwargs
:
dict
[
str
,
Any
],
chat_template
:
str
|
None
=
None
,
):
model_config
=
self
.
model_config
tokenizer
=
self
.
tokenizer
prompt_1
,
prompt_2
,
mm_data
,
mm_uuids
=
parse_score_data
(
data_1
,
data_2
,
model_config
,
)
def
default_tokenizer_encode
():
if
self
.
supports_score_template
:
assert
self
.
model
is
not
None
full_prompt
=
self
.
model
.
get_score_template
(
prompt_1
,
prompt_2
)
if
full_prompt
is
None
:
raise
ValueError
(
"Get empty score template from model"
)
prompt_inputs
=
tokenizer
(
full_prompt
,
**
encode_kwargs
)
else
:
if
self
.
use_sep_token
:
# cross_encoder models defaults to using separating token.
prompt_inputs
=
tokenizer
(
text
=
prompt_1
,
text_pair
=
prompt_2
,
**
encode_kwargs
)
full_prompt
=
tokenizer
.
decode
(
prompt_inputs
[
"input_ids"
])
else
:
# `llm as reranker` defaults to not using separating token.
full_prompt
=
prompt_1
+
prompt_2
prompt_inputs
=
tokenizer
(
text
=
full_prompt
,
**
encode_kwargs
)
return
full_prompt
,
prompt_inputs
# FIXME: For now, we only apply a template when one is explicitly provided.
# We cannot rely on the tokenizer's chat template because many models
# inherit junk templates from their base LLM, which breaks both the models
# and the tests that use them.
if
chat_template
is
None
:
full_prompt
,
prompt_inputs
=
default_tokenizer_encode
()
else
:
# FIXME:
# Try applying a score template from the CLI arg or tokenizer_config.json
# If that fails because there is no such template,
# fall back to the default implementation.
try
:
full_prompt
=
safe_apply_chat_template
(
model_config
,
tokenizer
,
[
{
"role"
:
"query"
,
"content"
:
prompt_1
},
{
"role"
:
"document"
,
"content"
:
prompt_2
},
],
chat_template
=
chat_template
,
tools
=
None
,
tokenize
=
False
,
)
prompt_inputs
=
tokenizer
(
full_prompt
,
**
encode_kwargs
)
except
ChatTemplateResolutionError
:
full_prompt
,
prompt_inputs
=
default_tokenizer_encode
()
engine_prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_inputs
[
"input_ids"
])
if
(
token_type_ids
:
=
prompt_inputs
.
get
(
"token_type_ids"
))
is
not
None
:
engine_prompt
[
"token_type_ids"
]
=
token_type_ids
if
self
.
model
is
not
None
:
self
.
model
.
post_process_tokens
(
engine_prompt
)
if
mm_data
is
not
None
:
engine_prompt
[
"multi_modal_data"
]
=
mm_data
if
mm_uuids
is
not
None
:
engine_prompt
[
"multi_modal_uuids"
]
=
mm_uuids
return
full_prompt
,
engine_prompt
ScoringIOProcessors
:
dict
[
ScoreType
,
type
[
ScoringIOProcessor
]]
=
{
"bi-encoder"
:
BiEncoderIOProcessor
,
"late-interaction"
:
LateInteractionIOProcessor
,
"cross-encoder"
:
CrossEncoderIOProcessor
,
}
vllm/entrypoints/pooling/scor
e
/protocol.py
→
vllm/entrypoints/pooling/scor
ing
/protocol.py
View file @
d9d21eb8
...
@@ -12,15 +12,12 @@ from vllm.entrypoints.pooling.base.protocol import (
...
@@ -12,15 +12,12 @@ from vllm.entrypoints.pooling.base.protocol import (
ClassifyRequestMixin
,
ClassifyRequestMixin
,
PoolingBasicRequestMixin
,
PoolingBasicRequestMixin
,
)
)
from
vllm.entrypoints.pooling.score.utils
import
(
ScoreContentPartParam
,
ScoreInput
,
ScoreInputs
,
)
from
vllm.renderers
import
TokenizeParams
from
vllm.renderers
import
TokenizeParams
from
vllm.tasks
import
PoolingTask
from
vllm.tasks
import
PoolingTask
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
from
.typing
import
ScoreContentPartParam
,
ScoreInput
class
ScoreRequestMixin
(
PoolingBasicRequestMixin
,
ClassifyRequestMixin
):
class
ScoreRequestMixin
(
PoolingBasicRequestMixin
,
ClassifyRequestMixin
):
def
build_tok_params
(
self
,
model_config
:
ModelConfig
)
->
TokenizeParams
:
def
build_tok_params
(
self
,
model_config
:
ModelConfig
)
->
TokenizeParams
:
...
@@ -43,13 +40,13 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
...
@@ -43,13 +40,13 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
class
ScoreDataRequest
(
ScoreRequestMixin
):
class
ScoreDataRequest
(
ScoreRequestMixin
):
data_1
:
ScoreInput
s
data_1
:
ScoreInput
|
list
[
ScoreInput
]
data_2
:
ScoreInput
s
data_2
:
ScoreInput
|
list
[
ScoreInput
]
class
ScoreQueriesDocumentsRequest
(
ScoreRequestMixin
):
class
ScoreQueriesDocumentsRequest
(
ScoreRequestMixin
):
queries
:
ScoreInput
s
queries
:
ScoreInput
|
list
[
ScoreInput
]
documents
:
ScoreInput
s
documents
:
ScoreInput
|
list
[
ScoreInput
]
@
property
@
property
def
data_1
(
self
):
def
data_1
(
self
):
...
@@ -61,8 +58,8 @@ class ScoreQueriesDocumentsRequest(ScoreRequestMixin):
...
@@ -61,8 +58,8 @@ class ScoreQueriesDocumentsRequest(ScoreRequestMixin):
class
ScoreQueriesItemsRequest
(
ScoreRequestMixin
):
class
ScoreQueriesItemsRequest
(
ScoreRequestMixin
):
queries
:
ScoreInput
s
queries
:
ScoreInput
|
list
[
ScoreInput
]
items
:
ScoreInput
s
items
:
ScoreInput
|
list
[
ScoreInput
]
@
property
@
property
def
data_1
(
self
):
def
data_1
(
self
):
...
@@ -74,8 +71,8 @@ class ScoreQueriesItemsRequest(ScoreRequestMixin):
...
@@ -74,8 +71,8 @@ class ScoreQueriesItemsRequest(ScoreRequestMixin):
class
ScoreTextRequest
(
ScoreRequestMixin
):
class
ScoreTextRequest
(
ScoreRequestMixin
):
text_1
:
ScoreInput
s
text_1
:
ScoreInput
|
list
[
ScoreInput
]
text_2
:
ScoreInput
s
text_2
:
ScoreInput
|
list
[
ScoreInput
]
@
property
@
property
def
data_1
(
self
):
def
data_1
(
self
):
...
@@ -96,7 +93,7 @@ ScoreRequest: TypeAlias = (
...
@@ -96,7 +93,7 @@ ScoreRequest: TypeAlias = (
class
RerankRequest
(
PoolingBasicRequestMixin
,
ClassifyRequestMixin
):
class
RerankRequest
(
PoolingBasicRequestMixin
,
ClassifyRequestMixin
):
query
:
ScoreInput
query
:
ScoreInput
documents
:
ScoreInput
s
documents
:
ScoreInput
|
list
[
ScoreInput
]
top_n
:
int
=
Field
(
default_factory
=
lambda
:
0
)
top_n
:
int
=
Field
(
default_factory
=
lambda
:
0
)
def
build_tok_params
(
self
,
model_config
:
ModelConfig
)
->
TokenizeParams
:
def
build_tok_params
(
self
,
model_config
:
ModelConfig
)
->
TokenizeParams
:
...
@@ -118,6 +115,9 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
...
@@ -118,6 +115,9 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
)
)
ScoringRequest
:
TypeAlias
=
ScoreRequest
|
RerankRequest
class
RerankDocument
(
BaseModel
):
class
RerankDocument
(
BaseModel
):
text
:
str
|
None
=
None
text
:
str
|
None
=
None
multi_modal
:
list
[
ScoreContentPartParam
]
|
None
=
None
multi_modal
:
list
[
ScoreContentPartParam
]
|
None
=
None
...
@@ -154,3 +154,6 @@ class ScoreResponse(OpenAIBaseModel):
...
@@ -154,3 +154,6 @@ class ScoreResponse(OpenAIBaseModel):
model
:
str
model
:
str
data
:
list
[
ScoreResponseData
]
data
:
list
[
ScoreResponseData
]
usage
:
UsageInfo
usage
:
UsageInfo
ScoringResponse
:
TypeAlias
=
RerankResponse
|
ScoreResponse
vllm/entrypoints/pooling/scoring/serving.py
0 → 100644
View file @
d9d21eb8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
fastapi.responses
import
JSONResponse
from
vllm.config
import
ModelConfig
from
vllm.entrypoints.chat_utils
import
ChatTemplateConfig
from
vllm.entrypoints.openai.engine.protocol
import
UsageInfo
from
vllm.entrypoints.pooling.base.io_processor
import
PoolingIOProcessor
from
vllm.entrypoints.pooling.base.serving
import
PoolingServing
from
vllm.logger
import
init_logger
from
vllm.outputs
import
PoolingRequestOutput
,
ScoringRequestOutput
from
vllm.renderers
import
BaseRenderer
from
.io_processor
import
ScoringIOProcessors
,
ScoringServeContext
from
.protocol
import
(
RerankDocument
,
RerankRequest
,
RerankResponse
,
RerankResult
,
RerankUsage
,
ScoreRequest
,
ScoreResponse
,
ScoreResponseData
,
)
from
.typing
import
ScoreInput
logger
=
init_logger
(
__name__
)
class
ServingScores
(
PoolingServing
):
request_id_prefix
=
"score"
def
init_io_processor
(
self
,
model_config
:
ModelConfig
,
renderer
:
BaseRenderer
,
chat_template_config
:
ChatTemplateConfig
,
)
->
PoolingIOProcessor
:
score_type
=
model_config
.
score_type
assert
score_type
in
ScoringIOProcessors
processor_cls
=
ScoringIOProcessors
[
score_type
]
return
processor_cls
(
model_config
=
model_config
,
renderer
=
renderer
,
chat_template_config
=
chat_template_config
,
)
async
def
_build_response
(
self
,
ctx
:
ScoringServeContext
,
)
->
JSONResponse
:
final_res_batch
=
ctx
.
final_res_batch
request_id
=
ctx
.
request_id
created_time
=
ctx
.
created_time
model_name
=
self
.
models
.
model_name
()
if
isinstance
(
ctx
.
request
,
ScoreRequest
):
return
self
.
_request_output_to_score_response
(
final_res_batch
,
request_id
,
created_time
,
model_name
,
)
elif
isinstance
(
ctx
.
request
,
RerankRequest
):
return
self
.
_request_output_to_rerank_response
(
final_res_batch
,
request_id
,
model_name
,
ctx
.
request
.
documents
,
ctx
.
request
.
top_n
if
ctx
.
request
.
top_n
>
0
else
len
(
final_res_batch
),
)
else
:
raise
NotImplementedError
(
""
)
def
_request_output_to_score_response
(
self
,
final_res_batch
:
list
[
PoolingRequestOutput
],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
)
->
JSONResponse
:
items
:
list
[
ScoreResponseData
]
=
[]
num_prompt_tokens
=
0
for
idx
,
final_res
in
enumerate
(
final_res_batch
):
classify_res
=
ScoringRequestOutput
.
from_base
(
final_res
)
item
=
ScoreResponseData
(
index
=
idx
,
score
=
classify_res
.
outputs
.
score
,
)
prompt_token_ids
=
final_res
.
prompt_token_ids
items
.
append
(
item
)
num_prompt_tokens
+=
len
(
prompt_token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
total_tokens
=
num_prompt_tokens
,
)
response
=
ScoreResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
data
=
items
,
usage
=
usage
,
)
return
JSONResponse
(
content
=
response
.
model_dump
())
def
_request_output_to_rerank_response
(
self
,
final_res_batch
:
list
[
PoolingRequestOutput
],
request_id
:
str
,
model_name
:
str
,
documents
:
ScoreInput
|
list
[
ScoreInput
],
top_n
:
int
,
)
->
JSONResponse
:
if
not
isinstance
(
documents
,
list
):
documents
=
[
documents
]
results
:
list
[
RerankResult
]
=
[]
num_prompt_tokens
=
0
for
idx
,
final_res
in
enumerate
(
final_res_batch
):
classify_res
=
ScoringRequestOutput
.
from_base
(
final_res
)
document
=
documents
[
idx
]
if
isinstance
(
document
,
str
):
rerank_document
=
RerankDocument
(
text
=
document
)
else
:
rerank_document
=
RerankDocument
(
multi_modal
=
document
.
get
(
"content"
,
[])
)
result
=
RerankResult
(
index
=
idx
,
document
=
rerank_document
,
relevance_score
=
classify_res
.
outputs
.
score
,
)
results
.
append
(
result
)
prompt_token_ids
=
final_res
.
prompt_token_ids
num_prompt_tokens
+=
len
(
prompt_token_ids
)
# sort by relevance, then return the top n if set
results
.
sort
(
key
=
lambda
x
:
x
.
relevance_score
,
reverse
=
True
)
if
top_n
<
len
(
documents
):
results
=
results
[:
top_n
]
response
=
RerankResponse
(
id
=
request_id
,
model
=
model_name
,
results
=
results
,
usage
=
RerankUsage
(
total_tokens
=
num_prompt_tokens
,
prompt_tokens
=
num_prompt_tokens
),
)
return
JSONResponse
(
content
=
response
.
model_dump
())
vllm/entrypoints/pooling/scoring/typing.py
0 → 100644
View file @
d9d21eb8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
TypeAlias
from
typing_extensions
import
Required
,
TypedDict
from
vllm.entrypoints.chat_utils
import
(
ChatCompletionContentPartImageEmbedsParam
,
ChatCompletionContentPartImageParam
,
ChatCompletionContentPartTextParam
,
ChatCompletionContentPartVideoParam
,
)
ScoreContentPartParam
:
TypeAlias
=
(
ChatCompletionContentPartImageParam
|
ChatCompletionContentPartImageEmbedsParam
|
ChatCompletionContentPartTextParam
|
ChatCompletionContentPartVideoParam
)
class
ScoreMultiModalParam
(
TypedDict
,
total
=
False
):
"""
A specialized parameter type for scoring multimodal content
The reasons why don't reuse `CustomChatCompletionMessageParam` directly:
1. Score tasks don't need the 'role' field (user/assistant/system) that's required in chat completions
2. Including chat-specific fields would confuse users about their purpose in scoring
3. This is a more focused interface that only exposes what's needed for scoring
"""
# noqa: E501
content
:
Required
[
list
[
ScoreContentPartParam
]]
"""The multimodal contents"""
# Raw input data with content key in ScoreMultiModalParam.
ScoreInput
=
str
|
ScoreMultiModalParam
# Score data without content key.
ScoreData
=
str
|
list
[
ScoreContentPartParam
]
@
dataclass
class
ScoringData
:
data_1
:
list
[
ScoreData
]
data_2
:
list
[
ScoreData
]
vllm/entrypoints/pooling/scor
e
/utils.py
→
vllm/entrypoints/pooling/scor
ing
/utils.py
View file @
d9d21eb8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
typing
import
Any
,
TypeAlias
,
cast
from
typing
import
cast
import
torch
import
torch
from
torch.nn
import
CosineSimilarity
from
typing_extensions
import
Required
,
TypedDict
from
vllm
import
PromptType
,
TextPrompt
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.entrypoints.chat_utils
import
(
from
vllm.entrypoints.chat_utils
import
(
BaseMultiModalItemTracker
,
BaseMultiModalItemTracker
,
ChatCompletionContentPartImageEmbedsParam
,
ChatCompletionContentPartImageParam
,
ChatCompletionContentPartParam
,
ChatCompletionContentPartParam
,
ChatCompletionContentPartTextParam
,
ChatCompletionContentPartTextParam
,
ChatCompletionContentPartVideoParam
,
ChatTemplateResolutionError
,
ConversationMessage
,
ConversationMessage
,
MultiModalItemTracker
,
MultiModalItemTracker
,
_parse_chat_message_content_parts
,
_parse_chat_message_content_parts
,
)
)
from
vllm.inputs
import
(
from
vllm.inputs
import
MultiModalDataDict
,
MultiModalUUIDDict
MultiModalDataDict
,
MultiModalUUIDDict
,
from
.typing
import
(
PromptType
,
ScoreContentPartParam
,
TextPrompt
,
ScoreData
,
TokensPrompt
,
ScoreInput
,
)
ScoringData
,
from
vllm.model_executor.models.interfaces
import
supports_score_template
from
vllm.outputs
import
PoolingRequestOutput
from
vllm.renderers.hf
import
safe_apply_chat_template
from
vllm.tokenizers
import
TokenizerLike
ScoreContentPartParam
:
TypeAlias
=
(
ChatCompletionContentPartImageParam
|
ChatCompletionContentPartImageEmbedsParam
|
ChatCompletionContentPartTextParam
|
ChatCompletionContentPartVideoParam
)
)
...
@@ -57,72 +42,6 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens
...
@@ -57,72 +42,6 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens
return
token_scores
.
amax
(
dim
=-
1
).
sum
()
return
token_scores
.
amax
(
dim
=-
1
).
sum
()
class
ScoreMultiModalParam
(
TypedDict
,
total
=
False
):
"""
A specialized parameter type for scoring multimodal content
The reasons why don't reuse `CustomChatCompletionMessageParam` directly:
1. Score tasks don't need the 'role' field (user/assistant/system) that's required in chat completions
2. Including chat-specific fields would confuse users about their purpose in scoring
3. This is a more focused interface that only exposes what's needed for scoring
"""
# noqa: E501
content
:
Required
[
list
[
ScoreContentPartParam
]]
"""The multimodal contents"""
# Raw input data with content key in ScoreMultiModalParam.
ScoreInput
=
str
|
ScoreMultiModalParam
ScoreInputs
=
ScoreInput
|
list
[
ScoreInput
]
# Score data without content key.
ScoreData
=
str
|
list
[
ScoreContentPartParam
]
def
_cosine_similarity
(
tokenizer
:
TokenizerLike
,
embed_1
:
list
[
PoolingRequestOutput
],
embed_2
:
list
[
PoolingRequestOutput
],
)
->
list
[
PoolingRequestOutput
]:
scorer
=
CosineSimilarity
(
0
)
scores
:
list
[
PoolingRequestOutput
]
=
[]
for
emb_1
,
emb_2
in
zip
(
embed_1
,
embed_2
):
pair_score
=
scorer
(
emb_1
.
outputs
.
data
,
emb_2
.
outputs
.
data
)
padding
:
list
[
int
]
=
[]
if
(
pad_token_id
:
=
tokenizer
.
pad_token_id
)
is
not
None
:
padding
=
[
pad_token_id
]
tokens
=
emb_1
.
prompt_token_ids
+
padding
+
emb_2
.
prompt_token_ids
scores
.
append
(
PoolingRequestOutput
(
request_id
=
f
"
{
emb_1
.
request_id
}
_
{
emb_2
.
request_id
}
"
,
outputs
=
pair_score
,
prompt_token_ids
=
tokens
,
num_cached_tokens
=
emb_1
.
num_cached_tokens
+
emb_2
.
num_cached_tokens
,
finished
=
True
,
)
)
return
scores
def
_validate_score_input_lens
(
data_1
:
list
[
ScoreData
],
data_2
:
list
[
ScoreData
],
):
len_1
=
len
(
data_1
)
len_2
=
len
(
data_2
)
if
len_1
>
1
and
len_1
!=
len_2
:
raise
ValueError
(
"Input lengths must be either 1:1, 1:N or N:N"
)
if
len_1
==
0
:
raise
ValueError
(
"At least one text element must be given"
)
if
len_2
==
0
:
raise
ValueError
(
"At least one text_pair element must be given"
)
def
_validate_mm_score_input
(
def
_validate_mm_score_input
(
data
:
list
[
ScoreInput
],
data
:
list
[
ScoreInput
],
is_multimodal_model
:
bool
,
is_multimodal_model
:
bool
,
...
@@ -140,12 +59,27 @@ def _validate_mm_score_input(
...
@@ -140,12 +59,27 @@ def _validate_mm_score_input(
return
out
return
out
def
_validate_score_input_lens
(
data_1
:
list
[
ScoreData
],
data_2
:
list
[
ScoreData
],
):
len_1
=
len
(
data_1
)
len_2
=
len
(
data_2
)
if
len_1
>
1
and
len_1
!=
len_2
:
raise
ValueError
(
"Input lengths must be either 1:1, 1:N or N:N"
)
if
len_1
==
0
:
raise
ValueError
(
"At least one text element must be given"
)
if
len_2
==
0
:
raise
ValueError
(
"At least one text_pair element must be given"
)
def
validate_score_input
(
def
validate_score_input
(
data_1
:
ScoreInput
s
,
data_1
:
ScoreInput
|
list
[
ScoreInput
]
,
data_2
:
ScoreInput
s
,
data_2
:
ScoreInput
|
list
[
ScoreInput
]
,
is_multimodal_model
:
bool
,
is_multimodal_model
:
bool
,
architecture
:
str
,
architecture
:
str
,
)
->
tuple
[
list
[
ScoreData
],
list
[
Score
Data
]]
:
)
->
Scoring
Data
:
if
not
isinstance
(
data_1
,
list
):
if
not
isinstance
(
data_1
,
list
):
data_1
=
[
data_1
]
data_1
=
[
data_1
]
...
@@ -155,62 +89,7 @@ def validate_score_input(
...
@@ -155,62 +89,7 @@ def validate_score_input(
score_input_1
=
_validate_mm_score_input
(
data_1
,
is_multimodal_model
,
architecture
)
score_input_1
=
_validate_mm_score_input
(
data_1
,
is_multimodal_model
,
architecture
)
score_input_2
=
_validate_mm_score_input
(
data_2
,
is_multimodal_model
,
architecture
)
score_input_2
=
_validate_mm_score_input
(
data_2
,
is_multimodal_model
,
architecture
)
_validate_score_input_lens
(
score_input_1
,
score_input_2
)
_validate_score_input_lens
(
score_input_1
,
score_input_2
)
return
score_input_1
,
score_input_2
return
ScoringData
(
data_1
=
score_input_1
,
data_2
=
score_input_2
)
def
_ensure_str
(
content
:
list
[
ConversationMessage
])
->
str
:
"""Extract a single string prompt from parsed conversation content."""
assert
len
(
content
)
==
1
prompt
=
content
[
0
][
"content"
]
if
prompt
is
not
None
and
isinstance
(
prompt
,
str
):
return
cast
(
str
,
prompt
)
raise
ValueError
(
f
"Only string content is supported, but got
{
content
}
."
)
def
parse_score_data
(
data_1
:
ScoreData
,
data_2
:
ScoreData
,
model_config
:
ModelConfig
,
)
->
tuple
[
str
,
str
,
MultiModalDataDict
|
None
,
MultiModalUUIDDict
|
None
]:
"""Parse a query-document pair into text prompts and shared multi-modal
data.
Uses a **single** :class:`MultiModalItemTracker` so that multi-modal
items from both inputs are merged into one ``mm_data`` dict. This is
the correct behaviour for cross-encoder scoring, where query and
document are concatenated into a single model prompt.
"""
mm_tracker
=
MultiModalItemTracker
(
model_config
)
content_1
=
_parse_score_content
(
"query"
,
data_1
,
mm_tracker
)
content_2
=
_parse_score_content
(
"document"
,
data_2
,
mm_tracker
)
prompt_1
=
_ensure_str
(
content_1
)
prompt_2
=
_ensure_str
(
content_2
)
mm_items
,
mm_uuids
=
mm_tracker
.
resolve_items
()
return
prompt_1
,
prompt_2
,
mm_items
,
mm_uuids
def
parse_score_data_single
(
data
:
ScoreData
,
role
:
str
,
model_config
:
ModelConfig
,
)
->
tuple
[
str
,
MultiModalDataDict
|
None
,
MultiModalUUIDDict
|
None
]:
"""Parse **one** ScoreData into a text prompt and its own multi-modal
data.
Unlike :func:`parse_score_data`, each call creates an **independent**
:class:`MultiModalItemTracker` so multi-modal items are kept separate.
This is the correct behaviour for late-interaction scoring, where
query and document are encoded independently.
"""
mm_tracker
=
MultiModalItemTracker
(
model_config
)
content
=
_parse_score_content
(
role
,
data
,
mm_tracker
)
prompt
=
_ensure_str
(
content
)
mm_items
,
mm_uuids
=
mm_tracker
.
resolve_items
()
return
prompt
,
mm_items
,
mm_uuids
def
score_data_to_prompts
(
def
score_data_to_prompts
(
...
@@ -243,6 +122,15 @@ def score_data_to_prompts(
...
@@ -243,6 +122,15 @@ def score_data_to_prompts(
return
prompts
return
prompts
def
_ensure_str
(
content
:
list
[
ConversationMessage
])
->
str
:
"""Extract a single string prompt from parsed conversation content."""
assert
len
(
content
)
==
1
prompt
=
content
[
0
][
"content"
]
if
prompt
is
not
None
and
isinstance
(
prompt
,
str
):
return
cast
(
str
,
prompt
)
raise
ValueError
(
f
"Only string content is supported, but got
{
content
}
."
)
def
_parse_score_content
(
def
_parse_score_content
(
role
:
str
,
role
:
str
,
data
:
ScoreData
,
data
:
ScoreData
,
...
@@ -278,113 +166,50 @@ def _parse_score_content(
...
@@ -278,113 +166,50 @@ def _parse_score_content(
return
next
(
iter
(
mm_placeholder_storage
.
values
()))[
0
]
return
next
(
iter
(
mm_placeholder_storage
.
values
()))[
0
]
def
_apply_model_score_template
(
def
parse_score_data_single
(
model_config
:
ModelConfig
,
prompt_1
:
str
,
prompt_2
:
str
data
:
ScoreData
,
)
->
str
:
role
:
str
,
# NOTE(Simon): lazy import to avoid bring in all dependencies (e.g. gguf)
from
vllm.model_executor.model_loader
import
get_model_cls
model
=
get_model_cls
(
model_config
)
if
supports_score_template
(
model
):
full_prompt
=
model
.
get_score_template
(
prompt_1
,
prompt_2
)
if
full_prompt
is
None
:
raise
ValueError
(
"Get empty score template from model"
)
return
full_prompt
raise
ValueError
(
f
"Unsupported model architecture:
{
model_config
.
architecture
}
"
)
def
post_process_tokens
(
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
prompt
:
TokensPrompt
,
)
->
tuple
[
str
,
MultiModalDataDict
|
None
,
MultiModalUUIDDict
|
None
]:
)
->
None
:
"""Parse **one** ScoreData into a text prompt and its own multi-modal
"""
data.
Perform architecture-specific manipulations on the input tokens.
Note:
Unlike :func:`parse_score_data`, each call creates an **independent**
This is an in-place operation.
:class:`MultiModalItemTracker` so multi-modal items are kept separate.
This is the correct behaviour for late-interaction scoring, where
query and document are encoded independently.
"""
"""
# NOTE(Simon): lazy import to avoid bring in all dependencies (e.g. gguf
)
mm_tracker
=
MultiModalItemTracker
(
model_config
)
from
vllm.model_executor.model_loader
import
get_model_cls
content
=
_parse_score_content
(
role
,
data
,
mm_tracker
)
model
=
get_model_cls
(
model_config
)
prompt
=
_ensure_str
(
content
)
if
supports_scor
e_tem
plate
(
model
):
mm_items
,
mm_uuids
=
mm_tracker
.
resolv
e_
i
tem
s
()
model
.
post_process_tokens
(
prompt
)
return
prompt
,
mm_items
,
mm_uuids
def
get_score_prompt
(
def
parse_score_data
(
model_config
:
ModelConfig
,
tokenizer
:
TokenizerLike
,
tokenization_kwargs
:
dict
[
str
,
Any
],
data_1
:
ScoreData
,
data_1
:
ScoreData
,
data_2
:
ScoreData
,
data_2
:
ScoreData
,
score_template
:
str
|
None
=
None
,
model_config
:
ModelConfig
,
)
->
tuple
[
str
,
TokensPrompt
]:
)
->
tuple
[
str
,
str
,
MultiModalDataDict
|
None
,
MultiModalUUIDDict
|
None
]:
prompt_1
,
prompt_2
,
mm_data
,
mm_uuids
=
parse_score_data
(
"""Parse a query-document pair into text prompts and shared multi-modal
data_1
,
data.
data_2
,
model_config
,
)
from
vllm.model_executor.model_loader
import
get_model_cls
model
=
get_model_cls
(
model_config
)
Uses a **single** :class:`MultiModalItemTracker` so that multi-modal
items from both inputs are merged into one ``mm_data`` dict. This is
the correct behaviour for cross-encoder scoring, where query and
document are concatenated into a single model prompt.
"""
mm_tracker
=
MultiModalItemTracker
(
model_config
)
def
default_tokenizer_encode
():
content_1
=
_parse_score_content
(
"query"
,
data_1
,
mm_tracker
)
if
supports_score_template
(
model
):
content_2
=
_parse_score_content
(
"document"
,
data_2
,
mm_tracker
)
full_prompt
=
_apply_model_score_template
(
model_config
,
prompt_1
,
prompt_2
)
prompt_inputs
=
tokenizer
(
full_prompt
,
**
tokenization_kwargs
)
prompt_1
=
_ensure_str
(
content_1
)
else
:
prompt_2
=
_ensure_str
(
content_2
)
if
model_config
.
use_sep_token
:
mm_items
,
mm_uuids
=
mm_tracker
.
resolve_items
()
# cross_encoder models defaults to using separating token.
prompt_inputs
=
tokenizer
(
return
prompt_1
,
prompt_2
,
mm_items
,
mm_uuids
text
=
prompt_1
,
text_pair
=
prompt_2
,
**
tokenization_kwargs
)
full_prompt
=
tokenizer
.
decode
(
prompt_inputs
[
"input_ids"
])
else
:
# `llm as reranker` defaults to not using separating token.
full_prompt
=
prompt_1
+
prompt_2
prompt_inputs
=
tokenizer
(
text
=
full_prompt
,
**
tokenization_kwargs
)
return
full_prompt
,
prompt_inputs
# FIXME: For now, we only apply a template when one is explicitly provided.
# We cannot rely on the tokenizer's chat template because many models
# inherit junk templates from their base LLM, which breaks both the models
# and the tests that use them.
if
score_template
is
None
:
full_prompt
,
prompt_inputs
=
default_tokenizer_encode
()
else
:
# FIXME: Try applying a score template from the CLI arg or tokenizer_config.json
# If that fails because there is no such template,
# fall back to the default implementation.
try
:
full_prompt
=
safe_apply_chat_template
(
model_config
,
tokenizer
,
[
{
"role"
:
"query"
,
"content"
:
prompt_1
},
{
"role"
:
"document"
,
"content"
:
prompt_2
},
],
chat_template
=
score_template
,
tools
=
None
,
tokenize
=
False
,
)
prompt_inputs
=
tokenizer
(
full_prompt
,
**
tokenization_kwargs
)
except
ChatTemplateResolutionError
:
full_prompt
,
prompt_inputs
=
default_tokenizer_encode
()
engine_prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_inputs
[
"input_ids"
])
if
(
token_type_ids
:
=
prompt_inputs
.
get
(
"token_type_ids"
))
is
not
None
:
engine_prompt
[
"token_type_ids"
]
=
token_type_ids
post_process_tokens
(
model_config
,
engine_prompt
)
if
mm_data
is
not
None
:
engine_prompt
[
"multi_modal_data"
]
=
mm_data
if
mm_uuids
is
not
None
:
engine_prompt
[
"multi_modal_uuids"
]
=
mm_uuids
return
full_prompt
,
engine_prompt
def
compress_token_type_ids
(
token_type_ids
:
list
[
int
])
->
int
:
def
compress_token_type_ids
(
token_type_ids
:
list
[
int
])
->
int
:
...
...
vllm/entrypoints/pooling/typing.py
View file @
d9d21eb8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
time
import
time
from
collections.abc
import
AsyncGenerator
from
collections.abc
import
AsyncGenerator
,
Sequence
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
,
Generic
,
TypeAlias
,
TypeVar
from
typing
import
Any
,
Generic
,
TypeAlias
,
TypeVar
from
fastapi
import
Request
from
fastapi
import
Request
from
pydantic
import
ConfigDict
from
pydantic
import
ConfigDict
from
vllm
import
PoolingRequestOutput
from
vllm
import
PoolingParams
,
PoolingRequestOutput
,
PromptType
from
vllm.entrypoints.pooling.classify.protocol
import
(
from
vllm.entrypoints.pooling.classify.protocol
import
(
ClassificationChatRequest
,
ClassificationChatRequest
,
ClassificationCompletionRequest
,
ClassificationCompletionRequest
,
...
@@ -23,15 +23,13 @@ from vllm.entrypoints.pooling.embed.protocol import (
...
@@ -23,15 +23,13 @@ from vllm.entrypoints.pooling.embed.protocol import (
)
)
from
vllm.entrypoints.pooling.pooling.protocol
import
(
from
vllm.entrypoints.pooling.pooling.protocol
import
(
IOProcessorRequest
,
IOProcessorRequest
,
PoolingBytesResponse
,
PoolingChatRequest
,
PoolingChatRequest
,
PoolingCompletionRequest
,
PoolingCompletionRequest
,
PoolingResponse
,
PoolingResponse
,
)
)
from
vllm.entrypoints.pooling.score.protocol
import
(
from
vllm.entrypoints.pooling.scoring.protocol
import
ScoringRequest
,
ScoringResponse
RerankRequest
,
from
vllm.entrypoints.pooling.scoring.typing
import
ScoringData
ScoreRequest
,
ScoreResponse
,
)
from
vllm.inputs
import
EngineInput
from
vllm.inputs
import
EngineInput
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -49,8 +47,7 @@ AnyPoolingRequest: TypeAlias = (
...
@@ -49,8 +47,7 @@ AnyPoolingRequest: TypeAlias = (
PoolingCompletionLikeRequest
PoolingCompletionLikeRequest
|
PoolingChatLikeRequest
|
PoolingChatLikeRequest
|
IOProcessorRequest
|
IOProcessorRequest
|
RerankRequest
|
ScoringRequest
|
ScoreRequest
|
CohereEmbedRequest
|
CohereEmbedRequest
)
)
...
@@ -59,7 +56,8 @@ AnyPoolingResponse: TypeAlias = (
...
@@ -59,7 +56,8 @@ AnyPoolingResponse: TypeAlias = (
|
EmbeddingResponse
|
EmbeddingResponse
|
EmbeddingBytesResponse
|
EmbeddingBytesResponse
|
PoolingResponse
|
PoolingResponse
|
ScoreResponse
|
PoolingBytesResponse
|
ScoringResponse
)
)
PoolingRequestT
=
TypeVar
(
"PoolingRequestT"
,
bound
=
AnyPoolingRequest
)
PoolingRequestT
=
TypeVar
(
"PoolingRequestT"
,
bound
=
AnyPoolingRequest
)
...
@@ -73,8 +71,8 @@ class PoolingServeContext(Generic[PoolingRequestT]):
...
@@ -73,8 +71,8 @@ class PoolingServeContext(Generic[PoolingRequestT]):
request_id
:
str
request_id
:
str
created_time
:
int
=
field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
created_time
:
int
=
field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
lora_request
:
LoRARequest
|
None
=
None
lora_request
:
LoRARequest
|
None
=
None
pooling_params
:
PoolingParams
|
list
[
PoolingParams
]
|
None
=
None
engine_inputs
:
list
[
EngineInput
]
|
None
=
None
engine_inputs
:
Sequence
[
EngineInput
]
|
None
=
None
prompt_request_ids
:
list
[
str
]
|
None
=
None
prompt_request_ids
:
list
[
str
]
|
None
=
None
intermediates
:
Any
|
None
=
None
intermediates
:
Any
|
None
=
None
...
@@ -84,3 +82,22 @@ class PoolingServeContext(Generic[PoolingRequestT]):
...
@@ -84,3 +82,22 @@ class PoolingServeContext(Generic[PoolingRequestT]):
final_res_batch
:
list
[
PoolingRequestOutput
]
=
field
(
default_factory
=
list
)
final_res_batch
:
list
[
PoolingRequestOutput
]
=
field
(
default_factory
=
list
)
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
@
dataclass
class
OfflineInputsContext
:
prompts
:
PromptType
|
Sequence
[
PromptType
]
|
ScoringData
pooling_params
:
PoolingParams
|
list
[
PoolingParams
]
|
None
=
None
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
chat_template
:
str
|
None
=
None
## for bi-encoder & late-interaction
offset
:
int
|
None
=
None
@
dataclass
class
OfflineOutputsContext
:
outputs
:
list
[
PoolingRequestOutput
]
## for bi-encoder & late-interaction
offset
:
int
|
None
=
None
vllm/entrypoints/pooling/utils.py
View file @
d9d21eb8
...
@@ -11,8 +11,10 @@ import pybase64
...
@@ -11,8 +11,10 @@ import pybase64
import
torch
import
torch
from
fastapi.responses
import
JSONResponse
from
fastapi.responses
import
JSONResponse
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
PoolingRequestOutput
from
vllm.outputs
import
PoolingRequestOutput
from
vllm.tasks
import
SupportedTask
from
vllm.utils.serial_utils
import
(
from
vllm.utils.serial_utils
import
(
EMBED_DTYPES
,
EMBED_DTYPES
,
EmbedDType
,
EmbedDType
,
...
@@ -133,3 +135,20 @@ def get_json_response_cls() -> type[JSONResponse]:
...
@@ -133,3 +135,20 @@ def get_json_response_cls() -> type[JSONResponse]:
"To make v1/embeddings API fast, please install orjson by `pip install orjson`"
"To make v1/embeddings API fast, please install orjson by `pip install orjson`"
)
)
return
JSONResponse
return
JSONResponse
def
enable_scoring_api
(
supported_tasks
:
tuple
[
"SupportedTask"
,
...],
model_config
:
ModelConfig
|
None
=
None
,
)
->
bool
:
if
any
(
t
in
supported_tasks
for
t
in
(
"embed"
,
"token_embed"
)):
return
True
if
model_config
is
not
None
and
"classify"
in
supported_tasks
:
num_labels
=
getattr
(
model_config
.
hf_config
,
"num_labels"
,
0
)
if
num_labels
!=
1
:
logger
.
debug_once
(
"Scoring API is only enabled for num_labels == 1."
)
return
False
return
True
return
False
vllm/entrypoints/sagemaker/api_router.py
View file @
d9d21eb8
...
@@ -14,8 +14,8 @@ from vllm.config import ModelConfig
...
@@ -14,8 +14,8 @@ from vllm.config import ModelConfig
from
vllm.entrypoints.openai.engine.protocol
import
ErrorResponse
from
vllm.entrypoints.openai.engine.protocol
import
ErrorResponse
from
vllm.entrypoints.openai.engine.serving
import
OpenAIServing
from
vllm.entrypoints.openai.engine.serving
import
OpenAIServing
from
vllm.entrypoints.openai.utils
import
validate_json_request
from
vllm.entrypoints.openai.utils
import
validate_json_request
from
vllm.entrypoints.pooling
import
enable_scoring_api
from
vllm.entrypoints.pooling.base.serving
import
PoolingServing
from
vllm.entrypoints.pooling.base.serving
import
PoolingServing
from
vllm.entrypoints.pooling.utils
import
enable_scoring_api
from
vllm.entrypoints.serve.instrumentator.basic
import
base
from
vllm.entrypoints.serve.instrumentator.basic
import
base
from
vllm.entrypoints.serve.instrumentator.health
import
health
from
vllm.entrypoints.serve.instrumentator.health
import
health
from
vllm.tasks
import
POOLING_TASKS
,
SupportedTask
from
vllm.tasks
import
POOLING_TASKS
,
SupportedTask
...
@@ -76,15 +76,15 @@ def get_invocation_types(
...
@@ -76,15 +76,15 @@ def get_invocation_types(
]
]
if
enable_scoring_api
(
supported_tasks
,
model_config
):
if
enable_scoring_api
(
supported_tasks
,
model_config
):
from
vllm.entrypoints.pooling.scor
e
.api_router
import
do_rerank
,
rerank
from
vllm.entrypoints.pooling.scor
ing
.api_router
import
do_rerank
,
rerank
from
vllm.entrypoints.pooling.scor
e
.protocol
import
RerankRequest
from
vllm.entrypoints.pooling.scor
ing
.protocol
import
RerankRequest
INVOCATION_TYPES
+=
[
INVOCATION_TYPES
+=
[
(
RerankRequest
,
(
rerank
,
do_rerank
)),
(
RerankRequest
,
(
rerank
,
do_rerank
)),
]
]
from
vllm.entrypoints.pooling.scor
e
.api_router
import
create_score
,
score
from
vllm.entrypoints.pooling.scor
ing
.api_router
import
create_score
,
score
from
vllm.entrypoints.pooling.scor
e
.protocol
import
ScoreRequest
from
vllm.entrypoints.pooling.scor
ing
.protocol
import
ScoreRequest
INVOCATION_TYPES
+=
[
INVOCATION_TYPES
+=
[
(
ScoreRequest
,
(
score
,
create_score
)),
(
ScoreRequest
,
(
score
,
create_score
)),
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment