Unverified Commit f91b42b9 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(frontend): Reduce Python-side overhead in the vLLM chat path (#6437)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent fd839b8d
......@@ -82,7 +82,8 @@ class FrontendConfig(ConfigBase):
event_plane: str
chat_processor: str
enable_anthropic_api: bool
exp_python_factory: bool
debug_perf: bool
preprocess_workers: int
def validate(self) -> None:
if bool(self.tls_cert_path) ^ bool(self.tls_key_path): # ^ is XOR
......@@ -515,9 +516,10 @@ class FrontendArgGroup(ArgGroup):
)
add_argument(
g,
flag_name="--chat-processor",
flag_name="--dyn-chat-processor",
env_var="DYN_CHAT_PROCESSOR",
default="dynamo",
dest="chat_processor",
help=(
"[EXPERIMENTAL] When set to 'vllm', use local vllm for the pre and post "
"processor."
......@@ -527,11 +529,28 @@ class FrontendArgGroup(ArgGroup):
add_negatable_bool_argument(
g,
flag_name="--exp-python-factory",
env_var="DYN_EXP_PYTHON_FACTORY",
flag_name="--dyn-debug-perf",
env_var="DYN_DEBUG_PERF",
default=False,
dest="debug_perf",
help=(
"[EXPERIMENTAL] Enable Python-based engine factory. When set, engines will be "
"created via a Python callback instead of the default Rust pipeline."
"[EXPERIMENTAL] Enable performance instrumentation for diagnosing preprocessing bottlenecks. "
"Logs per-function timing, request concurrency, and hot-path section durations. "
"'--dyn-chat-processor vllm' only."
),
)
add_argument(
g,
flag_name="--dyn-preprocess-workers",
env_var="DYN_PREPROCESS_WORKERS",
default=0,
dest="preprocess_workers",
help=(
"[EXPERIMENTAL] Number of worker processes for preprocessing and output processing. "
"When > 0, offloads CPU-bound work (tokenization, template rendering, "
"detokenization) to a ProcessPoolExecutor with N workers, each with its "
"own GIL. 0 (default) keeps all processing on the main event loop. '--dyn-chat-processor vllm' only."
),
arg_type=int,
)
......@@ -57,7 +57,7 @@ def setup_engine_factory(
"""
from .vllm_processor import EngineFactory
return EngineFactory(runtime, router_config, config, vllm_flags)
return EngineFactory(runtime, router_config, config, vllm_flags, config.debug_perf)
def parse_args() -> tuple[FrontendConfig, Optional[Namespace]]:
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Performance instrumentation for diagnosing frontend preprocessing bottlenecks.
Activated by passing --dyn-debug-perf to dynamo.frontend.
"""
from __future__ import annotations
import logging
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Concurrency gauge
# ---------------------------------------------------------------------------
_active_requests = 0
_peak_requests = 0
def enter_generator() -> int:
"""Increment active request count. Returns current count.
Safe without a lock: only called while the GIL is held (all callers are
in Python code), so the read-modify-write on the global int is atomic
with respect to other Python threads.
"""
global _active_requests, _peak_requests
_active_requests += 1
count = _active_requests
if count > _peak_requests:
_peak_requests = count
return count
def exit_generator() -> int:
"""Decrement active request count. Returns current count."""
global _active_requests
_active_requests -= 1
return _active_requests
def get_active_requests() -> int:
return _active_requests
def get_peak_requests() -> int:
return _peak_requests
......@@ -3,6 +3,7 @@
from __future__ import annotations
import os
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
......@@ -13,6 +14,7 @@ from vllm.reasoning import ReasoningParser
from vllm.sampling_params import SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
@dataclass
......@@ -24,6 +26,19 @@ class PreprocessResult:
prompt_token_ids: list[int]
_ASYNC_TOKENIZER_POOL: dict[int, AsyncMicrobatchTokenizer] = {}
SKIP_REQUEST_VALIDATION = os.getenv("DYN_VLLM_SKIP_REQUEST_VALIDATION", "1") == "1"
def _get_async_tokenizer(tokenizer: TokenizerLike) -> AsyncMicrobatchTokenizer:
key = id(tokenizer)
async_tokenizer = _ASYNC_TOKENIZER_POOL.get(key)
if async_tokenizer is None:
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
_ASYNC_TOKENIZER_POOL[key] = async_tokenizer
return async_tokenizer
def _materialize_assistant_tool_calls(
messages: Sequence[Any],
) -> list[dict[str, Any] | Any]:
......@@ -53,13 +68,33 @@ def _materialize_assistant_tool_calls(
return normalized
async def preprocess_chat_request(
request: dict[str, Any],
def _prepare_request(
request: dict[str, Any] | ChatCompletionRequest,
*,
tokenizer: TokenizerLike,
renderer,
tool_parser_class: type[ToolParser] | None,
) -> PreprocessResult:
) -> tuple[
ChatCompletionRequest, ToolParser | None, dict[str, Any], Any, dict[str, Any]
]:
"""Validate request and build arguments for template rendering.
Returns:
request_for_sampling: Validated ChatCompletionRequest.
tool_parser: Instantiated tool parser, or None.
chat_template_kwargs: Template kwargs (for PreprocessResult).
messages_for_render: Messages to pass as first arg to render_messages.
render_kwargs: Keyword arguments for render_messages / render_messages_async.
"""
if isinstance(request, ChatCompletionRequest):
request_for_sampling = request
elif SKIP_REQUEST_VALIDATION:
# Trusted fast path; caller must provide OpenAI-compatible payload.
request_for_sampling = ChatCompletionRequest.model_construct(**request)
if request_for_sampling.tools and any(
not hasattr(tool, "model_dump") for tool in request_for_sampling.tools
):
request_for_sampling = ChatCompletionRequest.model_validate(request)
else:
request_for_sampling = ChatCompletionRequest.model_validate(request)
tool_parser: ToolParser | None = None
......@@ -88,8 +123,7 @@ async def preprocess_chat_request(
else request_for_sampling.messages
)
_, engine_prompt = await renderer.render_messages_async(
messages_for_render,
render_kwargs = dict(
chat_template=request_for_sampling.chat_template,
chat_template_content_format="auto",
add_generation_prompt=request_for_sampling.add_generation_prompt,
......@@ -100,6 +134,73 @@ async def preprocess_chat_request(
**chat_template_kwargs,
)
return (
request_for_sampling,
tool_parser,
chat_template_kwargs,
messages_for_render,
render_kwargs,
)
async def preprocess_chat_request(
request: dict[str, Any] | ChatCompletionRequest,
*,
tokenizer: TokenizerLike,
renderer,
tool_parser_class: type[ToolParser] | None,
) -> PreprocessResult:
(
request_for_sampling,
tool_parser,
chat_template_kwargs,
messages,
render_kwargs,
) = _prepare_request(
request, tokenizer=tokenizer, tool_parser_class=tool_parser_class
)
_, engine_prompt = await renderer.render_messages_async(messages, **render_kwargs)
if "prompt_token_ids" in engine_prompt:
tokens = list(engine_prompt["prompt_token_ids"])
else:
async_tokenizer = _get_async_tokenizer(tokenizer)
encoded = await async_tokenizer(
engine_prompt["prompt"],
add_special_tokens=request_for_sampling.add_special_tokens,
)
tokens = list(encoded.input_ids)
return PreprocessResult(
request_for_sampling=request_for_sampling,
tool_parser=tool_parser,
chat_template_kwargs=chat_template_kwargs,
engine_prompt=engine_prompt,
prompt_token_ids=tokens,
)
def preprocess_chat_request_sync(
request: dict[str, Any] | ChatCompletionRequest,
*,
tokenizer: TokenizerLike,
renderer,
tool_parser_class: type[ToolParser] | None,
) -> PreprocessResult:
"""Sync version of preprocess_chat_request for worker processes."""
(
request_for_sampling,
tool_parser,
chat_template_kwargs,
messages,
render_kwargs,
) = _prepare_request(
request, tokenizer=tokenizer, tool_parser_class=tool_parser_class
)
_, engine_prompt = renderer.render_messages(messages, **render_kwargs)
if "prompt_token_ids" in engine_prompt:
tokens = list(engine_prompt["prompt_token_ids"])
else:
......@@ -141,6 +242,9 @@ class StreamingPostProcessor:
if reasoning_parser_class
else None
)
self._fast_plain_text = (
self.tool_parser is None and self.reasoning_parser is None
)
self._control_markers = tuple(
t for t in getattr(tokenizer, "all_special_tokens", ()) if t
......@@ -191,6 +295,23 @@ class StreamingPostProcessor:
# to text. Re-detokenizing from token_ids can reintroduce stop markers.
delta_text = output.text or ""
if self._fast_plain_text:
if delta_text:
delta: dict[str, Any] = {
"role": "assistant",
"content": delta_text,
}
elif output.finish_reason:
delta = {}
else:
return None
return {
"index": output.index,
"delta": delta,
"finish_reason": output.finish_reason,
"logprobs": output.logprobs,
}
current_text = self.previous_text + delta_text
current_token_ids = self.previous_token_ids + delta_token_ids
......
......@@ -12,6 +12,9 @@ import time
import uuid
from argparse import Namespace
from collections.abc import AsyncGenerator
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import wait as _futures_wait
from dataclasses import dataclass
from typing import Any
from vllm.config import CacheConfig, LoadConfig, ModelConfig, VllmConfig
......@@ -36,7 +39,11 @@ from dynamo.llm import (
)
from dynamo.runtime import DistributedRuntime
from .prepost import StreamingPostProcessor, preprocess_chat_request
from .prepost import (
StreamingPostProcessor,
preprocess_chat_request,
preprocess_chat_request_sync,
)
logger = logging.getLogger(__name__)
......@@ -72,6 +79,209 @@ def map_finish_reason(raw_reason: str | None) -> FinishReason | None:
return mapped
# --- Worker process globals (initialized once per process by _init_worker) ---
_w_input_processor: InputProcessor | None = None
_w_tokenizer: Any = None
_w_tool_parser_class: type[ToolParser] | None = None
_w_reasoning_parser_class: type[ReasoningParser] | None = None
_w_stream_interval: int = 20
class _PreprocessError(Exception):
"""Raised by _preprocess_worker for user-facing errors (e.g., n!=1)."""
def __init__(self, error_dict: dict[str, Any]):
self.error_dict = error_dict
super().__init__(str(error_dict))
@dataclass
class PreprocessWorkerResult:
"""Picklable return value from the preprocess worker."""
dynamo_preproc: dict[str, Any]
tokens: list[int]
vllm_preproc: EngineCoreRequest
sampling_params: SamplingParams
request_for_sampling: Any # ChatCompletionRequest (Pydantic model, picklable)
chat_template_kwargs: dict[str, Any]
def _init_worker(
model_path: str,
tokenizer_mode: str,
config_format: str,
load_format: str,
tool_parser_name: str | None,
reasoning_parser_name: str | None,
stream_interval: int,
) -> None:
"""Initialize a worker process with its own VllmConfig and InputProcessor."""
global _w_input_processor, _w_tokenizer, _w_tool_parser_class
global _w_reasoning_parser_class, _w_stream_interval
model_config = ModelConfig(
model=model_path,
tokenizer_mode=tokenizer_mode,
config_format=config_format,
)
vllm_config = VllmConfig(
model_config=model_config,
load_config=LoadConfig(load_format=load_format),
cache_config=CacheConfig(),
)
_w_input_processor = InputProcessor(vllm_config)
_w_tokenizer = _w_input_processor.get_tokenizer()
if tool_parser_name:
_w_tool_parser_class = ToolParserManager.get_tool_parser(tool_parser_name)
else:
_w_tool_parser_class = None
if reasoning_parser_name:
_w_reasoning_parser_class = ReasoningParserManager.get_reasoning_parser(
reasoning_parser_name
)
else:
_w_reasoning_parser_class = None
_w_stream_interval = max(1, stream_interval)
def _worker_warmup() -> bool:
"""Dummy task to ensure worker process is fully initialized."""
return True
def _preprocess_worker(
request: dict[str, Any],
request_id: str,
model_name: str,
) -> PreprocessWorkerResult:
"""Preprocess a request in a worker process and return a picklable result.
This replaces _request_handler's Phase A. No queues — errors propagate
naturally via the Future.
"""
pre = preprocess_chat_request_sync(
request,
tokenizer=_w_tokenizer,
renderer=_w_input_processor.renderer,
tool_parser_class=_w_tool_parser_class,
)
request_for_sampling = pre.request_for_sampling
engine_prompt = pre.engine_prompt
tokens = pre.prompt_token_ids
if request_for_sampling.max_completion_tokens is not None:
max_tokens = request_for_sampling.max_completion_tokens
elif request_for_sampling.max_tokens is not None:
max_tokens = request_for_sampling.max_tokens
else:
max_tokens = None
sampling_params = SamplingParams(
output_kind=RequestOutputKind.DELTA,
max_tokens=max_tokens,
)
for k, v in _w_input_processor.generation_config_fields.items():
if hasattr(sampling_params, k):
setattr(sampling_params, k, v)
sampling_fields = (
set(getattr(SamplingParams, "__annotations__", ()))
& set(type(request_for_sampling).model_fields)
) - {"max_tokens", "logprobs", "output_kind"}
for k in sorted(sampling_fields):
v = getattr(request_for_sampling, k, None)
if v is not None:
setattr(sampling_params, k, v)
logprobs = request_for_sampling.logprobs
top_logprobs = request_for_sampling.top_logprobs
if logprobs is True:
sampling_params.logprobs = top_logprobs or 1
elif isinstance(logprobs, int) and not isinstance(logprobs, bool):
sampling_params.logprobs = logprobs
elif top_logprobs not in (None, 0):
sampling_params.logprobs = top_logprobs
prompt_inputs = TokensPrompt(prompt_token_ids=tokens)
if "multi_modal_data" in engine_prompt:
prompt_inputs["multi_modal_data"] = engine_prompt["multi_modal_data"]
if "multi_modal_uuids" in engine_prompt:
prompt_inputs["multi_modal_uuids"] = engine_prompt["multi_modal_uuids"]
if request_for_sampling.cache_salt is not None:
prompt_inputs["cache_salt"] = request_for_sampling.cache_salt
if request_for_sampling.mm_processor_kwargs is not None:
prompt_inputs["mm_processor_kwargs"] = request_for_sampling.mm_processor_kwargs
vllm_preproc: EngineCoreRequest = _w_input_processor.process_inputs(
request_id,
prompt_inputs,
sampling_params,
)
InputProcessor.assign_request_id(vllm_preproc)
sp = vllm_preproc.sampling_params
if sp.n != 1:
raise _PreprocessError(
{
"error": {
"message": (
f"Unsupported value: 'n={sp.n}'. "
"This endpoint currently supports only n=1."
),
"type": "invalid_request_error",
"param": "n",
"code": "unsupported_value",
}
}
)
dynamo_preproc = {
"model": model_name,
"token_ids": tokens,
"stop_conditions": {
"max_tokens": sp.max_tokens,
"stop": sp.stop,
"stop_token_ids": sp.stop_token_ids,
"min_tokens": sp.min_tokens,
"ignore_eos": sp.ignore_eos,
},
"sampling_options": {
"n": sp.n,
"presence_penalty": sp.presence_penalty,
"frequency_penalty": sp.frequency_penalty,
"repetition_penalty": sp.repetition_penalty,
"temperature": sp.temperature,
"top_p": sp.top_p,
"top_k": sp.top_k,
"min_p": sp.min_p,
"seed": sp.seed,
},
"output_options": {
"logprobs": sp.logprobs,
"prompt_logprobs": sp.prompt_logprobs,
"skip_special_tokens": sp.skip_special_tokens,
},
"eos_token_ids": [vllm_preproc.eos_token_id]
if vllm_preproc.eos_token_id is not None
else [],
"annotations": [],
}
return PreprocessWorkerResult(
dynamo_preproc=dynamo_preproc,
tokens=tokens,
vllm_preproc=vllm_preproc,
sampling_params=sampling_params,
request_for_sampling=request_for_sampling,
chat_template_kwargs=pre.chat_template_kwargs,
)
class VllmProcessor:
def __init__(
self,
......@@ -81,6 +291,9 @@ class VllmProcessor:
output_processor: OutputProcessor,
tool_parser_class: type[ToolParser] | None,
reasoning_parser_class: type[ReasoningParser] | None,
debug_perf: bool = False,
preprocess_pool: ProcessPoolExecutor | None = None,
preprocess_workers: int = 0,
):
self.tokenizer = tokenizer
self.input_processor = input_processor
......@@ -89,6 +302,16 @@ class VllmProcessor:
self.output_processor = output_processor
self.tool_parser_class = tool_parser_class
self.reasoning_parser_class = reasoning_parser_class
self.debug_perf = debug_perf
self.preprocess_pool = preprocess_pool
if preprocess_pool is not None:
# Allow a small buffer beyond the worker count so the pool's
# internal queue always has work ready when a worker finishes.
self._worker_semaphore: asyncio.Semaphore | None = asyncio.Semaphore(
preprocess_workers + 2
)
else:
self._worker_semaphore = None
# Ideally we would map NVCreateChatCompletionRequest into Python so it can be type checked, but
# it has a lot of fields.
......@@ -103,12 +326,55 @@ class VllmProcessor:
# ** VllmProcessor.generator called: {'messages': [{'role': 'user', 'content': 'What is the capital of Tuvalu?'}], 'model': '/home/grahamk/llms/Qwen3-0.6B', 'max_completion_tokens': 1000, 'stream': False}
if self.debug_perf:
from .perf_instrumentation import enter_generator, exit_generator
active = enter_generator()
t_start = time.monotonic()
logger.info("[perf] generator enter: active_requests=%d", active)
try:
if self.preprocess_pool is None:
# Single process
async for item in self._generator_inner(request):
yield item
else:
# Multi process
async for item in self._generator_inner_pool(request):
yield item
finally:
if self.debug_perf:
active = exit_generator()
elapsed_ms = (time.monotonic() - t_start) * 1000.0
logger.info(
"[perf] generator exit: total=%.2fms active_requests=%d",
elapsed_ms,
active,
)
async def _generator_inner(
self, request: dict[str, Any]
) -> AsyncGenerator[dict[str, Any], None]:
request_id = random_uuid()
if self.debug_perf:
t0 = time.monotonic()
pre = await preprocess_chat_request(
request,
tokenizer=self.tokenizer,
renderer=self.input_processor.renderer,
tool_parser_class=self.tool_parser_class,
)
if self.debug_perf:
t1 = time.monotonic()
logger.info(
"[perf] preprocess_chat_request: %.2fms (request=%s)",
(t1 - t0) * 1000.0,
request_id,
)
request_for_sampling = pre.request_for_sampling
tool_parser = pre.tool_parser
chat_template_kwargs = pre.chat_template_kwargs
......@@ -155,7 +421,6 @@ class VllmProcessor:
"Logprobs requested but not supported in distributed inference mode"
)
request_id = random_uuid()
# This calls update_from_generation_config and update_from_tokenizer on SamplingParams
prompt_inputs = TokensPrompt(prompt_token_ids=tokens)
if "multi_modal_data" in engine_prompt:
......@@ -168,6 +433,10 @@ class VllmProcessor:
prompt_inputs[
"mm_processor_kwargs"
] = request_for_sampling.mm_processor_kwargs
if self.debug_perf:
t2 = time.monotonic()
vllm_preproc: EngineCoreRequest = self.input_processor.process_inputs(
request_id,
prompt_inputs,
......@@ -179,6 +448,16 @@ class VllmProcessor:
# priority: int = 0,
# data_parallel_rank: int | None = None,
)
if self.debug_perf:
t3 = time.monotonic()
logger.info(
"[perf] input_processor.process_inputs: %.2fms (request=%s tokens=%d)",
(t3 - t2) * 1000.0,
request_id,
len(tokens),
)
InputProcessor.assign_request_id(vllm_preproc)
# Processed: EngineCoreRequest(request_id='a2b76a85cd65e151', prompt_token_ids=[3838, 374, 279, 6722, 315, 28649, 25510, 30], mm_features=None, sampling_params=SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=1.0, top_p=1.0, top_k=0, min_p=0.0, seed=None, stop=[], stop_token_ids=[151643], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=16, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, structured_outputs=None, extra_args=None), pooling_params=None, eos_token_id=151645, arrival_time=1769036937.9417946, lora_request=None, cache_salt=None, data_parallel_rank=None, prompt_embeds=None, client_index=0, current_wave=0, priority=0, trace_headers=None)
......@@ -200,18 +479,9 @@ class VllmProcessor:
}
return
prompt = None
self.output_processor.add_request(
vllm_preproc,
prompt,
# parent_req: ParentRequest | None = None,
# request_index: int = 0,
# queue: RequestOutputCollector | None = None,
)
dynamo_preproc = {
"model": request["model"],
"token_ids": tokens,
# protocols.common.StopConditions
"stop_conditions": {
"max_tokens": sp.max_tokens,
"stop": sp.stop,
......@@ -219,8 +489,6 @@ class VllmProcessor:
"min_tokens": sp.min_tokens,
"ignore_eos": sp.ignore_eos,
},
# protocols.common.SamplingOptions
# Is there a better way than typing it out like this?
"sampling_options": {
"n": sp.n,
"presence_penalty": sp.presence_penalty,
......@@ -232,7 +500,6 @@ class VllmProcessor:
"min_p": sp.min_p,
"seed": sp.seed,
},
# protocols.common.OutputOptions
"output_options": {
"logprobs": sp.logprobs,
"prompt_logprobs": sp.prompt_logprobs,
......@@ -242,7 +509,6 @@ class VllmProcessor:
if vllm_preproc.eos_token_id is not None
else [],
"annotations": [],
# "prompt_embeds": vllm_preproc.prompt_embeds,
}
post = StreamingPostProcessor(
......@@ -255,10 +521,33 @@ class VllmProcessor:
chat_template_kwargs=chat_template_kwargs,
)
# dynamo_response: Annotated
async for item in self._generate_and_stream(
request_id,
request,
dynamo_preproc,
tokens,
vllm_preproc,
post,
):
yield item
async def _generate_and_stream(
self,
request_id: str,
request: dict[str, Any],
dynamo_preproc: dict[str, Any],
tokens: list[int],
vllm_preproc: EngineCoreRequest,
post: StreamingPostProcessor,
) -> AsyncGenerator[dict[str, Any], None]:
"""Shared streaming logic for both single-process and pool paths."""
self.output_processor.add_request(vllm_preproc, None)
token_count = 0
output_proc_total_ms = 0.0
post_proc_total_ms = 0.0
try:
# Dynamo Router. This goes to the backend, waits, gets the streaming response, returns it.
# Stream is AsyncResponseStream
if self.is_kv_router:
dynamo_stream = await self.router.generate(
token_ids=tokens,
......@@ -268,37 +557,25 @@ class VllmProcessor:
output_options=dynamo_preproc["output_options"],
)
else:
# Round robin or random, depending on cmd line flag
dynamo_stream = await self.router.generate(dynamo_preproc)
dynamo_stream = await self.router.generate(
dynamo_preproc, annotated=False
)
async for dynamo_response in dynamo_stream:
# dynamo_response looks like this for regular router:
# Stream got: Annotated(data={'token_ids': [7281]}, event=None, comment=[], id=None)
# For KV router is is only the inner map: {'token_ids': [7281]}
if self.is_kv_router:
engine_response = dynamo_response
else:
elif hasattr(dynamo_response, "data"):
engine_response = dynamo_response.data()
# engine_response:
# Normal: {'token_ids': [151658]}
# Last: {'token_ids': [151645], 'finish_reason': 'stop', 'completion_usage': {'prompt_tokens': 190, 'completion_tokens': 168, 'total_tokens': 358, 'prompt_tokens_details': {'cached_tokens': 176}}}
else:
engine_response = dynamo_response
if engine_response is None or "token_ids" not in engine_response:
logger.error("No outputs from engine for request %s", request_id)
yield {
"id": request_id,
"choices": [
{
"index": 0,
"delta": {},
"finish_reason": "error",
"error": {
"message": f"Invalid engine response for request {request_id}",
"type": "internal_error",
}
],
"created": int(time.time()),
"model": request["model"],
"object": "chat.completion.chunk",
}
break
......@@ -311,28 +588,22 @@ class VllmProcessor:
new_token_ids=engine_response["token_ids"],
finish_reason=finish_reason,
stop_reason=stop_reason,
# new_logprobs=new_logprobs,
# new_prompt_logprobs_tensors=prompt_logprobs_tensors,
# pooling_output=pooler_output,
# events=request.take_events(),
# kv_transfer_params=kv_transfer_params,
# trace_headers=request.trace_headers,
# num_cached_tokens=request.num_cached_tokens,
# num_nans_in_logits=request.num_nans_in_logits,
)
# Let vllm handle all post-processing
if self.debug_perf:
t_op0 = time.monotonic()
vllm_out: OutputProcessorOutput = self.output_processor.process_outputs(
[vllm_response]
)
if self.debug_perf:
t_op1 = time.monotonic()
output_proc_total_ms += (t_op1 - t_op0) * 1000.0
if vllm_out.reqs_to_abort:
# Router has no abort API; we cannot propagate aborts.
pass
# vllm
# RequestOutput: OutputProcessorOutput(request_outputs=[RequestOutput(request_id=9dbe240d8de78db3, prompt='What is the capital of Tuvalu?', prompt_token_ids=[3838, 374, 279, 6722, 315, 28649, 25510, 30], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=' The', token_ids=[576], cumulative_logprob=None, logprobs=None, finish_reason=None, stop_reason=None)], finished=False, metrics=RequestStateStats(num_generation_tokens=0, arrival_time=1769118902.2172132, queued_ts=0.0, scheduled_ts=0.0, first_token_ts=0.0, last_token_ts=0.0, first_token_latency=0.0, is_corrupted=False), lora_request=None, num_cached_tokens=0, multi_modal_placeholders={})], reqs_to_abort=[])
# Vec<ChatChoiceStream>
choices = []
if not vllm_out.request_outputs:
continue
......@@ -341,8 +612,12 @@ class VllmProcessor:
if choice:
choices.append(choice)
if self.debug_perf:
t_op2 = time.monotonic()
post_proc_total_ms += (t_op2 - t_op1) * 1000.0
token_count += len(engine_response["token_ids"])
if choices:
# dynamo_out: NvCreateChatCompletionStreamResponse
dynamo_out = {
"id": request_id,
"choices": choices,
......@@ -351,16 +626,95 @@ class VllmProcessor:
"object": "chat.completion.chunk",
}
if usage := engine_response.get("completion_usage"):
# The engine only includes this on the last response
dynamo_out["usage"] = usage
# Rust handles HTTP / Server Sent Events back to user
yield dynamo_out
finally:
if vllm_preproc.request_id in self.output_processor.request_states:
self.output_processor.abort_requests(
[vllm_preproc.request_id], internal=True
)
if self.debug_perf and token_count > 0:
logger.info(
"[perf] stream done: request=%s tokens=%d "
"output_processor_total=%.2fms (%.3fms/tok) "
"post_processor_total=%.2fms (%.3fms/tok)",
request_id,
token_count,
output_proc_total_ms,
output_proc_total_ms / token_count,
post_proc_total_ms,
post_proc_total_ms / token_count,
)
async def _generator_inner_pool(
self, request: dict[str, Any]
) -> AsyncGenerator[dict[str, Any], None]:
"""Process a request using the worker pool.
Phase 1: Preprocess in a worker process (semaphore held).
Phase 2: Remote inference via router (no worker held).
Phase 3: Post-process tokens in the main process.
"""
request_id = random_uuid()
# --- Phase 1: Preprocess (semaphore held) ---
try:
async with self._worker_semaphore:
future = self.preprocess_pool.submit(
_preprocess_worker, request, request_id, request["model"]
)
preproc_result: PreprocessWorkerResult = await asyncio.wrap_future(
future
)
# Semaphore + worker released here
except _PreprocessError as exc:
yield exc.error_dict
return
except Exception as exc:
logger.exception("Worker preprocessing failed for request %s", request_id)
yield {
"error": {
"message": f"Worker error: {exc}",
"type": "internal_error",
}
}
return
# --- Between phases: reconstruct main-process objects ---
dynamo_preproc = preproc_result.dynamo_preproc
tokens = preproc_result.tokens
vllm_preproc = preproc_result.vllm_preproc
sampling_params = preproc_result.sampling_params
request_for_sampling = preproc_result.request_for_sampling
tool_parser = None
if (
self.tool_parser_class
and request_for_sampling.tools
and request_for_sampling.tool_choice != "none"
):
tool_parser = self.tool_parser_class(self.tokenizer)
post = StreamingPostProcessor(
tokenizer=self.tokenizer,
request_for_sampling=request_for_sampling,
sampling_params=sampling_params,
prompt_token_ids=tokens,
tool_parser=tool_parser,
reasoning_parser_class=self.reasoning_parser_class,
chat_template_kwargs=preproc_result.chat_template_kwargs,
)
async for item in self._generate_and_stream(
request_id,
request,
dynamo_preproc,
tokens,
vllm_preproc,
post,
):
yield item
class EngineFactory:
......@@ -370,11 +724,24 @@ class EngineFactory:
router_config: RouterConfig,
config: FrontendConfig,
flags: Namespace,
debug_perf: bool = False,
):
self.runtime = runtime
self.router_config = router_config
self.config = config
self.flags = flags
self.debug_perf = debug_perf
self.stream_interval = 20
raw_stream_interval = os.getenv("DYN_VLLM_STREAM_INTERVAL")
if raw_stream_interval:
try:
self.stream_interval = max(1, int(raw_stream_interval))
except ValueError:
logger.warning(
"Invalid DYN_VLLM_STREAM_INTERVAL=%r, using default=%d",
raw_stream_interval,
self.stream_interval,
)
async def chat_engine_factory(
self,
......@@ -416,8 +783,9 @@ class EngineFactory:
output_processor = OutputProcessor(
tokenizer,
log_stats=False,
stream_interval=1,
stream_interval=self.stream_interval,
)
logger.info("vLLM OutputProcessor stream_interval=%d", self.stream_interval)
tool_parser_name = self.flags.tool_call_parser or mdc.runtime_config().get(
"tool_call_parser"
......@@ -453,6 +821,48 @@ class EngineFactory:
router_mode=self.router_config.router_mode
)
preprocess_pool = None
preprocess_workers = self.config.preprocess_workers
if preprocess_workers > 0:
logger.info(
"Creating preprocess worker pool with %d workers for model %s",
preprocess_workers,
source_path,
)
preprocess_pool = ProcessPoolExecutor(
max_workers=preprocess_workers,
initializer=_init_worker,
initargs=(
source_path,
tokenizer_mode,
config_format,
load_format,
tool_parser_name,
reasoning_parser_name,
self.stream_interval,
),
)
# Warm up all workers to ensure initialization completes
futures = [
preprocess_pool.submit(_worker_warmup)
for _ in range(preprocess_workers)
]
done, not_done = _futures_wait(futures, timeout=120)
if not_done:
for f in not_done:
f.cancel()
preprocess_pool.shutdown(wait=False, cancel_futures=True)
raise RuntimeError(
"Timed out waiting for preprocess worker pool warmup"
)
try:
for f in done:
f.result() # Raises if initializer failed
except Exception:
preprocess_pool.shutdown(wait=False, cancel_futures=True)
raise
logger.info("Preprocess worker pool ready (%d workers)", preprocess_workers)
gen = VllmProcessor(
tokenizer,
input_processor,
......@@ -460,6 +870,9 @@ class EngineFactory:
output_processor,
tool_parser_class,
reasoning_parser_class,
debug_perf=self.debug_perf,
preprocess_pool=preprocess_pool,
preprocess_workers=preprocess_workers,
)
return PythonAsyncEngine(gen.generator, loop)
......@@ -159,7 +159,11 @@ async def launch_workers(args, extra_engine_args_path):
logger.info(f"Creating mocker worker {worker_id + 1}/{args.num_workers}")
# Create a separate DistributedRuntime for this worker (on same event loop)
runtime = DistributedRuntime(loop, args.discovery_backend, args.request_plane)
runtime = DistributedRuntime(
loop,
args.discovery_backend,
args.request_plane,
)
runtimes.append(runtime)
# Determine which engine args file to use
......
......@@ -559,3 +559,75 @@ class TestVllmRendererApi:
"ReasoningParser.is_reasoning_end_streaming signature changed; "
f"expected ['self', 'input_ids', 'delta_ids'], got {end_params}"
)
def test_preprocess_worker_result_picklability(self):
"""Verify PreprocessWorkerResult survives pickle round-trip.
_preprocess_worker returns this dataclass via a ProcessPoolExecutor
Future. If any field becomes unpicklable, the pool path breaks.
"""
import pickle
from dynamo.frontend.vllm_processor import PreprocessWorkerResult
result = PreprocessWorkerResult(
dynamo_preproc={
"model": "test-model",
"token_ids": [1, 2, 3],
"stop_conditions": {
"max_tokens": 100,
"stop": [],
"stop_token_ids": [2],
"min_tokens": 0,
"ignore_eos": False,
},
"sampling_options": {
"n": 1,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"repetition_penalty": 1.0,
"temperature": 1.0,
"top_p": 1.0,
"top_k": 0,
"min_p": 0.0,
"seed": None,
},
"output_options": {
"logprobs": None,
"prompt_logprobs": None,
"skip_special_tokens": True,
},
"eos_token_ids": [2],
"annotations": [],
},
tokens=[1, 2, 3],
vllm_preproc=EngineCoreRequest(
request_id="test-123",
prompt_token_ids=[1, 2, 3],
mm_features=None,
sampling_params=SamplingParams(),
pooling_params=None,
eos_token_id=2,
arrival_time=0.0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
prompt_embeds=None,
client_index=0,
current_wave=0,
priority=0,
trace_headers=None,
),
sampling_params=SamplingParams(),
request_for_sampling={"model": "test-model", "tools": None},
chat_template_kwargs={"reasoning_effort": None},
)
data = pickle.dumps(result)
restored = pickle.loads(data)
assert restored.dynamo_preproc == result.dynamo_preproc
assert restored.tokens == result.tokens
assert restored.vllm_preproc.request_id == "test-123"
assert restored.request_for_sampling == result.request_for_sampling
assert restored.chat_template_kwargs == result.chat_template_kwargs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment