Unverified Commit fda022b1 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

test(frontend): Minimal integration test for vllm processor (#7173)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 96a55928
......@@ -11,9 +11,6 @@ import os
import time
from argparse import Namespace
from collections.abc import AsyncGenerator
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import wait as _futures_wait
from dataclasses import dataclass
from typing import Any
from vllm.config import CacheConfig, LoadConfig, ModelConfig, VllmConfig
......@@ -38,12 +35,8 @@ from dynamo.llm import (
)
from dynamo.runtime import Client, DistributedRuntime
from .prepost import (
StreamingPostProcessor,
preprocess_chat_request,
preprocess_chat_request_sync,
)
from .utils import PreprocessError, random_uuid, worker_warmup
from .prepost import StreamingPostProcessor, preprocess_chat_request
from .utils import random_uuid
logger = logging.getLogger(__name__)
......@@ -74,181 +67,6 @@ def map_finish_reason(raw_reason: str | None) -> FinishReason | None:
return mapped
# --- Worker process globals (initialized once per process by _init_worker) ---
_w_input_processor: InputProcessor | None = None
_w_tokenizer: Any = None
_w_tool_parser_class: type[ToolParser] | None = None
@dataclass
class PreprocessWorkerResult:
"""Picklable return value from the preprocess worker."""
dynamo_preproc: dict[str, Any]
tokens: list[int]
vllm_preproc: EngineCoreRequest
sampling_params: SamplingParams
request_for_sampling: Any # ChatCompletionRequest (Pydantic model, picklable)
chat_template_kwargs: dict[str, Any]
def _init_worker(
model_path: str,
tokenizer_mode: str,
config_format: str,
load_format: str,
tool_parser_name: str | None,
) -> None:
"""Initialize a worker process with its own VllmConfig and InputProcessor."""
global _w_input_processor, _w_tokenizer, _w_tool_parser_class
global _w_reasoning_parser_class
model_config = ModelConfig(
model=model_path,
tokenizer_mode=tokenizer_mode,
config_format=config_format,
)
vllm_config = VllmConfig(
model_config=model_config,
load_config=LoadConfig(load_format=load_format),
cache_config=CacheConfig(),
)
_w_input_processor = InputProcessor(vllm_config)
_w_tokenizer = _w_input_processor.get_tokenizer()
if tool_parser_name:
_w_tool_parser_class = ToolParserManager.get_tool_parser(tool_parser_name)
else:
_w_tool_parser_class = None
def _preprocess_worker(
request: dict[str, Any],
request_id: str,
model_name: str,
) -> PreprocessWorkerResult:
"""Preprocess a request in a worker process and return a picklable result."""
assert _w_input_processor is not None
pre = preprocess_chat_request_sync(
request,
tokenizer=_w_tokenizer,
renderer=_w_input_processor.renderer,
tool_parser_class=_w_tool_parser_class,
)
request_for_sampling = pre.request_for_sampling
engine_prompt = pre.engine_prompt
tokens = pre.prompt_token_ids
if request_for_sampling.max_completion_tokens is not None:
max_tokens = request_for_sampling.max_completion_tokens
elif request_for_sampling.max_tokens is not None:
max_tokens = request_for_sampling.max_tokens
else:
max_tokens = None
sampling_params = SamplingParams(
output_kind=RequestOutputKind.DELTA,
max_tokens=max_tokens,
)
for k, v in _w_input_processor.generation_config_fields.items():
if hasattr(sampling_params, k):
setattr(sampling_params, k, v)
sampling_fields = (
set(getattr(SamplingParams, "__annotations__", ()))
& set(type(request_for_sampling).model_fields)
) - {"max_tokens", "logprobs", "output_kind"}
for k in sorted(sampling_fields):
v = getattr(request_for_sampling, k, None)
if v is not None:
setattr(sampling_params, k, v)
logprobs = request_for_sampling.logprobs
top_logprobs = request_for_sampling.top_logprobs
if logprobs is True:
sampling_params.logprobs = top_logprobs or 1
elif isinstance(logprobs, int) and not isinstance(logprobs, bool):
sampling_params.logprobs = logprobs
elif top_logprobs not in (None, 0):
sampling_params.logprobs = top_logprobs
prompt_inputs = TokensPrompt(prompt_token_ids=tokens)
if "multi_modal_data" in engine_prompt:
prompt_inputs["multi_modal_data"] = engine_prompt["multi_modal_data"]
if "multi_modal_uuids" in engine_prompt:
prompt_inputs["multi_modal_uuids"] = engine_prompt["multi_modal_uuids"]
if request_for_sampling.cache_salt is not None:
prompt_inputs["cache_salt"] = request_for_sampling.cache_salt
if request_for_sampling.mm_processor_kwargs is not None:
prompt_inputs["mm_processor_kwargs"] = request_for_sampling.mm_processor_kwargs
vllm_preproc: EngineCoreRequest = _w_input_processor.process_inputs(
request_id,
prompt_inputs,
sampling_params,
)
InputProcessor.assign_request_id(vllm_preproc)
sp = vllm_preproc.sampling_params
if sp.n != 1:
raise PreprocessError(
{
"error": {
"message": (
f"Unsupported value: 'n={sp.n}'. "
"This endpoint currently supports only n=1."
),
"type": "invalid_request_error",
"param": "n",
"code": "unsupported_value",
}
}
)
dynamo_preproc = {
"model": model_name,
"token_ids": tokens,
"stop_conditions": {
"max_tokens": sp.max_tokens,
"stop": sp.stop,
"stop_token_ids": sp.stop_token_ids,
"min_tokens": sp.min_tokens,
"ignore_eos": sp.ignore_eos,
},
"sampling_options": {
"n": sp.n,
"presence_penalty": sp.presence_penalty,
"frequency_penalty": sp.frequency_penalty,
"repetition_penalty": sp.repetition_penalty,
"temperature": sp.temperature,
"top_p": sp.top_p,
"top_k": sp.top_k,
"min_p": sp.min_p,
"seed": sp.seed,
},
"output_options": {
"logprobs": sp.logprobs,
"prompt_logprobs": sp.prompt_logprobs,
"skip_special_tokens": sp.skip_special_tokens,
},
"eos_token_ids": (
[vllm_preproc.eos_token_id] if vllm_preproc.eos_token_id is not None else []
),
"annotations": [],
}
return PreprocessWorkerResult(
dynamo_preproc=dynamo_preproc,
tokens=tokens,
vllm_preproc=vllm_preproc,
sampling_params=sampling_params,
request_for_sampling=request_for_sampling,
chat_template_kwargs=pre.chat_template_kwargs,
)
class VllmProcessor:
def __init__(
self,
......@@ -526,77 +344,6 @@ class VllmProcessor:
[vllm_preproc.request_id], internal=True
)
async def _generator_inner_pool(
self, request: dict[str, Any]
) -> AsyncGenerator[dict[str, Any], None]:
"""Process a request using the worker pool.
Phase 1: Preprocess in a worker process (semaphore held).
Phase 2: Remote inference via router (no worker held).
Phase 3: Post-process tokens in the main process.
"""
request_id = random_uuid()
# --- Phase 1: Preprocess (semaphore held) ---
try:
assert self._worker_semaphore is not None
async with self._worker_semaphore:
assert self.preprocess_pool is not None
future = self.preprocess_pool.submit(
_preprocess_worker, request, request_id, request["model"]
)
preproc_result: PreprocessWorkerResult = await asyncio.wrap_future(
future
)
# Semaphore + worker released here
except PreprocessError as exc:
yield exc.error_dict
return
except Exception as exc:
logger.exception("Worker preprocessing failed for request %s", request_id)
yield {
"error": {
"message": f"Worker error: {exc}",
"type": "internal_error",
}
}
return
# --- Between phases: reconstruct main-process objects ---
dynamo_preproc = preproc_result.dynamo_preproc
tokens = preproc_result.tokens
vllm_preproc = preproc_result.vllm_preproc
sampling_params = preproc_result.sampling_params
request_for_sampling = preproc_result.request_for_sampling
tool_parser = None
if (
self.tool_parser_class
and request_for_sampling.tools
and request_for_sampling.tool_choice != "none"
):
tool_parser = self.tool_parser_class(self.tokenizer)
post = StreamingPostProcessor(
tokenizer=self.tokenizer,
request_for_sampling=request_for_sampling,
sampling_params=sampling_params,
prompt_token_ids=tokens,
tool_parser=tool_parser,
reasoning_parser_class=self.reasoning_parser_class,
chat_template_kwargs=preproc_result.chat_template_kwargs,
)
async for item in self._generate_and_stream(
request_id,
request,
dynamo_preproc,
tokens,
vllm_preproc,
post,
):
yield item
class EngineFactory:
def __init__(
......@@ -705,45 +452,6 @@ class EngineFactory:
router_mode=self.router_config.router_mode
)
preprocess_pool = None
preprocess_workers = self.config.preprocess_workers
if preprocess_workers > 0:
logger.info(
"Creating preprocess worker pool with %d workers for model %s",
preprocess_workers,
source_path,
)
preprocess_pool = ProcessPoolExecutor(
max_workers=preprocess_workers,
initializer=_init_worker,
initargs=(
source_path,
tokenizer_mode,
config_format,
load_format,
tool_parser_name,
),
)
# Warm up all workers to ensure initialization completes
futures = [
preprocess_pool.submit(worker_warmup) for _ in range(preprocess_workers)
]
done, not_done = _futures_wait(futures, timeout=120)
if not_done:
for f in not_done:
f.cancel()
preprocess_pool.shutdown(wait=False, cancel_futures=True)
raise RuntimeError(
"Timed out waiting for preprocess worker pool warmup"
)
try:
for f in done:
f.result() # Raises if initializer failed
except Exception:
preprocess_pool.shutdown(wait=False, cancel_futures=True)
raise
logger.info("Preprocess worker pool ready (%d workers)", preprocess_workers)
gen = VllmProcessor(
tokenizer,
input_processor,
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import json
import logging
import os
import time
from pathlib import Path
from typing import Any, Generator
import pytest
import requests
from tests.utils.constants import QWEN
from tests.utils.managed_process import DynamoFrontendProcess, ManagedProcess
from tests.utils.port_utils import ServicePorts
logger = logging.getLogger(__name__)
TEST_MODEL = QWEN
CAPTURE_PATH_ENV = "DYN_VLLM_PREPOST_CAPTURE_PATH"
SEARCH_TOOL = {
"type": "function",
"function": {
"name": "search_gutenberg_books",
"description": "Search for books in the Project Gutenberg library",
"parameters": {
"type": "object",
"properties": {
"search_terms": {
"type": "array",
"items": {"type": "string"},
"description": "List of search terms to find books",
}
},
"required": ["search_terms"],
},
},
}
pytestmark = [
pytest.mark.vllm,
# vllm frontend doesn't need or use the GPU, but in CI pytorch seems to look for the Device
pytest.mark.gpu_1,
pytest.mark.pre_merge,
pytest.mark.integration,
pytest.mark.parallel,
pytest.mark.model(TEST_MODEL),
]
class MockVllmPrepostWorkerProcess(ManagedProcess):
"""Test worker that captures frontend tokenized requests."""
def __init__(
self,
request,
*,
frontend_port: int,
capture_path: Path,
worker_id: str = "vllm-prepost-worker",
) -> None:
env = os.environ.copy()
env[CAPTURE_PATH_ENV] = str(capture_path)
super().__init__(
command=["python3", "-m", "tests.frontend.vllm_prepost_worker"],
env=env,
health_check_urls=[
(
f"http://localhost:{frontend_port}/v1/models",
self._check_models_api,
)
],
timeout=60,
display_output=True,
terminate_all_matching_process_names=False,
straggler_commands=["-m tests.frontend.vllm_prepost_worker"],
log_dir=f"{request.node.name}_{worker_id}",
)
@staticmethod
def _check_models_api(response: requests.Response) -> bool:
try:
if response.status_code != 200:
return False
data = response.json()
except (ValueError, KeyError):
return False
for model in data.get("data", []):
if model.get("id") == TEST_MODEL:
return True
return False
def _read_captured_request(path: Path, timeout_s: float = 20.0) -> dict[str, Any]:
deadline = time.time() + timeout_s
while time.time() < deadline:
if path.exists():
return json.loads(path.read_text(encoding="utf-8"))
time.sleep(0.1)
raise AssertionError(f"Timed out waiting for captured request at {path}")
def _collect_stream_chunks(response: requests.Response) -> list[dict[str, Any]]:
response.raise_for_status()
chunks: list[dict[str, Any]] = []
saw_done = False
for line in response.iter_lines(decode_unicode=True):
if not line:
continue
assert line.startswith("data: "), f"Unexpected SSE line: {line!r}"
payload = line[len("data: ") :]
if payload == "[DONE]":
saw_done = True
break
chunks.append(json.loads(payload))
assert saw_done, "Missing [DONE] marker in SSE stream"
assert chunks, "Expected streamed chunks but got none"
return chunks
def _collect_reasoning(chunks: list[dict[str, Any]]) -> str:
parts: list[str] = []
for chunk in chunks:
for choice in chunk.get("choices", []):
reasoning = (choice.get("delta") or {}).get("reasoning_content")
if reasoning is not None:
parts.append(reasoning)
return "".join(parts)
def _collect_tool_calls(chunks: list[dict[str, Any]]) -> list[dict[str, Any]]:
merged: dict[int, dict[str, Any]] = {}
for chunk in chunks:
for choice in chunk.get("choices", []):
for tool_call in (choice.get("delta") or {}).get("tool_calls") or []:
idx = tool_call["index"]
if idx not in merged:
merged[idx] = {
"id": tool_call.get("id"),
"type": tool_call.get("type"),
"function": {
"name": tool_call.get("function", {}).get("name"),
"arguments": tool_call.get("function", {}).get(
"arguments", ""
),
},
}
continue
existing = merged[idx]
if tool_call.get("id") and not existing["id"]:
existing["id"] = tool_call["id"]
if tool_call.get("type") and not existing["type"]:
existing["type"] = tool_call["type"]
incoming_fn = tool_call.get("function", {})
if incoming_fn.get("name") and not existing["function"]["name"]:
existing["function"]["name"] = incoming_fn["name"]
if incoming_fn.get("arguments"):
existing["function"]["arguments"] += incoming_fn["arguments"]
return [merged[idx] for idx in sorted(merged)]
@pytest.fixture(scope="function")
def start_services(
request,
runtime_services_dynamic_ports,
dynamo_dynamic_ports: ServicePorts,
tmp_path: Path,
) -> Generator[tuple[int, Path], None, None]:
_ = runtime_services_dynamic_ports
frontend_port = dynamo_dynamic_ports.frontend_port
capture_path = tmp_path / "captured_request.json"
with DynamoFrontendProcess(
request,
frontend_port=frontend_port,
extra_args=[
"--dyn-chat-processor",
"vllm",
"--discovery-backend",
"etcd", # Started by the fixture
"--request-plane",
"tcp",
"--enable-auto-tool-choice",
"--tool-call-parser",
"hermes",
"--reasoning-parser",
"qwen3",
],
extra_env={"DYN_VLLM_STREAM_INTERVAL": "20"},
terminate_all_matching_process_names=False,
):
logger.info("Frontend started on port %s", frontend_port)
with MockVllmPrepostWorkerProcess(
request,
frontend_port=frontend_port,
capture_path=capture_path,
):
logger.info("vLLM pre/post test worker registered model %s", TEST_MODEL)
yield frontend_port, capture_path
@pytest.mark.timeout(120)
def test_vllm_chat_processor_tokenizes_and_streams_tool_calls(
start_services: tuple[int, Path],
) -> None:
frontend_port, capture_path = start_services
payload = {
"model": TEST_MODEL,
"messages": [
{
"role": "user",
"content": "What are the titles of some James Joyce books? Use the tool to search.",
}
],
"tools": [SEARCH_TOOL],
"tool_choice": "auto",
"stream": True,
"max_tokens": 128,
}
response = requests.post(
f"http://localhost:{frontend_port}/v1/chat/completions",
json=payload,
timeout=60,
stream=True,
)
chunks = _collect_stream_chunks(response)
captured = _read_captured_request(capture_path)
assert captured["model"] == TEST_MODEL
assert isinstance(captured["token_ids"], list) and captured["token_ids"]
decoded_prompt = captured["decoded_prompt"]
assert "What are the titles of some James Joyce books?" in decoded_prompt
assert "search_gutenberg_books" in decoded_prompt
reasoning = _collect_reasoning(chunks)
assert "titles of some James Joyce books" in reasoning
tool_calls = _collect_tool_calls(chunks)
assert len(tool_calls) == 1
tool_call = tool_calls[0]
assert tool_call["function"]["name"] == "search_gutenberg_books"
assert json.loads(tool_call["function"]["arguments"]) == {
"search_terms": ["James Joyce", "Project Gutenberg"],
}
content = "".join(
(choice.get("delta") or {}).get("content") or ""
for chunk in chunks
for choice in chunk.get("choices", [])
)
assert "<tool_call>" not in content
assert "</tool_call>" not in content
finish_reasons = [
choice.get("finish_reason")
for chunk in chunks
for choice in chunk.get("choices", [])
if choice.get("finish_reason")
]
assert finish_reasons, "Expected at least one finish_reason"
assert set(finish_reasons) <= {"stop", "tool_calls"}
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Lightweight token-based worker for vLLM frontend pre/post integration tests."""
from __future__ import annotations
import asyncio
import json
import os
from pathlib import Path
from typing import Any
import uvloop
from transformers import AutoTokenizer
from dynamo.llm import ModelInput, ModelType, register_model
from dynamo.runtime import DistributedRuntime
from tests.frontend.test_prepost import OUTPUTS_INTERVAL_20
from tests.frontend.test_vllm_prepost_integration import CAPTURE_PATH_ENV
from tests.utils.constants import QWEN
class VllmPrepostTestHandler:
"""Captures tokenized requests and streams a fixed token response."""
def __init__(self, model_name: str = QWEN):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def _write_capture(self, request: dict[str, Any]) -> None:
capture_path = os.environ.get(CAPTURE_PATH_ENV)
if not capture_path:
return
token_ids = request.get("token_ids", [])
captured = {
"model": request.get("model"),
"token_ids": token_ids,
"stop_conditions": request.get("stop_conditions"),
"sampling_options": request.get("sampling_options"),
"output_options": request.get("output_options"),
"eos_token_ids": request.get("eos_token_ids"),
"decoded_prompt": self.tokenizer.decode(
token_ids,
skip_special_tokens=False,
),
}
path = Path(capture_path)
path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = path.with_suffix(path.suffix + ".tmp")
tmp_path.write_text(json.dumps(captured), encoding="utf-8")
tmp_path.replace(path)
async def generate(self, request: dict[str, Any], context):
self._write_capture(request)
for output in OUTPUTS_INTERVAL_20:
chunk = {"token_ids": list(output.token_ids)}
if output.finish_reason is not None:
chunk["finish_reason"] = output.finish_reason
if output.stop_reason is not None:
chunk["stop_reason"] = output.stop_reason
yield chunk
async def main():
"""Register a token-based chat model and stream deterministic responses."""
runtime = DistributedRuntime(
asyncio.get_running_loop(), "etcd", "tcp", enable_nats=False
)
endpoint = runtime.endpoint("test.vllm-prepost.generate")
await register_model(
ModelInput.Tokens,
ModelType.Chat,
endpoint,
QWEN,
model_name=QWEN,
)
handler = VllmPrepostTestHandler(QWEN)
await endpoint.serve_endpoint(handler.generate)
if __name__ == "__main__":
uvloop.run(main())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment