Unverified Commit c0ecaed9 authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Frontend] Offload blocking preprocessing & postprocessing ops to thread pool...


[Frontend] Offload blocking preprocessing & postprocessing ops to thread pool for pooling entrypoints. (#39763)
Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: default avatarwang.yuqi <noooop@126.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 0008729a
......@@ -4,6 +4,7 @@
import pytest
from vllm import PoolingParams
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
from vllm.entrypoints.pooling.embed.protocol import (
CohereEmbedContent,
......@@ -218,6 +219,7 @@ class TestPreProcessCohereOnline:
def _make_context(**request_kwargs) -> PoolingServeContext[CohereEmbedRequest]:
return PoolingServeContext(
request=CohereEmbedRequest(model="test", **request_kwargs),
pooling_params=PoolingParams(),
model_name="test",
request_id="embd-test",
)
......@@ -233,13 +235,13 @@ class TestPreProcessCohereOnline:
ctx = self._make_context(texts=["hello"])
calls: list[tuple[str, object]] = []
def preprocess_completion(request, prompt_input, prompt_embeds):
def preprocess_cmpl_online(request, prompt_input, prompt_embeds):
calls.append(("completion", prompt_input))
return ["completion"]
handler._get_task_instruction_prefix = lambda _input_type: None
handler._has_chat_template = lambda: False
handler._preprocess_completion_online = preprocess_completion
handler._preprocess_cmpl_online = preprocess_cmpl_online
handler._batch_render_chat = lambda *_args, **_kwargs: (
pytest.fail("text-only request should not require chat rendering")
)
......@@ -254,7 +256,7 @@ class TestPreProcessCohereOnline:
ctx = self._make_context(texts=["hello"], input_type="query")
calls: list[tuple[str, object]] = []
def preprocess_completion(request, prompt_input, prompt_embeds):
def preprocess_cmpl(request, prompt_input, prompt_embeds):
calls.append(("completion", prompt_input))
return ["fallback"]
......@@ -263,7 +265,7 @@ class TestPreProcessCohereOnline:
handler._batch_render_chat = lambda *_args, **_kwargs: (
pytest.fail("chat rendering should be skipped without a template")
)
handler._preprocess_completion_online = preprocess_completion
handler._preprocess_cmpl_online = preprocess_cmpl
handler._pre_process_cohere_online(ctx)
......@@ -297,7 +299,7 @@ class TestPreProcessCohereOnline:
handler._get_task_instruction_prefix = lambda _input_type: "query: "
handler._has_chat_template = lambda: True
handler._batch_render_chat = batch_render_chat
handler._preprocess_completion_online = lambda *_args, **_kwargs: (
handler._preprocess_cmpl_online = lambda *_args, **_kwargs: (
pytest.fail("completion path should be skipped when a template exists")
)
......
......@@ -72,7 +72,7 @@ class PoolingIOProcessor:
default_template_kwargs=None,
)
elif isinstance(request, PoolingCompletionLikeRequest):
engine_inputs = self._preprocess_completion_online(
engine_inputs = self._preprocess_cmpl_online(
request,
prompt_input=request.input,
prompt_embeds=None,
......@@ -82,21 +82,12 @@ class PoolingIOProcessor:
ctx.engine_inputs = engine_inputs
async def pre_process_online_async(self, ctx: PoolingServeContext):
self.pre_process_online(ctx)
def post_process_online(
self,
ctx: PoolingServeContext,
):
pass
async def post_process_online_async(
self,
ctx: PoolingServeContext,
):
self.post_process_online(ctx)
#######################################
# offline APIs
......@@ -109,12 +100,7 @@ class PoolingIOProcessor:
tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
**(ctx.tokenization_kwargs or {})
)
return self._preprocess_completion_offline(
prompts=prompts_seq, tok_params=tok_params
)
async def pre_process_offline_async(self, ctx: OfflineInputsContext):
return self.pre_process_offline(ctx)
return self._preprocess_cmpl_offline(prompts=prompts_seq, tok_params=tok_params)
def post_process_offline(
self,
......@@ -122,16 +108,10 @@ class PoolingIOProcessor:
) -> list[PoolingRequestOutput]:
return ctx.outputs
async def post_process_offline_async(
self,
ctx: OfflineOutputsContext,
) -> list[PoolingRequestOutput]:
return self.post_process_offline(ctx)
#######################################
# helpers
def _preprocess_completion_online(
def _preprocess_cmpl_online(
self,
request: RendererRequest,
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
......@@ -209,7 +189,7 @@ class PoolingIOProcessor:
return conversation, [engine_input]
def _preprocess_completion_offline(
def _preprocess_cmpl_offline(
self,
prompts: PromptType | Sequence[PromptType],
tok_params: TokenizeParams,
......
......@@ -3,9 +3,11 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Mapping
from concurrent.futures import Executor
from http import HTTPStatus
from typing import ClassVar
import torch
from fastapi import Request
from fastapi.responses import Response
from starlette.datastructures import Headers
......@@ -32,7 +34,7 @@ from vllm.tracing import (
log_tracing_disabled_warning,
)
from vllm.utils import random_uuid
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.async_utils import make_async, merge_async_iterators
from .io_processor import PoolingIOProcessor
......@@ -67,16 +69,47 @@ class PoolingServingBase(ABC):
trust_request_chat_template=trust_request_chat_template,
)
@abstractmethod
# Shared thread pool executor for preprocessing and postprocessing.
self._executor: Executor = models.renderer._executor
self._preprocessing_async = make_async(
self._preprocessing, executor=self._executor
)
self._postprocessing_async = make_async(
self._postprocessing, executor=self._executor
)
async def __call__(
self,
request: AnyPoolingRequest,
raw_request: Request | None = None,
) -> Response:
io_processor = self.get_io_processor(request)
ctx = await self._init_ctx(io_processor, request, raw_request)
await self._preprocessing_async(io_processor, ctx)
await self._prepare_generators(ctx)
await self._collect_batch(ctx)
return await self._postprocessing_async(io_processor, ctx)
@abstractmethod
def get_io_processor(self, request: AnyPoolingRequest) -> PoolingIOProcessor:
raise NotImplementedError
@torch.inference_mode()
def _preprocessing(
self, io_processor: PoolingIOProcessor, ctx: PoolingServeContext
):
return io_processor.pre_process_online(ctx)
@torch.inference_mode()
def _postprocessing(
self, io_processor: PoolingIOProcessor, ctx: PoolingServeContext
):
io_processor.post_process_online(ctx)
return self._build_response(ctx)
async def _init_ctx(
self,
io_processor: PoolingIOProcessor,
request: AnyPoolingRequest,
raw_request: Request | None = None,
):
......@@ -84,10 +117,12 @@ class PoolingServingBase(ABC):
request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
await self._check_model(request)
pooling_params = io_processor.create_pooling_params(request)
ctx = PoolingServeContext(
request=request,
raw_request=raw_request,
model_name=model_name,
pooling_params=pooling_params,
request_id=request_id,
)
......@@ -175,7 +210,7 @@ class PoolingServingBase(ABC):
ctx.final_res_batch = [res for res in final_res_batch if res is not None]
@abstractmethod
async def _build_response(
def _build_response(
self,
ctx: PoolingServeContext,
) -> Response:
......@@ -362,18 +397,5 @@ class PoolingServing(PoolingServingBase, ABC):
) -> PoolingIOProcessor:
raise NotImplementedError
async def __call__(
self,
request: AnyPoolingRequest,
raw_request: Request | None = None,
) -> Response:
ctx = await self._init_ctx(request, raw_request)
await self.io_processor.pre_process_online_async(ctx)
if ctx.pooling_params is None:
ctx.pooling_params = self.io_processor.create_pooling_params(request)
await self._prepare_generators(ctx)
await self._collect_batch(ctx)
await self.io_processor.post_process_online_async(ctx)
return await self._build_response(ctx)
def get_io_processor(self, request: AnyPoolingRequest) -> PoolingIOProcessor:
return self.io_processor
......@@ -31,7 +31,7 @@ class ServingClassification(PoolingServing):
def init_io_processor(self, *args, **kwargs) -> ClassifyIOProcessor:
return ClassifyIOProcessor(*args, **kwargs)
async def _build_response(
def _build_response(
self,
ctx: ClassificationServeContext,
) -> JSONResponse:
......
......@@ -470,7 +470,7 @@ class EmbedIOProcessor(PoolingIOProcessor):
truncate_prompt_tokens=truncate_prompt_tokens,
truncation_side=truncation_side,
)
return self._preprocess_completion_online(
return self._preprocess_cmpl_online(
proxy, prompt_input=proxy.input, prompt_embeds=None
)
......@@ -579,7 +579,7 @@ class JinaRankingTokenEmbedIOProcessor(
query=text_prompts[-1], docs=text_prompts[:-1]
)
engine_inputs = self._preprocess_completion_online(
engine_inputs = self._preprocess_cmpl_online(
request,
prompt_input=prompt_input,
prompt_embeds=None,
......
......@@ -53,15 +53,15 @@ class ServingEmbedding(PoolingServing):
def init_io_processor(self, *args, **kwargs) -> EmbedIOProcessor:
return EmbedIOProcessor(*args, **kwargs)
async def _build_response(
def _build_response(
self,
ctx: PoolingServeContext,
) -> Response:
if isinstance(ctx.request, CohereEmbedRequest):
return self._build_cohere_response_from_ctx(ctx)
return await self._build_openai_response(ctx)
return self._build_openai_response(ctx)
async def _build_openai_response(
def _build_openai_response(
self,
ctx: EmbeddingServeContext,
) -> JSONResponse | StreamingResponse:
......
......@@ -7,11 +7,11 @@ from collections.abc import Callable
from functools import partial
from typing import Literal, cast
from fastapi import Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from typing_extensions import assert_never
from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.base.serving import PoolingServingBase
from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors
from vllm.entrypoints.pooling.pooling.protocol import (
......@@ -57,27 +57,10 @@ class ServingPooling(PoolingServingBase):
)
self.json_response_cls = get_json_response_cls()
async def __call__(
self,
request: AnyPoolingRequest,
raw_request: Request | None = None,
) -> Response:
def get_io_processor(self, request: AnyPoolingRequest) -> PoolingIOProcessor:
assert isinstance(request, PoolingRequest)
pooling_task = self._verify_pooling_task(request)
io_processor = self.io_processors[pooling_task]
ctx = await self._init_ctx(request, raw_request)
await io_processor.pre_process_online_async(ctx)
if ctx.pooling_params is None:
ctx.pooling_params = io_processor.create_pooling_params(request)
await self._prepare_generators(ctx)
await self._collect_batch(ctx)
await io_processor.post_process_online_async(ctx)
return await self._build_response(ctx)
return self.io_processors[pooling_task]
def _verify_pooling_task(self, request: PoolingRequest) -> str:
if getattr(request, "dimensions", None) is not None:
......@@ -117,7 +100,7 @@ class ServingPooling(PoolingServingBase):
return pooling_task
async def _build_response(
def _build_response(
self,
ctx: PoolingServeContext,
) -> Response:
......
......@@ -220,7 +220,7 @@ class BiEncoderIOProcessor(ScoringIOProcessor):
scoring_data.data_2, "document", self.model_config
)
return self._preprocess_completion_offline(
return self._preprocess_cmpl_offline(
prompts=data_1 + data_2, tok_params=tok_params, prompt_extras=prompt_extras
)
......@@ -682,7 +682,7 @@ class JinaRankingIOProcessor(LateInteractionIOProcessor, JinaRankingIOProcessorM
for q, d in zip(queries, docs)
]
return self._preprocess_completion_offline(
return self._preprocess_cmpl_offline(
prompts=prompts, tok_params=tok_params, prompt_extras=prompt_extras
)
......
......@@ -65,7 +65,7 @@ class ServingScores(PoolingServing):
return await self.flash_late_interaction(*args, **kwargs)
async def _build_response(
def _build_response(
self,
ctx: ScoringServeContext,
) -> JSONResponse:
......@@ -183,17 +183,15 @@ class ServingScores(PoolingServing):
### Can significantly improve late-interaction scoring performance.
async def flash_late_interaction(self, *args, **kwargs) -> Response:
ctx = await self._init_ctx(*args, **kwargs)
ctx.pooling_params = self.io_processor.create_pooling_params(ctx.request)
await self.io_processor.pre_process_online_async(ctx)
ctx = await self._init_ctx(self.io_processor, *args, **kwargs)
await self._preprocessing_async(self.io_processor, ctx)
# stage 1: encode queries and cache token embeddings on workers.
await self._flash_late_interaction_encode_queries(ctx)
# stage 2: encode docs and return scalar scores from workers.
await self._flash_late_interaction_encode_docs(ctx)
await self.io_processor.post_process_online_async(ctx)
return await self._build_response(ctx)
return await self._postprocessing_async(self.io_processor, ctx)
async def _flash_late_interaction_encode_queries(self, ctx: ScoringServeContext):
assert ctx.n_queries is not None
......
......@@ -69,9 +69,9 @@ class PoolingServeContext(Generic[PoolingRequestT]):
raw_request: Request | None = None
model_name: str
request_id: str
pooling_params: PoolingParams | list[PoolingParams]
created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None
pooling_params: PoolingParams | list[PoolingParams] | None = None
engine_inputs: Sequence[EngineInput] | None = None
prompt_request_ids: list[str] | None = None
intermediates: Any | None = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment