Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
66c079ae
Unverified
Commit
66c079ae
authored
Apr 09, 2026
by
wang.yuqi
Committed by
GitHub
Apr 09, 2026
Browse files
[Frontend][4/n] Improve pooling entrypoints | pooling. (#39153)
Signed-off-by:
wang.yuqi
<
yuqi.wang@daocloud.io
>
parent
b6c9be50
Changes
43
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
482 additions
and
662 deletions
+482
-662
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+59
-118
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+1
-7
vllm/entrypoints/openai/engine/serving.py
vllm/entrypoints/openai/engine/serving.py
+5
-180
vllm/entrypoints/openai/models/serving.py
vllm/entrypoints/openai/models/serving.py
+0
-1
vllm/entrypoints/pooling/__init__.py
vllm/entrypoints/pooling/__init__.py
+2
-4
vllm/entrypoints/pooling/base/io_processor.py
vllm/entrypoints/pooling/base/io_processor.py
+30
-9
vllm/entrypoints/pooling/base/serving.py
vllm/entrypoints/pooling/base/serving.py
+47
-26
vllm/entrypoints/pooling/classify/io_processor.py
vllm/entrypoints/pooling/classify/io_processor.py
+5
-1
vllm/entrypoints/pooling/classify/serving.py
vllm/entrypoints/pooling/classify/serving.py
+2
-14
vllm/entrypoints/pooling/embed/io_processor.py
vllm/entrypoints/pooling/embed/io_processor.py
+5
-1
vllm/entrypoints/pooling/embed/serving.py
vllm/entrypoints/pooling/embed/serving.py
+10
-18
vllm/entrypoints/pooling/io_processor_factories.py
vllm/entrypoints/pooling/io_processor_factories.py
+41
-14
vllm/entrypoints/pooling/pooling/api_router.py
vllm/entrypoints/pooling/pooling/api_router.py
+4
-26
vllm/entrypoints/pooling/pooling/io_processor.py
vllm/entrypoints/pooling/pooling/io_processor.py
+156
-0
vllm/entrypoints/pooling/pooling/serving.py
vllm/entrypoints/pooling/pooling/serving.py
+101
-223
vllm/entrypoints/pooling/scoring/io_processor.py
vllm/entrypoints/pooling/scoring/io_processor.py
+1
-1
vllm/entrypoints/pooling/scoring/serving.py
vllm/entrypoints/pooling/scoring/serving.py
+5
-12
vllm/entrypoints/pooling/typing.py
vllm/entrypoints/pooling/typing.py
+6
-3
vllm/entrypoints/sagemaker/api_router.py
vllm/entrypoints/sagemaker/api_router.py
+2
-2
vllm/entrypoints/serve/render/serving.py
vllm/entrypoints/serve/render/serving.py
+0
-2
No files found.
vllm/entrypoints/llm.py
View file @
66c079ae
...
@@ -49,9 +49,7 @@ from vllm.entrypoints.chat_utils import (
...
@@ -49,9 +49,7 @@ from vllm.entrypoints.chat_utils import (
load_chat_template
,
load_chat_template
,
)
)
from
vllm.entrypoints.pooling.io_processor_factories
import
init_pooling_io_processors
from
vllm.entrypoints.pooling.io_processor_factories
import
init_pooling_io_processors
from
vllm.entrypoints.pooling.scoring.io_processor
import
(
from
vllm.entrypoints.pooling.scoring.io_processor
import
ScoringIOProcessor
ScoringIOProcessor
,
)
from
vllm.entrypoints.pooling.scoring.typing
import
ScoreInput
from
vllm.entrypoints.pooling.scoring.typing
import
ScoreInput
from
vllm.entrypoints.pooling.typing
import
OfflineInputsContext
,
OfflineOutputsContext
from
vllm.entrypoints.pooling.typing
import
OfflineInputsContext
,
OfflineOutputsContext
from
vllm.entrypoints.utils
import
log_non_default_args
from
vllm.entrypoints.utils
import
log_non_default_args
...
@@ -398,12 +396,11 @@ class LLM:
...
@@ -398,12 +396,11 @@ class LLM:
self
.
runner_type
=
self
.
model_config
.
runner_type
self
.
runner_type
=
self
.
model_config
.
runner_type
self
.
renderer
=
self
.
llm_engine
.
renderer
self
.
renderer
=
self
.
llm_engine
.
renderer
self
.
chat_template
=
load_chat_template
(
chat_template
)
self
.
chat_template
=
load_chat_template
(
chat_template
)
self
.
io_processor
=
self
.
llm_engine
.
io_processor
self
.
input_processor
=
self
.
llm_engine
.
input_processor
self
.
input_processor
=
self
.
llm_engine
.
input_processor
self
.
chat_template_config
=
ChatTemplateConfig
(
chat_template
=
self
.
chat_template
)
self
.
chat_template_config
=
ChatTemplateConfig
(
chat_template
=
self
.
chat_template
)
self
.
pooling_io_processors
=
init_pooling_io_processors
(
self
.
pooling_io_processors
=
init_pooling_io_processors
(
supported_tasks
=
supported_tasks
,
supported_tasks
=
supported_tasks
,
model
_config
=
self
.
model
_config
,
vllm
_config
=
self
.
llm_engine
.
vllm
_config
,
renderer
=
self
.
renderer
,
renderer
=
self
.
renderer
,
chat_template_config
=
self
.
chat_template_config
,
chat_template_config
=
self
.
chat_template_config
,
)
)
...
@@ -1081,118 +1078,55 @@ class LLM:
...
@@ -1081,118 +1078,55 @@ class LLM:
pooled hidden states in the same order as the input prompts.
pooled hidden states in the same order as the input prompts.
"""
"""
if
isinstance
(
prompts
,
dict
)
and
"data"
in
prompts
and
pooling_task
!=
"plugin"
:
raise
ValueError
(
"The 'data' field is only supported for the 'plugin' pooling task."
)
self
.
_verify_pooling_task
(
pooling_task
)
self
.
_verify_pooling_task
(
pooling_task
)
assert
pooling_task
is
not
None
and
pooling_task
in
self
.
pooling_io_processors
if
isinstance
(
prompts
,
dict
)
and
"data"
in
prompts
:
io_processor
=
self
.
pooling_io_processors
[
pooling_task
]
if
self
.
io_processor
is
None
:
raise
ValueError
(
"No IOProcessor plugin installed. Please refer "
"to the documentation and to the "
"'prithvi_geospatial_mae_io_processor' "
"offline inference example for more details."
)
# Validate the request data is valid for the loaded plugin
if
pooling_params
is
None
:
prompt_data
=
prompts
.
get
(
"data"
)
pooling_params
=
PoolingParams
()
if
prompt_data
is
None
:
raise
ValueError
(
"The 'data' field of the prompt is expected to contain "
"the prompt data and it cannot be None. "
"Refer to the documentation of the IOProcessor "
"in use for more details."
)
validated_prompt
=
self
.
io_processor
.
parse_data
(
prompt_data
)
# obtain the actual model prompts from the pre-processor
ctx
=
OfflineInputsContext
(
prompts
=
self
.
io_processor
.
pre_process
(
prompt
=
validated_prompt
)
prompts
=
prompts
,
prompts_seq
=
prompt_to_seq
(
prompts
)
pooling_params
=
pooling_params
,
tokenization_kwargs
=
tokenization_kwargs
,
)
params_seq
:
Sequence
[
PoolingParams
]
=
[
engine_inputs
=
io_processor
.
pre_process_offline
(
ctx
)
self
.
io_processor
.
merge_pooling_params
(
param
)
n_inputs
=
len
(
engine_inputs
)
for
param
in
self
.
_params_to_seq
(
assert
ctx
.
pooling_params
is
not
None
pooling_params
,
len
(
prompts_seq
),
)
]
for
p
in
params_seq
:
if
p
.
task
is
None
:
p
.
task
=
"plugin"
outputs
=
self
.
_run_completion
(
prompts
=
prompts_seq
,
params
=
params_seq
,
output_type
=
PoolingRequestOutput
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
tokenization_kwargs
=
tokenization_kwargs
,
)
# get the post-processed model outputs
params_seq
=
self
.
_params_to_seq
(
ctx
.
pooling_params
,
n_inputs
)
assert
self
.
io_processor
is
not
None
processed_outputs
=
self
.
io_processor
.
post_process
(
outputs
)
return
[
for
param
in
params_seq
:
PoolingRequestOutput
[
Any
](
if
param
.
task
is
None
:
request_id
=
""
,
param
.
task
=
pooling_task
outputs
=
processed_outputs
,
elif
pooling_task
==
"plugin"
:
num_cached_tokens
=
getattr
(
# `plugin` task uses io_processor.parse_request to verify inputs.
processed_outputs
,
"num_cached_tokens"
,
0
# We actually allow plugin to overwrite pooling_task.
),
pass
prompt_token_ids
=
[],
elif
param
.
task
!=
pooling_task
:
finished
=
True
,
msg
=
f
"You cannot overwrite
{
param
.
task
=
!
r
}
with
{
pooling_task
=
!
r
}
!"
)
raise
ValueError
(
msg
)
]
else
:
if
pooling_params
is
None
:
# Use default pooling params.
pooling_params
=
PoolingParams
()
prompts_seq
=
prompt_to_seq
(
prompts
)
params_seq
=
self
.
_params_to_seq
(
pooling_params
,
len
(
prompts_seq
))
for
param
in
params_seq
:
if
param
.
task
is
None
:
param
.
task
=
pooling_task
elif
param
.
task
!=
pooling_task
:
msg
=
(
f
"You cannot overwrite
{
param
.
task
=
!
r
}
with
{
pooling_task
=
!
r
}
!"
)
raise
ValueError
(
msg
)
if
pooling_task
in
self
.
pooling_io_processors
:
seq_lora_requests
=
self
.
_lora_request_to_seq
(
lora_request
,
n_inputs
)
io_processor
=
self
.
pooling_io_processors
[
pooling_task
]
seq_priority
=
self
.
_priority_to_seq
(
None
,
n_inputs
)
processor_inputs
=
io_processor
.
pre_process_offline
(
ctx
=
OfflineInputsContext
(
prompts
=
prompts_seq
,
tokenization_kwargs
=
tokenization_kwargs
)
)
seq_lora_requests
=
self
.
_lora_request_to_seq
(
lora_request
,
len
(
prompts_seq
)
)
seq_priority
=
self
.
_priority_to_seq
(
None
,
len
(
prompts
))
self
.
_render_and_add_requests
(
self
.
_render_and_add_requests
(
prompts
=
processor
_inputs
,
prompts
=
engine
_inputs
,
params
=
params_seq
,
params
=
params_seq
,
lora_requests
=
seq_lora_requests
,
lora_requests
=
seq_lora_requests
,
priorities
=
seq_priority
,
priorities
=
seq_priority
,
)
)
outputs
=
self
.
_run_engine
(
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
,
output_type
=
PoolingRequestOutput
)
use_tqdm
=
use_tqdm
,
output_type
=
PoolingRequestOutput
outputs
=
io_processor
.
post_process_offline
(
)
ctx
=
OfflineOutputsContext
(
outputs
=
outputs
)
outputs
=
io_processor
.
post_process_offline
(
)
ctx
=
OfflineOutputsContext
(
outputs
=
outputs
)
)
else
:
outputs
=
self
.
_run_completion
(
prompts
=
prompts_seq
,
params
=
params_seq
,
output_type
=
PoolingRequestOutput
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
tokenization_kwargs
=
tokenization_kwargs
,
)
return
outputs
return
outputs
def
_verify_pooling_task
(
self
,
pooling_task
:
PoolingTask
|
None
):
def
_verify_pooling_task
(
self
,
pooling_task
:
PoolingTask
|
None
):
...
@@ -1254,6 +1188,14 @@ class LLM:
...
@@ -1254,6 +1188,14 @@ class LLM:
pooling_task
,
pooling_task
,
)
)
if
pooling_task
==
"plugin"
and
"plugin"
not
in
self
.
pooling_io_processors
:
raise
ValueError
(
"No IOProcessor plugin installed. Please refer "
"to the documentation and to the "
"'prithvi_geospatial_mae_io_processor' "
"offline inference example for more details."
)
def
embed
(
def
embed
(
self
,
self
,
prompts
:
PromptType
|
Sequence
[
PromptType
],
prompts
:
PromptType
|
Sequence
[
PromptType
],
...
@@ -1458,6 +1400,9 @@ class LLM:
...
@@ -1458,6 +1400,9 @@ class LLM:
scoring_data
=
io_processor
.
valid_inputs
(
data_1
,
data_2
)
scoring_data
=
io_processor
.
valid_inputs
(
data_1
,
data_2
)
n_queries
=
len
(
scoring_data
.
data_1
)
n_queries
=
len
(
scoring_data
.
data_1
)
if
pooling_params
is
None
:
pooling_params
=
PoolingParams
()
ctx
=
OfflineInputsContext
(
ctx
=
OfflineInputsContext
(
prompts
=
scoring_data
,
prompts
=
scoring_data
,
pooling_params
=
pooling_params
,
pooling_params
=
pooling_params
,
...
@@ -1466,15 +1411,11 @@ class LLM:
...
@@ -1466,15 +1411,11 @@ class LLM:
n_queries
=
n_queries
,
n_queries
=
n_queries
,
)
)
processor_inputs
=
io_processor
.
pre_process_offline
(
ctx
)
engine_inputs
=
io_processor
.
pre_process_offline
(
ctx
)
n_inputs
=
len
(
engine_inputs
)
seq_lora_requests
=
self
.
_lora_request_to_seq
(
seq_lora_requests
=
self
.
_lora_request_to_seq
(
lora_request
,
n_inputs
)
lora_request
,
len
(
processor_inputs
)
params_seq
=
self
.
_params_to_seq
(
ctx
.
pooling_params
,
n_inputs
)
)
if
ctx
.
pooling_params
is
None
:
ctx
.
pooling_params
=
PoolingParams
()
params_seq
=
self
.
_params_to_seq
(
ctx
.
pooling_params
,
len
(
processor_inputs
))
for
param
in
params_seq
:
for
param
in
params_seq
:
if
param
.
task
is
None
:
if
param
.
task
is
None
:
...
@@ -1483,10 +1424,10 @@ class LLM:
...
@@ -1483,10 +1424,10 @@ class LLM:
msg
=
f
"You cannot overwrite
{
param
.
task
=
!
r
}
with
{
pooling_task
=
!
r
}
!"
msg
=
f
"You cannot overwrite
{
param
.
task
=
!
r
}
with
{
pooling_task
=
!
r
}
!"
raise
ValueError
(
msg
)
raise
ValueError
(
msg
)
seq_priority
=
self
.
_priority_to_seq
(
None
,
len
(
processor
_inputs
)
)
seq_priority
=
self
.
_priority_to_seq
(
None
,
n
_inputs
)
self
.
_render_and_add_requests
(
self
.
_render_and_add_requests
(
prompts
=
processor
_inputs
,
prompts
=
engine
_inputs
,
params
=
params_seq
,
params
=
params_seq
,
lora_requests
=
seq_lora_requests
,
lora_requests
=
seq_lora_requests
,
priorities
=
seq_priority
,
priorities
=
seq_priority
,
...
@@ -1579,7 +1520,7 @@ class LLM:
...
@@ -1579,7 +1520,7 @@ class LLM:
if
isinstance
(
params
,
Sequence
):
if
isinstance
(
params
,
Sequence
):
if
len
(
params
)
!=
num_requests
:
if
len
(
params
)
!=
num_requests
:
raise
ValueError
(
raise
ValueError
(
f
"The lengths of prompts (
{
param
s
}
) "
f
"The lengths of prompts (
{
num_request
s
}
) "
f
"and params (
{
len
(
params
)
}
) must be the same."
f
"and params (
{
len
(
params
)
}
) must be the same."
)
)
...
...
vllm/entrypoints/openai/api_server.py
View file @
66c079ae
...
@@ -370,7 +370,6 @@ async def init_app_state(
...
@@ -370,7 +370,6 @@ async def init_app_state(
state
.
openai_serving_render
=
OpenAIServingRender
(
state
.
openai_serving_render
=
OpenAIServingRender
(
model_config
=
engine_client
.
model_config
,
model_config
=
engine_client
.
model_config
,
renderer
=
engine_client
.
renderer
,
renderer
=
engine_client
.
renderer
,
io_processor
=
engine_client
.
io_processor
,
model_registry
=
state
.
openai_serving_models
.
registry
,
model_registry
=
state
.
openai_serving_models
.
registry
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
chat_template
=
resolved_chat_template
,
chat_template
=
resolved_chat_template
,
...
@@ -441,13 +440,12 @@ async def init_render_app_state(
...
@@ -441,13 +440,12 @@ async def init_render_app_state(
Unlike :func:`init_app_state` this function does not require an
Unlike :func:`init_app_state` this function does not require an
:class:`~vllm.engine.protocol.EngineClient`; it bootstraps the
:class:`~vllm.engine.protocol.EngineClient`; it bootstraps the
preprocessing pipeline (renderer,
io_processor,
input_processor)
preprocessing pipeline (renderer, input_processor)
directly from the :class:`~vllm.config.VllmConfig`.
directly from the :class:`~vllm.config.VllmConfig`.
"""
"""
from
vllm.entrypoints.chat_utils
import
load_chat_template
from
vllm.entrypoints.chat_utils
import
load_chat_template
from
vllm.entrypoints.openai.models.serving
import
OpenAIModelRegistry
from
vllm.entrypoints.openai.models.serving
import
OpenAIModelRegistry
from
vllm.entrypoints.serve.render.serving
import
OpenAIServingRender
from
vllm.entrypoints.serve.render.serving
import
OpenAIServingRender
from
vllm.plugins.io_processors
import
get_io_processor
from
vllm.renderers
import
renderer_from_config
from
vllm.renderers
import
renderer_from_config
served_model_names
=
args
.
served_model_name
or
[
args
.
model
]
served_model_names
=
args
.
served_model_name
or
[
args
.
model
]
...
@@ -465,15 +463,11 @@ async def init_render_app_state(
...
@@ -465,15 +463,11 @@ async def init_render_app_state(
request_logger
=
None
request_logger
=
None
renderer
=
renderer_from_config
(
vllm_config
)
renderer
=
renderer_from_config
(
vllm_config
)
io_processor
=
get_io_processor
(
vllm_config
,
renderer
,
vllm_config
.
model_config
.
io_processor_plugin
)
resolved_chat_template
=
load_chat_template
(
args
.
chat_template
)
resolved_chat_template
=
load_chat_template
(
args
.
chat_template
)
state
.
openai_serving_render
=
OpenAIServingRender
(
state
.
openai_serving_render
=
OpenAIServingRender
(
model_config
=
vllm_config
.
model_config
,
model_config
=
vllm_config
.
model_config
,
renderer
=
renderer
,
renderer
=
renderer
,
io_processor
=
io_processor
,
model_registry
=
model_registry
,
model_registry
=
model_registry
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
chat_template
=
resolved_chat_template
,
chat_template
=
resolved_chat_template
,
...
...
vllm/entrypoints/openai/engine/serving.py
View file @
66c079ae
...
@@ -44,12 +44,6 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
...
@@ -44,12 +44,6 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionResponse
,
TranscriptionResponse
,
TranslationRequest
,
TranslationRequest
,
)
)
from
vllm.entrypoints.pooling.pooling.protocol
import
(
IOProcessorRequest
,
PoolingChatRequest
,
PoolingCompletionRequest
,
PoolingResponse
,
)
from
vllm.entrypoints.serve.disagg.protocol
import
GenerateRequest
,
GenerateResponse
from
vllm.entrypoints.serve.disagg.protocol
import
GenerateRequest
,
GenerateResponse
from
vllm.entrypoints.serve.tokenize.protocol
import
(
from
vllm.entrypoints.serve.tokenize.protocol
import
(
DetokenizeRequest
,
DetokenizeRequest
,
...
@@ -62,8 +56,7 @@ from vllm.inputs import EngineInput, PromptType
...
@@ -62,8 +56,7 @@ from vllm.inputs import EngineInput, PromptType
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.logprobs
import
Logprob
,
PromptLogprobs
from
vllm.logprobs
import
Logprob
,
PromptLogprobs
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
CompletionOutput
,
PoolingRequestOutput
,
RequestOutput
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.renderers
import
ChatParams
,
TokenizeParams
from
vllm.renderers
import
ChatParams
,
TokenizeParams
from
vllm.renderers.inputs.preprocess
import
(
from
vllm.renderers.inputs.preprocess
import
(
extract_prompt_components
,
extract_prompt_components
,
...
@@ -78,10 +71,7 @@ from vllm.tracing import (
...
@@ -78,10 +71,7 @@ from vllm.tracing import (
log_tracing_disabled_warning
,
log_tracing_disabled_warning
,
)
)
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
from
vllm.utils.async_utils
import
(
from
vllm.utils.async_utils
import
collect_from_async_generator
collect_from_async_generator
,
merge_async_iterators
,
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -101,17 +91,11 @@ class RendererChatRequest(RendererRequest, Protocol):
...
@@ -101,17 +91,11 @@ class RendererChatRequest(RendererRequest, Protocol):
CompletionLikeRequest
:
TypeAlias
=
(
CompletionLikeRequest
:
TypeAlias
=
(
CompletionRequest
CompletionRequest
|
TokenizeCompletionRequest
|
DetokenizeRequest
|
TokenizeCompletionRequest
|
DetokenizeRequest
|
PoolingCompletionRequest
)
)
ChatLikeRequest
:
TypeAlias
=
(
ChatLikeRequest
:
TypeAlias
=
(
ChatCompletionRequest
ChatCompletionRequest
|
BatchChatCompletionRequest
|
TokenizeChatRequest
|
BatchChatCompletionRequest
|
TokenizeChatRequest
|
PoolingChatRequest
)
)
SpeechToTextRequest
:
TypeAlias
=
TranscriptionRequest
|
TranslationRequest
SpeechToTextRequest
:
TypeAlias
=
TranscriptionRequest
|
TranslationRequest
...
@@ -121,7 +105,6 @@ AnyRequest: TypeAlias = (
...
@@ -121,7 +105,6 @@ AnyRequest: TypeAlias = (
|
ChatLikeRequest
|
ChatLikeRequest
|
SpeechToTextRequest
|
SpeechToTextRequest
|
ResponsesRequest
|
ResponsesRequest
|
IOProcessorRequest
|
GenerateRequest
|
GenerateRequest
)
)
...
@@ -130,7 +113,6 @@ AnyResponse: TypeAlias = (
...
@@ -130,7 +113,6 @@ AnyResponse: TypeAlias = (
|
ChatCompletionResponse
|
ChatCompletionResponse
|
TranscriptionResponse
|
TranscriptionResponse
|
TokenizeResponse
|
TokenizeResponse
|
PoolingResponse
|
GenerateResponse
|
GenerateResponse
)
)
...
@@ -146,12 +128,6 @@ class ServeContext(Generic[RequestT]):
...
@@ -146,12 +128,6 @@ class ServeContext(Generic[RequestT]):
created_time
:
int
=
field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
created_time
:
int
=
field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
lora_request
:
LoRARequest
|
None
=
None
lora_request
:
LoRARequest
|
None
=
None
engine_inputs
:
list
[
EngineInput
]
|
None
=
None
engine_inputs
:
list
[
EngineInput
]
|
None
=
None
result_generator
:
AsyncGenerator
[
tuple
[
int
,
PoolingRequestOutput
],
None
]
|
None
=
(
None
)
final_res_batch
:
list
[
PoolingRequestOutput
]
=
field
(
default_factory
=
list
)
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
...
@@ -171,7 +147,6 @@ class OpenAIServing:
...
@@ -171,7 +147,6 @@ class OpenAIServing:
super
().
__init__
()
super
().
__init__
()
self
.
engine_client
=
engine_client
self
.
engine_client
=
engine_client
self
.
models
=
models
self
.
models
=
models
self
.
request_logger
=
request_logger
self
.
request_logger
=
request_logger
...
@@ -179,7 +154,6 @@ class OpenAIServing:
...
@@ -179,7 +154,6 @@ class OpenAIServing:
self
.
model_config
=
engine_client
.
model_config
self
.
model_config
=
engine_client
.
model_config
self
.
renderer
=
engine_client
.
renderer
self
.
renderer
=
engine_client
.
renderer
self
.
io_processor
=
engine_client
.
io_processor
self
.
input_processor
=
engine_client
.
input_processor
self
.
input_processor
=
engine_client
.
input_processor
async
def
beam_search
(
async
def
beam_search
(
...
@@ -381,155 +355,6 @@ class OpenAIServing:
...
@@ -381,155 +355,6 @@ class OpenAIServing:
prompt_logprobs
=
None
,
prompt_logprobs
=
None
,
)
)
async
def
_preprocess
(
self
,
ctx
:
ServeContext
,
)
->
ErrorResponse
|
None
:
"""
Default preprocessing hook. Subclasses may override to prepare `ctx`.
"""
return
None
def
_build_response
(
self
,
ctx
:
ServeContext
,
)
->
AnyResponse
|
ErrorResponse
:
"""
Default response builder. Subclass may override this method
to return the appropriate response object.
"""
return
self
.
create_error_response
(
"unimplemented endpoint"
)
async
def
handle
(
self
,
ctx
:
ServeContext
,
)
->
AnyResponse
|
ErrorResponse
:
async
for
response
in
self
.
_pipeline
(
ctx
):
return
response
return
self
.
create_error_response
(
"No response yielded from pipeline"
)
async
def
_pipeline
(
self
,
ctx
:
ServeContext
,
)
->
AsyncGenerator
[
AnyResponse
|
ErrorResponse
,
None
]:
"""Execute the request processing pipeline yielding responses."""
if
error
:
=
await
self
.
_check_model
(
ctx
.
request
):
yield
error
if
error
:
=
self
.
_validate_request
(
ctx
):
yield
error
preprocess_ret
=
await
self
.
_preprocess
(
ctx
)
if
isinstance
(
preprocess_ret
,
ErrorResponse
):
yield
preprocess_ret
generators_ret
=
await
self
.
_prepare_generators
(
ctx
)
if
isinstance
(
generators_ret
,
ErrorResponse
):
yield
generators_ret
collect_ret
=
await
self
.
_collect_batch
(
ctx
)
if
isinstance
(
collect_ret
,
ErrorResponse
):
yield
collect_ret
yield
self
.
_build_response
(
ctx
)
def
_validate_request
(
self
,
ctx
:
ServeContext
)
->
ErrorResponse
|
None
:
truncate_prompt_tokens
=
getattr
(
ctx
.
request
,
"truncate_prompt_tokens"
,
None
)
if
(
truncate_prompt_tokens
is
not
None
and
truncate_prompt_tokens
>
self
.
model_config
.
max_model_len
):
return
self
.
create_error_response
(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please request a smaller truncation size."
)
return
None
def
_create_pooling_params
(
self
,
ctx
:
ServeContext
,
)
->
PoolingParams
|
ErrorResponse
:
if
not
hasattr
(
ctx
.
request
,
"to_pooling_params"
):
return
self
.
create_error_response
(
"Request type does not support pooling parameters"
)
return
ctx
.
request
.
to_pooling_params
()
async
def
_prepare_generators
(
self
,
ctx
:
ServeContext
,
)
->
ErrorResponse
|
None
:
"""Schedule the request and get the result generator."""
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
trace_headers
=
(
None
if
ctx
.
raw_request
is
None
else
await
self
.
_get_trace_headers
(
ctx
.
raw_request
.
headers
)
)
pooling_params
=
self
.
_create_pooling_params
(
ctx
)
if
isinstance
(
pooling_params
,
ErrorResponse
):
return
pooling_params
if
ctx
.
engine_inputs
is
None
:
return
self
.
create_error_response
(
"Engine prompts not available"
)
for
i
,
engine_input
in
enumerate
(
ctx
.
engine_inputs
):
request_id_item
=
f
"
{
ctx
.
request_id
}
-
{
i
}
"
self
.
_log_inputs
(
request_id_item
,
engine_input
,
params
=
pooling_params
,
lora_request
=
ctx
.
lora_request
,
)
generator
=
self
.
engine_client
.
encode
(
engine_input
,
pooling_params
,
request_id_item
,
lora_request
=
ctx
.
lora_request
,
trace_headers
=
trace_headers
,
priority
=
getattr
(
ctx
.
request
,
"priority"
,
0
),
)
generators
.
append
(
generator
)
ctx
.
result_generator
=
merge_async_iterators
(
*
generators
)
return
None
async
def
_collect_batch
(
self
,
ctx
:
ServeContext
,
)
->
ErrorResponse
|
None
:
"""Collect batch results from the result generator."""
if
ctx
.
engine_inputs
is
None
:
return
self
.
create_error_response
(
"Engine prompts not available"
)
num_prompts
=
len
(
ctx
.
engine_inputs
)
final_res_batch
:
list
[
PoolingRequestOutput
|
None
]
final_res_batch
=
[
None
]
*
num_prompts
if
ctx
.
result_generator
is
None
:
return
self
.
create_error_response
(
"Result generator not available"
)
async
for
i
,
res
in
ctx
.
result_generator
:
final_res_batch
[
i
]
=
res
if
None
in
final_res_batch
:
return
self
.
create_error_response
(
"Failed to generate results for all prompts"
)
ctx
.
final_res_batch
=
[
res
for
res
in
final_res_batch
if
res
is
not
None
]
return
None
@
staticmethod
@
staticmethod
def
create_error_response
(
def
create_error_response
(
message
:
str
|
Exception
,
message
:
str
|
Exception
,
...
@@ -719,7 +544,7 @@ class OpenAIServing:
...
@@ -719,7 +544,7 @@ class OpenAIServing:
self
,
self
,
request_id
:
str
,
request_id
:
str
,
inputs
:
PromptType
|
EngineInput
,
inputs
:
PromptType
|
EngineInput
,
params
:
SamplingParams
|
PoolingParams
|
BeamSearchParams
|
None
,
params
:
SamplingParams
|
BeamSearchParams
|
None
,
lora_request
:
LoRARequest
|
None
,
lora_request
:
LoRARequest
|
None
,
)
->
None
:
)
->
None
:
if
self
.
request_logger
is
None
:
if
self
.
request_logger
is
None
:
...
...
vllm/entrypoints/openai/models/serving.py
View file @
66c079ae
...
@@ -112,7 +112,6 @@ class OpenAIServingModels:
...
@@ -112,7 +112,6 @@ class OpenAIServingModels:
self
.
model_config
=
self
.
engine_client
.
model_config
self
.
model_config
=
self
.
engine_client
.
model_config
self
.
renderer
=
self
.
engine_client
.
renderer
self
.
renderer
=
self
.
engine_client
.
renderer
self
.
io_processor
=
self
.
engine_client
.
io_processor
self
.
input_processor
=
self
.
engine_client
.
input_processor
self
.
input_processor
=
self
.
engine_client
.
input_processor
async
def
init_static_loras
(
self
):
async
def
init_static_loras
(
self
):
...
...
vllm/entrypoints/pooling/__init__.py
View file @
66c079ae
...
@@ -67,20 +67,18 @@ def init_pooling_state(
...
@@ -67,20 +67,18 @@ def init_pooling_state(
from
vllm.entrypoints.chat_utils
import
load_chat_template
from
vllm.entrypoints.chat_utils
import
load_chat_template
from
vllm.entrypoints.pooling.classify.serving
import
ServingClassification
from
vllm.entrypoints.pooling.classify.serving
import
ServingClassification
from
vllm.entrypoints.pooling.embed.serving
import
ServingEmbedding
from
vllm.entrypoints.pooling.embed.serving
import
ServingEmbedding
from
vllm.entrypoints.pooling.pooling.serving
import
OpenAI
ServingPooling
from
vllm.entrypoints.pooling.pooling.serving
import
ServingPooling
from
vllm.entrypoints.pooling.scoring.serving
import
ServingScores
from
vllm.entrypoints.pooling.scoring.serving
import
ServingScores
from
vllm.tasks
import
POOLING_TASKS
from
vllm.tasks
import
POOLING_TASKS
model_config
=
engine_client
.
model_config
model_config
=
engine_client
.
model_config
resolved_chat_template
=
load_chat_template
(
args
.
chat_template
)
resolved_chat_template
=
load_chat_template
(
args
.
chat_template
)
state
.
serving_pooling
=
(
state
.
serving_pooling
=
(
(
(
OpenAI
ServingPooling
(
ServingPooling
(
engine_client
,
engine_client
,
state
.
openai_serving_models
,
state
.
openai_serving_models
,
state
.
openai_serving_render
,
supported_tasks
=
supported_tasks
,
supported_tasks
=
supported_tasks
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
chat_template
=
resolved_chat_template
,
chat_template
=
resolved_chat_template
,
...
...
vllm/entrypoints/pooling/base/io_processor.py
View file @
66c079ae
...
@@ -4,8 +4,8 @@
...
@@ -4,8 +4,8 @@
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
typing
import
Any
,
Final
from
typing
import
Any
,
Final
from
vllm
import
PoolingRequestOutput
,
PromptType
from
vllm
import
PoolingParams
,
PoolingRequestOutput
,
PromptType
from
vllm.config
import
Model
Config
from
vllm.config
import
Vllm
Config
from
vllm.entrypoints.chat_utils
import
(
from
vllm.entrypoints.chat_utils
import
(
ChatCompletionMessageParam
,
ChatCompletionMessageParam
,
ChatTemplateConfig
,
ChatTemplateConfig
,
...
@@ -33,11 +33,12 @@ class PoolingIOProcessor:
...
@@ -33,11 +33,12 @@ class PoolingIOProcessor:
def
__init__
(
def
__init__
(
self
,
self
,
model
_config
:
Model
Config
,
vllm
_config
:
Vllm
Config
,
renderer
:
BaseRenderer
,
renderer
:
BaseRenderer
,
chat_template_config
:
ChatTemplateConfig
,
chat_template_config
:
ChatTemplateConfig
,
):
):
self
.
model_config
=
model_config
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
self
.
renderer
=
renderer
self
.
renderer
=
renderer
self
.
chat_template
=
chat_template_config
.
chat_template
self
.
chat_template
=
chat_template_config
.
chat_template
...
@@ -48,12 +49,12 @@ class PoolingIOProcessor:
...
@@ -48,12 +49,12 @@ class PoolingIOProcessor:
chat_template_config
.
trust_request_chat_template
chat_template_config
.
trust_request_chat_template
)
)
def
create_pooling_params
(
self
,
request
):
return
request
.
to_pooling_params
()
#######################################
#######################################
# online APIs
# online APIs
def
create_pooling_params
(
self
,
request
):
return
request
.
to_pooling_params
()
def
pre_process_online
(
self
,
ctx
:
PoolingServeContext
):
def
pre_process_online
(
self
,
ctx
:
PoolingServeContext
):
request
=
ctx
.
request
request
=
ctx
.
request
...
@@ -100,12 +101,16 @@ class PoolingIOProcessor:
...
@@ -100,12 +101,16 @@ class PoolingIOProcessor:
# offline APIs
# offline APIs
def
pre_process_offline
(
self
,
ctx
:
OfflineInputsContext
)
->
Sequence
[
EngineInput
]:
def
pre_process_offline
(
self
,
ctx
:
OfflineInputsContext
)
->
Sequence
[
EngineInput
]:
assert
not
isinstance
(
ctx
.
prompts
,
ScoringData
)
assert
not
isinstance
(
ctx
.
prompts
,
ScoringData
)
and
not
(
isinstance
(
ctx
.
prompts
,
dict
)
and
"data"
in
ctx
.
prompts
)
prompts_seq
=
prompt_to_seq
(
ctx
.
prompts
)
tok_params
=
self
.
renderer
.
default_cmpl_tok_params
.
with_kwargs
(
tok_params
=
self
.
renderer
.
default_cmpl_tok_params
.
with_kwargs
(
**
(
ctx
.
tokenization_kwargs
or
{})
**
(
ctx
.
tokenization_kwargs
or
{})
)
)
return
self
.
_preprocess_completion_offline
(
return
self
.
_preprocess_completion_offline
(
prompts
=
ctx
.
prompts
,
tok_params
=
tok_params
prompts
=
prompts
_seq
,
tok_params
=
tok_params
)
)
async
def
pre_process_offline_async
(
self
,
ctx
:
OfflineInputsContext
):
async
def
pre_process_offline_async
(
self
,
ctx
:
OfflineInputsContext
):
...
@@ -243,3 +248,19 @@ class PoolingIOProcessor:
...
@@ -243,3 +248,19 @@ class PoolingIOProcessor:
"Refused request with untrusted chat template."
"Refused request with untrusted chat template."
)
)
return
None
return
None
def
_params_to_seq
(
self
,
params
:
PoolingParams
|
Sequence
[
PoolingParams
],
num_requests
:
int
,
)
->
Sequence
[
PoolingParams
]:
if
isinstance
(
params
,
Sequence
):
if
len
(
params
)
!=
num_requests
:
raise
ValueError
(
f
"The lengths of prompts (
{
num_requests
}
) "
f
"and params (
{
len
(
params
)
}
) must be the same."
)
return
params
return
[
params
]
*
num_requests
vllm/entrypoints/pooling/base/serving.py
View file @
66c079ae
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
AsyncGenerator
,
Mapping
from
collections.abc
import
AsyncGenerator
,
Mapping
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
ClassVar
from
typing
import
ClassVar
...
@@ -9,7 +11,7 @@ from fastapi.responses import Response
...
@@ -9,7 +11,7 @@ from fastapi.responses import Response
from
starlette.datastructures
import
Headers
from
starlette.datastructures
import
Headers
from
vllm
import
PoolingParams
,
PoolingRequestOutput
,
envs
from
vllm
import
PoolingParams
,
PoolingRequestOutput
,
envs
from
vllm.config
import
Model
Config
from
vllm.config
import
Vllm
Config
from
vllm.engine.protocol
import
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.chat_utils
import
(
from
vllm.entrypoints.chat_utils
import
(
ChatTemplateConfig
,
ChatTemplateConfig
,
...
@@ -35,7 +37,7 @@ from vllm.utils.async_utils import merge_async_iterators
...
@@ -35,7 +37,7 @@ from vllm.utils.async_utils import merge_async_iterators
from
.io_processor
import
PoolingIOProcessor
from
.io_processor
import
PoolingIOProcessor
class
PoolingServing
:
class
PoolingServing
Base
(
ABC
)
:
request_id_prefix
:
ClassVar
[
str
]
request_id_prefix
:
ClassVar
[
str
]
def
__init__
(
def
__init__
(
...
@@ -50,10 +52,11 @@ class PoolingServing:
...
@@ -50,10 +52,11 @@ class PoolingServing:
return_tokens_as_token_ids
:
bool
=
False
,
return_tokens_as_token_ids
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
):
):
super
().
__init__
()
self
.
engine_client
=
engine_client
self
.
engine_client
=
engine_client
self
.
models
=
models
self
.
models
=
models
self
.
model_config
=
models
.
model_config
self
.
model_config
=
models
.
model_config
self
.
renderer
=
models
.
renderer
self
.
vllm_config
=
engine_client
.
vllm_config
self
.
max_model_len
=
self
.
model_config
.
max_model_len
self
.
max_model_len
=
self
.
model_config
.
max_model_len
self
.
request_logger
=
request_logger
self
.
request_logger
=
request_logger
self
.
return_tokens_as_token_ids
=
return_tokens_as_token_ids
self
.
return_tokens_as_token_ids
=
return_tokens_as_token_ids
...
@@ -63,31 +66,14 @@ class PoolingServing:
...
@@ -63,31 +66,14 @@ class PoolingServing:
chat_template_content_format
=
chat_template_content_format
,
chat_template_content_format
=
chat_template_content_format
,
trust_request_chat_template
=
trust_request_chat_template
,
trust_request_chat_template
=
trust_request_chat_template
,
)
)
self
.
io_processor
=
self
.
init_io_processor
(
model_config
=
models
.
model_config
,
renderer
=
models
.
renderer
,
chat_template_config
=
self
.
chat_template_config
,
)
def
init_io_processor
(
self
,
model_config
:
ModelConfig
,
renderer
:
BaseRenderer
,
chat_template_config
:
ChatTemplateConfig
,
)
->
PoolingIOProcessor
:
raise
NotImplementedError
@
abstractmethod
async
def
__call__
(
async
def
__call__
(
self
,
self
,
request
:
AnyPoolingRequest
,
request
:
AnyPoolingRequest
,
raw_request
:
Request
|
None
=
None
,
raw_request
:
Request
|
None
=
None
,
)
->
Response
:
)
->
Response
:
ctx
=
await
self
.
_init_ctx
(
request
,
raw_request
)
raise
NotImplementedError
await
self
.
io_processor
.
pre_process_online_async
(
ctx
)
await
self
.
_prepare_generators
(
ctx
)
await
self
.
_collect_batch
(
ctx
)
await
self
.
io_processor
.
post_process_online_async
(
ctx
)
return
await
self
.
_build_response
(
ctx
)
async
def
_init_ctx
(
async
def
_init_ctx
(
self
,
self
,
...
@@ -124,10 +110,8 @@ class PoolingServing:
...
@@ -124,10 +110,8 @@ class PoolingServing:
else
await
self
.
_get_trace_headers
(
ctx
.
raw_request
.
headers
)
else
await
self
.
_get_trace_headers
(
ctx
.
raw_request
.
headers
)
)
)
if
ctx
.
pooling_params
is
None
:
assert
ctx
.
pooling_params
is
not
None
pooling_params
=
self
.
io_processor
.
create_pooling_params
(
ctx
.
request
)
pooling_params
=
ctx
.
pooling_params
else
:
pooling_params
=
ctx
.
pooling_params
if
isinstance
(
pooling_params
,
list
):
if
isinstance
(
pooling_params
,
list
):
for
params
in
pooling_params
:
for
params
in
pooling_params
:
...
@@ -190,6 +174,7 @@ class PoolingServing:
...
@@ -190,6 +174,7 @@ class PoolingServing:
ctx
.
final_res_batch
=
[
res
for
res
in
final_res_batch
if
res
is
not
None
]
ctx
.
final_res_batch
=
[
res
for
res
in
final_res_batch
if
res
is
not
None
]
@
abstractmethod
async
def
_build_response
(
async
def
_build_response
(
self
,
self
,
ctx
:
PoolingServeContext
,
ctx
:
PoolingServeContext
,
...
@@ -355,3 +340,39 @@ class PoolingServing:
...
@@ -355,3 +340,39 @@ class PoolingServing:
params
=
params
,
params
=
params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
)
)
class
PoolingServing
(
PoolingServingBase
,
ABC
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
io_processor
=
self
.
init_io_processor
(
vllm_config
=
self
.
vllm_config
,
renderer
=
self
.
renderer
,
chat_template_config
=
self
.
chat_template_config
,
)
@
abstractmethod
def
init_io_processor
(
self
,
vllm_config
:
VllmConfig
,
renderer
:
BaseRenderer
,
chat_template_config
:
ChatTemplateConfig
,
)
->
PoolingIOProcessor
:
raise
NotImplementedError
async
def
__call__
(
self
,
request
:
AnyPoolingRequest
,
raw_request
:
Request
|
None
=
None
,
)
->
Response
:
ctx
=
await
self
.
_init_ctx
(
request
,
raw_request
)
await
self
.
io_processor
.
pre_process_online_async
(
ctx
)
if
ctx
.
pooling_params
is
None
:
ctx
.
pooling_params
=
self
.
io_processor
.
create_pooling_params
(
request
)
await
self
.
_prepare_generators
(
ctx
)
await
self
.
_collect_batch
(
ctx
)
await
self
.
io_processor
.
post_process_online_async
(
ctx
)
return
await
self
.
_build_response
(
ctx
)
vllm/entrypoints/pooling/classify/io_processor.py
View file @
66c079ae
...
@@ -5,4 +5,8 @@ from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
...
@@ -5,4 +5,8 @@ from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
class
ClassifyIOProcessor
(
PoolingIOProcessor
):
class
ClassifyIOProcessor
(
PoolingIOProcessor
):
name
=
"classification"
name
=
"classify"
class
TokenClassifyIOProcessor
(
PoolingIOProcessor
):
name
=
"token_classify"
vllm/entrypoints/pooling/classify/serving.py
View file @
66c079ae
...
@@ -6,14 +6,11 @@ from typing import TypeAlias
...
@@ -6,14 +6,11 @@ from typing import TypeAlias
import
numpy
as
np
import
numpy
as
np
from
fastapi.responses
import
JSONResponse
from
fastapi.responses
import
JSONResponse
from
vllm.config
import
ModelConfig
from
vllm.entrypoints.chat_utils
import
ChatTemplateConfig
from
vllm.entrypoints.openai.engine.protocol
import
UsageInfo
from
vllm.entrypoints.openai.engine.protocol
import
UsageInfo
from
vllm.entrypoints.pooling.base.serving
import
PoolingServing
from
vllm.entrypoints.pooling.base.serving
import
PoolingServing
from
vllm.entrypoints.pooling.typing
import
PoolingServeContext
from
vllm.entrypoints.pooling.typing
import
PoolingServeContext
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
ClassificationOutput
from
vllm.outputs
import
ClassificationOutput
from
vllm.renderers
import
BaseRenderer
from
.io_processor
import
ClassifyIOProcessor
from
.io_processor
import
ClassifyIOProcessor
from
.protocol
import
(
from
.protocol
import
(
...
@@ -31,17 +28,8 @@ ClassificationServeContext: TypeAlias = PoolingServeContext[ClassificationReques
...
@@ -31,17 +28,8 @@ ClassificationServeContext: TypeAlias = PoolingServeContext[ClassificationReques
class
ServingClassification
(
PoolingServing
):
class
ServingClassification
(
PoolingServing
):
request_id_prefix
=
"classify"
request_id_prefix
=
"classify"
def
init_io_processor
(
def
init_io_processor
(
self
,
*
args
,
**
kwargs
)
->
ClassifyIOProcessor
:
self
,
return
ClassifyIOProcessor
(
*
args
,
**
kwargs
)
model_config
:
ModelConfig
,
renderer
:
BaseRenderer
,
chat_template_config
:
ChatTemplateConfig
,
)
->
ClassifyIOProcessor
:
return
ClassifyIOProcessor
(
model_config
=
model_config
,
renderer
=
renderer
,
chat_template_config
=
chat_template_config
,
)
async
def
_build_response
(
async
def
_build_response
(
self
,
self
,
...
...
vllm/entrypoints/pooling/embed/io_processor.py
View file @
66c079ae
...
@@ -37,7 +37,7 @@ logger = init_logger(__name__)
...
@@ -37,7 +37,7 @@ logger = init_logger(__name__)
class
EmbedIOProcessor
(
PoolingIOProcessor
):
class
EmbedIOProcessor
(
PoolingIOProcessor
):
name
=
"embed
ding
"
name
=
"embed"
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
...
@@ -549,3 +549,7 @@ class EmbedIOProcessor(PoolingIOProcessor):
...
@@ -549,3 +549,7 @@ class EmbedIOProcessor(PoolingIOProcessor):
request
=
ctx
.
request
request
=
ctx
.
request
if
request
.
truncate
==
"NONE"
and
request
.
max_tokens
is
not
None
:
if
request
.
truncate
==
"NONE"
and
request
.
max_tokens
is
not
None
:
self
.
_check_cohere_max_tokens
(
ctx
.
final_res_batch
,
request
.
max_tokens
)
self
.
_check_cohere_max_tokens
(
ctx
.
final_res_batch
,
request
.
max_tokens
)
class
TokenEmbedIOProcessor
(
PoolingIOProcessor
):
name
=
"token_embed"
vllm/entrypoints/pooling/embed/serving.py
View file @
66c079ae
...
@@ -8,8 +8,6 @@ from typing import Literal, TypeAlias, cast
...
@@ -8,8 +8,6 @@ from typing import Literal, TypeAlias, cast
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
typing_extensions
import
assert_never
from
typing_extensions
import
assert_never
from
vllm.config
import
ModelConfig
from
vllm.entrypoints.chat_utils
import
ChatTemplateConfig
from
vllm.entrypoints.openai.engine.protocol
import
UsageInfo
from
vllm.entrypoints.openai.engine.protocol
import
UsageInfo
from
vllm.entrypoints.pooling.base.serving
import
PoolingServing
from
vllm.entrypoints.pooling.base.serving
import
PoolingServing
from
vllm.entrypoints.pooling.embed.io_processor
import
EmbedIOProcessor
from
vllm.entrypoints.pooling.embed.io_processor
import
EmbedIOProcessor
...
@@ -33,12 +31,10 @@ from vllm.entrypoints.pooling.utils import (
...
@@ -33,12 +31,10 @@ from vllm.entrypoints.pooling.utils import (
)
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
PoolingRequestOutput
from
vllm.outputs
import
PoolingRequestOutput
from
vllm.renderers
import
BaseRenderer
from
vllm.utils.serial_utils
import
EmbedDType
,
Endianness
from
vllm.utils.serial_utils
import
EmbedDType
,
Endianness
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
JSONResponseCLS
=
get_json_response_cls
()
EmbeddingServeContext
:
TypeAlias
=
PoolingServeContext
[
EmbeddingRequest
]
EmbeddingServeContext
:
TypeAlias
=
PoolingServeContext
[
EmbeddingRequest
]
...
@@ -49,17 +45,13 @@ class ServingEmbedding(PoolingServing):
...
@@ -49,17 +45,13 @@ class ServingEmbedding(PoolingServing):
request_id_prefix
=
"embd"
request_id_prefix
=
"embd"
io_processor
:
EmbedIOProcessor
io_processor
:
EmbedIOProcessor
def
init_io_processor
(
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
,
super
().
__init__
(
*
args
,
**
kwargs
)
model_config
:
ModelConfig
,
renderer
:
BaseRenderer
,
self
.
json_response_cls
=
get_json_response_cls
()
chat_template_config
:
ChatTemplateConfig
,
)
->
EmbedIOProcessor
:
def
init_io_processor
(
self
,
*
args
,
**
kwargs
)
->
EmbedIOProcessor
:
return
EmbedIOProcessor
(
return
EmbedIOProcessor
(
*
args
,
**
kwargs
)
model_config
=
model_config
,
renderer
=
renderer
,
chat_template_config
=
chat_template_config
,
)
async
def
_build_response
(
async
def
_build_response
(
self
,
self
,
...
@@ -149,7 +141,7 @@ class ServingEmbedding(PoolingServing):
...
@@ -149,7 +141,7 @@ class ServingEmbedding(PoolingServing):
data
=
items
,
data
=
items
,
usage
=
usage
,
usage
=
usage
,
)
)
return
JSONR
esponse
CLS
(
content
=
response
.
model_dump
())
return
self
.
json_r
esponse
_cls
(
content
=
response
.
model_dump
())
def
_openai_bytes_response
(
def
_openai_bytes_response
(
self
,
self
,
...
@@ -190,8 +182,8 @@ class ServingEmbedding(PoolingServing):
...
@@ -190,8 +182,8 @@ class ServingEmbedding(PoolingServing):
media_type
=
response
.
media_type
,
media_type
=
response
.
media_type
,
)
)
@
staticmethod
def
_build_cohere_response_from_ctx
(
def
_build_cohere_response_from_ctx
(
self
,
ctx
:
PoolingServeContext
,
ctx
:
PoolingServeContext
,
)
->
JSONResponse
:
)
->
JSONResponse
:
request
=
ctx
.
request
request
=
ctx
.
request
...
@@ -218,4 +210,4 @@ class ServingEmbedding(PoolingServing):
...
@@ -218,4 +210,4 @@ class ServingEmbedding(PoolingServing):
),
),
),
),
)
)
return
JSONR
esponse
(
content
=
response
.
model_dump
(
exclude_none
=
True
))
return
self
.
json_r
esponse
_cls
(
content
=
response
.
model_dump
(
exclude_none
=
True
))
vllm/entrypoints/pooling/io_processor_factories.py
View file @
66c079ae
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.config
import
VllmConfig
from
vllm.config
import
ModelConfig
from
vllm.entrypoints.chat_utils
import
ChatTemplateConfig
from
vllm.entrypoints.chat_utils
import
ChatTemplateConfig
from
vllm.entrypoints.pooling.base.io_processor
import
PoolingIOProcessor
from
vllm.plugins.io_processors
import
has_io_processor
from
vllm.entrypoints.pooling.scoring.io_processor
import
ScoringIOProcessors
from
vllm.entrypoints.pooling.utils
import
enable_scoring_api
from
vllm.renderers
import
BaseRenderer
from
vllm.renderers
import
BaseRenderer
from
vllm.tasks
import
SupportedTask
from
vllm.tasks
import
SupportedTask
from
.base.io_processor
import
PoolingIOProcessor
from
.utils
import
enable_scoring_api
def
init_pooling_io_processors
(
def
init_pooling_io_processors
(
supported_tasks
:
tuple
[
SupportedTask
,
...],
supported_tasks
:
tuple
[
SupportedTask
,
...],
model
_config
:
Model
Config
,
vllm
_config
:
Vllm
Config
,
renderer
:
BaseRenderer
,
renderer
:
BaseRenderer
,
chat_template_config
:
ChatTemplateConfig
,
chat_template_config
:
ChatTemplateConfig
,
)
->
dict
[
str
,
PoolingIOProcessor
]:
)
->
dict
[
str
,
PoolingIOProcessor
]:
processors
:
list
[
tuple
[
str
,
type
[
PoolingIOProcessor
]]]
=
[]
model_config
=
vllm_config
.
model_config
processors
:
dict
[
str
,
type
[
PoolingIOProcessor
]]
=
{}
if
"classify"
in
supported_tasks
:
if
"classify"
in
supported_tasks
:
from
vllm.entrypoints.pooling.classify.io_processor
import
ClassifyIOProcessor
from
.classify.io_processor
import
ClassifyIOProcessor
processors
[
"classify"
]
=
ClassifyIOProcessor
if
"token_classify"
in
supported_tasks
:
from
.classify.io_processor
import
TokenClassifyIOProcessor
processors
[
"token_classify"
]
=
TokenClassifyIOProcessor
processors
.
append
((
"classify"
,
ClassifyIOProcessor
))
if
"embed"
in
supported_tasks
:
if
"embed"
in
supported_tasks
:
from
vllm.entrypoints.pooling
.embed.io_processor
import
EmbedIOProcessor
from
.embed.io_processor
import
EmbedIOProcessor
processors
.
append
((
"embed"
,
EmbedIOProcessor
))
processors
[
"embed"
]
=
EmbedIOProcessor
if
"token_embed"
in
supported_tasks
:
from
.embed.io_processor
import
TokenEmbedIOProcessor
processors
[
"token_embed"
]
=
TokenEmbedIOProcessor
if
has_io_processor
(
vllm_config
,
model_config
.
io_processor_plugin
,
):
from
.pooling.io_processor
import
PluginWithIOProcessorPlugins
processors
[
"plugin"
]
=
PluginWithIOProcessorPlugins
elif
"plugin"
in
supported_tasks
:
from
.pooling.io_processor
import
PluginWithoutIOProcessorPlugins
processors
[
"plugin"
]
=
PluginWithoutIOProcessorPlugins
if
enable_scoring_api
(
supported_tasks
,
model_config
):
if
enable_scoring_api
(
supported_tasks
,
model_config
):
score_type
=
model_config
.
score_type
score_type
=
model_config
.
score_type
from
.scoring.io_processor
import
ScoringIOProcessors
if
score_type
is
not
None
and
score_type
in
ScoringIOProcessors
:
if
score_type
is
not
None
and
score_type
in
ScoringIOProcessors
:
processors
.
append
((
score_type
,
ScoringIOProcessors
[
score_type
]
))
processors
[
score_type
]
=
ScoringIOProcessors
[
score_type
]
return
{
return
{
task
:
processor_cls
(
task
:
processor_cls
(
model
_config
=
model
_config
,
vllm
_config
=
vllm
_config
,
renderer
=
renderer
,
renderer
=
renderer
,
chat_template_config
=
chat_template_config
,
chat_template_config
=
chat_template_config
,
)
)
for
task
,
processor_cls
in
processors
for
task
,
processor_cls
in
processors
.
items
()
}
}
vllm/entrypoints/pooling/pooling/api_router.py
View file @
66c079ae
...
@@ -3,24 +3,17 @@
...
@@ -3,24 +3,17 @@
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
fastapi
import
APIRouter
,
Depends
,
Request
from
fastapi
import
APIRouter
,
Depends
,
Request
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
from
typing_extensions
import
assert_never
from
vllm.entrypoints.openai.engine.protocol
import
ErrorResponse
from
vllm.entrypoints.openai.engine.protocol
import
ErrorResponse
from
vllm.entrypoints.openai.utils
import
validate_json_request
from
vllm.entrypoints.openai.utils
import
validate_json_request
from
vllm.entrypoints.pooling.pooling.protocol
import
(
from
vllm.entrypoints.pooling.pooling.protocol
import
PoolingRequest
IOProcessorResponse
,
from
vllm.entrypoints.pooling.pooling.serving
import
ServingPooling
PoolingBytesResponse
,
PoolingRequest
,
PoolingResponse
,
)
from
vllm.entrypoints.pooling.pooling.serving
import
OpenAIServingPooling
from
vllm.entrypoints.utils
import
load_aware_call
,
with_cancellation
from
vllm.entrypoints.utils
import
load_aware_call
,
with_cancellation
router
=
APIRouter
()
router
=
APIRouter
()
def
pooling
(
request
:
Request
)
->
OpenAI
ServingPooling
|
None
:
def
pooling
(
request
:
Request
)
->
ServingPooling
|
None
:
return
request
.
app
.
state
.
serving_pooling
return
request
.
app
.
state
.
serving_pooling
...
@@ -39,19 +32,4 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
...
@@ -39,19 +32,4 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
if
handler
is
None
:
if
handler
is
None
:
raise
NotImplementedError
(
"The model does not support Pooling API"
)
raise
NotImplementedError
(
"The model does not support Pooling API"
)
generator
=
await
handler
.
create_pooling
(
request
,
raw_request
)
return
await
handler
(
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
error
.
code
)
elif
isinstance
(
generator
,
(
PoolingResponse
,
IOProcessorResponse
)):
return
JSONResponse
(
content
=
generator
.
model_dump
())
elif
isinstance
(
generator
,
PoolingBytesResponse
):
return
StreamingResponse
(
content
=
generator
.
content
,
headers
=
generator
.
headers
,
media_type
=
generator
.
media_type
,
)
assert_never
(
generator
)
vllm/entrypoints/pooling/pooling/io_processor.py
0 → 100644
View file @
66c079ae
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Sequence
from
typing
import
Any
from
vllm
import
PoolingParams
,
PoolingRequestOutput
from
vllm.entrypoints.pooling.base.io_processor
import
PoolingIOProcessor
from
vllm.inputs
import
EngineInput
from
vllm.logger
import
init_logger
from
vllm.plugins.io_processors
import
get_io_processor
from
vllm.renderers.inputs.preprocess
import
parse_model_prompt
,
prompt_to_seq
from
..typing
import
OfflineInputsContext
,
OfflineOutputsContext
,
PoolingServeContext
from
.protocol
import
IOProcessorRequest
,
IOProcessorResponse
logger
=
init_logger
(
__name__
)
class
PluginWithoutIOProcessorPlugins
(
PoolingIOProcessor
):
name
=
"plugin"
class
PluginWithIOProcessorPlugins
(
PoolingIOProcessor
):
"""IO Processor plugins are a feature that allows pre- and post-processing
of the model input and output for pooling models."""
name
=
"plugin"
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
io_processor
=
get_io_processor
(
self
.
vllm_config
,
self
.
renderer
,
self
.
model_config
.
io_processor_plugin
,
)
assert
io_processor
is
not
None
self
.
io_processor
=
io_processor
#######################################
# online APIs
def
pre_process_online
(
self
,
ctx
:
PoolingServeContext
):
assert
isinstance
(
ctx
.
request
,
IOProcessorRequest
)
validated_prompt
=
self
.
io_processor
.
parse_data
(
ctx
.
request
.
data
)
raw_prompts
=
self
.
io_processor
.
pre_process
(
prompt
=
validated_prompt
,
request_id
=
ctx
.
request_id
)
parsed_prompts
=
[
(
prompt
if
isinstance
(
prompt
,
bytes
)
else
parse_model_prompt
(
self
.
model_config
,
prompt
)
)
for
prompt
in
prompt_to_seq
(
raw_prompts
)
]
tok_params
=
ctx
.
request
.
build_tok_params
(
self
.
model_config
)
ctx
.
engine_inputs
=
self
.
renderer
.
render_cmpl
(
parsed_prompts
,
tok_params
,
prompt_extras
=
{
k
:
v
for
k
in
(
"mm_processor_kwargs"
,
"cache_salt"
)
if
(
v
:
=
getattr
(
ctx
.
request
,
k
,
None
))
is
not
None
},
)
pooling_params
=
self
.
io_processor
.
merge_pooling_params
()
if
pooling_params
.
task
is
None
:
pooling_params
.
task
=
"plugin"
ctx
.
pooling_params
=
pooling_params
def
post_process_online
(
self
,
ctx
:
PoolingServeContext
,
):
output
=
self
.
io_processor
.
post_process
(
ctx
.
final_res_batch
,
request_id
=
ctx
.
request_id
,
)
if
callable
(
output_to_response
:
=
getattr
(
self
.
io_processor
,
"output_to_response"
,
None
)
):
logger
.
warning_once
(
"`IOProcessor.output_to_response` is deprecated. To ensure "
"consistency between offline and online APIs, "
"`IOProcessorResponse` will become a transparent wrapper "
"around output data from v0.19 onwards."
,
)
if
hasattr
(
output
,
"request_id"
)
and
output
.
request_id
is
None
:
output
.
request_id
=
ctx
.
request_id
# type: ignore
ctx
.
response
=
output_to_response
(
output
)
# type: ignore
else
:
ctx
.
response
=
IOProcessorResponse
(
request_id
=
ctx
.
request_id
,
data
=
output
)
#######################################
# offline APIs
def
pre_process_offline
(
self
,
ctx
:
OfflineInputsContext
)
->
Sequence
[
EngineInput
]:
assert
isinstance
(
ctx
.
prompts
,
dict
)
and
"data"
in
ctx
.
prompts
assert
ctx
.
pooling_params
is
not
None
# Validate the request data is valid for the loaded plugin
prompt_data
=
ctx
.
prompts
.
get
(
"data"
)
if
prompt_data
is
None
:
raise
ValueError
(
"The 'data' field of the prompt is expected to contain "
"the prompt data and it cannot be None. "
"Refer to the documentation of the IOProcessor "
"in use for more details."
)
validated_prompt
=
self
.
io_processor
.
parse_data
(
prompt_data
)
# obtain the actual model prompts from the pre-processor
prompts
=
self
.
io_processor
.
pre_process
(
prompt
=
validated_prompt
)
prompts_seq
=
prompt_to_seq
(
prompts
)
params_seq
:
list
[
PoolingParams
]
=
[
self
.
io_processor
.
merge_pooling_params
(
param
)
for
param
in
self
.
_params_to_seq
(
ctx
.
pooling_params
,
len
(
prompts_seq
),
)
]
for
p
in
params_seq
:
if
p
.
task
is
None
:
p
.
task
=
"plugin"
ctx
.
pooling_params
=
params_seq
ctx
.
prompts
=
prompts_seq
return
super
().
pre_process_offline
(
ctx
)
def
post_process_offline
(
self
,
ctx
:
OfflineOutputsContext
,
)
->
list
[
PoolingRequestOutput
]:
processed_outputs
=
self
.
io_processor
.
post_process
(
ctx
.
outputs
)
return
[
PoolingRequestOutput
[
Any
](
request_id
=
""
,
outputs
=
processed_outputs
,
num_cached_tokens
=
getattr
(
processed_outputs
,
"num_cached_tokens"
,
0
),
prompt_token_ids
=
[],
finished
=
True
,
)
]
vllm/entrypoints/pooling/pooling/serving.py
View file @
66c079ae
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
json
import
json
import
time
from
collections.abc
import
Callable
from
collections.abc
import
AsyncGenerator
,
Callable
,
Sequence
from
functools
import
partial
from
functools
import
partial
from
typing
import
Final
,
Literal
,
cast
from
typing
import
Literal
,
cast
from
fastapi
import
Request
from
fastapi
import
Request
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
typing_extensions
import
assert_never
from
typing_extensions
import
assert_never
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.openai.engine.protocol
import
UsageInfo
from
vllm.entrypoints.chat_utils
import
ChatTemplateContentFormatOption
from
vllm.entrypoints.pooling.base.serving
import
PoolingServingBase
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.pooling.io_processor_factories
import
init_pooling_io_processors
from
vllm.entrypoints.openai.engine.protocol
import
ErrorResponse
,
UsageInfo
from
vllm.entrypoints.openai.engine.serving
import
OpenAIServing
from
vllm.entrypoints.openai.models.serving
import
OpenAIServingModels
from
vllm.entrypoints.pooling.pooling.protocol
import
(
from
vllm.entrypoints.pooling.pooling.protocol
import
(
IOProcessorRequest
,
IOProcessorRequest
,
IOProcessorResponse
,
PoolingBytesResponse
,
PoolingBytesResponse
,
PoolingChatRequest
,
PoolingCompletionRequest
,
PoolingRequest
,
PoolingRequest
,
PoolingResponse
,
PoolingResponse
,
PoolingResponseData
,
PoolingResponseData
,
)
)
from
vllm.entrypoints.pooling.typing
import
AnyPoolingRequest
,
PoolingServeContext
from
vllm.entrypoints.pooling.utils
import
(
from
vllm.entrypoints.pooling.utils
import
(
encode_pooling_bytes
,
encode_pooling_bytes
,
encode_pooling_output_base64
,
encode_pooling_output_base64
,
encode_pooling_output_float
,
encode_pooling_output_float
,
get_json_response_cls
,
)
)
from
vllm.entrypoints.serve.render.serving
import
OpenAIServingRender
from
vllm.inputs
import
EngineInput
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
PoolingRequestOutput
from
vllm.outputs
import
PoolingRequestOutput
from
vllm.renderers.inputs.preprocess
import
prompt_to_seq
from
vllm.tasks
import
SupportedTask
from
vllm.tasks
import
SupportedTask
from
vllm.utils.async_utils
import
merge_async_iterators
from
vllm.utils.serial_utils
import
EmbedDType
,
Endianness
from
vllm.utils.serial_utils
import
EmbedDType
,
EncodingFormat
,
Endianness
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
OpenAIServingPooling
(
OpenAIServing
):
class
ServingPooling
(
PoolingServingBase
):
request_id_prefix
=
"pooling"
def
__init__
(
def
__init__
(
self
,
self
,
engine_client
:
EngineClient
,
*
args
,
models
:
OpenAIServingModels
,
openai_serving_render
:
OpenAIServingRender
,
supported_tasks
:
tuple
[
SupportedTask
,
...],
supported_tasks
:
tuple
[
SupportedTask
,
...],
*
,
**
kwargs
,
request_logger
:
RequestLogger
|
None
,
):
chat_template
:
str
|
None
,
super
().
__init__
(
*
args
,
**
kwargs
)
chat_template_content_format
:
ChatTemplateContentFormatOption
,
trust_request_chat_template
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
engine_client
=
engine_client
,
models
=
models
,
request_logger
=
request_logger
,
)
self
.
supported_tasks
=
supported_tasks
self
.
supported_tasks
=
supported_tasks
self
.
pooling_task
=
self
.
model_config
.
get_pooling_task
(
supported_tasks
)
self
.
pooling_task
=
self
.
model_config
.
get_pooling_task
(
supported_tasks
)
self
.
openai_serving_render
=
openai_serving_render
self
.
io_processors
=
init_pooling_io_processors
(
self
.
chat_template
=
chat_template
supported_tasks
=
supported_tasks
,
self
.
chat_template_content_format
:
Final
=
chat_template_content_format
vllm_config
=
self
.
vllm_config
,
self
.
trust_request_chat_template
=
trust_request_chat_template
renderer
=
self
.
renderer
,
chat_template_config
=
self
.
chat_template_config
,
)
self
.
json_response_cls
=
get_json_response_cls
()
async
def
create_pooling
(
async
def
__call__
(
self
,
self
,
request
:
PoolingRequest
,
request
:
Any
PoolingRequest
,
raw_request
:
Request
|
None
=
None
,
raw_request
:
Request
|
None
=
None
,
)
->
PoolingResponse
|
IOProcessorResponse
|
PoolingBytesResponse
|
ErrorResponse
:
)
->
Response
:
"""
assert
isinstance
(
request
,
PoolingRequest
)
See https://platform.openai.com/docs/api-reference/embeddings/create
pooling_task
=
self
.
_verify_pooling_task
(
request
)
for the API specification. This API mimics the OpenAI Embedding API.
"""
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
model_name
=
self
.
models
.
model_name
()
io_processor
=
self
.
io_processors
[
pooling_task
]
ctx
=
await
self
.
_init_ctx
(
request
,
raw_request
)
request_id
=
f
"pool-
{
self
.
_base_request_id
(
raw_request
)
}
"
await
io_processor
.
pre_process_online_async
(
ctx
)
created_time
=
int
(
time
.
time
())
lora_request
=
self
.
_maybe_get_adapters
(
request
)
if
ctx
.
pooling_params
is
None
:
ctx
.
pooling_params
=
io_processor
.
create_pooling_params
(
request
)
await
self
.
_prepare_generators
(
ctx
)
await
self
.
_collect_batch
(
ctx
)
await
io_processor
.
post_process_online_async
(
ctx
)
return
await
self
.
_build_response
(
ctx
)
def
_verify_pooling_task
(
self
,
request
:
PoolingRequest
)
->
str
:
if
getattr
(
request
,
"dimensions"
,
None
)
is
not
None
:
raise
ValueError
(
"dimensions is currently not supported"
)
if
request
.
task
is
None
:
if
request
.
task
is
None
:
request
.
task
=
self
.
pooling_task
request
.
task
=
self
.
pooling_task
if
getattr
(
request
,
"dimensions"
,
None
)
is
not
None
:
if
isinstance
(
request
,
IOProcessorRequest
):
return
self
.
create_error_response
(
"dimensions is currently not supported"
)
request
.
task
=
"plugin"
assert
request
.
task
is
not
None
pooling_task
=
request
.
task
# plugin task uses io_processor.parse_request to verify inputs
# plugin task uses io_processor.parse_request to verify inputs
if
request
.
task
!=
"plugin"
and
request
.
task
!=
self
.
pooling_task
:
if
pooling_
task
!=
"plugin"
and
pooling_
task
!=
self
.
pooling_task
:
if
request
.
task
not
in
self
.
supported_task
s
:
if
pooling_
task
not
in
self
.
io_processor
s
:
raise
ValueError
(
raise
ValueError
(
f
"Unsupported task:
{
request
.
task
!
r
}
"
f
"Unsupported task:
{
pooling_
task
!
r
}
"
f
"Supported tasks:
{
self
.
supported_tasks
}
"
f
"Supported tasks:
{
self
.
supported_tasks
}
"
)
)
else
:
else
:
logger
.
warning_once
(
logger
.
warning_once
(
"Pooling multitask support is deprecated and will be removed "
"Pooling multitask support is deprecated and will be removed "
"in v0.20. When the default pooling task is not what you want, you "
"in v0.20. When the default pooling task is not what you want, you "
'need to manually specify it via --pooler-config.task "%s". '
,
"need to manually specify it via --pooler-config.task %s. "
,
request
.
task
,
pooling_task
,
)
engine_inputs
:
Sequence
[
EngineInput
]
if
use_io_processor
:
=
isinstance
(
request
,
IOProcessorRequest
):
if
self
.
io_processor
is
None
:
raise
ValueError
(
"No IOProcessor plugin installed. Please refer "
"to the documentation and to the "
"'prithvi_geospatial_mae_io_processor' "
"offline inference example for more details."
)
)
validated_prompt
=
self
.
io_processor
.
parse_data
(
request
.
data
)
if
pooling_task
==
"plugin"
and
"plugin"
not
in
self
.
io_processors
:
raise
ValueError
(
raw_prompts
=
await
self
.
io_processor
.
pre_process_async
(
"No IOProcessor plugin installed. Please refer "
prompt
=
validated_prompt
,
request_id
=
request_id
"to the documentation and to the "
)
"'prithvi_geospatial_mae_io_processor' "
engine_inputs
=
await
self
.
openai_serving_render
.
preprocess_cmpl
(
"offline inference example for more details."
request
,
prompt_to_seq
(
raw_prompts
),
)
elif
isinstance
(
request
,
PoolingChatRequest
):
error_check_ret
=
self
.
openai_serving_render
.
validate_chat_template
(
request_chat_template
=
request
.
chat_template
,
chat_template_kwargs
=
request
.
chat_template_kwargs
,
trust_request_chat_template
=
self
.
trust_request_chat_template
,
)
if
error_check_ret
is
not
None
:
return
error_check_ret
_
,
engine_inputs
=
await
self
.
openai_serving_render
.
preprocess_chat
(
request
,
request
.
messages
,
default_template
=
self
.
chat_template
,
default_template_content_format
=
self
.
chat_template_content_format
,
default_template_kwargs
=
None
,
)
elif
isinstance
(
request
,
PoolingCompletionRequest
):
engine_inputs
=
await
self
.
openai_serving_render
.
preprocess_completion
(
request
,
prompt_input
=
request
.
input
,
prompt_embeds
=
None
,
)
else
:
raise
ValueError
(
f
"Unsupported request of type
{
type
(
request
)
}
"
)
# Schedule the request and get the result generator.
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
if
use_io_processor
:
assert
self
.
io_processor
is
not
None
pooling_params
=
self
.
io_processor
.
merge_pooling_params
()
if
pooling_params
.
task
is
None
:
pooling_params
.
task
=
"plugin"
else
:
pooling_params
=
request
.
to_pooling_params
()
# type: ignore
for
i
,
engine_input
in
enumerate
(
engine_inputs
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
self
.
_log_inputs
(
request_id_item
,
engine_input
,
params
=
pooling_params
,
lora_request
=
lora_request
,
)
)
trace_headers
=
(
return
pooling_task
None
if
raw_request
is
None
else
await
self
.
_get_trace_headers
(
raw_request
.
headers
)
)
generator
=
self
.
engine_client
.
encode
(
engine_input
,
pooling_params
,
request_id_item
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
)
generators
.
append
(
generator
)
async
def
_build_response
(
self
,
ctx
:
PoolingServeContext
,
)
->
Response
:
if
ctx
.
response
is
not
None
:
# for IOProcessorResponse
return
self
.
json_response_cls
(
content
=
ctx
.
response
.
model_dump
())
result_generator
=
merge_async_iterators
(
*
generators
)
encoding_format
=
ctx
.
request
.
encoding_format
embed_dtype
=
ctx
.
request
.
embed_dtype
endianness
=
ctx
.
request
.
endianness
if
use_io_processor
:
if
encoding_format
==
"float"
or
encoding_format
==
"base64"
:
assert
self
.
io_processor
is
not
None
return
self
.
request_output_to_pooling_json_response
(
output
=
await
self
.
io_processor
.
post_process_async
(
ctx
.
final_res_batch
,
result_generator
,
ctx
.
request_id
,
request_id
=
request_id
,
ctx
.
created_time
,
ctx
.
model_name
,
encoding_format
,
embed_dtype
,
endianness
,
)
)
if
callable
(
if
encoding_format
==
"bytes"
or
encoding_format
==
"bytes_only"
:
output_to_response
:
=
getattr
(
return
self
.
request_output_to_pooling_bytes_response
(
self
.
io_processor
,
"output_to_response"
,
None
ctx
.
final_res_batch
,
)
ctx
.
request_id
,
):
ctx
.
created_time
,
logger
.
warning_once
(
ctx
.
model_name
,
"`IOProcessor.output_to_response` is deprecated. To ensure "
encoding_format
,
"consistency between offline and online APIs, "
embed_dtype
,
"`IOProcessorResponse` will become a transparent wrapper "
endianness
,
"around output data from v0.19 onwards."
,
)
if
hasattr
(
output
,
"request_id"
)
and
output
.
request_id
is
None
:
output
.
request_id
=
request_id
# type: ignore
return
output_to_response
(
output
)
# type: ignore
return
IOProcessorResponse
(
request_id
=
request_id
,
data
=
output
)
assert
isinstance
(
request
,
(
PoolingCompletionRequest
,
PoolingChatRequest
))
num_prompts
=
len
(
engine_inputs
)
# Non-streaming response
final_res_batch
:
list
[
PoolingRequestOutput
|
None
]
final_res_batch
=
[
None
]
*
num_prompts
try
:
async
for
i
,
res
in
result_generator
:
final_res_batch
[
i
]
=
res
assert
all
(
final_res
is
not
None
for
final_res
in
final_res_batch
)
final_res_batch_checked
=
cast
(
list
[
PoolingRequestOutput
],
final_res_batch
)
response
=
self
.
request_output_to_pooling_response
(
final_res_batch_checked
,
request_id
,
created_time
,
model_name
,
request
.
encoding_format
,
request
.
embed_dtype
,
request
.
endianness
,
)
)
except
asyncio
.
CancelledError
:
return
self
.
create_error_response
(
"Client disconnected"
)
return
response
assert_never
(
encoding_format
)
def
request_output_to_pooling_json_response
(
def
request_output_to_pooling_json_response
(
self
,
self
,
...
@@ -257,7 +162,7 @@ class OpenAIServingPooling(OpenAIServing):
...
@@ -257,7 +162,7 @@ class OpenAIServingPooling(OpenAIServing):
encoding_format
:
Literal
[
"float"
,
"base64"
],
encoding_format
:
Literal
[
"float"
,
"base64"
],
embed_dtype
:
EmbedDType
,
embed_dtype
:
EmbedDType
,
endianness
:
Endianness
,
endianness
:
Endianness
,
)
->
Pooling
Response
:
)
->
JSON
Response
:
encode_fn
=
cast
(
encode_fn
=
cast
(
Callable
[[
PoolingRequestOutput
],
list
[
float
]
|
str
],
Callable
[[
PoolingRequestOutput
],
list
[
float
]
|
str
],
(
(
...
@@ -289,13 +194,14 @@ class OpenAIServingPooling(OpenAIServing):
...
@@ -289,13 +194,14 @@ class OpenAIServingPooling(OpenAIServing):
total_tokens
=
num_prompt_tokens
,
total_tokens
=
num_prompt_tokens
,
)
)
re
turn
PoolingResponse
(
re
sponse
=
PoolingResponse
(
id
=
request_id
,
id
=
request_id
,
created
=
created_time
,
created
=
created_time
,
model
=
model_name
,
model
=
model_name
,
data
=
items
,
data
=
items
,
usage
=
usage
,
usage
=
usage
,
)
)
return
self
.
json_response_cls
(
content
=
response
.
model_dump
())
def
request_output_to_pooling_bytes_response
(
def
request_output_to_pooling_bytes_response
(
self
,
self
,
...
@@ -306,7 +212,7 @@ class OpenAIServingPooling(OpenAIServing):
...
@@ -306,7 +212,7 @@ class OpenAIServingPooling(OpenAIServing):
encoding_format
:
Literal
[
"bytes"
,
"bytes_only"
],
encoding_format
:
Literal
[
"bytes"
,
"bytes_only"
],
embed_dtype
:
EmbedDType
,
embed_dtype
:
EmbedDType
,
endianness
:
Endianness
,
endianness
:
Endianness
,
)
->
PoolingBytes
Response
:
)
->
Streaming
Response
:
content
,
items
,
usage
=
encode_pooling_bytes
(
content
,
items
,
usage
=
encode_pooling_bytes
(
pooling_outputs
=
final_res_batch
,
pooling_outputs
=
final_res_batch
,
embed_dtype
=
embed_dtype
,
embed_dtype
=
embed_dtype
,
...
@@ -329,38 +235,10 @@ class OpenAIServingPooling(OpenAIServing):
...
@@ -329,38 +235,10 @@ class OpenAIServingPooling(OpenAIServing):
}
}
)
)
return
PoolingBytesResponse
(
content
=
content
,
headers
=
headers
)
response
=
PoolingBytesResponse
(
content
=
content
,
headers
=
headers
)
def
request_output_to_pooling_response
(
self
,
final_res_batch
:
list
[
PoolingRequestOutput
],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
encoding_format
:
EncodingFormat
,
embed_dtype
:
EmbedDType
,
endianness
:
Endianness
,
)
->
PoolingResponse
|
PoolingBytesResponse
:
if
encoding_format
==
"float"
or
encoding_format
==
"base64"
:
return
self
.
request_output_to_pooling_json_response
(
final_res_batch
,
request_id
,
created_time
,
model_name
,
encoding_format
,
embed_dtype
,
endianness
,
)
if
encoding_format
==
"bytes"
or
encoding_format
==
"bytes_only"
:
return
self
.
request_output_to_pooling_bytes_response
(
final_res_batch
,
request_id
,
created_time
,
model_name
,
encoding_format
,
embed_dtype
,
endianness
,
)
assert_never
(
encoding_format
)
return
StreamingResponse
(
content
=
response
.
content
,
headers
=
response
.
headers
,
media_type
=
response
.
media_type
,
)
vllm/entrypoints/pooling/scoring/io_processor.py
View file @
66c079ae
...
@@ -278,7 +278,7 @@ class CrossEncoderIOProcessor(ScoringIOProcessor):
...
@@ -278,7 +278,7 @@ class CrossEncoderIOProcessor(ScoringIOProcessor):
def
pre_process_offline
(
self
,
ctx
:
OfflineInputsContext
)
->
Sequence
[
EngineInput
]:
def
pre_process_offline
(
self
,
ctx
:
OfflineInputsContext
)
->
Sequence
[
EngineInput
]:
assert
isinstance
(
ctx
.
prompts
,
ScoringData
)
assert
isinstance
(
ctx
.
prompts
,
ScoringData
)
assert
not
isinstance
(
ctx
.
pooling_params
,
list
)
assert
not
isinstance
(
ctx
.
pooling_params
,
Sequence
)
tok_params
=
self
.
renderer
.
default_cmpl_tok_params
.
with_kwargs
(
tok_params
=
self
.
renderer
.
default_cmpl_tok_params
.
with_kwargs
(
**
(
ctx
.
tokenization_kwargs
or
{})
**
(
ctx
.
tokenization_kwargs
or
{})
...
...
vllm/entrypoints/pooling/scoring/serving.py
View file @
66c079ae
...
@@ -4,15 +4,13 @@
...
@@ -4,15 +4,13 @@
from
fastapi.responses
import
JSONResponse
,
Response
from
fastapi.responses
import
JSONResponse
,
Response
from
vllm
import
PoolingParams
from
vllm
import
PoolingParams
from
vllm.config
import
Model
Config
from
vllm.config
import
Vllm
Config
from
vllm.engine.protocol
import
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.chat_utils
import
ChatTemplateConfig
from
vllm.entrypoints.openai.engine.protocol
import
UsageInfo
from
vllm.entrypoints.openai.engine.protocol
import
UsageInfo
from
vllm.entrypoints.pooling.base.io_processor
import
PoolingIOProcessor
from
vllm.entrypoints.pooling.base.io_processor
import
PoolingIOProcessor
from
vllm.entrypoints.pooling.base.serving
import
PoolingServing
from
vllm.entrypoints.pooling.base.serving
import
PoolingServing
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
PoolingRequestOutput
,
ScoringRequestOutput
from
vllm.outputs
import
PoolingRequestOutput
,
ScoringRequestOutput
from
vllm.renderers
import
BaseRenderer
from
vllm.v1.pool.late_interaction
import
(
from
vllm.v1.pool.late_interaction
import
(
build_late_interaction_doc_params
,
build_late_interaction_doc_params
,
build_late_interaction_query_params
,
build_late_interaction_query_params
,
...
@@ -52,22 +50,17 @@ class ServingScores(PoolingServing):
...
@@ -52,22 +50,17 @@ class ServingScores(PoolingServing):
super
().
__init__
(
engine_client
,
*
args
,
**
kwargs
)
super
().
__init__
(
engine_client
,
*
args
,
**
kwargs
)
def
init_io_processor
(
def
init_io_processor
(
self
,
self
,
vllm_config
:
VllmConfig
,
*
args
,
**
kwargs
model_config
:
ModelConfig
,
renderer
:
BaseRenderer
,
chat_template_config
:
ChatTemplateConfig
,
)
->
PoolingIOProcessor
:
)
->
PoolingIOProcessor
:
model_config
=
vllm_config
.
model_config
score_type
:
str
=
model_config
.
score_type
score_type
:
str
=
model_config
.
score_type
if
self
.
enable_flash_late_interaction
:
if
self
.
enable_flash_late_interaction
:
score_type
=
"flash-late-interaction"
score_type
=
"flash-late-interaction"
assert
score_type
in
ScoringIOProcessors
assert
score_type
in
ScoringIOProcessors
processor_cls
=
ScoringIOProcessors
[
score_type
]
processor_cls
=
ScoringIOProcessors
[
score_type
]
return
processor_cls
(
return
processor_cls
(
vllm_config
,
*
args
,
**
kwargs
)
model_config
=
model_config
,
renderer
=
renderer
,
chat_template_config
=
chat_template_config
,
)
async
def
__call__
(
self
,
*
args
,
**
kwargs
)
->
Response
:
async
def
__call__
(
self
,
*
args
,
**
kwargs
)
->
Response
:
if
not
self
.
enable_flash_late_interaction
:
if
not
self
.
enable_flash_late_interaction
:
...
...
vllm/entrypoints/pooling/typing.py
View file @
66c079ae
...
@@ -30,7 +30,7 @@ from vllm.entrypoints.pooling.pooling.protocol import (
...
@@ -30,7 +30,7 @@ from vllm.entrypoints.pooling.pooling.protocol import (
)
)
from
vllm.entrypoints.pooling.scoring.protocol
import
ScoringRequest
,
ScoringResponse
from
vllm.entrypoints.pooling.scoring.protocol
import
ScoringRequest
,
ScoringResponse
from
vllm.entrypoints.pooling.scoring.typing
import
ScoringData
from
vllm.entrypoints.pooling.scoring.typing
import
ScoringData
from
vllm.inputs
import
EngineInput
from
vllm.inputs
import
DataPrompt
,
EngineInput
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
PoolingCompletionLikeRequest
:
TypeAlias
=
(
PoolingCompletionLikeRequest
:
TypeAlias
=
(
...
@@ -86,11 +86,14 @@ class PoolingServeContext(Generic[PoolingRequestT]):
...
@@ -86,11 +86,14 @@ class PoolingServeContext(Generic[PoolingRequestT]):
## for bi-encoder & late-interaction
## for bi-encoder & late-interaction
n_queries
:
int
|
None
=
None
n_queries
:
int
|
None
=
None
## for IOProcessorResponse
response
:
Any
|
None
=
None
@
dataclass
@
dataclass
class
OfflineInputsContext
:
class
OfflineInputsContext
:
prompts
:
PromptType
|
Sequence
[
PromptType
]
|
ScoringData
prompts
:
PromptType
|
Sequence
[
PromptType
]
|
DataPrompt
|
ScoringData
pooling_params
:
PoolingParams
|
list
[
PoolingParams
]
|
None
=
None
pooling_params
:
PoolingParams
|
Sequence
[
PoolingParams
]
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
chat_template
:
str
|
None
=
None
chat_template
:
str
|
None
=
None
...
...
vllm/entrypoints/sagemaker/api_router.py
View file @
66c079ae
...
@@ -14,7 +14,7 @@ from vllm.config import ModelConfig
...
@@ -14,7 +14,7 @@ from vllm.config import ModelConfig
from
vllm.entrypoints.openai.engine.protocol
import
ErrorResponse
from
vllm.entrypoints.openai.engine.protocol
import
ErrorResponse
from
vllm.entrypoints.openai.engine.serving
import
OpenAIServing
from
vllm.entrypoints.openai.engine.serving
import
OpenAIServing
from
vllm.entrypoints.openai.utils
import
validate_json_request
from
vllm.entrypoints.openai.utils
import
validate_json_request
from
vllm.entrypoints.pooling.base.serving
import
PoolingServing
from
vllm.entrypoints.pooling.base.serving
import
PoolingServing
Base
from
vllm.entrypoints.pooling.utils
import
enable_scoring_api
from
vllm.entrypoints.pooling.utils
import
enable_scoring_api
from
vllm.entrypoints.serve.instrumentator.basic
import
base
from
vllm.entrypoints.serve.instrumentator.basic
import
base
from
vllm.entrypoints.serve.instrumentator.health
import
health
from
vllm.entrypoints.serve.instrumentator.health
import
health
...
@@ -23,7 +23,7 @@ from vllm.tasks import POOLING_TASKS, SupportedTask
...
@@ -23,7 +23,7 @@ from vllm.tasks import POOLING_TASKS, SupportedTask
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13)
# (requires typing_extensions >= 4.13)
RequestType
=
Any
RequestType
=
Any
GetHandlerFn
=
Callable
[[
Request
],
OpenAIServing
|
PoolingServing
|
None
]
GetHandlerFn
=
Callable
[[
Request
],
OpenAIServing
|
PoolingServing
Base
|
None
]
EndpointFn
=
Callable
[[
RequestType
,
Request
],
Awaitable
[
Any
]]
EndpointFn
=
Callable
[[
RequestType
,
Request
],
Awaitable
[
Any
]]
...
...
vllm/entrypoints/serve/render/serving.py
View file @
66c079ae
...
@@ -65,7 +65,6 @@ class OpenAIServingRender:
...
@@ -65,7 +65,6 @@ class OpenAIServingRender:
self
,
self
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
renderer
:
BaseRenderer
,
renderer
:
BaseRenderer
,
io_processor
:
Any
,
model_registry
:
OpenAIModelRegistry
,
model_registry
:
OpenAIModelRegistry
,
*
,
*
,
request_logger
:
RequestLogger
|
None
,
request_logger
:
RequestLogger
|
None
,
...
@@ -81,7 +80,6 @@ class OpenAIServingRender:
...
@@ -81,7 +80,6 @@ class OpenAIServingRender:
)
->
None
:
)
->
None
:
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
renderer
=
renderer
self
.
renderer
=
renderer
self
.
io_processor
=
io_processor
self
.
model_registry
=
model_registry
self
.
model_registry
=
model_registry
self
.
request_logger
=
request_logger
self
.
request_logger
=
request_logger
self
.
chat_template
=
chat_template
self
.
chat_template
=
chat_template
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment