Unverified Commit f91b42b9 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(frontend): Reduce Python-side overhead in the vLLM chat path (#6437)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent fd839b8d
...@@ -82,7 +82,8 @@ class FrontendConfig(ConfigBase): ...@@ -82,7 +82,8 @@ class FrontendConfig(ConfigBase):
event_plane: str event_plane: str
chat_processor: str chat_processor: str
enable_anthropic_api: bool enable_anthropic_api: bool
exp_python_factory: bool debug_perf: bool
preprocess_workers: int
def validate(self) -> None: def validate(self) -> None:
if bool(self.tls_cert_path) ^ bool(self.tls_key_path): # ^ is XOR if bool(self.tls_cert_path) ^ bool(self.tls_key_path): # ^ is XOR
...@@ -515,9 +516,10 @@ class FrontendArgGroup(ArgGroup): ...@@ -515,9 +516,10 @@ class FrontendArgGroup(ArgGroup):
) )
add_argument( add_argument(
g, g,
flag_name="--chat-processor", flag_name="--dyn-chat-processor",
env_var="DYN_CHAT_PROCESSOR", env_var="DYN_CHAT_PROCESSOR",
default="dynamo", default="dynamo",
dest="chat_processor",
help=( help=(
"[EXPERIMENTAL] When set to 'vllm', use local vllm for the pre and post " "[EXPERIMENTAL] When set to 'vllm', use local vllm for the pre and post "
"processor." "processor."
...@@ -527,11 +529,28 @@ class FrontendArgGroup(ArgGroup): ...@@ -527,11 +529,28 @@ class FrontendArgGroup(ArgGroup):
add_negatable_bool_argument( add_negatable_bool_argument(
g, g,
flag_name="--exp-python-factory", flag_name="--dyn-debug-perf",
env_var="DYN_EXP_PYTHON_FACTORY", env_var="DYN_DEBUG_PERF",
default=False, default=False,
dest="debug_perf",
help=( help=(
"[EXPERIMENTAL] Enable Python-based engine factory. When set, engines will be " "[EXPERIMENTAL] Enable performance instrumentation for diagnosing preprocessing bottlenecks. "
"created via a Python callback instead of the default Rust pipeline." "Logs per-function timing, request concurrency, and hot-path section durations. "
"'--dyn-chat-processor vllm' only."
), ),
) )
add_argument(
g,
flag_name="--dyn-preprocess-workers",
env_var="DYN_PREPROCESS_WORKERS",
default=0,
dest="preprocess_workers",
help=(
"[EXPERIMENTAL] Number of worker processes for preprocessing and output processing. "
"When > 0, offloads CPU-bound work (tokenization, template rendering, "
"detokenization) to a ProcessPoolExecutor with N workers, each with its "
"own GIL. 0 (default) keeps all processing on the main event loop. '--dyn-chat-processor vllm' only."
),
arg_type=int,
)
...@@ -57,7 +57,7 @@ def setup_engine_factory( ...@@ -57,7 +57,7 @@ def setup_engine_factory(
""" """
from .vllm_processor import EngineFactory from .vllm_processor import EngineFactory
return EngineFactory(runtime, router_config, config, vllm_flags) return EngineFactory(runtime, router_config, config, vllm_flags, config.debug_perf)
def parse_args() -> tuple[FrontendConfig, Optional[Namespace]]: def parse_args() -> tuple[FrontendConfig, Optional[Namespace]]:
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Performance instrumentation for diagnosing frontend preprocessing bottlenecks.
Activated by passing --dyn-debug-perf to dynamo.frontend.
"""
from __future__ import annotations
import logging
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Concurrency gauge
# ---------------------------------------------------------------------------
_active_requests = 0
_peak_requests = 0
def enter_generator() -> int:
"""Increment active request count. Returns current count.
Safe without a lock: only called while the GIL is held (all callers are
in Python code), so the read-modify-write on the global int is atomic
with respect to other Python threads.
"""
global _active_requests, _peak_requests
_active_requests += 1
count = _active_requests
if count > _peak_requests:
_peak_requests = count
return count
def exit_generator() -> int:
"""Decrement active request count. Returns current count."""
global _active_requests
_active_requests -= 1
return _active_requests
def get_active_requests() -> int:
return _active_requests
def get_peak_requests() -> int:
return _peak_requests
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import os
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
...@@ -13,6 +14,7 @@ from vllm.reasoning import ReasoningParser ...@@ -13,6 +14,7 @@ from vllm.reasoning import ReasoningParser
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser from vllm.tool_parsers import ToolParser
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
@dataclass @dataclass
...@@ -24,6 +26,19 @@ class PreprocessResult: ...@@ -24,6 +26,19 @@ class PreprocessResult:
prompt_token_ids: list[int] prompt_token_ids: list[int]
_ASYNC_TOKENIZER_POOL: dict[int, AsyncMicrobatchTokenizer] = {}
SKIP_REQUEST_VALIDATION = os.getenv("DYN_VLLM_SKIP_REQUEST_VALIDATION", "1") == "1"
def _get_async_tokenizer(tokenizer: TokenizerLike) -> AsyncMicrobatchTokenizer:
key = id(tokenizer)
async_tokenizer = _ASYNC_TOKENIZER_POOL.get(key)
if async_tokenizer is None:
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
_ASYNC_TOKENIZER_POOL[key] = async_tokenizer
return async_tokenizer
def _materialize_assistant_tool_calls( def _materialize_assistant_tool_calls(
messages: Sequence[Any], messages: Sequence[Any],
) -> list[dict[str, Any] | Any]: ) -> list[dict[str, Any] | Any]:
...@@ -53,13 +68,33 @@ def _materialize_assistant_tool_calls( ...@@ -53,13 +68,33 @@ def _materialize_assistant_tool_calls(
return normalized return normalized
async def preprocess_chat_request( def _prepare_request(
request: dict[str, Any], request: dict[str, Any] | ChatCompletionRequest,
*, *,
tokenizer: TokenizerLike, tokenizer: TokenizerLike,
renderer,
tool_parser_class: type[ToolParser] | None, tool_parser_class: type[ToolParser] | None,
) -> PreprocessResult: ) -> tuple[
ChatCompletionRequest, ToolParser | None, dict[str, Any], Any, dict[str, Any]
]:
"""Validate request and build arguments for template rendering.
Returns:
request_for_sampling: Validated ChatCompletionRequest.
tool_parser: Instantiated tool parser, or None.
chat_template_kwargs: Template kwargs (for PreprocessResult).
messages_for_render: Messages to pass as first arg to render_messages.
render_kwargs: Keyword arguments for render_messages / render_messages_async.
"""
if isinstance(request, ChatCompletionRequest):
request_for_sampling = request
elif SKIP_REQUEST_VALIDATION:
# Trusted fast path; caller must provide OpenAI-compatible payload.
request_for_sampling = ChatCompletionRequest.model_construct(**request)
if request_for_sampling.tools and any(
not hasattr(tool, "model_dump") for tool in request_for_sampling.tools
):
request_for_sampling = ChatCompletionRequest.model_validate(request)
else:
request_for_sampling = ChatCompletionRequest.model_validate(request) request_for_sampling = ChatCompletionRequest.model_validate(request)
tool_parser: ToolParser | None = None tool_parser: ToolParser | None = None
...@@ -88,8 +123,7 @@ async def preprocess_chat_request( ...@@ -88,8 +123,7 @@ async def preprocess_chat_request(
else request_for_sampling.messages else request_for_sampling.messages
) )
_, engine_prompt = await renderer.render_messages_async( render_kwargs = dict(
messages_for_render,
chat_template=request_for_sampling.chat_template, chat_template=request_for_sampling.chat_template,
chat_template_content_format="auto", chat_template_content_format="auto",
add_generation_prompt=request_for_sampling.add_generation_prompt, add_generation_prompt=request_for_sampling.add_generation_prompt,
...@@ -100,6 +134,73 @@ async def preprocess_chat_request( ...@@ -100,6 +134,73 @@ async def preprocess_chat_request(
**chat_template_kwargs, **chat_template_kwargs,
) )
return (
request_for_sampling,
tool_parser,
chat_template_kwargs,
messages_for_render,
render_kwargs,
)
async def preprocess_chat_request(
request: dict[str, Any] | ChatCompletionRequest,
*,
tokenizer: TokenizerLike,
renderer,
tool_parser_class: type[ToolParser] | None,
) -> PreprocessResult:
(
request_for_sampling,
tool_parser,
chat_template_kwargs,
messages,
render_kwargs,
) = _prepare_request(
request, tokenizer=tokenizer, tool_parser_class=tool_parser_class
)
_, engine_prompt = await renderer.render_messages_async(messages, **render_kwargs)
if "prompt_token_ids" in engine_prompt:
tokens = list(engine_prompt["prompt_token_ids"])
else:
async_tokenizer = _get_async_tokenizer(tokenizer)
encoded = await async_tokenizer(
engine_prompt["prompt"],
add_special_tokens=request_for_sampling.add_special_tokens,
)
tokens = list(encoded.input_ids)
return PreprocessResult(
request_for_sampling=request_for_sampling,
tool_parser=tool_parser,
chat_template_kwargs=chat_template_kwargs,
engine_prompt=engine_prompt,
prompt_token_ids=tokens,
)
def preprocess_chat_request_sync(
request: dict[str, Any] | ChatCompletionRequest,
*,
tokenizer: TokenizerLike,
renderer,
tool_parser_class: type[ToolParser] | None,
) -> PreprocessResult:
"""Sync version of preprocess_chat_request for worker processes."""
(
request_for_sampling,
tool_parser,
chat_template_kwargs,
messages,
render_kwargs,
) = _prepare_request(
request, tokenizer=tokenizer, tool_parser_class=tool_parser_class
)
_, engine_prompt = renderer.render_messages(messages, **render_kwargs)
if "prompt_token_ids" in engine_prompt: if "prompt_token_ids" in engine_prompt:
tokens = list(engine_prompt["prompt_token_ids"]) tokens = list(engine_prompt["prompt_token_ids"])
else: else:
...@@ -141,6 +242,9 @@ class StreamingPostProcessor: ...@@ -141,6 +242,9 @@ class StreamingPostProcessor:
if reasoning_parser_class if reasoning_parser_class
else None else None
) )
self._fast_plain_text = (
self.tool_parser is None and self.reasoning_parser is None
)
self._control_markers = tuple( self._control_markers = tuple(
t for t in getattr(tokenizer, "all_special_tokens", ()) if t t for t in getattr(tokenizer, "all_special_tokens", ()) if t
...@@ -191,6 +295,23 @@ class StreamingPostProcessor: ...@@ -191,6 +295,23 @@ class StreamingPostProcessor:
# to text. Re-detokenizing from token_ids can reintroduce stop markers. # to text. Re-detokenizing from token_ids can reintroduce stop markers.
delta_text = output.text or "" delta_text = output.text or ""
if self._fast_plain_text:
if delta_text:
delta: dict[str, Any] = {
"role": "assistant",
"content": delta_text,
}
elif output.finish_reason:
delta = {}
else:
return None
return {
"index": output.index,
"delta": delta,
"finish_reason": output.finish_reason,
"logprobs": output.logprobs,
}
current_text = self.previous_text + delta_text current_text = self.previous_text + delta_text
current_token_ids = self.previous_token_ids + delta_token_ids current_token_ids = self.previous_token_ids + delta_token_ids
......
...@@ -12,6 +12,9 @@ import time ...@@ -12,6 +12,9 @@ import time
import uuid import uuid
from argparse import Namespace from argparse import Namespace
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import wait as _futures_wait
from dataclasses import dataclass
from typing import Any from typing import Any
from vllm.config import CacheConfig, LoadConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, LoadConfig, ModelConfig, VllmConfig
...@@ -36,7 +39,11 @@ from dynamo.llm import ( ...@@ -36,7 +39,11 @@ from dynamo.llm import (
) )
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from .prepost import StreamingPostProcessor, preprocess_chat_request from .prepost import (
StreamingPostProcessor,
preprocess_chat_request,
preprocess_chat_request_sync,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -72,6 +79,209 @@ def map_finish_reason(raw_reason: str | None) -> FinishReason | None: ...@@ -72,6 +79,209 @@ def map_finish_reason(raw_reason: str | None) -> FinishReason | None:
return mapped return mapped
# --- Worker process globals (initialized once per process by _init_worker) ---
_w_input_processor: InputProcessor | None = None
_w_tokenizer: Any = None
_w_tool_parser_class: type[ToolParser] | None = None
_w_reasoning_parser_class: type[ReasoningParser] | None = None
_w_stream_interval: int = 20
class _PreprocessError(Exception):
"""Raised by _preprocess_worker for user-facing errors (e.g., n!=1)."""
def __init__(self, error_dict: dict[str, Any]):
self.error_dict = error_dict
super().__init__(str(error_dict))
@dataclass
class PreprocessWorkerResult:
"""Picklable return value from the preprocess worker."""
dynamo_preproc: dict[str, Any]
tokens: list[int]
vllm_preproc: EngineCoreRequest
sampling_params: SamplingParams
request_for_sampling: Any # ChatCompletionRequest (Pydantic model, picklable)
chat_template_kwargs: dict[str, Any]
def _init_worker(
model_path: str,
tokenizer_mode: str,
config_format: str,
load_format: str,
tool_parser_name: str | None,
reasoning_parser_name: str | None,
stream_interval: int,
) -> None:
"""Initialize a worker process with its own VllmConfig and InputProcessor."""
global _w_input_processor, _w_tokenizer, _w_tool_parser_class
global _w_reasoning_parser_class, _w_stream_interval
model_config = ModelConfig(
model=model_path,
tokenizer_mode=tokenizer_mode,
config_format=config_format,
)
vllm_config = VllmConfig(
model_config=model_config,
load_config=LoadConfig(load_format=load_format),
cache_config=CacheConfig(),
)
_w_input_processor = InputProcessor(vllm_config)
_w_tokenizer = _w_input_processor.get_tokenizer()
if tool_parser_name:
_w_tool_parser_class = ToolParserManager.get_tool_parser(tool_parser_name)
else:
_w_tool_parser_class = None
if reasoning_parser_name:
_w_reasoning_parser_class = ReasoningParserManager.get_reasoning_parser(
reasoning_parser_name
)
else:
_w_reasoning_parser_class = None
_w_stream_interval = max(1, stream_interval)
def _worker_warmup() -> bool:
"""Dummy task to ensure worker process is fully initialized."""
return True
def _preprocess_worker(
request: dict[str, Any],
request_id: str,
model_name: str,
) -> PreprocessWorkerResult:
"""Preprocess a request in a worker process and return a picklable result.
This replaces _request_handler's Phase A. No queues — errors propagate
naturally via the Future.
"""
pre = preprocess_chat_request_sync(
request,
tokenizer=_w_tokenizer,
renderer=_w_input_processor.renderer,
tool_parser_class=_w_tool_parser_class,
)
request_for_sampling = pre.request_for_sampling
engine_prompt = pre.engine_prompt
tokens = pre.prompt_token_ids
if request_for_sampling.max_completion_tokens is not None:
max_tokens = request_for_sampling.max_completion_tokens
elif request_for_sampling.max_tokens is not None:
max_tokens = request_for_sampling.max_tokens
else:
max_tokens = None
sampling_params = SamplingParams(
output_kind=RequestOutputKind.DELTA,
max_tokens=max_tokens,
)
for k, v in _w_input_processor.generation_config_fields.items():
if hasattr(sampling_params, k):
setattr(sampling_params, k, v)
sampling_fields = (
set(getattr(SamplingParams, "__annotations__", ()))
& set(type(request_for_sampling).model_fields)
) - {"max_tokens", "logprobs", "output_kind"}
for k in sorted(sampling_fields):
v = getattr(request_for_sampling, k, None)
if v is not None:
setattr(sampling_params, k, v)
logprobs = request_for_sampling.logprobs
top_logprobs = request_for_sampling.top_logprobs
if logprobs is True:
sampling_params.logprobs = top_logprobs or 1
elif isinstance(logprobs, int) and not isinstance(logprobs, bool):
sampling_params.logprobs = logprobs
elif top_logprobs not in (None, 0):
sampling_params.logprobs = top_logprobs
prompt_inputs = TokensPrompt(prompt_token_ids=tokens)
if "multi_modal_data" in engine_prompt:
prompt_inputs["multi_modal_data"] = engine_prompt["multi_modal_data"]
if "multi_modal_uuids" in engine_prompt:
prompt_inputs["multi_modal_uuids"] = engine_prompt["multi_modal_uuids"]
if request_for_sampling.cache_salt is not None:
prompt_inputs["cache_salt"] = request_for_sampling.cache_salt
if request_for_sampling.mm_processor_kwargs is not None:
prompt_inputs["mm_processor_kwargs"] = request_for_sampling.mm_processor_kwargs
vllm_preproc: EngineCoreRequest = _w_input_processor.process_inputs(
request_id,
prompt_inputs,
sampling_params,
)
InputProcessor.assign_request_id(vllm_preproc)
sp = vllm_preproc.sampling_params
if sp.n != 1:
raise _PreprocessError(
{
"error": {
"message": (
f"Unsupported value: 'n={sp.n}'. "
"This endpoint currently supports only n=1."
),
"type": "invalid_request_error",
"param": "n",
"code": "unsupported_value",
}
}
)
dynamo_preproc = {
"model": model_name,
"token_ids": tokens,
"stop_conditions": {
"max_tokens": sp.max_tokens,
"stop": sp.stop,
"stop_token_ids": sp.stop_token_ids,
"min_tokens": sp.min_tokens,
"ignore_eos": sp.ignore_eos,
},
"sampling_options": {
"n": sp.n,
"presence_penalty": sp.presence_penalty,
"frequency_penalty": sp.frequency_penalty,
"repetition_penalty": sp.repetition_penalty,
"temperature": sp.temperature,
"top_p": sp.top_p,
"top_k": sp.top_k,
"min_p": sp.min_p,
"seed": sp.seed,
},
"output_options": {
"logprobs": sp.logprobs,
"prompt_logprobs": sp.prompt_logprobs,
"skip_special_tokens": sp.skip_special_tokens,
},
"eos_token_ids": [vllm_preproc.eos_token_id]
if vllm_preproc.eos_token_id is not None
else [],
"annotations": [],
}
return PreprocessWorkerResult(
dynamo_preproc=dynamo_preproc,
tokens=tokens,
vllm_preproc=vllm_preproc,
sampling_params=sampling_params,
request_for_sampling=request_for_sampling,
chat_template_kwargs=pre.chat_template_kwargs,
)
class VllmProcessor: class VllmProcessor:
def __init__( def __init__(
self, self,
...@@ -81,6 +291,9 @@ class VllmProcessor: ...@@ -81,6 +291,9 @@ class VllmProcessor:
output_processor: OutputProcessor, output_processor: OutputProcessor,
tool_parser_class: type[ToolParser] | None, tool_parser_class: type[ToolParser] | None,
reasoning_parser_class: type[ReasoningParser] | None, reasoning_parser_class: type[ReasoningParser] | None,
debug_perf: bool = False,
preprocess_pool: ProcessPoolExecutor | None = None,
preprocess_workers: int = 0,
): ):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.input_processor = input_processor self.input_processor = input_processor
...@@ -89,6 +302,16 @@ class VllmProcessor: ...@@ -89,6 +302,16 @@ class VllmProcessor:
self.output_processor = output_processor self.output_processor = output_processor
self.tool_parser_class = tool_parser_class self.tool_parser_class = tool_parser_class
self.reasoning_parser_class = reasoning_parser_class self.reasoning_parser_class = reasoning_parser_class
self.debug_perf = debug_perf
self.preprocess_pool = preprocess_pool
if preprocess_pool is not None:
# Allow a small buffer beyond the worker count so the pool's
# internal queue always has work ready when a worker finishes.
self._worker_semaphore: asyncio.Semaphore | None = asyncio.Semaphore(
preprocess_workers + 2
)
else:
self._worker_semaphore = None
# Ideally we would map NVCreateChatCompletionRequest into Python so it can be type checked, but # Ideally we would map NVCreateChatCompletionRequest into Python so it can be type checked, but
# it has a lot of fields. # it has a lot of fields.
...@@ -103,12 +326,55 @@ class VllmProcessor: ...@@ -103,12 +326,55 @@ class VllmProcessor:
# ** VllmProcessor.generator called: {'messages': [{'role': 'user', 'content': 'What is the capital of Tuvalu?'}], 'model': '/home/grahamk/llms/Qwen3-0.6B', 'max_completion_tokens': 1000, 'stream': False} # ** VllmProcessor.generator called: {'messages': [{'role': 'user', 'content': 'What is the capital of Tuvalu?'}], 'model': '/home/grahamk/llms/Qwen3-0.6B', 'max_completion_tokens': 1000, 'stream': False}
if self.debug_perf:
from .perf_instrumentation import enter_generator, exit_generator
active = enter_generator()
t_start = time.monotonic()
logger.info("[perf] generator enter: active_requests=%d", active)
try:
if self.preprocess_pool is None:
# Single process
async for item in self._generator_inner(request):
yield item
else:
# Multi process
async for item in self._generator_inner_pool(request):
yield item
finally:
if self.debug_perf:
active = exit_generator()
elapsed_ms = (time.monotonic() - t_start) * 1000.0
logger.info(
"[perf] generator exit: total=%.2fms active_requests=%d",
elapsed_ms,
active,
)
async def _generator_inner(
self, request: dict[str, Any]
) -> AsyncGenerator[dict[str, Any], None]:
request_id = random_uuid()
if self.debug_perf:
t0 = time.monotonic()
pre = await preprocess_chat_request( pre = await preprocess_chat_request(
request, request,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
renderer=self.input_processor.renderer, renderer=self.input_processor.renderer,
tool_parser_class=self.tool_parser_class, tool_parser_class=self.tool_parser_class,
) )
if self.debug_perf:
t1 = time.monotonic()
logger.info(
"[perf] preprocess_chat_request: %.2fms (request=%s)",
(t1 - t0) * 1000.0,
request_id,
)
request_for_sampling = pre.request_for_sampling request_for_sampling = pre.request_for_sampling
tool_parser = pre.tool_parser tool_parser = pre.tool_parser
chat_template_kwargs = pre.chat_template_kwargs chat_template_kwargs = pre.chat_template_kwargs
...@@ -155,7 +421,6 @@ class VllmProcessor: ...@@ -155,7 +421,6 @@ class VllmProcessor:
"Logprobs requested but not supported in distributed inference mode" "Logprobs requested but not supported in distributed inference mode"
) )
request_id = random_uuid()
# This calls update_from_generation_config and update_from_tokenizer on SamplingParams # This calls update_from_generation_config and update_from_tokenizer on SamplingParams
prompt_inputs = TokensPrompt(prompt_token_ids=tokens) prompt_inputs = TokensPrompt(prompt_token_ids=tokens)
if "multi_modal_data" in engine_prompt: if "multi_modal_data" in engine_prompt:
...@@ -168,6 +433,10 @@ class VllmProcessor: ...@@ -168,6 +433,10 @@ class VllmProcessor:
prompt_inputs[ prompt_inputs[
"mm_processor_kwargs" "mm_processor_kwargs"
] = request_for_sampling.mm_processor_kwargs ] = request_for_sampling.mm_processor_kwargs
if self.debug_perf:
t2 = time.monotonic()
vllm_preproc: EngineCoreRequest = self.input_processor.process_inputs( vllm_preproc: EngineCoreRequest = self.input_processor.process_inputs(
request_id, request_id,
prompt_inputs, prompt_inputs,
...@@ -179,6 +448,16 @@ class VllmProcessor: ...@@ -179,6 +448,16 @@ class VllmProcessor:
# priority: int = 0, # priority: int = 0,
# data_parallel_rank: int | None = None, # data_parallel_rank: int | None = None,
) )
if self.debug_perf:
t3 = time.monotonic()
logger.info(
"[perf] input_processor.process_inputs: %.2fms (request=%s tokens=%d)",
(t3 - t2) * 1000.0,
request_id,
len(tokens),
)
InputProcessor.assign_request_id(vllm_preproc) InputProcessor.assign_request_id(vllm_preproc)
# Processed: EngineCoreRequest(request_id='a2b76a85cd65e151', prompt_token_ids=[3838, 374, 279, 6722, 315, 28649, 25510, 30], mm_features=None, sampling_params=SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=1.0, top_p=1.0, top_k=0, min_p=0.0, seed=None, stop=[], stop_token_ids=[151643], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=16, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, structured_outputs=None, extra_args=None), pooling_params=None, eos_token_id=151645, arrival_time=1769036937.9417946, lora_request=None, cache_salt=None, data_parallel_rank=None, prompt_embeds=None, client_index=0, current_wave=0, priority=0, trace_headers=None) # Processed: EngineCoreRequest(request_id='a2b76a85cd65e151', prompt_token_ids=[3838, 374, 279, 6722, 315, 28649, 25510, 30], mm_features=None, sampling_params=SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=1.0, top_p=1.0, top_k=0, min_p=0.0, seed=None, stop=[], stop_token_ids=[151643], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=16, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, structured_outputs=None, extra_args=None), pooling_params=None, eos_token_id=151645, arrival_time=1769036937.9417946, lora_request=None, cache_salt=None, data_parallel_rank=None, prompt_embeds=None, client_index=0, current_wave=0, priority=0, trace_headers=None)
...@@ -200,18 +479,9 @@ class VllmProcessor: ...@@ -200,18 +479,9 @@ class VllmProcessor:
} }
return return
prompt = None
self.output_processor.add_request(
vllm_preproc,
prompt,
# parent_req: ParentRequest | None = None,
# request_index: int = 0,
# queue: RequestOutputCollector | None = None,
)
dynamo_preproc = { dynamo_preproc = {
"model": request["model"], "model": request["model"],
"token_ids": tokens, "token_ids": tokens,
# protocols.common.StopConditions
"stop_conditions": { "stop_conditions": {
"max_tokens": sp.max_tokens, "max_tokens": sp.max_tokens,
"stop": sp.stop, "stop": sp.stop,
...@@ -219,8 +489,6 @@ class VllmProcessor: ...@@ -219,8 +489,6 @@ class VllmProcessor:
"min_tokens": sp.min_tokens, "min_tokens": sp.min_tokens,
"ignore_eos": sp.ignore_eos, "ignore_eos": sp.ignore_eos,
}, },
# protocols.common.SamplingOptions
# Is there a better way than typing it out like this?
"sampling_options": { "sampling_options": {
"n": sp.n, "n": sp.n,
"presence_penalty": sp.presence_penalty, "presence_penalty": sp.presence_penalty,
...@@ -232,7 +500,6 @@ class VllmProcessor: ...@@ -232,7 +500,6 @@ class VllmProcessor:
"min_p": sp.min_p, "min_p": sp.min_p,
"seed": sp.seed, "seed": sp.seed,
}, },
# protocols.common.OutputOptions
"output_options": { "output_options": {
"logprobs": sp.logprobs, "logprobs": sp.logprobs,
"prompt_logprobs": sp.prompt_logprobs, "prompt_logprobs": sp.prompt_logprobs,
...@@ -242,7 +509,6 @@ class VllmProcessor: ...@@ -242,7 +509,6 @@ class VllmProcessor:
if vllm_preproc.eos_token_id is not None if vllm_preproc.eos_token_id is not None
else [], else [],
"annotations": [], "annotations": [],
# "prompt_embeds": vllm_preproc.prompt_embeds,
} }
post = StreamingPostProcessor( post = StreamingPostProcessor(
...@@ -255,10 +521,33 @@ class VllmProcessor: ...@@ -255,10 +521,33 @@ class VllmProcessor:
chat_template_kwargs=chat_template_kwargs, chat_template_kwargs=chat_template_kwargs,
) )
# dynamo_response: Annotated async for item in self._generate_and_stream(
request_id,
request,
dynamo_preproc,
tokens,
vllm_preproc,
post,
):
yield item
async def _generate_and_stream(
self,
request_id: str,
request: dict[str, Any],
dynamo_preproc: dict[str, Any],
tokens: list[int],
vllm_preproc: EngineCoreRequest,
post: StreamingPostProcessor,
) -> AsyncGenerator[dict[str, Any], None]:
"""Shared streaming logic for both single-process and pool paths."""
self.output_processor.add_request(vllm_preproc, None)
token_count = 0
output_proc_total_ms = 0.0
post_proc_total_ms = 0.0
try: try:
# Dynamo Router. This goes to the backend, waits, gets the streaming response, returns it.
# Stream is AsyncResponseStream
if self.is_kv_router: if self.is_kv_router:
dynamo_stream = await self.router.generate( dynamo_stream = await self.router.generate(
token_ids=tokens, token_ids=tokens,
...@@ -268,37 +557,25 @@ class VllmProcessor: ...@@ -268,37 +557,25 @@ class VllmProcessor:
output_options=dynamo_preproc["output_options"], output_options=dynamo_preproc["output_options"],
) )
else: else:
# Round robin or random, depending on cmd line flag dynamo_stream = await self.router.generate(
dynamo_stream = await self.router.generate(dynamo_preproc) dynamo_preproc, annotated=False
)
async for dynamo_response in dynamo_stream: async for dynamo_response in dynamo_stream:
# dynamo_response looks like this for regular router:
# Stream got: Annotated(data={'token_ids': [7281]}, event=None, comment=[], id=None)
# For KV router is is only the inner map: {'token_ids': [7281]}
if self.is_kv_router: if self.is_kv_router:
engine_response = dynamo_response engine_response = dynamo_response
else: elif hasattr(dynamo_response, "data"):
engine_response = dynamo_response.data() engine_response = dynamo_response.data()
else:
# engine_response: engine_response = dynamo_response
# Normal: {'token_ids': [151658]}
# Last: {'token_ids': [151645], 'finish_reason': 'stop', 'completion_usage': {'prompt_tokens': 190, 'completion_tokens': 168, 'total_tokens': 358, 'prompt_tokens_details': {'cached_tokens': 176}}}
if engine_response is None or "token_ids" not in engine_response: if engine_response is None or "token_ids" not in engine_response:
logger.error("No outputs from engine for request %s", request_id) logger.error("No outputs from engine for request %s", request_id)
yield { yield {
"id": request_id, "error": {
"choices": [ "message": f"Invalid engine response for request {request_id}",
{ "type": "internal_error",
"index": 0,
"delta": {},
"finish_reason": "error",
} }
],
"created": int(time.time()),
"model": request["model"],
"object": "chat.completion.chunk",
} }
break break
...@@ -311,28 +588,22 @@ class VllmProcessor: ...@@ -311,28 +588,22 @@ class VllmProcessor:
new_token_ids=engine_response["token_ids"], new_token_ids=engine_response["token_ids"],
finish_reason=finish_reason, finish_reason=finish_reason,
stop_reason=stop_reason, stop_reason=stop_reason,
# new_logprobs=new_logprobs,
# new_prompt_logprobs_tensors=prompt_logprobs_tensors,
# pooling_output=pooler_output,
# events=request.take_events(),
# kv_transfer_params=kv_transfer_params,
# trace_headers=request.trace_headers,
# num_cached_tokens=request.num_cached_tokens,
# num_nans_in_logits=request.num_nans_in_logits,
) )
# Let vllm handle all post-processing if self.debug_perf:
t_op0 = time.monotonic()
vllm_out: OutputProcessorOutput = self.output_processor.process_outputs( vllm_out: OutputProcessorOutput = self.output_processor.process_outputs(
[vllm_response] [vllm_response]
) )
if self.debug_perf:
t_op1 = time.monotonic()
output_proc_total_ms += (t_op1 - t_op0) * 1000.0
if vllm_out.reqs_to_abort: if vllm_out.reqs_to_abort:
# Router has no abort API; we cannot propagate aborts.
pass pass
# vllm
# RequestOutput: OutputProcessorOutput(request_outputs=[RequestOutput(request_id=9dbe240d8de78db3, prompt='What is the capital of Tuvalu?', prompt_token_ids=[3838, 374, 279, 6722, 315, 28649, 25510, 30], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=' The', token_ids=[576], cumulative_logprob=None, logprobs=None, finish_reason=None, stop_reason=None)], finished=False, metrics=RequestStateStats(num_generation_tokens=0, arrival_time=1769118902.2172132, queued_ts=0.0, scheduled_ts=0.0, first_token_ts=0.0, last_token_ts=0.0, first_token_latency=0.0, is_corrupted=False), lora_request=None, num_cached_tokens=0, multi_modal_placeholders={})], reqs_to_abort=[])
# Vec<ChatChoiceStream>
choices = [] choices = []
if not vllm_out.request_outputs: if not vllm_out.request_outputs:
continue continue
...@@ -341,8 +612,12 @@ class VllmProcessor: ...@@ -341,8 +612,12 @@ class VllmProcessor:
if choice: if choice:
choices.append(choice) choices.append(choice)
if self.debug_perf:
t_op2 = time.monotonic()
post_proc_total_ms += (t_op2 - t_op1) * 1000.0
token_count += len(engine_response["token_ids"])
if choices: if choices:
# dynamo_out: NvCreateChatCompletionStreamResponse
dynamo_out = { dynamo_out = {
"id": request_id, "id": request_id,
"choices": choices, "choices": choices,
...@@ -351,16 +626,95 @@ class VllmProcessor: ...@@ -351,16 +626,95 @@ class VllmProcessor:
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
} }
if usage := engine_response.get("completion_usage"): if usage := engine_response.get("completion_usage"):
# The engine only includes this on the last response
dynamo_out["usage"] = usage dynamo_out["usage"] = usage
# Rust handles HTTP / Server Sent Events back to user
yield dynamo_out yield dynamo_out
finally: finally:
if vllm_preproc.request_id in self.output_processor.request_states: if vllm_preproc.request_id in self.output_processor.request_states:
self.output_processor.abort_requests( self.output_processor.abort_requests(
[vllm_preproc.request_id], internal=True [vllm_preproc.request_id], internal=True
) )
if self.debug_perf and token_count > 0:
logger.info(
"[perf] stream done: request=%s tokens=%d "
"output_processor_total=%.2fms (%.3fms/tok) "
"post_processor_total=%.2fms (%.3fms/tok)",
request_id,
token_count,
output_proc_total_ms,
output_proc_total_ms / token_count,
post_proc_total_ms,
post_proc_total_ms / token_count,
)
async def _generator_inner_pool(
self, request: dict[str, Any]
) -> AsyncGenerator[dict[str, Any], None]:
"""Process a request using the worker pool.
Phase 1: Preprocess in a worker process (semaphore held).
Phase 2: Remote inference via router (no worker held).
Phase 3: Post-process tokens in the main process.
"""
request_id = random_uuid()
# --- Phase 1: Preprocess (semaphore held) ---
try:
async with self._worker_semaphore:
future = self.preprocess_pool.submit(
_preprocess_worker, request, request_id, request["model"]
)
preproc_result: PreprocessWorkerResult = await asyncio.wrap_future(
future
)
# Semaphore + worker released here
except _PreprocessError as exc:
yield exc.error_dict
return
except Exception as exc:
logger.exception("Worker preprocessing failed for request %s", request_id)
yield {
"error": {
"message": f"Worker error: {exc}",
"type": "internal_error",
}
}
return
# --- Between phases: reconstruct main-process objects ---
dynamo_preproc = preproc_result.dynamo_preproc
tokens = preproc_result.tokens
vllm_preproc = preproc_result.vllm_preproc
sampling_params = preproc_result.sampling_params
request_for_sampling = preproc_result.request_for_sampling
tool_parser = None
if (
self.tool_parser_class
and request_for_sampling.tools
and request_for_sampling.tool_choice != "none"
):
tool_parser = self.tool_parser_class(self.tokenizer)
post = StreamingPostProcessor(
tokenizer=self.tokenizer,
request_for_sampling=request_for_sampling,
sampling_params=sampling_params,
prompt_token_ids=tokens,
tool_parser=tool_parser,
reasoning_parser_class=self.reasoning_parser_class,
chat_template_kwargs=preproc_result.chat_template_kwargs,
)
async for item in self._generate_and_stream(
request_id,
request,
dynamo_preproc,
tokens,
vllm_preproc,
post,
):
yield item
class EngineFactory: class EngineFactory:
...@@ -370,11 +724,24 @@ class EngineFactory: ...@@ -370,11 +724,24 @@ class EngineFactory:
router_config: RouterConfig, router_config: RouterConfig,
config: FrontendConfig, config: FrontendConfig,
flags: Namespace, flags: Namespace,
debug_perf: bool = False,
): ):
self.runtime = runtime self.runtime = runtime
self.router_config = router_config self.router_config = router_config
self.config = config self.config = config
self.flags = flags self.flags = flags
self.debug_perf = debug_perf
self.stream_interval = 20
raw_stream_interval = os.getenv("DYN_VLLM_STREAM_INTERVAL")
if raw_stream_interval:
try:
self.stream_interval = max(1, int(raw_stream_interval))
except ValueError:
logger.warning(
"Invalid DYN_VLLM_STREAM_INTERVAL=%r, using default=%d",
raw_stream_interval,
self.stream_interval,
)
async def chat_engine_factory( async def chat_engine_factory(
self, self,
...@@ -416,8 +783,9 @@ class EngineFactory: ...@@ -416,8 +783,9 @@ class EngineFactory:
output_processor = OutputProcessor( output_processor = OutputProcessor(
tokenizer, tokenizer,
log_stats=False, log_stats=False,
stream_interval=1, stream_interval=self.stream_interval,
) )
logger.info("vLLM OutputProcessor stream_interval=%d", self.stream_interval)
tool_parser_name = self.flags.tool_call_parser or mdc.runtime_config().get( tool_parser_name = self.flags.tool_call_parser or mdc.runtime_config().get(
"tool_call_parser" "tool_call_parser"
...@@ -453,6 +821,48 @@ class EngineFactory: ...@@ -453,6 +821,48 @@ class EngineFactory:
router_mode=self.router_config.router_mode router_mode=self.router_config.router_mode
) )
preprocess_pool = None
preprocess_workers = self.config.preprocess_workers
if preprocess_workers > 0:
logger.info(
"Creating preprocess worker pool with %d workers for model %s",
preprocess_workers,
source_path,
)
preprocess_pool = ProcessPoolExecutor(
max_workers=preprocess_workers,
initializer=_init_worker,
initargs=(
source_path,
tokenizer_mode,
config_format,
load_format,
tool_parser_name,
reasoning_parser_name,
self.stream_interval,
),
)
# Warm up all workers to ensure initialization completes
futures = [
preprocess_pool.submit(_worker_warmup)
for _ in range(preprocess_workers)
]
done, not_done = _futures_wait(futures, timeout=120)
if not_done:
for f in not_done:
f.cancel()
preprocess_pool.shutdown(wait=False, cancel_futures=True)
raise RuntimeError(
"Timed out waiting for preprocess worker pool warmup"
)
try:
for f in done:
f.result() # Raises if initializer failed
except Exception:
preprocess_pool.shutdown(wait=False, cancel_futures=True)
raise
logger.info("Preprocess worker pool ready (%d workers)", preprocess_workers)
gen = VllmProcessor( gen = VllmProcessor(
tokenizer, tokenizer,
input_processor, input_processor,
...@@ -460,6 +870,9 @@ class EngineFactory: ...@@ -460,6 +870,9 @@ class EngineFactory:
output_processor, output_processor,
tool_parser_class, tool_parser_class,
reasoning_parser_class, reasoning_parser_class,
debug_perf=self.debug_perf,
preprocess_pool=preprocess_pool,
preprocess_workers=preprocess_workers,
) )
return PythonAsyncEngine(gen.generator, loop) return PythonAsyncEngine(gen.generator, loop)
...@@ -159,7 +159,11 @@ async def launch_workers(args, extra_engine_args_path): ...@@ -159,7 +159,11 @@ async def launch_workers(args, extra_engine_args_path):
logger.info(f"Creating mocker worker {worker_id + 1}/{args.num_workers}") logger.info(f"Creating mocker worker {worker_id + 1}/{args.num_workers}")
# Create a separate DistributedRuntime for this worker (on same event loop) # Create a separate DistributedRuntime for this worker (on same event loop)
runtime = DistributedRuntime(loop, args.discovery_backend, args.request_plane) runtime = DistributedRuntime(
loop,
args.discovery_backend,
args.request_plane,
)
runtimes.append(runtime) runtimes.append(runtime)
# Determine which engine args file to use # Determine which engine args file to use
......
...@@ -559,3 +559,75 @@ class TestVllmRendererApi: ...@@ -559,3 +559,75 @@ class TestVllmRendererApi:
"ReasoningParser.is_reasoning_end_streaming signature changed; " "ReasoningParser.is_reasoning_end_streaming signature changed; "
f"expected ['self', 'input_ids', 'delta_ids'], got {end_params}" f"expected ['self', 'input_ids', 'delta_ids'], got {end_params}"
) )
def test_preprocess_worker_result_picklability(self):
"""Verify PreprocessWorkerResult survives pickle round-trip.
_preprocess_worker returns this dataclass via a ProcessPoolExecutor
Future. If any field becomes unpicklable, the pool path breaks.
"""
import pickle
from dynamo.frontend.vllm_processor import PreprocessWorkerResult
result = PreprocessWorkerResult(
dynamo_preproc={
"model": "test-model",
"token_ids": [1, 2, 3],
"stop_conditions": {
"max_tokens": 100,
"stop": [],
"stop_token_ids": [2],
"min_tokens": 0,
"ignore_eos": False,
},
"sampling_options": {
"n": 1,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"repetition_penalty": 1.0,
"temperature": 1.0,
"top_p": 1.0,
"top_k": 0,
"min_p": 0.0,
"seed": None,
},
"output_options": {
"logprobs": None,
"prompt_logprobs": None,
"skip_special_tokens": True,
},
"eos_token_ids": [2],
"annotations": [],
},
tokens=[1, 2, 3],
vllm_preproc=EngineCoreRequest(
request_id="test-123",
prompt_token_ids=[1, 2, 3],
mm_features=None,
sampling_params=SamplingParams(),
pooling_params=None,
eos_token_id=2,
arrival_time=0.0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
prompt_embeds=None,
client_index=0,
current_wave=0,
priority=0,
trace_headers=None,
),
sampling_params=SamplingParams(),
request_for_sampling={"model": "test-model", "tools": None},
chat_template_kwargs={"reasoning_effort": None},
)
data = pickle.dumps(result)
restored = pickle.loads(data)
assert restored.dynamo_preproc == result.dynamo_preproc
assert restored.tokens == result.tokens
assert restored.vllm_preproc.request_id == "test-123"
assert restored.request_for_sampling == result.request_for_sampling
assert restored.chat_template_kwargs == result.chat_template_kwargs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment