Unverified Commit c0ecaed9 authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Frontend] Offload blocking preprocessing & postprocessing ops to thread pool...


[Frontend] Offload blocking preprocessing & postprocessing ops to thread pool for pooling entrypoints. (#39763)
Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: default avatarwang.yuqi <noooop@126.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 0008729a
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import pytest import pytest
from vllm import PoolingParams
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
from vllm.entrypoints.pooling.embed.protocol import ( from vllm.entrypoints.pooling.embed.protocol import (
CohereEmbedContent, CohereEmbedContent,
...@@ -218,6 +219,7 @@ class TestPreProcessCohereOnline: ...@@ -218,6 +219,7 @@ class TestPreProcessCohereOnline:
def _make_context(**request_kwargs) -> PoolingServeContext[CohereEmbedRequest]: def _make_context(**request_kwargs) -> PoolingServeContext[CohereEmbedRequest]:
return PoolingServeContext( return PoolingServeContext(
request=CohereEmbedRequest(model="test", **request_kwargs), request=CohereEmbedRequest(model="test", **request_kwargs),
pooling_params=PoolingParams(),
model_name="test", model_name="test",
request_id="embd-test", request_id="embd-test",
) )
...@@ -233,13 +235,13 @@ class TestPreProcessCohereOnline: ...@@ -233,13 +235,13 @@ class TestPreProcessCohereOnline:
ctx = self._make_context(texts=["hello"]) ctx = self._make_context(texts=["hello"])
calls: list[tuple[str, object]] = [] calls: list[tuple[str, object]] = []
def preprocess_completion(request, prompt_input, prompt_embeds): def preprocess_cmpl_online(request, prompt_input, prompt_embeds):
calls.append(("completion", prompt_input)) calls.append(("completion", prompt_input))
return ["completion"] return ["completion"]
handler._get_task_instruction_prefix = lambda _input_type: None handler._get_task_instruction_prefix = lambda _input_type: None
handler._has_chat_template = lambda: False handler._has_chat_template = lambda: False
handler._preprocess_completion_online = preprocess_completion handler._preprocess_cmpl_online = preprocess_cmpl_online
handler._batch_render_chat = lambda *_args, **_kwargs: ( handler._batch_render_chat = lambda *_args, **_kwargs: (
pytest.fail("text-only request should not require chat rendering") pytest.fail("text-only request should not require chat rendering")
) )
...@@ -254,7 +256,7 @@ class TestPreProcessCohereOnline: ...@@ -254,7 +256,7 @@ class TestPreProcessCohereOnline:
ctx = self._make_context(texts=["hello"], input_type="query") ctx = self._make_context(texts=["hello"], input_type="query")
calls: list[tuple[str, object]] = [] calls: list[tuple[str, object]] = []
def preprocess_completion(request, prompt_input, prompt_embeds): def preprocess_cmpl(request, prompt_input, prompt_embeds):
calls.append(("completion", prompt_input)) calls.append(("completion", prompt_input))
return ["fallback"] return ["fallback"]
...@@ -263,7 +265,7 @@ class TestPreProcessCohereOnline: ...@@ -263,7 +265,7 @@ class TestPreProcessCohereOnline:
handler._batch_render_chat = lambda *_args, **_kwargs: ( handler._batch_render_chat = lambda *_args, **_kwargs: (
pytest.fail("chat rendering should be skipped without a template") pytest.fail("chat rendering should be skipped without a template")
) )
handler._preprocess_completion_online = preprocess_completion handler._preprocess_cmpl_online = preprocess_cmpl
handler._pre_process_cohere_online(ctx) handler._pre_process_cohere_online(ctx)
...@@ -297,7 +299,7 @@ class TestPreProcessCohereOnline: ...@@ -297,7 +299,7 @@ class TestPreProcessCohereOnline:
handler._get_task_instruction_prefix = lambda _input_type: "query: " handler._get_task_instruction_prefix = lambda _input_type: "query: "
handler._has_chat_template = lambda: True handler._has_chat_template = lambda: True
handler._batch_render_chat = batch_render_chat handler._batch_render_chat = batch_render_chat
handler._preprocess_completion_online = lambda *_args, **_kwargs: ( handler._preprocess_cmpl_online = lambda *_args, **_kwargs: (
pytest.fail("completion path should be skipped when a template exists") pytest.fail("completion path should be skipped when a template exists")
) )
......
...@@ -72,7 +72,7 @@ class PoolingIOProcessor: ...@@ -72,7 +72,7 @@ class PoolingIOProcessor:
default_template_kwargs=None, default_template_kwargs=None,
) )
elif isinstance(request, PoolingCompletionLikeRequest): elif isinstance(request, PoolingCompletionLikeRequest):
engine_inputs = self._preprocess_completion_online( engine_inputs = self._preprocess_cmpl_online(
request, request,
prompt_input=request.input, prompt_input=request.input,
prompt_embeds=None, prompt_embeds=None,
...@@ -82,21 +82,12 @@ class PoolingIOProcessor: ...@@ -82,21 +82,12 @@ class PoolingIOProcessor:
ctx.engine_inputs = engine_inputs ctx.engine_inputs = engine_inputs
async def pre_process_online_async(self, ctx: PoolingServeContext):
self.pre_process_online(ctx)
def post_process_online( def post_process_online(
self, self,
ctx: PoolingServeContext, ctx: PoolingServeContext,
): ):
pass pass
async def post_process_online_async(
self,
ctx: PoolingServeContext,
):
self.post_process_online(ctx)
####################################### #######################################
# offline APIs # offline APIs
...@@ -109,12 +100,7 @@ class PoolingIOProcessor: ...@@ -109,12 +100,7 @@ class PoolingIOProcessor:
tok_params = self.renderer.default_cmpl_tok_params.with_kwargs( tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
**(ctx.tokenization_kwargs or {}) **(ctx.tokenization_kwargs or {})
) )
return self._preprocess_completion_offline( return self._preprocess_cmpl_offline(prompts=prompts_seq, tok_params=tok_params)
prompts=prompts_seq, tok_params=tok_params
)
async def pre_process_offline_async(self, ctx: OfflineInputsContext):
return self.pre_process_offline(ctx)
def post_process_offline( def post_process_offline(
self, self,
...@@ -122,16 +108,10 @@ class PoolingIOProcessor: ...@@ -122,16 +108,10 @@ class PoolingIOProcessor:
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
return ctx.outputs return ctx.outputs
async def post_process_offline_async(
self,
ctx: OfflineOutputsContext,
) -> list[PoolingRequestOutput]:
return self.post_process_offline(ctx)
####################################### #######################################
# helpers # helpers
def _preprocess_completion_online( def _preprocess_cmpl_online(
self, self,
request: RendererRequest, request: RendererRequest,
prompt_input: str | list[str] | list[int] | list[list[int]] | None, prompt_input: str | list[str] | list[int] | list[list[int]] | None,
...@@ -209,7 +189,7 @@ class PoolingIOProcessor: ...@@ -209,7 +189,7 @@ class PoolingIOProcessor:
return conversation, [engine_input] return conversation, [engine_input]
def _preprocess_completion_offline( def _preprocess_cmpl_offline(
self, self,
prompts: PromptType | Sequence[PromptType], prompts: PromptType | Sequence[PromptType],
tok_params: TokenizeParams, tok_params: TokenizeParams,
......
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Mapping from collections.abc import AsyncGenerator, Mapping
from concurrent.futures import Executor
from http import HTTPStatus from http import HTTPStatus
from typing import ClassVar from typing import ClassVar
import torch
from fastapi import Request from fastapi import Request
from fastapi.responses import Response from fastapi.responses import Response
from starlette.datastructures import Headers from starlette.datastructures import Headers
...@@ -32,7 +34,7 @@ from vllm.tracing import ( ...@@ -32,7 +34,7 @@ from vllm.tracing import (
log_tracing_disabled_warning, log_tracing_disabled_warning,
) )
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import make_async, merge_async_iterators
from .io_processor import PoolingIOProcessor from .io_processor import PoolingIOProcessor
...@@ -67,16 +69,47 @@ class PoolingServingBase(ABC): ...@@ -67,16 +69,47 @@ class PoolingServingBase(ABC):
trust_request_chat_template=trust_request_chat_template, trust_request_chat_template=trust_request_chat_template,
) )
@abstractmethod # Shared thread pool executor for preprocessing and postprocessing.
self._executor: Executor = models.renderer._executor
self._preprocessing_async = make_async(
self._preprocessing, executor=self._executor
)
self._postprocessing_async = make_async(
self._postprocessing, executor=self._executor
)
async def __call__( async def __call__(
self, self,
request: AnyPoolingRequest, request: AnyPoolingRequest,
raw_request: Request | None = None, raw_request: Request | None = None,
) -> Response: ) -> Response:
io_processor = self.get_io_processor(request)
ctx = await self._init_ctx(io_processor, request, raw_request)
await self._preprocessing_async(io_processor, ctx)
await self._prepare_generators(ctx)
await self._collect_batch(ctx)
return await self._postprocessing_async(io_processor, ctx)
@abstractmethod
def get_io_processor(self, request: AnyPoolingRequest) -> PoolingIOProcessor:
raise NotImplementedError raise NotImplementedError
@torch.inference_mode()
def _preprocessing(
self, io_processor: PoolingIOProcessor, ctx: PoolingServeContext
):
return io_processor.pre_process_online(ctx)
@torch.inference_mode()
def _postprocessing(
self, io_processor: PoolingIOProcessor, ctx: PoolingServeContext
):
io_processor.post_process_online(ctx)
return self._build_response(ctx)
async def _init_ctx( async def _init_ctx(
self, self,
io_processor: PoolingIOProcessor,
request: AnyPoolingRequest, request: AnyPoolingRequest,
raw_request: Request | None = None, raw_request: Request | None = None,
): ):
...@@ -84,10 +117,12 @@ class PoolingServingBase(ABC): ...@@ -84,10 +117,12 @@ class PoolingServingBase(ABC):
request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}" request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
await self._check_model(request) await self._check_model(request)
pooling_params = io_processor.create_pooling_params(request)
ctx = PoolingServeContext( ctx = PoolingServeContext(
request=request, request=request,
raw_request=raw_request, raw_request=raw_request,
model_name=model_name, model_name=model_name,
pooling_params=pooling_params,
request_id=request_id, request_id=request_id,
) )
...@@ -175,7 +210,7 @@ class PoolingServingBase(ABC): ...@@ -175,7 +210,7 @@ class PoolingServingBase(ABC):
ctx.final_res_batch = [res for res in final_res_batch if res is not None] ctx.final_res_batch = [res for res in final_res_batch if res is not None]
@abstractmethod @abstractmethod
async def _build_response( def _build_response(
self, self,
ctx: PoolingServeContext, ctx: PoolingServeContext,
) -> Response: ) -> Response:
...@@ -362,18 +397,5 @@ class PoolingServing(PoolingServingBase, ABC): ...@@ -362,18 +397,5 @@ class PoolingServing(PoolingServingBase, ABC):
) -> PoolingIOProcessor: ) -> PoolingIOProcessor:
raise NotImplementedError raise NotImplementedError
async def __call__( def get_io_processor(self, request: AnyPoolingRequest) -> PoolingIOProcessor:
self, return self.io_processor
request: AnyPoolingRequest,
raw_request: Request | None = None,
) -> Response:
ctx = await self._init_ctx(request, raw_request)
await self.io_processor.pre_process_online_async(ctx)
if ctx.pooling_params is None:
ctx.pooling_params = self.io_processor.create_pooling_params(request)
await self._prepare_generators(ctx)
await self._collect_batch(ctx)
await self.io_processor.post_process_online_async(ctx)
return await self._build_response(ctx)
...@@ -31,7 +31,7 @@ class ServingClassification(PoolingServing): ...@@ -31,7 +31,7 @@ class ServingClassification(PoolingServing):
def init_io_processor(self, *args, **kwargs) -> ClassifyIOProcessor: def init_io_processor(self, *args, **kwargs) -> ClassifyIOProcessor:
return ClassifyIOProcessor(*args, **kwargs) return ClassifyIOProcessor(*args, **kwargs)
async def _build_response( def _build_response(
self, self,
ctx: ClassificationServeContext, ctx: ClassificationServeContext,
) -> JSONResponse: ) -> JSONResponse:
......
...@@ -470,7 +470,7 @@ class EmbedIOProcessor(PoolingIOProcessor): ...@@ -470,7 +470,7 @@ class EmbedIOProcessor(PoolingIOProcessor):
truncate_prompt_tokens=truncate_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens,
truncation_side=truncation_side, truncation_side=truncation_side,
) )
return self._preprocess_completion_online( return self._preprocess_cmpl_online(
proxy, prompt_input=proxy.input, prompt_embeds=None proxy, prompt_input=proxy.input, prompt_embeds=None
) )
...@@ -579,7 +579,7 @@ class JinaRankingTokenEmbedIOProcessor( ...@@ -579,7 +579,7 @@ class JinaRankingTokenEmbedIOProcessor(
query=text_prompts[-1], docs=text_prompts[:-1] query=text_prompts[-1], docs=text_prompts[:-1]
) )
engine_inputs = self._preprocess_completion_online( engine_inputs = self._preprocess_cmpl_online(
request, request,
prompt_input=prompt_input, prompt_input=prompt_input,
prompt_embeds=None, prompt_embeds=None,
......
...@@ -53,15 +53,15 @@ class ServingEmbedding(PoolingServing): ...@@ -53,15 +53,15 @@ class ServingEmbedding(PoolingServing):
def init_io_processor(self, *args, **kwargs) -> EmbedIOProcessor: def init_io_processor(self, *args, **kwargs) -> EmbedIOProcessor:
return EmbedIOProcessor(*args, **kwargs) return EmbedIOProcessor(*args, **kwargs)
async def _build_response( def _build_response(
self, self,
ctx: PoolingServeContext, ctx: PoolingServeContext,
) -> Response: ) -> Response:
if isinstance(ctx.request, CohereEmbedRequest): if isinstance(ctx.request, CohereEmbedRequest):
return self._build_cohere_response_from_ctx(ctx) return self._build_cohere_response_from_ctx(ctx)
return await self._build_openai_response(ctx) return self._build_openai_response(ctx)
async def _build_openai_response( def _build_openai_response(
self, self,
ctx: EmbeddingServeContext, ctx: EmbeddingServeContext,
) -> JSONResponse | StreamingResponse: ) -> JSONResponse | StreamingResponse:
......
...@@ -7,11 +7,11 @@ from collections.abc import Callable ...@@ -7,11 +7,11 @@ from collections.abc import Callable
from functools import partial from functools import partial
from typing import Literal, cast from typing import Literal, cast
from fastapi import Request
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.entrypoints.openai.engine.protocol import UsageInfo from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.base.serving import PoolingServingBase from vllm.entrypoints.pooling.base.serving import PoolingServingBase
from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors
from vllm.entrypoints.pooling.pooling.protocol import ( from vllm.entrypoints.pooling.pooling.protocol import (
...@@ -57,27 +57,10 @@ class ServingPooling(PoolingServingBase): ...@@ -57,27 +57,10 @@ class ServingPooling(PoolingServingBase):
) )
self.json_response_cls = get_json_response_cls() self.json_response_cls = get_json_response_cls()
async def __call__( def get_io_processor(self, request: AnyPoolingRequest) -> PoolingIOProcessor:
self,
request: AnyPoolingRequest,
raw_request: Request | None = None,
) -> Response:
assert isinstance(request, PoolingRequest) assert isinstance(request, PoolingRequest)
pooling_task = self._verify_pooling_task(request) pooling_task = self._verify_pooling_task(request)
return self.io_processors[pooling_task]
io_processor = self.io_processors[pooling_task]
ctx = await self._init_ctx(request, raw_request)
await io_processor.pre_process_online_async(ctx)
if ctx.pooling_params is None:
ctx.pooling_params = io_processor.create_pooling_params(request)
await self._prepare_generators(ctx)
await self._collect_batch(ctx)
await io_processor.post_process_online_async(ctx)
return await self._build_response(ctx)
def _verify_pooling_task(self, request: PoolingRequest) -> str: def _verify_pooling_task(self, request: PoolingRequest) -> str:
if getattr(request, "dimensions", None) is not None: if getattr(request, "dimensions", None) is not None:
...@@ -117,7 +100,7 @@ class ServingPooling(PoolingServingBase): ...@@ -117,7 +100,7 @@ class ServingPooling(PoolingServingBase):
return pooling_task return pooling_task
async def _build_response( def _build_response(
self, self,
ctx: PoolingServeContext, ctx: PoolingServeContext,
) -> Response: ) -> Response:
......
...@@ -220,7 +220,7 @@ class BiEncoderIOProcessor(ScoringIOProcessor): ...@@ -220,7 +220,7 @@ class BiEncoderIOProcessor(ScoringIOProcessor):
scoring_data.data_2, "document", self.model_config scoring_data.data_2, "document", self.model_config
) )
return self._preprocess_completion_offline( return self._preprocess_cmpl_offline(
prompts=data_1 + data_2, tok_params=tok_params, prompt_extras=prompt_extras prompts=data_1 + data_2, tok_params=tok_params, prompt_extras=prompt_extras
) )
...@@ -682,7 +682,7 @@ class JinaRankingIOProcessor(LateInteractionIOProcessor, JinaRankingIOProcessorM ...@@ -682,7 +682,7 @@ class JinaRankingIOProcessor(LateInteractionIOProcessor, JinaRankingIOProcessorM
for q, d in zip(queries, docs) for q, d in zip(queries, docs)
] ]
return self._preprocess_completion_offline( return self._preprocess_cmpl_offline(
prompts=prompts, tok_params=tok_params, prompt_extras=prompt_extras prompts=prompts, tok_params=tok_params, prompt_extras=prompt_extras
) )
......
...@@ -65,7 +65,7 @@ class ServingScores(PoolingServing): ...@@ -65,7 +65,7 @@ class ServingScores(PoolingServing):
return await self.flash_late_interaction(*args, **kwargs) return await self.flash_late_interaction(*args, **kwargs)
async def _build_response( def _build_response(
self, self,
ctx: ScoringServeContext, ctx: ScoringServeContext,
) -> JSONResponse: ) -> JSONResponse:
...@@ -183,17 +183,15 @@ class ServingScores(PoolingServing): ...@@ -183,17 +183,15 @@ class ServingScores(PoolingServing):
### Can significantly improve late-interaction scoring performance. ### Can significantly improve late-interaction scoring performance.
async def flash_late_interaction(self, *args, **kwargs) -> Response: async def flash_late_interaction(self, *args, **kwargs) -> Response:
ctx = await self._init_ctx(*args, **kwargs) ctx = await self._init_ctx(self.io_processor, *args, **kwargs)
ctx.pooling_params = self.io_processor.create_pooling_params(ctx.request) await self._preprocessing_async(self.io_processor, ctx)
await self.io_processor.pre_process_online_async(ctx)
# stage 1: encode queries and cache token embeddings on workers. # stage 1: encode queries and cache token embeddings on workers.
await self._flash_late_interaction_encode_queries(ctx) await self._flash_late_interaction_encode_queries(ctx)
# stage 2: encode docs and return scalar scores from workers. # stage 2: encode docs and return scalar scores from workers.
await self._flash_late_interaction_encode_docs(ctx) await self._flash_late_interaction_encode_docs(ctx)
await self.io_processor.post_process_online_async(ctx) return await self._postprocessing_async(self.io_processor, ctx)
return await self._build_response(ctx)
async def _flash_late_interaction_encode_queries(self, ctx: ScoringServeContext): async def _flash_late_interaction_encode_queries(self, ctx: ScoringServeContext):
assert ctx.n_queries is not None assert ctx.n_queries is not None
......
...@@ -69,9 +69,9 @@ class PoolingServeContext(Generic[PoolingRequestT]): ...@@ -69,9 +69,9 @@ class PoolingServeContext(Generic[PoolingRequestT]):
raw_request: Request | None = None raw_request: Request | None = None
model_name: str model_name: str
request_id: str request_id: str
pooling_params: PoolingParams | list[PoolingParams]
created_time: int = field(default_factory=lambda: int(time.time())) created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None lora_request: LoRARequest | None = None
pooling_params: PoolingParams | list[PoolingParams] | None = None
engine_inputs: Sequence[EngineInput] | None = None engine_inputs: Sequence[EngineInput] | None = None
prompt_request_ids: list[str] | None = None prompt_request_ids: list[str] | None = None
intermediates: Any | None = None intermediates: Any | None = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment