Unverified Commit fe10dbfd authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: vLLM pre/postprocessing in-framework (#4529)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
parent 01819b87
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
class InputParamManager:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def get_input_param(self, request: dict, use_tokenizer: bool):
"""
Get the input parameter for the request.
"""
if use_tokenizer:
print(f"Request: {request}")
if self.tokenizer is None:
raise ValueError("Tokenizer is not available")
if "messages" in request:
return self.tokenizer.apply_chat_template(
request["messages"], tokenize=False, add_generation_prompt=True
)
elif "prompt" in request:
return request["prompt"]
elif "text" in request:
return request["text"]
else:
raise ValueError("No input parameter found in request")
return request.get("token_ids")
......@@ -70,7 +70,7 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = {
"flags": ["--use-sglang-tokenizer"],
"action": "store_true",
"default": False,
"help": "Use SGLang's tokenizer. This will skip tokenization of the input and output and only v1/chat/completions will be available when using the dynamo frontend. Cannot be used with --custom-jinja-template.",
"help": "Use SGLang's tokenizer for pre and post processing. This bypasses Dynamo's preprocessor and only v1/chat/completions will be available through the Dynamo frontend. Cannot be used with --custom-jinja-template.",
},
"multimodal-processor": {
"flags": ["--multimodal-processor"],
......
......@@ -53,7 +53,9 @@ class SglangHealthCheckPayload(HealthCheckPayload):
Provides SGLang defaults and inherits environment override support from base class.
"""
def __init__(self, engine: Optional[sgl.Engine] = None) -> None:
def __init__(
self, engine: Optional[sgl.Engine] = None, use_text_input: bool = False
) -> None:
"""Initialize SGLang health check payload with model-specific BOS token.
Args:
......@@ -62,7 +64,6 @@ class SglangHealthCheckPayload(HealthCheckPayload):
bos_token_id = _get_bos_token_id_from_engine(engine)
self.default_payload = {
"token_ids": [bos_token_id],
"stop_conditions": {
"max_tokens": 1, # Generate only 1 token
"ignore_eos": False,
......@@ -75,6 +76,12 @@ class SglangHealthCheckPayload(HealthCheckPayload):
"eos_token_ids": [],
"annotations": [],
}
if use_text_input:
self.default_payload["prompt"] = "Test"
else:
self.default_payload["token_ids"] = [bos_token_id]
super().__init__()
......@@ -84,7 +91,9 @@ class SglangPrefillHealthCheckPayload(HealthCheckPayload):
The prefill handler expects a wrapped structure with 'request' and 'sampling_params'.
"""
def __init__(self, engine: Optional[sgl.Engine] = None) -> None:
def __init__(
self, engine: Optional[sgl.Engine] = None, use_text_input: bool = False
) -> None:
"""Initialize SGLang prefill health check payload with proper wrapped structure.
Args:
......@@ -93,9 +102,7 @@ class SglangPrefillHealthCheckPayload(HealthCheckPayload):
bos_token_id = _get_bos_token_id_from_engine(engine)
self.default_payload = {
"request": {
"token_ids": [bos_token_id],
},
"request": {},
"sampling_params": {
"max_new_tokens": 1, # Generate only 1 token
"temperature": 0.0,
......@@ -104,4 +111,10 @@ class SglangPrefillHealthCheckPayload(HealthCheckPayload):
"ignore_eos": False,
},
}
if use_text_input:
self.default_payload["request"]["prompt"] = "Test" # type: ignore
else:
self.default_payload["request"]["token_ids"] = [bos_token_id] # type: ignore
super().__init__()
......@@ -168,8 +168,10 @@ async def init(runtime: DistributedRuntime, config: Config):
handler = DecodeWorkerHandler(
component, engine, config, publisher, prefill_client, prefill_router_client
)
health_check_payload = SglangHealthCheckPayload(engine).to_dict()
print(f"Config: {config}")
health_check_payload = SglangHealthCheckPayload(
engine, use_text_input=dynamo_args.use_sglang_tokenizer
).to_dict()
logging.info(
f"Registering model with endpoint types: {dynamo_args.dyn_endpoint_types}"
......@@ -319,7 +321,9 @@ async def init_embedding(runtime: DistributedRuntime, config: Config):
ready_event = asyncio.Event()
handler = EmbeddingWorkerHandler(component, engine, config, publisher)
health_check_payload = SglangHealthCheckPayload(engine).to_dict()
health_check_payload = SglangHealthCheckPayload(
engine, use_text_input=dynamo_args.use_sglang_tokenizer
).to_dict()
try:
# Start endpoint immediately and register model concurrently
......
......@@ -16,6 +16,7 @@ from sglang.srt.tracing import trace as sglang_trace
from sglang.srt.utils import get_local_ip_auto
from dynamo._core import Client, Component, Context
from dynamo.common.utils.input_params import InputParamManager
from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher
......@@ -54,6 +55,12 @@ class BaseWorkerHandler(ABC):
self.skip_tokenizer_init = config.server_args.skip_tokenizer_init
self.enable_trace = config.server_args.enable_trace
self.input_param_manager = InputParamManager(
self.engine.tokenizer_manager.tokenizer
if not self.skip_tokenizer_init
else None
)
@abstractmethod
async def generate(self, request: Dict[str, Any], context: Context):
"""Generate response from request.
......@@ -72,23 +79,13 @@ class BaseWorkerHandler(ABC):
pass
def _get_input_param(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""Get the appropriate input parameter for SGLang engine.
Args:
request: Request dict with token_ids or messages.
request_input = self.input_param_manager.get_input_param(
request, use_tokenizer=not self.skip_tokenizer_init
)
Returns:
Dict with either input_ids or prompt for engine.
"""
if self.skip_tokenizer_init:
return {"input_ids": request["token_ids"]}
else:
# use sglang's chat templating itself but leave tokenization to the
# interal engine's TokenizerManager
prompt = self.engine.tokenizer_manager.tokenizer.apply_chat_template(
request["messages"], tokenize=False, add_generation_prompt=True
)
return {"prompt": prompt}
return {
"prompt" if isinstance(request_input, str) else "input_ids": request_input
}
@staticmethod
def _generate_bootstrap_room() -> int:
......
......@@ -69,6 +69,9 @@ class Config:
# dump config to file
dump_config_to: Optional[str] = None
# Use vLLM's tokenizer for pre/post processing
use_vllm_tokenizer: bool = False
def has_connector(self, connector_name: str) -> bool:
"""
Check if a specific connector is enabled.
......@@ -201,6 +204,12 @@ def parse_args() -> Config:
default=os.environ.get("DYN_REQUEST_PLANE", "nats"),
help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]",
)
parser.add_argument(
"--use-vllm-tokenizer",
action="store_true",
default=False,
help="Use vLLM's tokenizer for pre and post processing. This bypasses Dynamo's preprocessor and only v1/chat/completions will be available through the Dynamo frontend.",
)
add_config_dump_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser)
......@@ -303,6 +312,7 @@ def parse_args() -> Config:
config.mm_prompt_template = args.mm_prompt_template
config.store_kv = args.store_kv
config.request_plane = args.request_plane
config.use_vllm_tokenizer = args.use_vllm_tokenizer
# Validate custom Jinja template file exists if provided
if config.custom_jinja_template is not None:
......
......@@ -5,16 +5,18 @@ import asyncio
import logging
import os
import tempfile
import time
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Final
from vllm.inputs import TokensPrompt
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.engine.exceptions import EngineDeadError
from dynamo.common.utils.input_params import InputParamManager
from dynamo.llm import (
ModelInput,
ModelType,
......@@ -70,7 +72,7 @@ def build_sampling_params(
model_max_len: int | None = None,
) -> SamplingParams:
"""
Build SamplingParams from a PreprocessedRequest.
Build SamplingParams from a PreprocessedRequest (internal protocol format).
Args:
request: The PreprocessedRequest dict with 'sampling_options', 'stop_conditions',
......@@ -164,6 +166,61 @@ def build_sampling_params(
return sampling_params
def build_sampling_params_openai(
request: Dict[str, Any],
default_sampling_params: Dict[str, Any],
) -> SamplingParams:
"""
Build SamplingParams from an OpenAI-compatible request format.
Args:
request: The OpenAI-style request dict with parameters like temperature, max_tokens, etc.
default_sampling_params: Default sampling parameters to initialize with
Returns:
SamplingParams configured from the request
"""
sampling_params = SamplingParams(**default_sampling_params)
sampling_params.detokenize = True
# Map common OpenAI parameters to SamplingParams
openai_mapping = {
"temperature": "temperature",
"top_p": "top_p",
"presence_penalty": "presence_penalty",
"frequency_penalty": "frequency_penalty",
"seed": "seed",
"top_k": "top_k",
"repetition_penalty": "repetition_penalty",
"min_p": "min_p",
"length_penalty": "length_penalty",
"use_beam_search": "use_beam_search",
}
for req_key, param_key in openai_mapping.items():
if req_key in request and request[req_key] is not None:
if hasattr(sampling_params, param_key):
setattr(sampling_params, param_key, request[req_key])
# Handle max_tokens
if "max_tokens" in request and request["max_tokens"] is not None:
sampling_params.max_tokens = request["max_tokens"]
# Handle stop sequences
if "stop" in request and request["stop"] is not None:
sampling_params.stop = request["stop"]
# Handle ignore_eos (custom extension)
if "ignore_eos" in request and request["ignore_eos"] is not None:
sampling_params.ignore_eos = request["ignore_eos"]
# Handle min_tokens (custom extension)
if "min_tokens" in request and request["min_tokens"] is not None:
sampling_params.min_tokens = request["min_tokens"]
return sampling_params
class BaseWorkerHandler(ABC):
"""
Request handler for the generate and clear_kv_blocks endpoints.
......@@ -179,6 +236,7 @@ class BaseWorkerHandler(ABC):
enable_multimodal: bool = False,
generate_endpoint=None,
config=None,
use_vllm_tokenizer: bool = False,
):
self.runtime = runtime
self.component = component
......@@ -196,6 +254,14 @@ class BaseWorkerHandler(ABC):
self.lora_id_for_name: dict[str, int] = {}
self.lora_name_to_path: dict[str, str] = {}
self.use_vllm_tokenizer = use_vllm_tokenizer
# Initialize InputParamManager for text-in-text-out mode
tokenizer = None
if use_vllm_tokenizer and hasattr(engine, "tokenizer"):
tokenizer = engine.tokenizer
self.input_param_manager = InputParamManager(tokenizer)
@abstractmethod
async def generate(self, request, context) -> AsyncGenerator[dict, None]:
raise NotImplementedError
......@@ -775,6 +841,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
enable_multimodal: bool = False,
generate_endpoint=None,
config=None,
use_vllm_tokenizer: bool = False,
):
super().__init__(
runtime,
......@@ -785,6 +852,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
enable_multimodal,
generate_endpoint,
config,
use_vllm_tokenizer,
)
async def generate(self, request, context):
......@@ -792,6 +860,17 @@ class DecodeWorkerHandler(BaseWorkerHandler):
request_id = context.id()
logger.debug(f"Decode Request ID: {request_id}")
if self.use_vllm_tokenizer:
# Text-in-text-out mode: use InputParamManager and OpenAI-compatible format
async for chunk in self._generate_text_mode(request, context, request_id):
yield chunk
else:
# Token-in-token-out mode: internal protocol format
async for chunk in self._generate_token_mode(request, context, request_id):
yield chunk
async def _generate_token_mode(self, request, context, request_id):
"""Generate tokens using internal protocol format (token-in-token-out)."""
# Extract and decode multimodal data if present
multi_modal_data = await self._extract_multimodal_data(request)
......@@ -865,6 +944,81 @@ class DecodeWorkerHandler(BaseWorkerHandler):
self.runtime.shutdown()
os._exit(1)
async def _generate_text_mode(self, request, context, request_id):
"""Generate text using OpenAI-compatible format (text-in-text-out)."""
# Get text input using InputParamManager
input_text = self.input_param_manager.get_input_param(
request, use_tokenizer=True
)
# Build prompt for vLLM
prompt = TextPrompt(prompt=input_text)
# Build sampling params from OpenAI-style request
sampling_params = build_sampling_params_openai(
request, self.default_sampling_params
)
dp_rank = request.get("dp_rank", None)
openai_request_id = request.get("id") or request.get("request_id", request_id)
previous_text = ""
async with self._abort_monitor(context, request_id):
try:
gen = self.engine_client.generate(
prompt,
sampling_params,
request_id,
data_parallel_rank=dp_rank,
)
async for res in gen:
if not res.outputs:
yield {
"id": openai_request_id,
"created": int(time.time()),
"object": "chat.completion.chunk",
"model": "unknown",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": ""},
"finish_reason": "error",
}
],
}
break
output = res.outputs[0]
# Calculate the delta text (new text since last chunk)
delta_text = output.text[len(previous_text) :]
previous_text = output.text
choice_data = {
"index": 0,
"delta": {
"role": "assistant",
"content": delta_text,
},
"finish_reason": output.finish_reason,
}
chunk = {
"id": openai_request_id,
"created": int(time.time()),
"object": "chat.completion.chunk",
"model": "unknown",
"choices": [choice_data],
}
yield chunk
except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}")
logger.warning("Initiating Dynamo Runtime shutdown.")
self.runtime.shutdown()
os._exit(1)
class PrefillWorkerHandler(BaseWorkerHandler):
def __init__(
......@@ -877,6 +1031,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
enable_multimodal: bool = False,
generate_endpoint=None,
config=None,
use_vllm_tokenizer: bool = False,
):
super().__init__(
runtime,
......@@ -887,6 +1042,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
enable_multimodal,
generate_endpoint,
config,
use_vllm_tokenizer,
)
async def generate(self, request, context):
......@@ -894,6 +1050,17 @@ class PrefillWorkerHandler(BaseWorkerHandler):
request_id = context.id()
logger.debug(f"Prefill Request ID: {request_id}")
if self.use_vllm_tokenizer:
# Text-in-text-out mode: use InputParamManager
async for chunk in self._generate_text_mode(request, context, request_id):
yield chunk
else:
# Token-in-token-out mode: internal protocol format
async for chunk in self._generate_token_mode(request, context, request_id):
yield chunk
async def _generate_token_mode(self, request, context, request_id):
"""Generate prefill using internal protocol format (token-in-token-out)."""
# Extract and decode multimodal data if present
multi_modal_data = await self._extract_multimodal_data(request)
......@@ -997,3 +1164,77 @@ class PrefillWorkerHandler(BaseWorkerHandler):
raise GeneratorExit(
"Prefill engine was shut down during token generation"
) from None
async def _generate_text_mode(self, request, context, request_id):
"""Generate prefill using OpenAI-compatible format (text-in-text-out)."""
# Get text input using InputParamManager
input_text = self.input_param_manager.get_input_param(
request, use_tokenizer=True
)
# Build prompt for vLLM
prompt = TextPrompt(prompt=input_text)
# Build sampling params from OpenAI-style request
sampling_params = build_sampling_params_openai(
request, self.default_sampling_params
)
sampling_params.detokenize = False # Prefill doesn't need detokenization
# Configure for prefill-only mode with remote decode
if sampling_params.extra_args is None:
sampling_params.extra_args = {}
sampling_params.extra_args["kv_transfer_params"] = {
"do_remote_decode": True,
}
sampling_params_defaults = {
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None,
}
# Add only missing keys
for k, v in sampling_params_defaults.items():
sampling_params.extra_args["kv_transfer_params"].setdefault(k, v)
# Override for prefill: only generate 1 token
sampling_params.max_tokens = 1
sampling_params.min_tokens = 1
dp_rank = request.get("dp_rank", None)
async with self._abort_monitor(context, request_id, is_prefill=True):
try:
gen = self.engine_client.generate(
prompt, sampling_params, request_id, data_parallel_rank=dp_rank
)
except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}")
logger.warning("Initiating Dynamo Runtime shutdown.")
self.runtime.shutdown()
os._exit(1)
try:
async for res in gen:
logger.debug(f"kv transfer params: {res.kv_transfer_params}")
token_ids = res.outputs[0].token_ids if res.outputs else []
output: Dict[str, Any] = {
"token_ids": list(token_ids),
"disaggregated_params": (
{"kv_transfer_params": res.kv_transfer_params}
if res.kv_transfer_params
else None
),
"completion_usage": BaseWorkerHandler._build_completion_usage(
request_output=res
),
}
yield output
except asyncio.CancelledError:
# raise the error because we cannot migrate prefill requests
raise GeneratorExit(
"Prefill engine was shut down during token generation"
) from None
......@@ -8,11 +8,15 @@ This module defines the default health check payload for vLLM backends.
"""
import logging
from typing import TYPE_CHECKING, Optional
from dynamo.health_check import HealthCheckPayload
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from vllm.v1.engine.async_llm import AsyncLLM
def _get_bos_token_id_from_engine(engine_client) -> int:
"""
......@@ -45,6 +49,36 @@ def _get_bos_token_id_from_engine(engine_client) -> int:
return 1
def _make_default_payload(
engine_client: Optional["AsyncLLM"], use_text_input: bool
) -> dict:
sampling_options = {
"temperature": 0.0,
}
stop_conditions = {
"max_tokens": 1,
"stop": None,
"stop_token_ids": None,
"include_stop_str_in_output": False,
"ignore_eos": False,
}
if use_text_input:
return {
"prompt": "Test",
**sampling_options,
**stop_conditions,
}
else:
bos_token_id = _get_bos_token_id_from_engine(engine_client)
return {
"token_ids": [bos_token_id],
"sampling_options": sampling_options,
"stop_conditions": stop_conditions,
}
class VllmHealthCheckPayload(HealthCheckPayload):
"""
vLLM-specific health check payload.
......@@ -52,31 +86,18 @@ class VllmHealthCheckPayload(HealthCheckPayload):
Provides vLLM defaults and inherits environment override support from base class.
"""
def __init__(self, engine_client=None):
def __init__(self, engine_client=None, use_text_input: bool = False):
"""
Initialize vLLM health check payload with vLLM-specific defaults.
Args:
engine_client: Optional vLLM AsyncLLM engine client to extract BOS token from.
If provided, will attempt to use the model's actual BOS token.
use_text_input: If True, use text-based input (prompt field) instead of token_ids.
This should match the use_vllm_tokenizer config setting.
"""
bos_token_id = _get_bos_token_id_from_engine(engine_client)
# Set vLLM default payload - minimal request that completes quickly
# The handler expects token_ids, sampling_options, and stop_conditions
self.default_payload = {
"token_ids": [bos_token_id],
"sampling_options": {
"temperature": 0.0,
},
"stop_conditions": {
"max_tokens": 1,
"stop": None,
"stop_token_ids": None,
"include_stop_str_in_output": False,
"ignore_eos": False,
},
}
self.default_payload = _make_default_payload(engine_client, use_text_input)
super().__init__()
......@@ -87,7 +108,7 @@ class VllmPrefillHealthCheckPayload(HealthCheckPayload):
The prefill handler expects PreprocessedRequest format with sampling_options and stop_conditions.
"""
def __init__(self, engine_client=None):
def __init__(self, engine_client=None, use_text_input: bool = False):
"""
Initialize vLLM prefill health check payload with proper PreprocessedRequest structure.
......@@ -95,23 +116,5 @@ class VllmPrefillHealthCheckPayload(HealthCheckPayload):
engine_client: Optional vLLM AsyncLLM engine client to extract BOS token from.
If provided, will attempt to use the model's actual BOS token.
"""
bos_token_id = _get_bos_token_id_from_engine(engine_client)
# Prefill handler expects PreprocessedRequest format: token_ids, sampling_options, stop_conditions
# The handler will override max_tokens/min_tokens to 1 and add do_remote_decode
self.default_payload = {
"token_ids": [bos_token_id],
"sampling_options": {
"temperature": 0.0,
"top_p": 1.0,
"top_k": -1,
},
"stop_conditions": {
"stop": None,
"stop_token_ids": None,
"include_stop_str_in_output": False,
"ignore_eos": False,
"min_tokens": 0,
},
}
self.default_payload = _make_default_payload(engine_client, use_text_input)
super().__init__()
......@@ -384,6 +384,7 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint,
config=config,
use_vllm_tokenizer=config.use_vllm_tokenizer,
)
handler.add_temp_dir(prometheus_temp_dir)
......@@ -418,8 +419,11 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
# Register prefill model with ModelType.Prefill
if not config.engine_args.data_parallel_rank: # if rank is 0 or None then register
model_input = (
ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
)
await register_vllm_model(
ModelInput.Tokens,
model_input,
ModelType.Prefill,
generate_endpoint,
config,
......@@ -428,7 +432,9 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
migration_limit=0, # Prefill doesn't support migration
)
health_check_payload = VllmPrefillHealthCheckPayload(engine_client).to_dict()
health_check_payload = VllmPrefillHealthCheckPayload(
engine_client, use_text_input=config.use_vllm_tokenizer
).to_dict()
try:
logger.debug("Starting serve_endpoint for prefill worker")
......@@ -497,6 +503,7 @@ async def init(runtime: DistributedRuntime, config: Config):
enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint,
config=config,
use_vllm_tokenizer=config.use_vllm_tokenizer,
)
handler.add_temp_dir(prometheus_temp_dir)
......@@ -536,6 +543,10 @@ async def init(runtime: DistributedRuntime, config: Config):
f"Registering model with endpoint types: {config.dyn_endpoint_types}"
)
model_input = (
ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
)
# Warn if custom template provided but chat endpoint not enabled
if config.custom_jinja_template and "chat" not in config.dyn_endpoint_types:
logger.warning(
......@@ -544,7 +555,7 @@ async def init(runtime: DistributedRuntime, config: Config):
)
await register_vllm_model(
ModelInput.Tokens,
model_input,
model_type,
generate_endpoint,
config,
......@@ -553,7 +564,9 @@ async def init(runtime: DistributedRuntime, config: Config):
migration_limit=config.migration_limit,
)
health_check_payload = VllmHealthCheckPayload(engine_client).to_dict()
health_check_payload = VllmHealthCheckPayload(
engine_client, use_text_input=config.use_vllm_tokenizer
).to_dict()
try:
logger.debug("Starting serve_endpoint for decode worker")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment