"...git@developer.sourcefind.cn:2222/OpenDAS/colossalai.git" did not exist on "b5f9e37c709656b286940f1b5e05abddfa257e3d"
Unverified Commit fda022b1 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

test(frontend): Minimal integration test for vllm processor (#7173)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 96a55928
...@@ -11,9 +11,6 @@ import os ...@@ -11,9 +11,6 @@ import os
import time import time
from argparse import Namespace from argparse import Namespace
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import wait as _futures_wait
from dataclasses import dataclass
from typing import Any from typing import Any
from vllm.config import CacheConfig, LoadConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, LoadConfig, ModelConfig, VllmConfig
...@@ -38,12 +35,8 @@ from dynamo.llm import ( ...@@ -38,12 +35,8 @@ from dynamo.llm import (
) )
from dynamo.runtime import Client, DistributedRuntime from dynamo.runtime import Client, DistributedRuntime
from .prepost import ( from .prepost import StreamingPostProcessor, preprocess_chat_request
StreamingPostProcessor, from .utils import random_uuid
preprocess_chat_request,
preprocess_chat_request_sync,
)
from .utils import PreprocessError, random_uuid, worker_warmup
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -74,181 +67,6 @@ def map_finish_reason(raw_reason: str | None) -> FinishReason | None: ...@@ -74,181 +67,6 @@ def map_finish_reason(raw_reason: str | None) -> FinishReason | None:
return mapped return mapped
# --- Worker process globals (initialized once per process by _init_worker) ---
_w_input_processor: InputProcessor | None = None
_w_tokenizer: Any = None
_w_tool_parser_class: type[ToolParser] | None = None
@dataclass
class PreprocessWorkerResult:
"""Picklable return value from the preprocess worker."""
dynamo_preproc: dict[str, Any]
tokens: list[int]
vllm_preproc: EngineCoreRequest
sampling_params: SamplingParams
request_for_sampling: Any # ChatCompletionRequest (Pydantic model, picklable)
chat_template_kwargs: dict[str, Any]
def _init_worker(
model_path: str,
tokenizer_mode: str,
config_format: str,
load_format: str,
tool_parser_name: str | None,
) -> None:
"""Initialize a worker process with its own VllmConfig and InputProcessor."""
global _w_input_processor, _w_tokenizer, _w_tool_parser_class
global _w_reasoning_parser_class
model_config = ModelConfig(
model=model_path,
tokenizer_mode=tokenizer_mode,
config_format=config_format,
)
vllm_config = VllmConfig(
model_config=model_config,
load_config=LoadConfig(load_format=load_format),
cache_config=CacheConfig(),
)
_w_input_processor = InputProcessor(vllm_config)
_w_tokenizer = _w_input_processor.get_tokenizer()
if tool_parser_name:
_w_tool_parser_class = ToolParserManager.get_tool_parser(tool_parser_name)
else:
_w_tool_parser_class = None
def _preprocess_worker(
request: dict[str, Any],
request_id: str,
model_name: str,
) -> PreprocessWorkerResult:
"""Preprocess a request in a worker process and return a picklable result."""
assert _w_input_processor is not None
pre = preprocess_chat_request_sync(
request,
tokenizer=_w_tokenizer,
renderer=_w_input_processor.renderer,
tool_parser_class=_w_tool_parser_class,
)
request_for_sampling = pre.request_for_sampling
engine_prompt = pre.engine_prompt
tokens = pre.prompt_token_ids
if request_for_sampling.max_completion_tokens is not None:
max_tokens = request_for_sampling.max_completion_tokens
elif request_for_sampling.max_tokens is not None:
max_tokens = request_for_sampling.max_tokens
else:
max_tokens = None
sampling_params = SamplingParams(
output_kind=RequestOutputKind.DELTA,
max_tokens=max_tokens,
)
for k, v in _w_input_processor.generation_config_fields.items():
if hasattr(sampling_params, k):
setattr(sampling_params, k, v)
sampling_fields = (
set(getattr(SamplingParams, "__annotations__", ()))
& set(type(request_for_sampling).model_fields)
) - {"max_tokens", "logprobs", "output_kind"}
for k in sorted(sampling_fields):
v = getattr(request_for_sampling, k, None)
if v is not None:
setattr(sampling_params, k, v)
logprobs = request_for_sampling.logprobs
top_logprobs = request_for_sampling.top_logprobs
if logprobs is True:
sampling_params.logprobs = top_logprobs or 1
elif isinstance(logprobs, int) and not isinstance(logprobs, bool):
sampling_params.logprobs = logprobs
elif top_logprobs not in (None, 0):
sampling_params.logprobs = top_logprobs
prompt_inputs = TokensPrompt(prompt_token_ids=tokens)
if "multi_modal_data" in engine_prompt:
prompt_inputs["multi_modal_data"] = engine_prompt["multi_modal_data"]
if "multi_modal_uuids" in engine_prompt:
prompt_inputs["multi_modal_uuids"] = engine_prompt["multi_modal_uuids"]
if request_for_sampling.cache_salt is not None:
prompt_inputs["cache_salt"] = request_for_sampling.cache_salt
if request_for_sampling.mm_processor_kwargs is not None:
prompt_inputs["mm_processor_kwargs"] = request_for_sampling.mm_processor_kwargs
vllm_preproc: EngineCoreRequest = _w_input_processor.process_inputs(
request_id,
prompt_inputs,
sampling_params,
)
InputProcessor.assign_request_id(vllm_preproc)
sp = vllm_preproc.sampling_params
if sp.n != 1:
raise PreprocessError(
{
"error": {
"message": (
f"Unsupported value: 'n={sp.n}'. "
"This endpoint currently supports only n=1."
),
"type": "invalid_request_error",
"param": "n",
"code": "unsupported_value",
}
}
)
dynamo_preproc = {
"model": model_name,
"token_ids": tokens,
"stop_conditions": {
"max_tokens": sp.max_tokens,
"stop": sp.stop,
"stop_token_ids": sp.stop_token_ids,
"min_tokens": sp.min_tokens,
"ignore_eos": sp.ignore_eos,
},
"sampling_options": {
"n": sp.n,
"presence_penalty": sp.presence_penalty,
"frequency_penalty": sp.frequency_penalty,
"repetition_penalty": sp.repetition_penalty,
"temperature": sp.temperature,
"top_p": sp.top_p,
"top_k": sp.top_k,
"min_p": sp.min_p,
"seed": sp.seed,
},
"output_options": {
"logprobs": sp.logprobs,
"prompt_logprobs": sp.prompt_logprobs,
"skip_special_tokens": sp.skip_special_tokens,
},
"eos_token_ids": (
[vllm_preproc.eos_token_id] if vllm_preproc.eos_token_id is not None else []
),
"annotations": [],
}
return PreprocessWorkerResult(
dynamo_preproc=dynamo_preproc,
tokens=tokens,
vllm_preproc=vllm_preproc,
sampling_params=sampling_params,
request_for_sampling=request_for_sampling,
chat_template_kwargs=pre.chat_template_kwargs,
)
class VllmProcessor: class VllmProcessor:
def __init__( def __init__(
self, self,
...@@ -526,77 +344,6 @@ class VllmProcessor: ...@@ -526,77 +344,6 @@ class VllmProcessor:
[vllm_preproc.request_id], internal=True [vllm_preproc.request_id], internal=True
) )
async def _generator_inner_pool(
self, request: dict[str, Any]
) -> AsyncGenerator[dict[str, Any], None]:
"""Process a request using the worker pool.
Phase 1: Preprocess in a worker process (semaphore held).
Phase 2: Remote inference via router (no worker held).
Phase 3: Post-process tokens in the main process.
"""
request_id = random_uuid()
# --- Phase 1: Preprocess (semaphore held) ---
try:
assert self._worker_semaphore is not None
async with self._worker_semaphore:
assert self.preprocess_pool is not None
future = self.preprocess_pool.submit(
_preprocess_worker, request, request_id, request["model"]
)
preproc_result: PreprocessWorkerResult = await asyncio.wrap_future(
future
)
# Semaphore + worker released here
except PreprocessError as exc:
yield exc.error_dict
return
except Exception as exc:
logger.exception("Worker preprocessing failed for request %s", request_id)
yield {
"error": {
"message": f"Worker error: {exc}",
"type": "internal_error",
}
}
return
# --- Between phases: reconstruct main-process objects ---
dynamo_preproc = preproc_result.dynamo_preproc
tokens = preproc_result.tokens
vllm_preproc = preproc_result.vllm_preproc
sampling_params = preproc_result.sampling_params
request_for_sampling = preproc_result.request_for_sampling
tool_parser = None
if (
self.tool_parser_class
and request_for_sampling.tools
and request_for_sampling.tool_choice != "none"
):
tool_parser = self.tool_parser_class(self.tokenizer)
post = StreamingPostProcessor(
tokenizer=self.tokenizer,
request_for_sampling=request_for_sampling,
sampling_params=sampling_params,
prompt_token_ids=tokens,
tool_parser=tool_parser,
reasoning_parser_class=self.reasoning_parser_class,
chat_template_kwargs=preproc_result.chat_template_kwargs,
)
async for item in self._generate_and_stream(
request_id,
request,
dynamo_preproc,
tokens,
vllm_preproc,
post,
):
yield item
class EngineFactory: class EngineFactory:
def __init__( def __init__(
...@@ -705,45 +452,6 @@ class EngineFactory: ...@@ -705,45 +452,6 @@ class EngineFactory:
router_mode=self.router_config.router_mode router_mode=self.router_config.router_mode
) )
preprocess_pool = None
preprocess_workers = self.config.preprocess_workers
if preprocess_workers > 0:
logger.info(
"Creating preprocess worker pool with %d workers for model %s",
preprocess_workers,
source_path,
)
preprocess_pool = ProcessPoolExecutor(
max_workers=preprocess_workers,
initializer=_init_worker,
initargs=(
source_path,
tokenizer_mode,
config_format,
load_format,
tool_parser_name,
),
)
# Warm up all workers to ensure initialization completes
futures = [
preprocess_pool.submit(worker_warmup) for _ in range(preprocess_workers)
]
done, not_done = _futures_wait(futures, timeout=120)
if not_done:
for f in not_done:
f.cancel()
preprocess_pool.shutdown(wait=False, cancel_futures=True)
raise RuntimeError(
"Timed out waiting for preprocess worker pool warmup"
)
try:
for f in done:
f.result() # Raises if initializer failed
except Exception:
preprocess_pool.shutdown(wait=False, cancel_futures=True)
raise
logger.info("Preprocess worker pool ready (%d workers)", preprocess_workers)
gen = VllmProcessor( gen = VllmProcessor(
tokenizer, tokenizer,
input_processor, input_processor,
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import json
import logging
import os
import time
from pathlib import Path
from typing import Any, Generator
import pytest
import requests
from tests.utils.constants import QWEN
from tests.utils.managed_process import DynamoFrontendProcess, ManagedProcess
from tests.utils.port_utils import ServicePorts
logger = logging.getLogger(__name__)
TEST_MODEL = QWEN
CAPTURE_PATH_ENV = "DYN_VLLM_PREPOST_CAPTURE_PATH"
SEARCH_TOOL = {
"type": "function",
"function": {
"name": "search_gutenberg_books",
"description": "Search for books in the Project Gutenberg library",
"parameters": {
"type": "object",
"properties": {
"search_terms": {
"type": "array",
"items": {"type": "string"},
"description": "List of search terms to find books",
}
},
"required": ["search_terms"],
},
},
}
pytestmark = [
pytest.mark.vllm,
# vllm frontend doesn't need or use the GPU, but in CI pytorch seems to look for the Device
pytest.mark.gpu_1,
pytest.mark.pre_merge,
pytest.mark.integration,
pytest.mark.parallel,
pytest.mark.model(TEST_MODEL),
]
class MockVllmPrepostWorkerProcess(ManagedProcess):
"""Test worker that captures frontend tokenized requests."""
def __init__(
self,
request,
*,
frontend_port: int,
capture_path: Path,
worker_id: str = "vllm-prepost-worker",
) -> None:
env = os.environ.copy()
env[CAPTURE_PATH_ENV] = str(capture_path)
super().__init__(
command=["python3", "-m", "tests.frontend.vllm_prepost_worker"],
env=env,
health_check_urls=[
(
f"http://localhost:{frontend_port}/v1/models",
self._check_models_api,
)
],
timeout=60,
display_output=True,
terminate_all_matching_process_names=False,
straggler_commands=["-m tests.frontend.vllm_prepost_worker"],
log_dir=f"{request.node.name}_{worker_id}",
)
@staticmethod
def _check_models_api(response: requests.Response) -> bool:
try:
if response.status_code != 200:
return False
data = response.json()
except (ValueError, KeyError):
return False
for model in data.get("data", []):
if model.get("id") == TEST_MODEL:
return True
return False
def _read_captured_request(path: Path, timeout_s: float = 20.0) -> dict[str, Any]:
deadline = time.time() + timeout_s
while time.time() < deadline:
if path.exists():
return json.loads(path.read_text(encoding="utf-8"))
time.sleep(0.1)
raise AssertionError(f"Timed out waiting for captured request at {path}")
def _collect_stream_chunks(response: requests.Response) -> list[dict[str, Any]]:
response.raise_for_status()
chunks: list[dict[str, Any]] = []
saw_done = False
for line in response.iter_lines(decode_unicode=True):
if not line:
continue
assert line.startswith("data: "), f"Unexpected SSE line: {line!r}"
payload = line[len("data: ") :]
if payload == "[DONE]":
saw_done = True
break
chunks.append(json.loads(payload))
assert saw_done, "Missing [DONE] marker in SSE stream"
assert chunks, "Expected streamed chunks but got none"
return chunks
def _collect_reasoning(chunks: list[dict[str, Any]]) -> str:
parts: list[str] = []
for chunk in chunks:
for choice in chunk.get("choices", []):
reasoning = (choice.get("delta") or {}).get("reasoning_content")
if reasoning is not None:
parts.append(reasoning)
return "".join(parts)
def _collect_tool_calls(chunks: list[dict[str, Any]]) -> list[dict[str, Any]]:
merged: dict[int, dict[str, Any]] = {}
for chunk in chunks:
for choice in chunk.get("choices", []):
for tool_call in (choice.get("delta") or {}).get("tool_calls") or []:
idx = tool_call["index"]
if idx not in merged:
merged[idx] = {
"id": tool_call.get("id"),
"type": tool_call.get("type"),
"function": {
"name": tool_call.get("function", {}).get("name"),
"arguments": tool_call.get("function", {}).get(
"arguments", ""
),
},
}
continue
existing = merged[idx]
if tool_call.get("id") and not existing["id"]:
existing["id"] = tool_call["id"]
if tool_call.get("type") and not existing["type"]:
existing["type"] = tool_call["type"]
incoming_fn = tool_call.get("function", {})
if incoming_fn.get("name") and not existing["function"]["name"]:
existing["function"]["name"] = incoming_fn["name"]
if incoming_fn.get("arguments"):
existing["function"]["arguments"] += incoming_fn["arguments"]
return [merged[idx] for idx in sorted(merged)]
@pytest.fixture(scope="function")
def start_services(
request,
runtime_services_dynamic_ports,
dynamo_dynamic_ports: ServicePorts,
tmp_path: Path,
) -> Generator[tuple[int, Path], None, None]:
_ = runtime_services_dynamic_ports
frontend_port = dynamo_dynamic_ports.frontend_port
capture_path = tmp_path / "captured_request.json"
with DynamoFrontendProcess(
request,
frontend_port=frontend_port,
extra_args=[
"--dyn-chat-processor",
"vllm",
"--discovery-backend",
"etcd", # Started by the fixture
"--request-plane",
"tcp",
"--enable-auto-tool-choice",
"--tool-call-parser",
"hermes",
"--reasoning-parser",
"qwen3",
],
extra_env={"DYN_VLLM_STREAM_INTERVAL": "20"},
terminate_all_matching_process_names=False,
):
logger.info("Frontend started on port %s", frontend_port)
with MockVllmPrepostWorkerProcess(
request,
frontend_port=frontend_port,
capture_path=capture_path,
):
logger.info("vLLM pre/post test worker registered model %s", TEST_MODEL)
yield frontend_port, capture_path
@pytest.mark.timeout(120)
def test_vllm_chat_processor_tokenizes_and_streams_tool_calls(
start_services: tuple[int, Path],
) -> None:
frontend_port, capture_path = start_services
payload = {
"model": TEST_MODEL,
"messages": [
{
"role": "user",
"content": "What are the titles of some James Joyce books? Use the tool to search.",
}
],
"tools": [SEARCH_TOOL],
"tool_choice": "auto",
"stream": True,
"max_tokens": 128,
}
response = requests.post(
f"http://localhost:{frontend_port}/v1/chat/completions",
json=payload,
timeout=60,
stream=True,
)
chunks = _collect_stream_chunks(response)
captured = _read_captured_request(capture_path)
assert captured["model"] == TEST_MODEL
assert isinstance(captured["token_ids"], list) and captured["token_ids"]
decoded_prompt = captured["decoded_prompt"]
assert "What are the titles of some James Joyce books?" in decoded_prompt
assert "search_gutenberg_books" in decoded_prompt
reasoning = _collect_reasoning(chunks)
assert "titles of some James Joyce books" in reasoning
tool_calls = _collect_tool_calls(chunks)
assert len(tool_calls) == 1
tool_call = tool_calls[0]
assert tool_call["function"]["name"] == "search_gutenberg_books"
assert json.loads(tool_call["function"]["arguments"]) == {
"search_terms": ["James Joyce", "Project Gutenberg"],
}
content = "".join(
(choice.get("delta") or {}).get("content") or ""
for chunk in chunks
for choice in chunk.get("choices", [])
)
assert "<tool_call>" not in content
assert "</tool_call>" not in content
finish_reasons = [
choice.get("finish_reason")
for chunk in chunks
for choice in chunk.get("choices", [])
if choice.get("finish_reason")
]
assert finish_reasons, "Expected at least one finish_reason"
assert set(finish_reasons) <= {"stop", "tool_calls"}
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Lightweight token-based worker for vLLM frontend pre/post integration tests."""
from __future__ import annotations
import asyncio
import json
import os
from pathlib import Path
from typing import Any
import uvloop
from transformers import AutoTokenizer
from dynamo.llm import ModelInput, ModelType, register_model
from dynamo.runtime import DistributedRuntime
from tests.frontend.test_prepost import OUTPUTS_INTERVAL_20
from tests.frontend.test_vllm_prepost_integration import CAPTURE_PATH_ENV
from tests.utils.constants import QWEN
class VllmPrepostTestHandler:
"""Captures tokenized requests and streams a fixed token response."""
def __init__(self, model_name: str = QWEN):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def _write_capture(self, request: dict[str, Any]) -> None:
capture_path = os.environ.get(CAPTURE_PATH_ENV)
if not capture_path:
return
token_ids = request.get("token_ids", [])
captured = {
"model": request.get("model"),
"token_ids": token_ids,
"stop_conditions": request.get("stop_conditions"),
"sampling_options": request.get("sampling_options"),
"output_options": request.get("output_options"),
"eos_token_ids": request.get("eos_token_ids"),
"decoded_prompt": self.tokenizer.decode(
token_ids,
skip_special_tokens=False,
),
}
path = Path(capture_path)
path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = path.with_suffix(path.suffix + ".tmp")
tmp_path.write_text(json.dumps(captured), encoding="utf-8")
tmp_path.replace(path)
async def generate(self, request: dict[str, Any], context):
self._write_capture(request)
for output in OUTPUTS_INTERVAL_20:
chunk = {"token_ids": list(output.token_ids)}
if output.finish_reason is not None:
chunk["finish_reason"] = output.finish_reason
if output.stop_reason is not None:
chunk["stop_reason"] = output.stop_reason
yield chunk
async def main():
"""Register a token-based chat model and stream deterministic responses."""
runtime = DistributedRuntime(
asyncio.get_running_loop(), "etcd", "tcp", enable_nats=False
)
endpoint = runtime.endpoint("test.vllm-prepost.generate")
await register_model(
ModelInput.Tokens,
ModelType.Chat,
endpoint,
QWEN,
model_name=QWEN,
)
handler = VllmPrepostTestHandler(QWEN)
await endpoint.serve_endpoint(handler.generate)
if __name__ == "__main__":
uvloop.run(main())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment