Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c831911b
Unverified
Commit
c831911b
authored
Jan 27, 2026
by
Cyrus Leung
Committed by
GitHub
Jan 27, 2026
Browse files
[Frontend] Reduce mixin usage in serving pooling (#33101)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
157caf51
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
134 additions
and
239 deletions
+134
-239
vllm/entrypoints/openai/engine/serving.py
vllm/entrypoints/openai/engine/serving.py
+13
-48
vllm/entrypoints/pooling/classify/serving.py
vllm/entrypoints/pooling/classify/serving.py
+57
-89
vllm/entrypoints/pooling/embed/serving.py
vllm/entrypoints/pooling/embed/serving.py
+64
-102
No files found.
vllm/entrypoints/openai/engine/serving.py
View file @
c831911b
...
...
@@ -64,13 +64,12 @@ from vllm.entrypoints.openai.translations.protocol import (
from
vllm.entrypoints.pooling.classify.protocol
import
(
ClassificationChatRequest
,
ClassificationCompletionRequest
,
ClassificationRequest
,
ClassificationResponse
,
)
from
vllm.entrypoints.pooling.embed.protocol
import
(
EmbeddingBytesResponse
,
EmbeddingChatRequest
,
EmbeddingCompletionRequest
,
EmbeddingRequest
,
EmbeddingResponse
,
)
from
vllm.entrypoints.pooling.pooling.protocol
import
(
...
...
@@ -170,6 +169,7 @@ AnyResponse: TypeAlias = (
CompletionResponse
|
ChatCompletionResponse
|
EmbeddingResponse
|
EmbeddingBytesResponse
|
TranscriptionResponse
|
TokenizeResponse
|
PoolingResponse
...
...
@@ -183,51 +183,21 @@ RequestT = TypeVar("RequestT", bound=AnyRequest)
@
dataclass
(
kw_only
=
True
)
class
RequestProcessingMixin
:
"""
Mixin for request processing,
handling prompt preparation and engine input.
"""
engine_prompts
:
list
[
TokensPrompt
]
|
None
=
field
(
default_factory
=
list
)
@
dataclass
(
kw_only
=
True
)
class
ResponseGenerationMixin
:
"""
Mixin for response generation,
managing result generators and final batch results.
"""
result_generator
:
(
AsyncGenerator
[
tuple
[
int
,
RequestOutput
|
PoolingRequestOutput
],
None
]
|
None
)
=
None
final_res_batch
:
list
[
RequestOutput
|
PoolingRequestOutput
]
=
field
(
default_factory
=
list
)
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
@
dataclass
(
kw_only
=
True
)
class
ServeContext
(
RequestProcessingMixin
,
ResponseGenerationMixin
,
Generic
[
RequestT
]):
class
ServeContext
(
Generic
[
RequestT
]):
request
:
RequestT
raw_request
:
Request
|
None
=
None
model_name
:
str
request_id
:
str
created_time
:
int
=
field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
lora_request
:
LoRARequest
|
None
=
None
engine_prompts
:
list
[
TokensPrompt
]
|
None
=
None
result_generator
:
AsyncGenerator
[
tuple
[
int
,
PoolingRequestOutput
],
None
]
|
None
=
(
None
)
final_res_batch
:
list
[
PoolingRequestOutput
]
=
field
(
default_factory
=
list
)
@
dataclass
(
kw_only
=
True
)
class
ClassificationServeContext
(
ServeContext
[
ClassificationRequest
]):
pass
@
dataclass
(
kw_only
=
True
)
class
EmbeddingServeContext
(
ServeContext
[
EmbeddingRequest
]):
chat_template
:
str
|
None
=
None
chat_template_content_format
:
ChatTemplateContentFormatOption
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
class
OpenAIServing
:
...
...
@@ -605,10 +575,7 @@ class OpenAIServing:
self
,
ctx
:
ServeContext
,
)
->
AnyResponse
|
ErrorResponse
:
generation
:
AsyncGenerator
[
AnyResponse
|
ErrorResponse
,
None
]
generation
=
self
.
_pipeline
(
ctx
)
async
for
response
in
generation
:
async
for
response
in
self
.
_pipeline
(
ctx
):
return
response
return
self
.
create_error_response
(
"No response yielded from pipeline"
)
...
...
@@ -667,9 +634,7 @@ class OpenAIServing:
ctx
:
ServeContext
,
)
->
ErrorResponse
|
None
:
"""Schedule the request and get the result generator."""
generators
:
list
[
AsyncGenerator
[
RequestOutput
|
PoolingRequestOutput
,
None
]
]
=
[]
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
try
:
trace_headers
=
(
...
...
@@ -723,7 +688,7 @@ class OpenAIServing:
return
self
.
create_error_response
(
"Engine prompts not available"
)
num_prompts
=
len
(
ctx
.
engine_prompts
)
final_res_batch
:
list
[
RequestOutput
|
PoolingRequestOutput
|
None
]
final_res_batch
:
list
[
PoolingRequestOutput
|
None
]
final_res_batch
=
[
None
]
*
num_prompts
if
ctx
.
result_generator
is
None
:
...
...
@@ -1011,7 +976,7 @@ class OpenAIServing:
def
_validate_input
(
self
,
request
:
AnyReques
t
,
request
:
objec
t
,
input_ids
:
list
[
int
],
input_text
:
str
,
)
->
TokensPrompt
:
...
...
vllm/entrypoints/pooling/classify/serving.py
View file @
c831911b
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
http
import
HTTPStatus
from
typing
import
cast
from
typing
import
Final
,
cast
import
jinja2
import
numpy
as
np
...
...
@@ -11,18 +11,8 @@ from fastapi import Request
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.chat_utils
import
ChatTemplateContentFormatOption
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.chat_completion.protocol
import
(
ChatCompletionRequest
,
)
from
vllm.entrypoints.openai.engine.protocol
import
(
ErrorResponse
,
UsageInfo
,
)
from
vllm.entrypoints.openai.engine.serving
import
(
ClassificationServeContext
,
OpenAIServing
,
ServeContext
,
)
from
vllm.entrypoints.openai.engine.protocol
import
ErrorResponse
,
UsageInfo
from
vllm.entrypoints.openai.engine.serving
import
OpenAIServing
,
ServeContext
from
vllm.entrypoints.openai.models.serving
import
OpenAIServingModels
from
vllm.entrypoints.pooling.classify.protocol
import
(
ClassificationChatRequest
,
...
...
@@ -39,60 +29,68 @@ from vllm.pooling_params import PoolingParams
logger
=
init_logger
(
__name__
)
class
ClassificationMixin
(
OpenAIServing
):
chat_template
:
str
|
None
chat_template_content_format
:
ChatTemplateContentFormatOption
trust_request_chat_template
:
bool
ClassificationServeContext
=
ServeContext
[
ClassificationRequest
]
class
ServingClassification
(
OpenAIServing
):
request_id_prefix
=
"classify"
def
__init__
(
self
,
engine_client
:
EngineClient
,
models
:
OpenAIServingModels
,
*
,
request_logger
:
RequestLogger
|
None
,
chat_template
:
str
|
None
=
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
=
"auto"
,
trust_request_chat_template
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
engine_client
=
engine_client
,
models
=
models
,
request_logger
=
request_logger
,
log_error_stack
=
log_error_stack
,
)
self
.
chat_template
=
chat_template
self
.
chat_template_content_format
:
Final
=
chat_template_content_format
self
.
trust_request_chat_template
=
trust_request_chat_template
async
def
_preprocess
(
self
,
ctx
:
ServeContext
,
ctx
:
Classification
ServeContext
,
)
->
ErrorResponse
|
None
:
"""
Process classification inputs: tokenize text, resolve adapters,
and prepare model-specific inputs.
"""
ctx
=
cast
(
ClassificationServeContext
,
ctx
)
try
:
request_obj
=
ctx
.
request
if
isinstance
(
request_obj
,
ClassificationChatRequest
):
chat_request
=
request_obj
messages
=
chat_request
.
messages
trust_request_chat_template
=
getattr
(
self
,
"trust_request_chat_template"
,
False
,
)
ret
=
self
.
_validate_chat_template
(
request_chat_template
=
chat_request
.
chat_template
,
chat_template_kwargs
=
chat_request
.
chat_template_kwargs
,
trust_request_chat_template
=
trust_request_chat_template
,
ctx
.
lora_request
=
self
.
_maybe_get_adapters
(
ctx
.
request
)
if
isinstance
(
ctx
.
request
,
ClassificationChatRequest
):
error_check_ret
=
self
.
_validate_chat_template
(
request_chat_template
=
ctx
.
request
.
chat_template
,
chat_template_kwargs
=
ctx
.
request
.
chat_template_kwargs
,
trust_request_chat_template
=
self
.
trust_request_chat_template
,
)
if
ret
:
return
ret
if
error_check_
ret
:
return
error_check_
ret
_
,
engine_prompts
=
await
self
.
_preprocess_chat
(
c
ast
(
ChatCompletionRequest
,
chat_
request
)
,
c
tx
.
request
,
self
.
renderer
,
messages
,
chat_template
=
(
chat_request
.
chat_template
or
getattr
(
self
,
"chat_template"
,
None
)
),
chat_template_content_format
=
cast
(
ChatTemplateContentFormatOption
,
getattr
(
self
,
"chat_template_content_format"
,
"auto"
),
),
add_generation_prompt
=
chat_request
.
add_generation_prompt
,
continue_final_message
=
chat_request
.
continue_final_message
,
add_special_tokens
=
chat_request
.
add_special_tokens
,
ctx
.
request
.
messages
,
chat_template
=
ctx
.
request
.
chat_template
or
self
.
chat_template
,
chat_template_content_format
=
self
.
chat_template_content_format
,
add_generation_prompt
=
ctx
.
request
.
add_generation_prompt
,
continue_final_message
=
ctx
.
request
.
continue_final_message
,
add_special_tokens
=
ctx
.
request
.
add_special_tokens
,
)
ctx
.
engine_prompts
=
engine_prompts
elif
isinstance
(
request_obj
,
ClassificationCompletionRequest
):
completion_request
=
request_obj
input_data
=
completion_request
.
input
elif
isinstance
(
ctx
.
request
,
ClassificationCompletionRequest
):
input_data
=
ctx
.
request
.
input
if
input_data
in
(
None
,
""
):
return
self
.
create_error_response
(
"Input or messages must be provided"
,
...
...
@@ -106,13 +104,10 @@ class ClassificationMixin(OpenAIServing):
prompt_input
=
cast
(
str
|
list
[
str
],
input_data
)
ctx
.
engine_prompts
=
await
renderer
.
render_prompt
(
prompt_or_prompts
=
prompt_input
,
config
=
self
.
_build_render_config
(
c
ompletion_
request
),
config
=
self
.
_build_render_config
(
c
tx
.
request
),
)
else
:
return
self
.
create_error_response
(
"Invalid classification request type"
,
status_code
=
HTTPStatus
.
BAD_REQUEST
,
)
return
self
.
create_error_response
(
"Invalid classification request type"
)
return
None
...
...
@@ -122,13 +117,14 @@ class ClassificationMixin(OpenAIServing):
def
_build_response
(
self
,
ctx
:
ServeContext
,
ctx
:
Classification
ServeContext
,
)
->
ClassificationResponse
|
ErrorResponse
:
"""
Convert model outputs to a formatted classification response
with probabilities and labels.
"""
ctx
=
cast
(
ClassificationServeContext
,
ctx
)
id2label
=
getattr
(
self
.
model_config
.
hf_config
,
"id2label"
,
{})
items
:
list
[
ClassificationData
]
=
[]
num_prompt_tokens
=
0
...
...
@@ -139,9 +135,7 @@ class ClassificationMixin(OpenAIServing):
probs
=
classify_res
.
probs
predicted_index
=
int
(
np
.
argmax
(
probs
))
label
=
getattr
(
self
.
model_config
.
hf_config
,
"id2label"
,
{}).
get
(
predicted_index
)
label
=
id2label
.
get
(
predicted_index
)
item
=
ClassificationData
(
index
=
idx
,
...
...
@@ -174,32 +168,6 @@ class ClassificationMixin(OpenAIServing):
add_special_tokens
=
request
.
add_special_tokens
,
)
class
ServingClassification
(
ClassificationMixin
):
request_id_prefix
=
"classify"
def
__init__
(
self
,
engine_client
:
EngineClient
,
models
:
OpenAIServingModels
,
*
,
request_logger
:
RequestLogger
|
None
,
chat_template
:
str
|
None
=
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
=
"auto"
,
trust_request_chat_template
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
engine_client
=
engine_client
,
models
=
models
,
request_logger
=
request_logger
,
log_error_stack
=
log_error_stack
,
)
self
.
chat_template
=
chat_template
self
.
chat_template_content_format
=
chat_template_content_format
self
.
trust_request_chat_template
=
trust_request_chat_template
async
def
create_classify
(
self
,
request
:
ClassificationRequest
,
...
...
@@ -215,11 +183,11 @@ class ServingClassification(ClassificationMixin):
request_id
=
request_id
,
)
return
await
s
uper
()
.
handle
(
ctx
)
# type: ignore
return
await
s
elf
.
handle
(
ctx
)
# type: ignore
[return-value]
def
_create_pooling_params
(
self
,
ctx
:
ServeContext
[
Classification
Request
]
,
ctx
:
Classification
ServeContext
,
)
->
PoolingParams
|
ErrorResponse
:
pooling_params
=
super
().
_create_pooling_params
(
ctx
)
if
isinstance
(
pooling_params
,
ErrorResponse
):
...
...
vllm/entrypoints/pooling/embed/serving.py
View file @
c831911b
...
...
@@ -6,21 +6,13 @@ from typing import Any, Final, cast
import
torch
from
fastapi
import
Request
from
fastapi.responses
import
Response
from
typing_extensions
import
assert_never
,
override
from
typing_extensions
import
assert_never
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.chat_utils
import
ChatTemplateContentFormatOption
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.engine.protocol
import
(
ErrorResponse
,
UsageInfo
,
)
from
vllm.entrypoints.openai.engine.serving
import
(
EmbeddingServeContext
,
OpenAIServing
,
ServeContext
,
)
from
vllm.entrypoints.openai.engine.protocol
import
ErrorResponse
,
UsageInfo
from
vllm.entrypoints.openai.engine.serving
import
OpenAIServing
,
ServeContext
from
vllm.entrypoints.openai.models.serving
import
OpenAIServingModels
from
vllm.entrypoints.pooling.embed.protocol
import
(
EmbeddingBytesResponse
,
...
...
@@ -33,19 +25,11 @@ from vllm.entrypoints.pooling.embed.protocol import (
from
vllm.entrypoints.renderer
import
RenderConfig
from
vllm.inputs.data
import
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.outputs
import
(
EmbeddingRequestOutput
,
PoolingOutput
,
PoolingRequestOutput
,
RequestOutput
,
)
from
vllm.outputs
import
PoolingOutput
,
PoolingRequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.utils.async_utils
import
merge_async_iterators
from
vllm.utils.collection_utils
import
chunk_list
from
vllm.utils.serial_utils
import
(
EmbedDType
,
EncodingFormat
,
Endianness
,
encode_pooling_bytes
,
encode_pooling_output
,
)
...
...
@@ -53,9 +37,33 @@ from vllm.utils.serial_utils import (
logger
=
init_logger
(
__name__
)
class
EmbeddingMixin
(
OpenAIServing
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
EmbeddingServeContext
=
ServeContext
[
EmbeddingRequest
]
class
OpenAIServingEmbedding
(
OpenAIServing
):
request_id_prefix
=
"embd"
def
__init__
(
self
,
engine_client
:
EngineClient
,
models
:
OpenAIServingModels
,
*
,
request_logger
:
RequestLogger
|
None
,
chat_template
:
str
|
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
,
trust_request_chat_template
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
engine_client
=
engine_client
,
models
=
models
,
request_logger
=
request_logger
,
log_error_stack
=
log_error_stack
,
)
self
.
chat_template
=
chat_template
self
.
chat_template_content_format
:
Final
=
chat_template_content_format
self
.
trust_request_chat_template
=
trust_request_chat_template
pooler_config
=
self
.
model_config
.
pooler_config
...
...
@@ -69,32 +77,41 @@ class EmbeddingMixin(OpenAIServing):
else
None
)
@
override
async
def
_preprocess
(
self
,
ctx
:
ServeContext
,
ctx
:
Embedding
ServeContext
,
)
->
ErrorResponse
|
None
:
ctx
=
cast
(
EmbeddingServeContext
,
ctx
)
try
:
ctx
.
lora_request
=
self
.
_maybe_get_adapters
(
ctx
.
request
)
if
isinstance
(
ctx
.
request
,
EmbeddingChatRequest
):
error_check_ret
=
self
.
_validate_chat_template
(
request_chat_template
=
ctx
.
request
.
chat_template
,
chat_template_kwargs
=
ctx
.
request
.
chat_template_kwargs
,
trust_request_chat_template
=
self
.
trust_request_chat_template
,
)
if
error_check_ret
is
not
None
:
return
error_check_ret
_
,
ctx
.
engine_prompts
=
await
self
.
_preprocess_chat
(
ctx
.
request
,
self
.
renderer
,
ctx
.
request
.
messages
,
chat_template
=
ctx
.
request
.
chat_template
or
ctx
.
chat_template
,
chat_template_content_format
=
ctx
.
chat_template_content_format
,
chat_template
=
ctx
.
request
.
chat_template
or
self
.
chat_template
,
chat_template_content_format
=
self
.
chat_template_content_format
,
add_generation_prompt
=
ctx
.
request
.
add_generation_prompt
,
continue_final_message
=
ctx
.
request
.
continue_final_message
,
add_special_tokens
=
ctx
.
request
.
add_special_tokens
,
)
el
se
:
el
if
isinstance
(
ctx
.
request
,
EmbeddingCompletionRequest
)
:
renderer
=
self
.
_get_completion_renderer
()
ctx
.
engine_prompts
=
await
renderer
.
render_prompt
(
prompt_or_prompts
=
ctx
.
request
.
input
,
config
=
self
.
_build_render_config
(
ctx
.
request
),
)
else
:
return
self
.
create_error_response
(
"Invalid classification request type"
)
return
None
except
(
ValueError
,
TypeError
)
as
e
:
logger
.
exception
(
"Error in preprocessing prompt inputs"
)
...
...
@@ -113,16 +130,15 @@ class EmbeddingMixin(OpenAIServing):
add_special_tokens
=
request
.
add_special_tokens
,
)
@
override
def
_build_response
(
self
,
ctx
:
ServeContext
,
)
->
EmbeddingResponse
|
Response
|
ErrorResponse
:
final_res_batch_checked
=
cast
(
list
[
PoolingRequestOutput
],
ctx
.
final_res_batch
)
ctx
:
Embedding
ServeContext
,
)
->
EmbeddingResponse
|
EmbeddingBytes
Response
|
ErrorResponse
:
final_res_batch_checked
=
ctx
.
final_res_batch
encoding_format
:
EncodingFormat
=
ctx
.
request
.
encoding_format
embed_dtype
:
EmbedDType
=
ctx
.
request
.
embed_dtype
endianness
:
Endianness
=
ctx
.
request
.
endianness
encoding_format
=
ctx
.
request
.
encoding_format
embed_dtype
=
ctx
.
request
.
embed_dtype
endianness
=
ctx
.
request
.
endianness
def
encode_float_base64
():
items
:
list
[
EmbeddingResponseData
]
=
[]
...
...
@@ -203,8 +219,8 @@ class EmbeddingMixin(OpenAIServing):
self
,
ctx
:
EmbeddingServeContext
,
token_ids
:
list
[
int
],
pooling_params
,
trace_headers
,
pooling_params
:
PoolingParams
,
trace_headers
:
Mapping
[
str
,
str
]
|
None
,
prompt_idx
:
int
,
)
->
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]:
"""Process a single prompt using chunked processing."""
...
...
@@ -246,7 +262,7 @@ class EmbeddingMixin(OpenAIServing):
def
_validate_input
(
self
,
request
,
request
:
object
,
input_ids
:
list
[
int
],
input_text
:
str
,
)
->
TokensPrompt
:
...
...
@@ -326,7 +342,7 @@ class EmbeddingMixin(OpenAIServing):
pooling_params
:
PoolingParams
,
trace_headers
:
Mapping
[
str
,
str
]
|
None
,
prompt_index
:
int
,
)
->
AsyncGenerator
[
RequestOutput
|
PoolingRequestOutput
,
None
]:
)
->
AsyncGenerator
[
PoolingRequestOutput
,
None
]:
"""Create a generator for a single prompt using standard processing."""
request_id_item
=
f
"
{
ctx
.
request_id
}
-
{
prompt_index
}
"
...
...
@@ -347,7 +363,6 @@ class EmbeddingMixin(OpenAIServing):
priority
=
getattr
(
ctx
.
request
,
"priority"
,
0
),
)
@
override
async
def
_prepare_generators
(
self
,
ctx
:
ServeContext
,
...
...
@@ -363,9 +378,7 @@ class EmbeddingMixin(OpenAIServing):
return
await
super
().
_prepare_generators
(
ctx
)
# Custom logic for chunked processing
generators
:
list
[
AsyncGenerator
[
RequestOutput
|
PoolingRequestOutput
,
None
]
]
=
[]
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
try
:
trace_headers
=
(
...
...
@@ -419,10 +432,9 @@ class EmbeddingMixin(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
@
override
async
def
_collect_batch
(
self
,
ctx
:
ServeContext
,
ctx
:
Embedding
ServeContext
,
)
->
ErrorResponse
|
None
:
"""Collect and aggregate batch results
with support for chunked processing.
...
...
@@ -431,7 +443,6 @@ class EmbeddingMixin(OpenAIServing):
minimize memory usage.
For regular requests, collects results normally.
"""
ctx
=
cast
(
EmbeddingServeContext
,
ctx
)
try
:
if
ctx
.
engine_prompts
is
None
:
return
self
.
create_error_response
(
"Engine prompts not available"
)
...
...
@@ -527,12 +538,10 @@ class EmbeddingMixin(OpenAIServing):
except
(
ValueError
,
IndexError
):
prompt_idx
=
result_idx
# Fallback to result_idx
short_prompts_results
[
prompt_idx
]
=
cast
(
PoolingRequestOutput
,
result
)
short_prompts_results
[
prompt_idx
]
=
result
# Finalize aggregated results
final_res_batch
:
list
[
PoolingRequestOutput
|
EmbeddingRequestOutput
]
=
[]
final_res_batch
:
list
[
PoolingRequestOutput
]
=
[]
num_prompts
=
len
(
ctx
.
engine_prompts
)
for
prompt_idx
in
range
(
num_prompts
):
...
...
@@ -580,49 +589,19 @@ class EmbeddingMixin(OpenAIServing):
f
"Failed to aggregate chunks for prompt
{
prompt_idx
}
"
)
elif
prompt_idx
in
short_prompts_results
:
final_res_batch
.
append
(
cast
(
PoolingRequestOutput
,
short_prompts_results
[
prompt_idx
])
)
final_res_batch
.
append
(
short_prompts_results
[
prompt_idx
])
else
:
return
self
.
create_error_response
(
f
"Result not found for prompt
{
prompt_idx
}
"
)
ctx
.
final_res_batch
=
cast
(
list
[
RequestOutput
|
PoolingRequestOutput
],
final_res_batch
)
ctx
.
final_res_batch
=
final_res_batch
return
None
except
Exception
as
e
:
return
self
.
create_error_response
(
str
(
e
))
class
OpenAIServingEmbedding
(
EmbeddingMixin
):
request_id_prefix
=
"embd"
def
__init__
(
self
,
engine_client
:
EngineClient
,
models
:
OpenAIServingModels
,
*
,
request_logger
:
RequestLogger
|
None
,
chat_template
:
str
|
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
,
trust_request_chat_template
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
engine_client
=
engine_client
,
models
=
models
,
request_logger
=
request_logger
,
log_error_stack
=
log_error_stack
,
)
self
.
chat_template
=
chat_template
self
.
chat_template_content_format
:
Final
=
chat_template_content_format
self
.
trust_request_chat_template
=
trust_request_chat_template
async
def
create_embedding
(
self
,
request
:
EmbeddingRequest
,
...
...
@@ -645,16 +624,13 @@ class OpenAIServingEmbedding(EmbeddingMixin):
raw_request
=
raw_request
,
model_name
=
model_name
,
request_id
=
request_id
,
chat_template
=
self
.
chat_template
,
chat_template_content_format
=
self
.
chat_template_content_format
,
)
return
await
s
uper
()
.
handle
(
ctx
)
# type: ignore
return
await
s
elf
.
handle
(
ctx
)
# type: ignore
[return-value]
@
override
def
_create_pooling_params
(
self
,
ctx
:
ServeContext
[
EmbeddingRequest
]
,
ctx
:
Embedding
ServeContext
,
)
->
PoolingParams
|
ErrorResponse
:
pooling_params
=
super
().
_create_pooling_params
(
ctx
)
if
isinstance
(
pooling_params
,
ErrorResponse
):
...
...
@@ -666,17 +642,3 @@ class OpenAIServingEmbedding(EmbeddingMixin):
return
self
.
create_error_response
(
str
(
e
))
return
pooling_params
async
def
_preprocess
(
self
,
ctx
:
ServeContext
,
)
->
ErrorResponse
|
None
:
if
isinstance
(
ctx
.
request
,
EmbeddingChatRequest
):
error_check_ret
=
self
.
_validate_chat_template
(
request_chat_template
=
ctx
.
request
.
chat_template
,
chat_template_kwargs
=
ctx
.
request
.
chat_template_kwargs
,
trust_request_chat_template
=
self
.
trust_request_chat_template
,
)
if
error_check_ret
is
not
None
:
return
error_check_ret
return
await
super
().
_preprocess
(
ctx
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment