Unverified Commit 66c079ae authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Frontend][4/n] Improve pooling entrypoints | pooling. (#39153)


Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent b6c9be50
...@@ -49,9 +49,7 @@ from vllm.entrypoints.chat_utils import ( ...@@ -49,9 +49,7 @@ from vllm.entrypoints.chat_utils import (
load_chat_template, load_chat_template,
) )
from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors
from vllm.entrypoints.pooling.scoring.io_processor import ( from vllm.entrypoints.pooling.scoring.io_processor import ScoringIOProcessor
ScoringIOProcessor,
)
from vllm.entrypoints.pooling.scoring.typing import ScoreInput from vllm.entrypoints.pooling.scoring.typing import ScoreInput
from vllm.entrypoints.pooling.typing import OfflineInputsContext, OfflineOutputsContext from vllm.entrypoints.pooling.typing import OfflineInputsContext, OfflineOutputsContext
from vllm.entrypoints.utils import log_non_default_args from vllm.entrypoints.utils import log_non_default_args
...@@ -398,12 +396,11 @@ class LLM: ...@@ -398,12 +396,11 @@ class LLM:
self.runner_type = self.model_config.runner_type self.runner_type = self.model_config.runner_type
self.renderer = self.llm_engine.renderer self.renderer = self.llm_engine.renderer
self.chat_template = load_chat_template(chat_template) self.chat_template = load_chat_template(chat_template)
self.io_processor = self.llm_engine.io_processor
self.input_processor = self.llm_engine.input_processor self.input_processor = self.llm_engine.input_processor
self.chat_template_config = ChatTemplateConfig(chat_template=self.chat_template) self.chat_template_config = ChatTemplateConfig(chat_template=self.chat_template)
self.pooling_io_processors = init_pooling_io_processors( self.pooling_io_processors = init_pooling_io_processors(
supported_tasks=supported_tasks, supported_tasks=supported_tasks,
model_config=self.model_config, vllm_config=self.llm_engine.vllm_config,
renderer=self.renderer, renderer=self.renderer,
chat_template_config=self.chat_template_config, chat_template_config=self.chat_template_config,
) )
...@@ -1081,118 +1078,55 @@ class LLM: ...@@ -1081,118 +1078,55 @@ class LLM:
pooled hidden states in the same order as the input prompts. pooled hidden states in the same order as the input prompts.
""" """
if isinstance(prompts, dict) and "data" in prompts and pooling_task != "plugin":
raise ValueError(
"The 'data' field is only supported for the 'plugin' pooling task."
)
self._verify_pooling_task(pooling_task) self._verify_pooling_task(pooling_task)
assert pooling_task is not None and pooling_task in self.pooling_io_processors
if isinstance(prompts, dict) and "data" in prompts: io_processor = self.pooling_io_processors[pooling_task]
if self.io_processor is None:
raise ValueError(
"No IOProcessor plugin installed. Please refer "
"to the documentation and to the "
"'prithvi_geospatial_mae_io_processor' "
"offline inference example for more details."
)
# Validate the request data is valid for the loaded plugin if pooling_params is None:
prompt_data = prompts.get("data") pooling_params = PoolingParams()
if prompt_data is None:
raise ValueError(
"The 'data' field of the prompt is expected to contain "
"the prompt data and it cannot be None. "
"Refer to the documentation of the IOProcessor "
"in use for more details."
)
validated_prompt = self.io_processor.parse_data(prompt_data)
# obtain the actual model prompts from the pre-processor ctx = OfflineInputsContext(
prompts = self.io_processor.pre_process(prompt=validated_prompt) prompts=prompts,
prompts_seq = prompt_to_seq(prompts) pooling_params=pooling_params,
tokenization_kwargs=tokenization_kwargs,
)
params_seq: Sequence[PoolingParams] = [ engine_inputs = io_processor.pre_process_offline(ctx)
self.io_processor.merge_pooling_params(param) n_inputs = len(engine_inputs)
for param in self._params_to_seq( assert ctx.pooling_params is not None
pooling_params,
len(prompts_seq),
)
]
for p in params_seq:
if p.task is None:
p.task = "plugin"
outputs = self._run_completion(
prompts=prompts_seq,
params=params_seq,
output_type=PoolingRequestOutput,
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
# get the post-processed model outputs params_seq = self._params_to_seq(ctx.pooling_params, n_inputs)
assert self.io_processor is not None
processed_outputs = self.io_processor.post_process(outputs)
return [ for param in params_seq:
PoolingRequestOutput[Any]( if param.task is None:
request_id="", param.task = pooling_task
outputs=processed_outputs, elif pooling_task == "plugin":
num_cached_tokens=getattr( # `plugin` task uses io_processor.parse_request to verify inputs.
processed_outputs, "num_cached_tokens", 0 # We actually allow plugin to overwrite pooling_task.
), pass
prompt_token_ids=[], elif param.task != pooling_task:
finished=True, msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
) raise ValueError(msg)
]
else:
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
prompts_seq = prompt_to_seq(prompts)
params_seq = self._params_to_seq(pooling_params, len(prompts_seq))
for param in params_seq:
if param.task is None:
param.task = pooling_task
elif param.task != pooling_task:
msg = (
f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
)
raise ValueError(msg)
if pooling_task in self.pooling_io_processors: seq_lora_requests = self._lora_request_to_seq(lora_request, n_inputs)
io_processor = self.pooling_io_processors[pooling_task] seq_priority = self._priority_to_seq(None, n_inputs)
processor_inputs = io_processor.pre_process_offline(
ctx=OfflineInputsContext(
prompts=prompts_seq, tokenization_kwargs=tokenization_kwargs
)
)
seq_lora_requests = self._lora_request_to_seq(
lora_request, len(prompts_seq)
)
seq_priority = self._priority_to_seq(None, len(prompts))
self._render_and_add_requests( self._render_and_add_requests(
prompts=processor_inputs, prompts=engine_inputs,
params=params_seq, params=params_seq,
lora_requests=seq_lora_requests, lora_requests=seq_lora_requests,
priorities=seq_priority, priorities=seq_priority,
) )
outputs = self._run_engine( outputs = self._run_engine(use_tqdm=use_tqdm, output_type=PoolingRequestOutput)
use_tqdm=use_tqdm, output_type=PoolingRequestOutput outputs = io_processor.post_process_offline(
) ctx=OfflineOutputsContext(outputs=outputs)
outputs = io_processor.post_process_offline( )
ctx=OfflineOutputsContext(outputs=outputs)
)
else:
outputs = self._run_completion(
prompts=prompts_seq,
params=params_seq,
output_type=PoolingRequestOutput,
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
return outputs return outputs
def _verify_pooling_task(self, pooling_task: PoolingTask | None): def _verify_pooling_task(self, pooling_task: PoolingTask | None):
...@@ -1254,6 +1188,14 @@ class LLM: ...@@ -1254,6 +1188,14 @@ class LLM:
pooling_task, pooling_task,
) )
if pooling_task == "plugin" and "plugin" not in self.pooling_io_processors:
raise ValueError(
"No IOProcessor plugin installed. Please refer "
"to the documentation and to the "
"'prithvi_geospatial_mae_io_processor' "
"offline inference example for more details."
)
def embed( def embed(
self, self,
prompts: PromptType | Sequence[PromptType], prompts: PromptType | Sequence[PromptType],
...@@ -1458,6 +1400,9 @@ class LLM: ...@@ -1458,6 +1400,9 @@ class LLM:
scoring_data = io_processor.valid_inputs(data_1, data_2) scoring_data = io_processor.valid_inputs(data_1, data_2)
n_queries = len(scoring_data.data_1) n_queries = len(scoring_data.data_1)
if pooling_params is None:
pooling_params = PoolingParams()
ctx = OfflineInputsContext( ctx = OfflineInputsContext(
prompts=scoring_data, prompts=scoring_data,
pooling_params=pooling_params, pooling_params=pooling_params,
...@@ -1466,15 +1411,11 @@ class LLM: ...@@ -1466,15 +1411,11 @@ class LLM:
n_queries=n_queries, n_queries=n_queries,
) )
processor_inputs = io_processor.pre_process_offline(ctx) engine_inputs = io_processor.pre_process_offline(ctx)
n_inputs = len(engine_inputs)
seq_lora_requests = self._lora_request_to_seq( seq_lora_requests = self._lora_request_to_seq(lora_request, n_inputs)
lora_request, len(processor_inputs) params_seq = self._params_to_seq(ctx.pooling_params, n_inputs)
)
if ctx.pooling_params is None:
ctx.pooling_params = PoolingParams()
params_seq = self._params_to_seq(ctx.pooling_params, len(processor_inputs))
for param in params_seq: for param in params_seq:
if param.task is None: if param.task is None:
...@@ -1483,10 +1424,10 @@ class LLM: ...@@ -1483,10 +1424,10 @@ class LLM:
msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!" msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
raise ValueError(msg) raise ValueError(msg)
seq_priority = self._priority_to_seq(None, len(processor_inputs)) seq_priority = self._priority_to_seq(None, n_inputs)
self._render_and_add_requests( self._render_and_add_requests(
prompts=processor_inputs, prompts=engine_inputs,
params=params_seq, params=params_seq,
lora_requests=seq_lora_requests, lora_requests=seq_lora_requests,
priorities=seq_priority, priorities=seq_priority,
...@@ -1579,7 +1520,7 @@ class LLM: ...@@ -1579,7 +1520,7 @@ class LLM:
if isinstance(params, Sequence): if isinstance(params, Sequence):
if len(params) != num_requests: if len(params) != num_requests:
raise ValueError( raise ValueError(
f"The lengths of prompts ({params}) " f"The lengths of prompts ({num_requests}) "
f"and params ({len(params)}) must be the same." f"and params ({len(params)}) must be the same."
) )
......
...@@ -370,7 +370,6 @@ async def init_app_state( ...@@ -370,7 +370,6 @@ async def init_app_state(
state.openai_serving_render = OpenAIServingRender( state.openai_serving_render = OpenAIServingRender(
model_config=engine_client.model_config, model_config=engine_client.model_config,
renderer=engine_client.renderer, renderer=engine_client.renderer,
io_processor=engine_client.io_processor,
model_registry=state.openai_serving_models.registry, model_registry=state.openai_serving_models.registry,
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
...@@ -441,13 +440,12 @@ async def init_render_app_state( ...@@ -441,13 +440,12 @@ async def init_render_app_state(
Unlike :func:`init_app_state` this function does not require an Unlike :func:`init_app_state` this function does not require an
:class:`~vllm.engine.protocol.EngineClient`; it bootstraps the :class:`~vllm.engine.protocol.EngineClient`; it bootstraps the
preprocessing pipeline (renderer, io_processor, input_processor) preprocessing pipeline (renderer, input_processor)
directly from the :class:`~vllm.config.VllmConfig`. directly from the :class:`~vllm.config.VllmConfig`.
""" """
from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.openai.models.serving import OpenAIModelRegistry from vllm.entrypoints.openai.models.serving import OpenAIModelRegistry
from vllm.entrypoints.serve.render.serving import OpenAIServingRender from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.plugins.io_processors import get_io_processor
from vllm.renderers import renderer_from_config from vllm.renderers import renderer_from_config
served_model_names = args.served_model_name or [args.model] served_model_names = args.served_model_name or [args.model]
...@@ -465,15 +463,11 @@ async def init_render_app_state( ...@@ -465,15 +463,11 @@ async def init_render_app_state(
request_logger = None request_logger = None
renderer = renderer_from_config(vllm_config) renderer = renderer_from_config(vllm_config)
io_processor = get_io_processor(
vllm_config, renderer, vllm_config.model_config.io_processor_plugin
)
resolved_chat_template = load_chat_template(args.chat_template) resolved_chat_template = load_chat_template(args.chat_template)
state.openai_serving_render = OpenAIServingRender( state.openai_serving_render = OpenAIServingRender(
model_config=vllm_config.model_config, model_config=vllm_config.model_config,
renderer=renderer, renderer=renderer,
io_processor=io_processor,
model_registry=model_registry, model_registry=model_registry,
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
......
...@@ -44,12 +44,6 @@ from vllm.entrypoints.openai.speech_to_text.protocol import ( ...@@ -44,12 +44,6 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionResponse, TranscriptionResponse,
TranslationRequest, TranslationRequest,
) )
from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorRequest,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingResponse,
)
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
from vllm.entrypoints.serve.tokenize.protocol import ( from vllm.entrypoints.serve.tokenize.protocol import (
DetokenizeRequest, DetokenizeRequest,
...@@ -62,8 +56,7 @@ from vllm.inputs import EngineInput, PromptType ...@@ -62,8 +56,7 @@ from vllm.inputs import EngineInput, PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, TokenizeParams from vllm.renderers import ChatParams, TokenizeParams
from vllm.renderers.inputs.preprocess import ( from vllm.renderers.inputs.preprocess import (
extract_prompt_components, extract_prompt_components,
...@@ -78,10 +71,7 @@ from vllm.tracing import ( ...@@ -78,10 +71,7 @@ from vllm.tracing import (
log_tracing_disabled_warning, log_tracing_disabled_warning,
) )
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.async_utils import ( from vllm.utils.async_utils import collect_from_async_generator
collect_from_async_generator,
merge_async_iterators,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -101,17 +91,11 @@ class RendererChatRequest(RendererRequest, Protocol): ...@@ -101,17 +91,11 @@ class RendererChatRequest(RendererRequest, Protocol):
CompletionLikeRequest: TypeAlias = ( CompletionLikeRequest: TypeAlias = (
CompletionRequest CompletionRequest | TokenizeCompletionRequest | DetokenizeRequest
| TokenizeCompletionRequest
| DetokenizeRequest
| PoolingCompletionRequest
) )
ChatLikeRequest: TypeAlias = ( ChatLikeRequest: TypeAlias = (
ChatCompletionRequest ChatCompletionRequest | BatchChatCompletionRequest | TokenizeChatRequest
| BatchChatCompletionRequest
| TokenizeChatRequest
| PoolingChatRequest
) )
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
...@@ -121,7 +105,6 @@ AnyRequest: TypeAlias = ( ...@@ -121,7 +105,6 @@ AnyRequest: TypeAlias = (
| ChatLikeRequest | ChatLikeRequest
| SpeechToTextRequest | SpeechToTextRequest
| ResponsesRequest | ResponsesRequest
| IOProcessorRequest
| GenerateRequest | GenerateRequest
) )
...@@ -130,7 +113,6 @@ AnyResponse: TypeAlias = ( ...@@ -130,7 +113,6 @@ AnyResponse: TypeAlias = (
| ChatCompletionResponse | ChatCompletionResponse
| TranscriptionResponse | TranscriptionResponse
| TokenizeResponse | TokenizeResponse
| PoolingResponse
| GenerateResponse | GenerateResponse
) )
...@@ -146,12 +128,6 @@ class ServeContext(Generic[RequestT]): ...@@ -146,12 +128,6 @@ class ServeContext(Generic[RequestT]):
created_time: int = field(default_factory=lambda: int(time.time())) created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None lora_request: LoRARequest | None = None
engine_inputs: list[EngineInput] | None = None engine_inputs: list[EngineInput] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None
)
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
...@@ -171,7 +147,6 @@ class OpenAIServing: ...@@ -171,7 +147,6 @@ class OpenAIServing:
super().__init__() super().__init__()
self.engine_client = engine_client self.engine_client = engine_client
self.models = models self.models = models
self.request_logger = request_logger self.request_logger = request_logger
...@@ -179,7 +154,6 @@ class OpenAIServing: ...@@ -179,7 +154,6 @@ class OpenAIServing:
self.model_config = engine_client.model_config self.model_config = engine_client.model_config
self.renderer = engine_client.renderer self.renderer = engine_client.renderer
self.io_processor = engine_client.io_processor
self.input_processor = engine_client.input_processor self.input_processor = engine_client.input_processor
async def beam_search( async def beam_search(
...@@ -381,155 +355,6 @@ class OpenAIServing: ...@@ -381,155 +355,6 @@ class OpenAIServing:
prompt_logprobs=None, prompt_logprobs=None,
) )
async def _preprocess(
self,
ctx: ServeContext,
) -> ErrorResponse | None:
"""
Default preprocessing hook. Subclasses may override to prepare `ctx`.
"""
return None
def _build_response(
self,
ctx: ServeContext,
) -> AnyResponse | ErrorResponse:
"""
Default response builder. Subclass may override this method
to return the appropriate response object.
"""
return self.create_error_response("unimplemented endpoint")
async def handle(
self,
ctx: ServeContext,
) -> AnyResponse | ErrorResponse:
async for response in self._pipeline(ctx):
return response
return self.create_error_response("No response yielded from pipeline")
async def _pipeline(
self,
ctx: ServeContext,
) -> AsyncGenerator[AnyResponse | ErrorResponse, None]:
"""Execute the request processing pipeline yielding responses."""
if error := await self._check_model(ctx.request):
yield error
if error := self._validate_request(ctx):
yield error
preprocess_ret = await self._preprocess(ctx)
if isinstance(preprocess_ret, ErrorResponse):
yield preprocess_ret
generators_ret = await self._prepare_generators(ctx)
if isinstance(generators_ret, ErrorResponse):
yield generators_ret
collect_ret = await self._collect_batch(ctx)
if isinstance(collect_ret, ErrorResponse):
yield collect_ret
yield self._build_response(ctx)
def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None:
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
if (
truncate_prompt_tokens is not None
and truncate_prompt_tokens > self.model_config.max_model_len
):
return self.create_error_response(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please request a smaller truncation size."
)
return None
def _create_pooling_params(
self,
ctx: ServeContext,
) -> PoolingParams | ErrorResponse:
if not hasattr(ctx.request, "to_pooling_params"):
return self.create_error_response(
"Request type does not support pooling parameters"
)
return ctx.request.to_pooling_params()
async def _prepare_generators(
self,
ctx: ServeContext,
) -> ErrorResponse | None:
"""Schedule the request and get the result generator."""
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
trace_headers = (
None
if ctx.raw_request is None
else await self._get_trace_headers(ctx.raw_request.headers)
)
pooling_params = self._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
return pooling_params
if ctx.engine_inputs is None:
return self.create_error_response("Engine prompts not available")
for i, engine_input in enumerate(ctx.engine_inputs):
request_id_item = f"{ctx.request_id}-{i}"
self._log_inputs(
request_id_item,
engine_input,
params=pooling_params,
lora_request=ctx.lora_request,
)
generator = self.engine_client.encode(
engine_input,
pooling_params,
request_id_item,
lora_request=ctx.lora_request,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
)
generators.append(generator)
ctx.result_generator = merge_async_iterators(*generators)
return None
async def _collect_batch(
self,
ctx: ServeContext,
) -> ErrorResponse | None:
"""Collect batch results from the result generator."""
if ctx.engine_inputs is None:
return self.create_error_response("Engine prompts not available")
num_prompts = len(ctx.engine_inputs)
final_res_batch: list[PoolingRequestOutput | None]
final_res_batch = [None] * num_prompts
if ctx.result_generator is None:
return self.create_error_response("Result generator not available")
async for i, res in ctx.result_generator:
final_res_batch[i] = res
if None in final_res_batch:
return self.create_error_response(
"Failed to generate results for all prompts"
)
ctx.final_res_batch = [res for res in final_res_batch if res is not None]
return None
@staticmethod @staticmethod
def create_error_response( def create_error_response(
message: str | Exception, message: str | Exception,
...@@ -719,7 +544,7 @@ class OpenAIServing: ...@@ -719,7 +544,7 @@ class OpenAIServing:
self, self,
request_id: str, request_id: str,
inputs: PromptType | EngineInput, inputs: PromptType | EngineInput,
params: SamplingParams | PoolingParams | BeamSearchParams | None, params: SamplingParams | BeamSearchParams | None,
lora_request: LoRARequest | None, lora_request: LoRARequest | None,
) -> None: ) -> None:
if self.request_logger is None: if self.request_logger is None:
......
...@@ -112,7 +112,6 @@ class OpenAIServingModels: ...@@ -112,7 +112,6 @@ class OpenAIServingModels:
self.model_config = self.engine_client.model_config self.model_config = self.engine_client.model_config
self.renderer = self.engine_client.renderer self.renderer = self.engine_client.renderer
self.io_processor = self.engine_client.io_processor
self.input_processor = self.engine_client.input_processor self.input_processor = self.engine_client.input_processor
async def init_static_loras(self): async def init_static_loras(self):
......
...@@ -67,20 +67,18 @@ def init_pooling_state( ...@@ -67,20 +67,18 @@ def init_pooling_state(
from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.pooling.classify.serving import ServingClassification from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.pooling.embed.serving import ServingEmbedding from vllm.entrypoints.pooling.embed.serving import ServingEmbedding
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling from vllm.entrypoints.pooling.pooling.serving import ServingPooling
from vllm.entrypoints.pooling.scoring.serving import ServingScores from vllm.entrypoints.pooling.scoring.serving import ServingScores
from vllm.tasks import POOLING_TASKS from vllm.tasks import POOLING_TASKS
model_config = engine_client.model_config model_config = engine_client.model_config
resolved_chat_template = load_chat_template(args.chat_template) resolved_chat_template = load_chat_template(args.chat_template)
state.serving_pooling = ( state.serving_pooling = (
( (
OpenAIServingPooling( ServingPooling(
engine_client, engine_client,
state.openai_serving_models, state.openai_serving_models,
state.openai_serving_render,
supported_tasks=supported_tasks, supported_tasks=supported_tasks,
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Final from typing import Any, Final
from vllm import PoolingRequestOutput, PromptType from vllm import PoolingParams, PoolingRequestOutput, PromptType
from vllm.config import ModelConfig from vllm.config import VllmConfig
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam, ChatCompletionMessageParam,
ChatTemplateConfig, ChatTemplateConfig,
...@@ -33,11 +33,12 @@ class PoolingIOProcessor: ...@@ -33,11 +33,12 @@ class PoolingIOProcessor:
def __init__( def __init__(
self, self,
model_config: ModelConfig, vllm_config: VllmConfig,
renderer: BaseRenderer, renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig, chat_template_config: ChatTemplateConfig,
): ):
self.model_config = model_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.renderer = renderer self.renderer = renderer
self.chat_template = chat_template_config.chat_template self.chat_template = chat_template_config.chat_template
...@@ -48,12 +49,12 @@ class PoolingIOProcessor: ...@@ -48,12 +49,12 @@ class PoolingIOProcessor:
chat_template_config.trust_request_chat_template chat_template_config.trust_request_chat_template
) )
def create_pooling_params(self, request):
return request.to_pooling_params()
####################################### #######################################
# online APIs # online APIs
def create_pooling_params(self, request):
return request.to_pooling_params()
def pre_process_online(self, ctx: PoolingServeContext): def pre_process_online(self, ctx: PoolingServeContext):
request = ctx.request request = ctx.request
...@@ -100,12 +101,16 @@ class PoolingIOProcessor: ...@@ -100,12 +101,16 @@ class PoolingIOProcessor:
# offline APIs # offline APIs
def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]: def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]:
assert not isinstance(ctx.prompts, ScoringData) assert not isinstance(ctx.prompts, ScoringData) and not (
isinstance(ctx.prompts, dict) and "data" in ctx.prompts
)
prompts_seq = prompt_to_seq(ctx.prompts)
tok_params = self.renderer.default_cmpl_tok_params.with_kwargs( tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
**(ctx.tokenization_kwargs or {}) **(ctx.tokenization_kwargs or {})
) )
return self._preprocess_completion_offline( return self._preprocess_completion_offline(
prompts=ctx.prompts, tok_params=tok_params prompts=prompts_seq, tok_params=tok_params
) )
async def pre_process_offline_async(self, ctx: OfflineInputsContext): async def pre_process_offline_async(self, ctx: OfflineInputsContext):
...@@ -243,3 +248,19 @@ class PoolingIOProcessor: ...@@ -243,3 +248,19 @@ class PoolingIOProcessor:
"Refused request with untrusted chat template." "Refused request with untrusted chat template."
) )
return None return None
def _params_to_seq(
self,
params: PoolingParams | Sequence[PoolingParams],
num_requests: int,
) -> Sequence[PoolingParams]:
if isinstance(params, Sequence):
if len(params) != num_requests:
raise ValueError(
f"The lengths of prompts ({num_requests}) "
f"and params ({len(params)}) must be the same."
)
return params
return [params] * num_requests
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Mapping from collections.abc import AsyncGenerator, Mapping
from http import HTTPStatus from http import HTTPStatus
from typing import ClassVar from typing import ClassVar
...@@ -9,7 +11,7 @@ from fastapi.responses import Response ...@@ -9,7 +11,7 @@ from fastapi.responses import Response
from starlette.datastructures import Headers from starlette.datastructures import Headers
from vllm import PoolingParams, PoolingRequestOutput, envs from vllm import PoolingParams, PoolingRequestOutput, envs
from vllm.config import ModelConfig from vllm.config import VllmConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
ChatTemplateConfig, ChatTemplateConfig,
...@@ -35,7 +37,7 @@ from vllm.utils.async_utils import merge_async_iterators ...@@ -35,7 +37,7 @@ from vllm.utils.async_utils import merge_async_iterators
from .io_processor import PoolingIOProcessor from .io_processor import PoolingIOProcessor
class PoolingServing: class PoolingServingBase(ABC):
request_id_prefix: ClassVar[str] request_id_prefix: ClassVar[str]
def __init__( def __init__(
...@@ -50,10 +52,11 @@ class PoolingServing: ...@@ -50,10 +52,11 @@ class PoolingServing:
return_tokens_as_token_ids: bool = False, return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False, log_error_stack: bool = False,
): ):
super().__init__()
self.engine_client = engine_client self.engine_client = engine_client
self.models = models self.models = models
self.model_config = models.model_config self.model_config = models.model_config
self.renderer = models.renderer
self.vllm_config = engine_client.vllm_config
self.max_model_len = self.model_config.max_model_len self.max_model_len = self.model_config.max_model_len
self.request_logger = request_logger self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids self.return_tokens_as_token_ids = return_tokens_as_token_ids
...@@ -63,31 +66,14 @@ class PoolingServing: ...@@ -63,31 +66,14 @@ class PoolingServing:
chat_template_content_format=chat_template_content_format, chat_template_content_format=chat_template_content_format,
trust_request_chat_template=trust_request_chat_template, trust_request_chat_template=trust_request_chat_template,
) )
self.io_processor = self.init_io_processor(
model_config=models.model_config,
renderer=models.renderer,
chat_template_config=self.chat_template_config,
)
def init_io_processor(
self,
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
) -> PoolingIOProcessor:
raise NotImplementedError
@abstractmethod
async def __call__( async def __call__(
self, self,
request: AnyPoolingRequest, request: AnyPoolingRequest,
raw_request: Request | None = None, raw_request: Request | None = None,
) -> Response: ) -> Response:
ctx = await self._init_ctx(request, raw_request) raise NotImplementedError
await self.io_processor.pre_process_online_async(ctx)
await self._prepare_generators(ctx)
await self._collect_batch(ctx)
await self.io_processor.post_process_online_async(ctx)
return await self._build_response(ctx)
async def _init_ctx( async def _init_ctx(
self, self,
...@@ -124,10 +110,8 @@ class PoolingServing: ...@@ -124,10 +110,8 @@ class PoolingServing:
else await self._get_trace_headers(ctx.raw_request.headers) else await self._get_trace_headers(ctx.raw_request.headers)
) )
if ctx.pooling_params is None: assert ctx.pooling_params is not None
pooling_params = self.io_processor.create_pooling_params(ctx.request) pooling_params = ctx.pooling_params
else:
pooling_params = ctx.pooling_params
if isinstance(pooling_params, list): if isinstance(pooling_params, list):
for params in pooling_params: for params in pooling_params:
...@@ -190,6 +174,7 @@ class PoolingServing: ...@@ -190,6 +174,7 @@ class PoolingServing:
ctx.final_res_batch = [res for res in final_res_batch if res is not None] ctx.final_res_batch = [res for res in final_res_batch if res is not None]
@abstractmethod
async def _build_response( async def _build_response(
self, self,
ctx: PoolingServeContext, ctx: PoolingServeContext,
...@@ -355,3 +340,39 @@ class PoolingServing: ...@@ -355,3 +340,39 @@ class PoolingServing:
params=params, params=params,
lora_request=lora_request, lora_request=lora_request,
) )
class PoolingServing(PoolingServingBase, ABC):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.io_processor = self.init_io_processor(
vllm_config=self.vllm_config,
renderer=self.renderer,
chat_template_config=self.chat_template_config,
)
@abstractmethod
def init_io_processor(
self,
vllm_config: VllmConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
) -> PoolingIOProcessor:
raise NotImplementedError
async def __call__(
self,
request: AnyPoolingRequest,
raw_request: Request | None = None,
) -> Response:
ctx = await self._init_ctx(request, raw_request)
await self.io_processor.pre_process_online_async(ctx)
if ctx.pooling_params is None:
ctx.pooling_params = self.io_processor.create_pooling_params(request)
await self._prepare_generators(ctx)
await self._collect_batch(ctx)
await self.io_processor.post_process_online_async(ctx)
return await self._build_response(ctx)
...@@ -5,4 +5,8 @@ from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor ...@@ -5,4 +5,8 @@ from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
class ClassifyIOProcessor(PoolingIOProcessor): class ClassifyIOProcessor(PoolingIOProcessor):
name = "classification" name = "classify"
class TokenClassifyIOProcessor(PoolingIOProcessor):
name = "token_classify"
...@@ -6,14 +6,11 @@ from typing import TypeAlias ...@@ -6,14 +6,11 @@ from typing import TypeAlias
import numpy as np import numpy as np
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.openai.engine.protocol import UsageInfo from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.serving import PoolingServing from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.entrypoints.pooling.typing import PoolingServeContext from vllm.entrypoints.pooling.typing import PoolingServeContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput from vllm.outputs import ClassificationOutput
from vllm.renderers import BaseRenderer
from .io_processor import ClassifyIOProcessor from .io_processor import ClassifyIOProcessor
from .protocol import ( from .protocol import (
...@@ -31,17 +28,8 @@ ClassificationServeContext: TypeAlias = PoolingServeContext[ClassificationReques ...@@ -31,17 +28,8 @@ ClassificationServeContext: TypeAlias = PoolingServeContext[ClassificationReques
class ServingClassification(PoolingServing): class ServingClassification(PoolingServing):
request_id_prefix = "classify" request_id_prefix = "classify"
def init_io_processor( def init_io_processor(self, *args, **kwargs) -> ClassifyIOProcessor:
self, return ClassifyIOProcessor(*args, **kwargs)
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
) -> ClassifyIOProcessor:
return ClassifyIOProcessor(
model_config=model_config,
renderer=renderer,
chat_template_config=chat_template_config,
)
async def _build_response( async def _build_response(
self, self,
......
...@@ -37,7 +37,7 @@ logger = init_logger(__name__) ...@@ -37,7 +37,7 @@ logger = init_logger(__name__)
class EmbedIOProcessor(PoolingIOProcessor): class EmbedIOProcessor(PoolingIOProcessor):
name = "embedding" name = "embed"
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
...@@ -549,3 +549,7 @@ class EmbedIOProcessor(PoolingIOProcessor): ...@@ -549,3 +549,7 @@ class EmbedIOProcessor(PoolingIOProcessor):
request = ctx.request request = ctx.request
if request.truncate == "NONE" and request.max_tokens is not None: if request.truncate == "NONE" and request.max_tokens is not None:
self._check_cohere_max_tokens(ctx.final_res_batch, request.max_tokens) self._check_cohere_max_tokens(ctx.final_res_batch, request.max_tokens)
class TokenEmbedIOProcessor(PoolingIOProcessor):
name = "token_embed"
...@@ -8,8 +8,6 @@ from typing import Literal, TypeAlias, cast ...@@ -8,8 +8,6 @@ from typing import Literal, TypeAlias, cast
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.openai.engine.protocol import UsageInfo from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.serving import PoolingServing from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
...@@ -33,12 +31,10 @@ from vllm.entrypoints.pooling.utils import ( ...@@ -33,12 +31,10 @@ from vllm.entrypoints.pooling.utils import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.renderers import BaseRenderer
from vllm.utils.serial_utils import EmbedDType, Endianness from vllm.utils.serial_utils import EmbedDType, Endianness
logger = init_logger(__name__) logger = init_logger(__name__)
JSONResponseCLS = get_json_response_cls()
EmbeddingServeContext: TypeAlias = PoolingServeContext[EmbeddingRequest] EmbeddingServeContext: TypeAlias = PoolingServeContext[EmbeddingRequest]
...@@ -49,17 +45,13 @@ class ServingEmbedding(PoolingServing): ...@@ -49,17 +45,13 @@ class ServingEmbedding(PoolingServing):
request_id_prefix = "embd" request_id_prefix = "embd"
io_processor: EmbedIOProcessor io_processor: EmbedIOProcessor
def init_io_processor( def __init__(self, *args, **kwargs):
self, super().__init__(*args, **kwargs)
model_config: ModelConfig,
renderer: BaseRenderer, self.json_response_cls = get_json_response_cls()
chat_template_config: ChatTemplateConfig,
) -> EmbedIOProcessor: def init_io_processor(self, *args, **kwargs) -> EmbedIOProcessor:
return EmbedIOProcessor( return EmbedIOProcessor(*args, **kwargs)
model_config=model_config,
renderer=renderer,
chat_template_config=chat_template_config,
)
async def _build_response( async def _build_response(
self, self,
...@@ -149,7 +141,7 @@ class ServingEmbedding(PoolingServing): ...@@ -149,7 +141,7 @@ class ServingEmbedding(PoolingServing):
data=items, data=items,
usage=usage, usage=usage,
) )
return JSONResponseCLS(content=response.model_dump()) return self.json_response_cls(content=response.model_dump())
def _openai_bytes_response( def _openai_bytes_response(
self, self,
...@@ -190,8 +182,8 @@ class ServingEmbedding(PoolingServing): ...@@ -190,8 +182,8 @@ class ServingEmbedding(PoolingServing):
media_type=response.media_type, media_type=response.media_type,
) )
@staticmethod
def _build_cohere_response_from_ctx( def _build_cohere_response_from_ctx(
self,
ctx: PoolingServeContext, ctx: PoolingServeContext,
) -> JSONResponse: ) -> JSONResponse:
request = ctx.request request = ctx.request
...@@ -218,4 +210,4 @@ class ServingEmbedding(PoolingServing): ...@@ -218,4 +210,4 @@ class ServingEmbedding(PoolingServing):
), ),
), ),
) )
return JSONResponse(content=response.model_dump(exclude_none=True)) return self.json_response_cls(content=response.model_dump(exclude_none=True))
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config import VllmConfig
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateConfig from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor from vllm.plugins.io_processors import has_io_processor
from vllm.entrypoints.pooling.scoring.io_processor import ScoringIOProcessors
from vllm.entrypoints.pooling.utils import enable_scoring_api
from vllm.renderers import BaseRenderer from vllm.renderers import BaseRenderer
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from .base.io_processor import PoolingIOProcessor
from .utils import enable_scoring_api
def init_pooling_io_processors( def init_pooling_io_processors(
supported_tasks: tuple[SupportedTask, ...], supported_tasks: tuple[SupportedTask, ...],
model_config: ModelConfig, vllm_config: VllmConfig,
renderer: BaseRenderer, renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig, chat_template_config: ChatTemplateConfig,
) -> dict[str, PoolingIOProcessor]: ) -> dict[str, PoolingIOProcessor]:
processors: list[tuple[str, type[PoolingIOProcessor]]] = [] model_config = vllm_config.model_config
processors: dict[str, type[PoolingIOProcessor]] = {}
if "classify" in supported_tasks: if "classify" in supported_tasks:
from vllm.entrypoints.pooling.classify.io_processor import ClassifyIOProcessor from .classify.io_processor import ClassifyIOProcessor
processors["classify"] = ClassifyIOProcessor
if "token_classify" in supported_tasks:
from .classify.io_processor import TokenClassifyIOProcessor
processors["token_classify"] = TokenClassifyIOProcessor
processors.append(("classify", ClassifyIOProcessor))
if "embed" in supported_tasks: if "embed" in supported_tasks:
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor from .embed.io_processor import EmbedIOProcessor
processors.append(("embed", EmbedIOProcessor)) processors["embed"] = EmbedIOProcessor
if "token_embed" in supported_tasks:
from .embed.io_processor import TokenEmbedIOProcessor
processors["token_embed"] = TokenEmbedIOProcessor
if has_io_processor(
vllm_config,
model_config.io_processor_plugin,
):
from .pooling.io_processor import PluginWithIOProcessorPlugins
processors["plugin"] = PluginWithIOProcessorPlugins
elif "plugin" in supported_tasks:
from .pooling.io_processor import PluginWithoutIOProcessorPlugins
processors["plugin"] = PluginWithoutIOProcessorPlugins
if enable_scoring_api(supported_tasks, model_config): if enable_scoring_api(supported_tasks, model_config):
score_type = model_config.score_type score_type = model_config.score_type
from .scoring.io_processor import ScoringIOProcessors
if score_type is not None and score_type in ScoringIOProcessors: if score_type is not None and score_type in ScoringIOProcessors:
processors.append((score_type, ScoringIOProcessors[score_type])) processors[score_type] = ScoringIOProcessors[score_type]
return { return {
task: processor_cls( task: processor_cls(
model_config=model_config, vllm_config=vllm_config,
renderer=renderer, renderer=renderer,
chat_template_config=chat_template_config, chat_template_config=chat_template_config,
) )
for task, processor_cls in processors for task, processor_cls in processors.items()
} }
...@@ -3,24 +3,17 @@ ...@@ -3,24 +3,17 @@
from http import HTTPStatus from http import HTTPStatus
from fastapi import APIRouter, Depends, Request from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, StreamingResponse
from typing_extensions import assert_never
from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.pooling.protocol import ( from vllm.entrypoints.pooling.pooling.protocol import PoolingRequest
IOProcessorResponse, from vllm.entrypoints.pooling.pooling.serving import ServingPooling
PoolingBytesResponse,
PoolingRequest,
PoolingResponse,
)
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
from vllm.entrypoints.utils import load_aware_call, with_cancellation from vllm.entrypoints.utils import load_aware_call, with_cancellation
router = APIRouter() router = APIRouter()
def pooling(request: Request) -> OpenAIServingPooling | None: def pooling(request: Request) -> ServingPooling | None:
return request.app.state.serving_pooling return request.app.state.serving_pooling
...@@ -39,19 +32,4 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): ...@@ -39,19 +32,4 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
if handler is None: if handler is None:
raise NotImplementedError("The model does not support Pooling API") raise NotImplementedError("The model does not support Pooling API")
generator = await handler.create_pooling(request, raw_request) return await handler(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, (PoolingResponse, IOProcessorResponse)):
return JSONResponse(content=generator.model_dump())
elif isinstance(generator, PoolingBytesResponse):
return StreamingResponse(
content=generator.content,
headers=generator.headers,
media_type=generator.media_type,
)
assert_never(generator)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Any
from vllm import PoolingParams, PoolingRequestOutput
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.inputs import EngineInput
from vllm.logger import init_logger
from vllm.plugins.io_processors import get_io_processor
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
from ..typing import OfflineInputsContext, OfflineOutputsContext, PoolingServeContext
from .protocol import IOProcessorRequest, IOProcessorResponse
logger = init_logger(__name__)
class PluginWithoutIOProcessorPlugins(PoolingIOProcessor):
name = "plugin"
class PluginWithIOProcessorPlugins(PoolingIOProcessor):
"""IO Processor plugins are a feature that allows pre- and post-processing
of the model input and output for pooling models."""
name = "plugin"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
io_processor = get_io_processor(
self.vllm_config,
self.renderer,
self.model_config.io_processor_plugin,
)
assert io_processor is not None
self.io_processor = io_processor
#######################################
# online APIs
def pre_process_online(self, ctx: PoolingServeContext):
assert isinstance(ctx.request, IOProcessorRequest)
validated_prompt = self.io_processor.parse_data(ctx.request.data)
raw_prompts = self.io_processor.pre_process(
prompt=validated_prompt, request_id=ctx.request_id
)
parsed_prompts = [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(self.model_config, prompt)
)
for prompt in prompt_to_seq(raw_prompts)
]
tok_params = ctx.request.build_tok_params(self.model_config)
ctx.engine_inputs = self.renderer.render_cmpl(
parsed_prompts,
tok_params,
prompt_extras={
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(ctx.request, k, None)) is not None
},
)
pooling_params = self.io_processor.merge_pooling_params()
if pooling_params.task is None:
pooling_params.task = "plugin"
ctx.pooling_params = pooling_params
def post_process_online(
self,
ctx: PoolingServeContext,
):
output = self.io_processor.post_process(
ctx.final_res_batch,
request_id=ctx.request_id,
)
if callable(
output_to_response := getattr(self.io_processor, "output_to_response", None)
):
logger.warning_once(
"`IOProcessor.output_to_response` is deprecated. To ensure "
"consistency between offline and online APIs, "
"`IOProcessorResponse` will become a transparent wrapper "
"around output data from v0.19 onwards.",
)
if hasattr(output, "request_id") and output.request_id is None:
output.request_id = ctx.request_id # type: ignore
ctx.response = output_to_response(output) # type: ignore
else:
ctx.response = IOProcessorResponse(request_id=ctx.request_id, data=output)
#######################################
# offline APIs
def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]:
assert isinstance(ctx.prompts, dict) and "data" in ctx.prompts
assert ctx.pooling_params is not None
# Validate the request data is valid for the loaded plugin
prompt_data = ctx.prompts.get("data")
if prompt_data is None:
raise ValueError(
"The 'data' field of the prompt is expected to contain "
"the prompt data and it cannot be None. "
"Refer to the documentation of the IOProcessor "
"in use for more details."
)
validated_prompt = self.io_processor.parse_data(prompt_data)
# obtain the actual model prompts from the pre-processor
prompts = self.io_processor.pre_process(prompt=validated_prompt)
prompts_seq = prompt_to_seq(prompts)
params_seq: list[PoolingParams] = [
self.io_processor.merge_pooling_params(param)
for param in self._params_to_seq(
ctx.pooling_params,
len(prompts_seq),
)
]
for p in params_seq:
if p.task is None:
p.task = "plugin"
ctx.pooling_params = params_seq
ctx.prompts = prompts_seq
return super().pre_process_offline(ctx)
def post_process_offline(
self,
ctx: OfflineOutputsContext,
) -> list[PoolingRequestOutput]:
processed_outputs = self.io_processor.post_process(ctx.outputs)
return [
PoolingRequestOutput[Any](
request_id="",
outputs=processed_outputs,
num_cached_tokens=getattr(processed_outputs, "num_cached_tokens", 0),
prompt_token_ids=[],
finished=True,
)
]
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import json import json
import time from collections.abc import Callable
from collections.abc import AsyncGenerator, Callable, Sequence
from functools import partial from functools import partial
from typing import Final, Literal, cast from typing import Literal, cast
from fastapi import Request from fastapi import Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.engine.protocol import EngineClient from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.pooling.base.serving import PoolingServingBase
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.pooling.protocol import ( from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorRequest, IOProcessorRequest,
IOProcessorResponse,
PoolingBytesResponse, PoolingBytesResponse,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingRequest, PoolingRequest,
PoolingResponse, PoolingResponse,
PoolingResponseData, PoolingResponseData,
) )
from vllm.entrypoints.pooling.typing import AnyPoolingRequest, PoolingServeContext
from vllm.entrypoints.pooling.utils import ( from vllm.entrypoints.pooling.utils import (
encode_pooling_bytes, encode_pooling_bytes,
encode_pooling_output_base64, encode_pooling_output_base64,
encode_pooling_output_float, encode_pooling_output_float,
get_json_response_cls,
) )
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.inputs import EngineInput
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.renderers.inputs.preprocess import prompt_to_seq
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.serial_utils import EmbedDType, Endianness
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
logger = init_logger(__name__) logger = init_logger(__name__)
class OpenAIServingPooling(OpenAIServing): class ServingPooling(PoolingServingBase):
request_id_prefix = "pooling"
def __init__( def __init__(
self, self,
engine_client: EngineClient, *args,
models: OpenAIServingModels,
openai_serving_render: OpenAIServingRender,
supported_tasks: tuple[SupportedTask, ...], supported_tasks: tuple[SupportedTask, ...],
*, **kwargs,
request_logger: RequestLogger | None, ):
chat_template: str | None, super().__init__(*args, **kwargs)
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
)
self.supported_tasks = supported_tasks self.supported_tasks = supported_tasks
self.pooling_task = self.model_config.get_pooling_task(supported_tasks) self.pooling_task = self.model_config.get_pooling_task(supported_tasks)
self.openai_serving_render = openai_serving_render self.io_processors = init_pooling_io_processors(
self.chat_template = chat_template supported_tasks=supported_tasks,
self.chat_template_content_format: Final = chat_template_content_format vllm_config=self.vllm_config,
self.trust_request_chat_template = trust_request_chat_template renderer=self.renderer,
chat_template_config=self.chat_template_config,
)
self.json_response_cls = get_json_response_cls()
async def create_pooling( async def __call__(
self, self,
request: PoolingRequest, request: AnyPoolingRequest,
raw_request: Request | None = None, raw_request: Request | None = None,
) -> PoolingResponse | IOProcessorResponse | PoolingBytesResponse | ErrorResponse: ) -> Response:
""" assert isinstance(request, PoolingRequest)
See https://platform.openai.com/docs/api-reference/embeddings/create pooling_task = self._verify_pooling_task(request)
for the API specification. This API mimics the OpenAI Embedding API.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
model_name = self.models.model_name() io_processor = self.io_processors[pooling_task]
ctx = await self._init_ctx(request, raw_request)
request_id = f"pool-{self._base_request_id(raw_request)}" await io_processor.pre_process_online_async(ctx)
created_time = int(time.time())
lora_request = self._maybe_get_adapters(request) if ctx.pooling_params is None:
ctx.pooling_params = io_processor.create_pooling_params(request)
await self._prepare_generators(ctx)
await self._collect_batch(ctx)
await io_processor.post_process_online_async(ctx)
return await self._build_response(ctx)
def _verify_pooling_task(self, request: PoolingRequest) -> str:
if getattr(request, "dimensions", None) is not None:
raise ValueError("dimensions is currently not supported")
if request.task is None: if request.task is None:
request.task = self.pooling_task request.task = self.pooling_task
if getattr(request, "dimensions", None) is not None: if isinstance(request, IOProcessorRequest):
return self.create_error_response("dimensions is currently not supported") request.task = "plugin"
assert request.task is not None
pooling_task = request.task
# plugin task uses io_processor.parse_request to verify inputs # plugin task uses io_processor.parse_request to verify inputs
if request.task != "plugin" and request.task != self.pooling_task: if pooling_task != "plugin" and pooling_task != self.pooling_task:
if request.task not in self.supported_tasks: if pooling_task not in self.io_processors:
raise ValueError( raise ValueError(
f"Unsupported task: {request.task!r} " f"Unsupported task: {pooling_task!r} "
f"Supported tasks: {self.supported_tasks}" f"Supported tasks: {self.supported_tasks}"
) )
else: else:
logger.warning_once( logger.warning_once(
"Pooling multitask support is deprecated and will be removed " "Pooling multitask support is deprecated and will be removed "
"in v0.20. When the default pooling task is not what you want, you " "in v0.20. When the default pooling task is not what you want, you "
'need to manually specify it via --pooler-config.task "%s". ', "need to manually specify it via --pooler-config.task %s. ",
request.task, pooling_task,
)
engine_inputs: Sequence[EngineInput]
if use_io_processor := isinstance(request, IOProcessorRequest):
if self.io_processor is None:
raise ValueError(
"No IOProcessor plugin installed. Please refer "
"to the documentation and to the "
"'prithvi_geospatial_mae_io_processor' "
"offline inference example for more details."
) )
validated_prompt = self.io_processor.parse_data(request.data) if pooling_task == "plugin" and "plugin" not in self.io_processors:
raise ValueError(
raw_prompts = await self.io_processor.pre_process_async( "No IOProcessor plugin installed. Please refer "
prompt=validated_prompt, request_id=request_id "to the documentation and to the "
) "'prithvi_geospatial_mae_io_processor' "
engine_inputs = await self.openai_serving_render.preprocess_cmpl( "offline inference example for more details."
request,
prompt_to_seq(raw_prompts),
)
elif isinstance(request, PoolingChatRequest):
error_check_ret = self.openai_serving_render.validate_chat_template(
request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
_, engine_inputs = await self.openai_serving_render.preprocess_chat(
request,
request.messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
)
elif isinstance(request, PoolingCompletionRequest):
engine_inputs = await self.openai_serving_render.preprocess_completion(
request,
prompt_input=request.input,
prompt_embeds=None,
)
else:
raise ValueError(f"Unsupported request of type {type(request)}")
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
if use_io_processor:
assert self.io_processor is not None
pooling_params = self.io_processor.merge_pooling_params()
if pooling_params.task is None:
pooling_params.task = "plugin"
else:
pooling_params = request.to_pooling_params() # type: ignore
for i, engine_input in enumerate(engine_inputs):
request_id_item = f"{request_id}-{i}"
self._log_inputs(
request_id_item,
engine_input,
params=pooling_params,
lora_request=lora_request,
) )
trace_headers = ( return pooling_task
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
generator = self.engine_client.encode(
engine_input,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator) async def _build_response(
self,
ctx: PoolingServeContext,
) -> Response:
if ctx.response is not None:
# for IOProcessorResponse
return self.json_response_cls(content=ctx.response.model_dump())
result_generator = merge_async_iterators(*generators) encoding_format = ctx.request.encoding_format
embed_dtype = ctx.request.embed_dtype
endianness = ctx.request.endianness
if use_io_processor: if encoding_format == "float" or encoding_format == "base64":
assert self.io_processor is not None return self.request_output_to_pooling_json_response(
output = await self.io_processor.post_process_async( ctx.final_res_batch,
result_generator, ctx.request_id,
request_id=request_id, ctx.created_time,
ctx.model_name,
encoding_format,
embed_dtype,
endianness,
) )
if callable( if encoding_format == "bytes" or encoding_format == "bytes_only":
output_to_response := getattr( return self.request_output_to_pooling_bytes_response(
self.io_processor, "output_to_response", None ctx.final_res_batch,
) ctx.request_id,
): ctx.created_time,
logger.warning_once( ctx.model_name,
"`IOProcessor.output_to_response` is deprecated. To ensure " encoding_format,
"consistency between offline and online APIs, " embed_dtype,
"`IOProcessorResponse` will become a transparent wrapper " endianness,
"around output data from v0.19 onwards.",
)
if hasattr(output, "request_id") and output.request_id is None:
output.request_id = request_id # type: ignore
return output_to_response(output) # type: ignore
return IOProcessorResponse(request_id=request_id, data=output)
assert isinstance(request, (PoolingCompletionRequest, PoolingChatRequest))
num_prompts = len(engine_inputs)
# Non-streaming response
final_res_batch: list[PoolingRequestOutput | None]
final_res_batch = [None] * num_prompts
try:
async for i, res in result_generator:
final_res_batch[i] = res
assert all(final_res is not None for final_res in final_res_batch)
final_res_batch_checked = cast(list[PoolingRequestOutput], final_res_batch)
response = self.request_output_to_pooling_response(
final_res_batch_checked,
request_id,
created_time,
model_name,
request.encoding_format,
request.embed_dtype,
request.endianness,
) )
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
return response assert_never(encoding_format)
def request_output_to_pooling_json_response( def request_output_to_pooling_json_response(
self, self,
...@@ -257,7 +162,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -257,7 +162,7 @@ class OpenAIServingPooling(OpenAIServing):
encoding_format: Literal["float", "base64"], encoding_format: Literal["float", "base64"],
embed_dtype: EmbedDType, embed_dtype: EmbedDType,
endianness: Endianness, endianness: Endianness,
) -> PoolingResponse: ) -> JSONResponse:
encode_fn = cast( encode_fn = cast(
Callable[[PoolingRequestOutput], list[float] | str], Callable[[PoolingRequestOutput], list[float] | str],
( (
...@@ -289,13 +194,14 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -289,13 +194,14 @@ class OpenAIServingPooling(OpenAIServing):
total_tokens=num_prompt_tokens, total_tokens=num_prompt_tokens,
) )
return PoolingResponse( response = PoolingResponse(
id=request_id, id=request_id,
created=created_time, created=created_time,
model=model_name, model=model_name,
data=items, data=items,
usage=usage, usage=usage,
) )
return self.json_response_cls(content=response.model_dump())
def request_output_to_pooling_bytes_response( def request_output_to_pooling_bytes_response(
self, self,
...@@ -306,7 +212,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -306,7 +212,7 @@ class OpenAIServingPooling(OpenAIServing):
encoding_format: Literal["bytes", "bytes_only"], encoding_format: Literal["bytes", "bytes_only"],
embed_dtype: EmbedDType, embed_dtype: EmbedDType,
endianness: Endianness, endianness: Endianness,
) -> PoolingBytesResponse: ) -> StreamingResponse:
content, items, usage = encode_pooling_bytes( content, items, usage = encode_pooling_bytes(
pooling_outputs=final_res_batch, pooling_outputs=final_res_batch,
embed_dtype=embed_dtype, embed_dtype=embed_dtype,
...@@ -329,38 +235,10 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -329,38 +235,10 @@ class OpenAIServingPooling(OpenAIServing):
} }
) )
return PoolingBytesResponse(content=content, headers=headers) response = PoolingBytesResponse(content=content, headers=headers)
def request_output_to_pooling_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
encoding_format: EncodingFormat,
embed_dtype: EmbedDType,
endianness: Endianness,
) -> PoolingResponse | PoolingBytesResponse:
if encoding_format == "float" or encoding_format == "base64":
return self.request_output_to_pooling_json_response(
final_res_batch,
request_id,
created_time,
model_name,
encoding_format,
embed_dtype,
endianness,
)
if encoding_format == "bytes" or encoding_format == "bytes_only":
return self.request_output_to_pooling_bytes_response(
final_res_batch,
request_id,
created_time,
model_name,
encoding_format,
embed_dtype,
endianness,
)
assert_never(encoding_format) return StreamingResponse(
content=response.content,
headers=response.headers,
media_type=response.media_type,
)
...@@ -278,7 +278,7 @@ class CrossEncoderIOProcessor(ScoringIOProcessor): ...@@ -278,7 +278,7 @@ class CrossEncoderIOProcessor(ScoringIOProcessor):
def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]: def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]:
assert isinstance(ctx.prompts, ScoringData) assert isinstance(ctx.prompts, ScoringData)
assert not isinstance(ctx.pooling_params, list) assert not isinstance(ctx.pooling_params, Sequence)
tok_params = self.renderer.default_cmpl_tok_params.with_kwargs( tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
**(ctx.tokenization_kwargs or {}) **(ctx.tokenization_kwargs or {})
......
...@@ -4,15 +4,13 @@ ...@@ -4,15 +4,13 @@
from fastapi.responses import JSONResponse, Response from fastapi.responses import JSONResponse, Response
from vllm import PoolingParams from vllm import PoolingParams
from vllm.config import ModelConfig from vllm.config import VllmConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.openai.engine.protocol import UsageInfo from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.base.serving import PoolingServing from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.renderers import BaseRenderer
from vllm.v1.pool.late_interaction import ( from vllm.v1.pool.late_interaction import (
build_late_interaction_doc_params, build_late_interaction_doc_params,
build_late_interaction_query_params, build_late_interaction_query_params,
...@@ -52,22 +50,17 @@ class ServingScores(PoolingServing): ...@@ -52,22 +50,17 @@ class ServingScores(PoolingServing):
super().__init__(engine_client, *args, **kwargs) super().__init__(engine_client, *args, **kwargs)
def init_io_processor( def init_io_processor(
self, self, vllm_config: VllmConfig, *args, **kwargs
model_config: ModelConfig,
renderer: BaseRenderer,
chat_template_config: ChatTemplateConfig,
) -> PoolingIOProcessor: ) -> PoolingIOProcessor:
model_config = vllm_config.model_config
score_type: str = model_config.score_type score_type: str = model_config.score_type
if self.enable_flash_late_interaction: if self.enable_flash_late_interaction:
score_type = "flash-late-interaction" score_type = "flash-late-interaction"
assert score_type in ScoringIOProcessors assert score_type in ScoringIOProcessors
processor_cls = ScoringIOProcessors[score_type] processor_cls = ScoringIOProcessors[score_type]
return processor_cls( return processor_cls(vllm_config, *args, **kwargs)
model_config=model_config,
renderer=renderer,
chat_template_config=chat_template_config,
)
async def __call__(self, *args, **kwargs) -> Response: async def __call__(self, *args, **kwargs) -> Response:
if not self.enable_flash_late_interaction: if not self.enable_flash_late_interaction:
......
...@@ -30,7 +30,7 @@ from vllm.entrypoints.pooling.pooling.protocol import ( ...@@ -30,7 +30,7 @@ from vllm.entrypoints.pooling.pooling.protocol import (
) )
from vllm.entrypoints.pooling.scoring.protocol import ScoringRequest, ScoringResponse from vllm.entrypoints.pooling.scoring.protocol import ScoringRequest, ScoringResponse
from vllm.entrypoints.pooling.scoring.typing import ScoringData from vllm.entrypoints.pooling.scoring.typing import ScoringData
from vllm.inputs import EngineInput from vllm.inputs import DataPrompt, EngineInput
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
PoolingCompletionLikeRequest: TypeAlias = ( PoolingCompletionLikeRequest: TypeAlias = (
...@@ -86,11 +86,14 @@ class PoolingServeContext(Generic[PoolingRequestT]): ...@@ -86,11 +86,14 @@ class PoolingServeContext(Generic[PoolingRequestT]):
## for bi-encoder & late-interaction ## for bi-encoder & late-interaction
n_queries: int | None = None n_queries: int | None = None
## for IOProcessorResponse
response: Any | None = None
@dataclass @dataclass
class OfflineInputsContext: class OfflineInputsContext:
prompts: PromptType | Sequence[PromptType] | ScoringData prompts: PromptType | Sequence[PromptType] | DataPrompt | ScoringData
pooling_params: PoolingParams | list[PoolingParams] | None = None pooling_params: PoolingParams | Sequence[PoolingParams]
tokenization_kwargs: dict[str, Any] | None = None tokenization_kwargs: dict[str, Any] | None = None
chat_template: str | None = None chat_template: str | None = None
......
...@@ -14,7 +14,7 @@ from vllm.config import ModelConfig ...@@ -14,7 +14,7 @@ from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.base.serving import PoolingServing from vllm.entrypoints.pooling.base.serving import PoolingServingBase
from vllm.entrypoints.pooling.utils import enable_scoring_api from vllm.entrypoints.pooling.utils import enable_scoring_api
from vllm.entrypoints.serve.instrumentator.basic import base from vllm.entrypoints.serve.instrumentator.basic import base
from vllm.entrypoints.serve.instrumentator.health import health from vllm.entrypoints.serve.instrumentator.health import health
...@@ -23,7 +23,7 @@ from vllm.tasks import POOLING_TASKS, SupportedTask ...@@ -23,7 +23,7 @@ from vllm.tasks import POOLING_TASKS, SupportedTask
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers # TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13) # (requires typing_extensions >= 4.13)
RequestType = Any RequestType = Any
GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServing | None] GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServingBase | None]
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]] EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
......
...@@ -65,7 +65,6 @@ class OpenAIServingRender: ...@@ -65,7 +65,6 @@ class OpenAIServingRender:
self, self,
model_config: ModelConfig, model_config: ModelConfig,
renderer: BaseRenderer, renderer: BaseRenderer,
io_processor: Any,
model_registry: OpenAIModelRegistry, model_registry: OpenAIModelRegistry,
*, *,
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
...@@ -81,7 +80,6 @@ class OpenAIServingRender: ...@@ -81,7 +80,6 @@ class OpenAIServingRender:
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
self.renderer = renderer self.renderer = renderer
self.io_processor = io_processor
self.model_registry = model_registry self.model_registry = model_registry
self.request_logger = request_logger self.request_logger = request_logger
self.chat_template = chat_template self.chat_template = chat_template
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment