Unverified Commit fe10dbfd authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: vLLM pre/postprocessing in-framework (#4529)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
parent 01819b87
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
class InputParamManager:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def get_input_param(self, request: dict, use_tokenizer: bool):
"""
Get the input parameter for the request.
"""
if use_tokenizer:
print(f"Request: {request}")
if self.tokenizer is None:
raise ValueError("Tokenizer is not available")
if "messages" in request:
return self.tokenizer.apply_chat_template(
request["messages"], tokenize=False, add_generation_prompt=True
)
elif "prompt" in request:
return request["prompt"]
elif "text" in request:
return request["text"]
else:
raise ValueError("No input parameter found in request")
return request.get("token_ids")
...@@ -70,7 +70,7 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = { ...@@ -70,7 +70,7 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = {
"flags": ["--use-sglang-tokenizer"], "flags": ["--use-sglang-tokenizer"],
"action": "store_true", "action": "store_true",
"default": False, "default": False,
"help": "Use SGLang's tokenizer. This will skip tokenization of the input and output and only v1/chat/completions will be available when using the dynamo frontend. Cannot be used with --custom-jinja-template.", "help": "Use SGLang's tokenizer for pre and post processing. This bypasses Dynamo's preprocessor and only v1/chat/completions will be available through the Dynamo frontend. Cannot be used with --custom-jinja-template.",
}, },
"multimodal-processor": { "multimodal-processor": {
"flags": ["--multimodal-processor"], "flags": ["--multimodal-processor"],
......
...@@ -53,7 +53,9 @@ class SglangHealthCheckPayload(HealthCheckPayload): ...@@ -53,7 +53,9 @@ class SglangHealthCheckPayload(HealthCheckPayload):
Provides SGLang defaults and inherits environment override support from base class. Provides SGLang defaults and inherits environment override support from base class.
""" """
def __init__(self, engine: Optional[sgl.Engine] = None) -> None: def __init__(
self, engine: Optional[sgl.Engine] = None, use_text_input: bool = False
) -> None:
"""Initialize SGLang health check payload with model-specific BOS token. """Initialize SGLang health check payload with model-specific BOS token.
Args: Args:
...@@ -62,7 +64,6 @@ class SglangHealthCheckPayload(HealthCheckPayload): ...@@ -62,7 +64,6 @@ class SglangHealthCheckPayload(HealthCheckPayload):
bos_token_id = _get_bos_token_id_from_engine(engine) bos_token_id = _get_bos_token_id_from_engine(engine)
self.default_payload = { self.default_payload = {
"token_ids": [bos_token_id],
"stop_conditions": { "stop_conditions": {
"max_tokens": 1, # Generate only 1 token "max_tokens": 1, # Generate only 1 token
"ignore_eos": False, "ignore_eos": False,
...@@ -75,6 +76,12 @@ class SglangHealthCheckPayload(HealthCheckPayload): ...@@ -75,6 +76,12 @@ class SglangHealthCheckPayload(HealthCheckPayload):
"eos_token_ids": [], "eos_token_ids": [],
"annotations": [], "annotations": [],
} }
if use_text_input:
self.default_payload["prompt"] = "Test"
else:
self.default_payload["token_ids"] = [bos_token_id]
super().__init__() super().__init__()
...@@ -84,7 +91,9 @@ class SglangPrefillHealthCheckPayload(HealthCheckPayload): ...@@ -84,7 +91,9 @@ class SglangPrefillHealthCheckPayload(HealthCheckPayload):
The prefill handler expects a wrapped structure with 'request' and 'sampling_params'. The prefill handler expects a wrapped structure with 'request' and 'sampling_params'.
""" """
def __init__(self, engine: Optional[sgl.Engine] = None) -> None: def __init__(
self, engine: Optional[sgl.Engine] = None, use_text_input: bool = False
) -> None:
"""Initialize SGLang prefill health check payload with proper wrapped structure. """Initialize SGLang prefill health check payload with proper wrapped structure.
Args: Args:
...@@ -93,9 +102,7 @@ class SglangPrefillHealthCheckPayload(HealthCheckPayload): ...@@ -93,9 +102,7 @@ class SglangPrefillHealthCheckPayload(HealthCheckPayload):
bos_token_id = _get_bos_token_id_from_engine(engine) bos_token_id = _get_bos_token_id_from_engine(engine)
self.default_payload = { self.default_payload = {
"request": { "request": {},
"token_ids": [bos_token_id],
},
"sampling_params": { "sampling_params": {
"max_new_tokens": 1, # Generate only 1 token "max_new_tokens": 1, # Generate only 1 token
"temperature": 0.0, "temperature": 0.0,
...@@ -104,4 +111,10 @@ class SglangPrefillHealthCheckPayload(HealthCheckPayload): ...@@ -104,4 +111,10 @@ class SglangPrefillHealthCheckPayload(HealthCheckPayload):
"ignore_eos": False, "ignore_eos": False,
}, },
} }
if use_text_input:
self.default_payload["request"]["prompt"] = "Test" # type: ignore
else:
self.default_payload["request"]["token_ids"] = [bos_token_id] # type: ignore
super().__init__() super().__init__()
...@@ -168,8 +168,10 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -168,8 +168,10 @@ async def init(runtime: DistributedRuntime, config: Config):
handler = DecodeWorkerHandler( handler = DecodeWorkerHandler(
component, engine, config, publisher, prefill_client, prefill_router_client component, engine, config, publisher, prefill_client, prefill_router_client
) )
print(f"Config: {config}")
health_check_payload = SglangHealthCheckPayload(engine).to_dict() health_check_payload = SglangHealthCheckPayload(
engine, use_text_input=dynamo_args.use_sglang_tokenizer
).to_dict()
logging.info( logging.info(
f"Registering model with endpoint types: {dynamo_args.dyn_endpoint_types}" f"Registering model with endpoint types: {dynamo_args.dyn_endpoint_types}"
...@@ -319,7 +321,9 @@ async def init_embedding(runtime: DistributedRuntime, config: Config): ...@@ -319,7 +321,9 @@ async def init_embedding(runtime: DistributedRuntime, config: Config):
ready_event = asyncio.Event() ready_event = asyncio.Event()
handler = EmbeddingWorkerHandler(component, engine, config, publisher) handler = EmbeddingWorkerHandler(component, engine, config, publisher)
health_check_payload = SglangHealthCheckPayload(engine).to_dict() health_check_payload = SglangHealthCheckPayload(
engine, use_text_input=dynamo_args.use_sglang_tokenizer
).to_dict()
try: try:
# Start endpoint immediately and register model concurrently # Start endpoint immediately and register model concurrently
......
...@@ -16,6 +16,7 @@ from sglang.srt.tracing import trace as sglang_trace ...@@ -16,6 +16,7 @@ from sglang.srt.tracing import trace as sglang_trace
from sglang.srt.utils import get_local_ip_auto from sglang.srt.utils import get_local_ip_auto
from dynamo._core import Client, Component, Context from dynamo._core import Client, Component, Context
from dynamo.common.utils.input_params import InputParamManager
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.publisher import DynamoSglangPublisher
...@@ -54,6 +55,12 @@ class BaseWorkerHandler(ABC): ...@@ -54,6 +55,12 @@ class BaseWorkerHandler(ABC):
self.skip_tokenizer_init = config.server_args.skip_tokenizer_init self.skip_tokenizer_init = config.server_args.skip_tokenizer_init
self.enable_trace = config.server_args.enable_trace self.enable_trace = config.server_args.enable_trace
self.input_param_manager = InputParamManager(
self.engine.tokenizer_manager.tokenizer
if not self.skip_tokenizer_init
else None
)
@abstractmethod @abstractmethod
async def generate(self, request: Dict[str, Any], context: Context): async def generate(self, request: Dict[str, Any], context: Context):
"""Generate response from request. """Generate response from request.
...@@ -72,23 +79,13 @@ class BaseWorkerHandler(ABC): ...@@ -72,23 +79,13 @@ class BaseWorkerHandler(ABC):
pass pass
def _get_input_param(self, request: Dict[str, Any]) -> Dict[str, Any]: def _get_input_param(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""Get the appropriate input parameter for SGLang engine. request_input = self.input_param_manager.get_input_param(
request, use_tokenizer=not self.skip_tokenizer_init
Args:
request: Request dict with token_ids or messages.
Returns:
Dict with either input_ids or prompt for engine.
"""
if self.skip_tokenizer_init:
return {"input_ids": request["token_ids"]}
else:
# use sglang's chat templating itself but leave tokenization to the
# interal engine's TokenizerManager
prompt = self.engine.tokenizer_manager.tokenizer.apply_chat_template(
request["messages"], tokenize=False, add_generation_prompt=True
) )
return {"prompt": prompt}
return {
"prompt" if isinstance(request_input, str) else "input_ids": request_input
}
@staticmethod @staticmethod
def _generate_bootstrap_room() -> int: def _generate_bootstrap_room() -> int:
......
...@@ -69,6 +69,9 @@ class Config: ...@@ -69,6 +69,9 @@ class Config:
# dump config to file # dump config to file
dump_config_to: Optional[str] = None dump_config_to: Optional[str] = None
# Use vLLM's tokenizer for pre/post processing
use_vllm_tokenizer: bool = False
def has_connector(self, connector_name: str) -> bool: def has_connector(self, connector_name: str) -> bool:
""" """
Check if a specific connector is enabled. Check if a specific connector is enabled.
...@@ -201,6 +204,12 @@ def parse_args() -> Config: ...@@ -201,6 +204,12 @@ def parse_args() -> Config:
default=os.environ.get("DYN_REQUEST_PLANE", "nats"), default=os.environ.get("DYN_REQUEST_PLANE", "nats"),
help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]", help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]",
) )
parser.add_argument(
"--use-vllm-tokenizer",
action="store_true",
default=False,
help="Use vLLM's tokenizer for pre and post processing. This bypasses Dynamo's preprocessor and only v1/chat/completions will be available through the Dynamo frontend.",
)
add_config_dump_args(parser) add_config_dump_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
...@@ -303,6 +312,7 @@ def parse_args() -> Config: ...@@ -303,6 +312,7 @@ def parse_args() -> Config:
config.mm_prompt_template = args.mm_prompt_template config.mm_prompt_template = args.mm_prompt_template
config.store_kv = args.store_kv config.store_kv = args.store_kv
config.request_plane = args.request_plane config.request_plane = args.request_plane
config.use_vllm_tokenizer = args.use_vllm_tokenizer
# Validate custom Jinja template file exists if provided # Validate custom Jinja template file exists if provided
if config.custom_jinja_template is not None: if config.custom_jinja_template is not None:
......
...@@ -5,16 +5,18 @@ import asyncio ...@@ -5,16 +5,18 @@ import asyncio
import logging import logging
import os import os
import tempfile import tempfile
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Final from typing import Any, AsyncGenerator, Dict, Final
from vllm.inputs import TokensPrompt from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.engine.exceptions import EngineDeadError
from dynamo.common.utils.input_params import InputParamManager
from dynamo.llm import ( from dynamo.llm import (
ModelInput, ModelInput,
ModelType, ModelType,
...@@ -70,7 +72,7 @@ def build_sampling_params( ...@@ -70,7 +72,7 @@ def build_sampling_params(
model_max_len: int | None = None, model_max_len: int | None = None,
) -> SamplingParams: ) -> SamplingParams:
""" """
Build SamplingParams from a PreprocessedRequest. Build SamplingParams from a PreprocessedRequest (internal protocol format).
Args: Args:
request: The PreprocessedRequest dict with 'sampling_options', 'stop_conditions', request: The PreprocessedRequest dict with 'sampling_options', 'stop_conditions',
...@@ -164,6 +166,61 @@ def build_sampling_params( ...@@ -164,6 +166,61 @@ def build_sampling_params(
return sampling_params return sampling_params
def build_sampling_params_openai(
request: Dict[str, Any],
default_sampling_params: Dict[str, Any],
) -> SamplingParams:
"""
Build SamplingParams from an OpenAI-compatible request format.
Args:
request: The OpenAI-style request dict with parameters like temperature, max_tokens, etc.
default_sampling_params: Default sampling parameters to initialize with
Returns:
SamplingParams configured from the request
"""
sampling_params = SamplingParams(**default_sampling_params)
sampling_params.detokenize = True
# Map common OpenAI parameters to SamplingParams
openai_mapping = {
"temperature": "temperature",
"top_p": "top_p",
"presence_penalty": "presence_penalty",
"frequency_penalty": "frequency_penalty",
"seed": "seed",
"top_k": "top_k",
"repetition_penalty": "repetition_penalty",
"min_p": "min_p",
"length_penalty": "length_penalty",
"use_beam_search": "use_beam_search",
}
for req_key, param_key in openai_mapping.items():
if req_key in request and request[req_key] is not None:
if hasattr(sampling_params, param_key):
setattr(sampling_params, param_key, request[req_key])
# Handle max_tokens
if "max_tokens" in request and request["max_tokens"] is not None:
sampling_params.max_tokens = request["max_tokens"]
# Handle stop sequences
if "stop" in request and request["stop"] is not None:
sampling_params.stop = request["stop"]
# Handle ignore_eos (custom extension)
if "ignore_eos" in request and request["ignore_eos"] is not None:
sampling_params.ignore_eos = request["ignore_eos"]
# Handle min_tokens (custom extension)
if "min_tokens" in request and request["min_tokens"] is not None:
sampling_params.min_tokens = request["min_tokens"]
return sampling_params
class BaseWorkerHandler(ABC): class BaseWorkerHandler(ABC):
""" """
Request handler for the generate and clear_kv_blocks endpoints. Request handler for the generate and clear_kv_blocks endpoints.
...@@ -179,6 +236,7 @@ class BaseWorkerHandler(ABC): ...@@ -179,6 +236,7 @@ class BaseWorkerHandler(ABC):
enable_multimodal: bool = False, enable_multimodal: bool = False,
generate_endpoint=None, generate_endpoint=None,
config=None, config=None,
use_vllm_tokenizer: bool = False,
): ):
self.runtime = runtime self.runtime = runtime
self.component = component self.component = component
...@@ -196,6 +254,14 @@ class BaseWorkerHandler(ABC): ...@@ -196,6 +254,14 @@ class BaseWorkerHandler(ABC):
self.lora_id_for_name: dict[str, int] = {} self.lora_id_for_name: dict[str, int] = {}
self.lora_name_to_path: dict[str, str] = {} self.lora_name_to_path: dict[str, str] = {}
self.use_vllm_tokenizer = use_vllm_tokenizer
# Initialize InputParamManager for text-in-text-out mode
tokenizer = None
if use_vllm_tokenizer and hasattr(engine, "tokenizer"):
tokenizer = engine.tokenizer
self.input_param_manager = InputParamManager(tokenizer)
@abstractmethod @abstractmethod
async def generate(self, request, context) -> AsyncGenerator[dict, None]: async def generate(self, request, context) -> AsyncGenerator[dict, None]:
raise NotImplementedError raise NotImplementedError
...@@ -775,6 +841,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -775,6 +841,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
enable_multimodal: bool = False, enable_multimodal: bool = False,
generate_endpoint=None, generate_endpoint=None,
config=None, config=None,
use_vllm_tokenizer: bool = False,
): ):
super().__init__( super().__init__(
runtime, runtime,
...@@ -785,6 +852,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -785,6 +852,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
enable_multimodal, enable_multimodal,
generate_endpoint, generate_endpoint,
config, config,
use_vllm_tokenizer,
) )
async def generate(self, request, context): async def generate(self, request, context):
...@@ -792,6 +860,17 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -792,6 +860,17 @@ class DecodeWorkerHandler(BaseWorkerHandler):
request_id = context.id() request_id = context.id()
logger.debug(f"Decode Request ID: {request_id}") logger.debug(f"Decode Request ID: {request_id}")
if self.use_vllm_tokenizer:
# Text-in-text-out mode: use InputParamManager and OpenAI-compatible format
async for chunk in self._generate_text_mode(request, context, request_id):
yield chunk
else:
# Token-in-token-out mode: internal protocol format
async for chunk in self._generate_token_mode(request, context, request_id):
yield chunk
async def _generate_token_mode(self, request, context, request_id):
"""Generate tokens using internal protocol format (token-in-token-out)."""
# Extract and decode multimodal data if present # Extract and decode multimodal data if present
multi_modal_data = await self._extract_multimodal_data(request) multi_modal_data = await self._extract_multimodal_data(request)
...@@ -865,6 +944,81 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -865,6 +944,81 @@ class DecodeWorkerHandler(BaseWorkerHandler):
self.runtime.shutdown() self.runtime.shutdown()
os._exit(1) os._exit(1)
async def _generate_text_mode(self, request, context, request_id):
"""Generate text using OpenAI-compatible format (text-in-text-out)."""
# Get text input using InputParamManager
input_text = self.input_param_manager.get_input_param(
request, use_tokenizer=True
)
# Build prompt for vLLM
prompt = TextPrompt(prompt=input_text)
# Build sampling params from OpenAI-style request
sampling_params = build_sampling_params_openai(
request, self.default_sampling_params
)
dp_rank = request.get("dp_rank", None)
openai_request_id = request.get("id") or request.get("request_id", request_id)
previous_text = ""
async with self._abort_monitor(context, request_id):
try:
gen = self.engine_client.generate(
prompt,
sampling_params,
request_id,
data_parallel_rank=dp_rank,
)
async for res in gen:
if not res.outputs:
yield {
"id": openai_request_id,
"created": int(time.time()),
"object": "chat.completion.chunk",
"model": "unknown",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": ""},
"finish_reason": "error",
}
],
}
break
output = res.outputs[0]
# Calculate the delta text (new text since last chunk)
delta_text = output.text[len(previous_text) :]
previous_text = output.text
choice_data = {
"index": 0,
"delta": {
"role": "assistant",
"content": delta_text,
},
"finish_reason": output.finish_reason,
}
chunk = {
"id": openai_request_id,
"created": int(time.time()),
"object": "chat.completion.chunk",
"model": "unknown",
"choices": [choice_data],
}
yield chunk
except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}")
logger.warning("Initiating Dynamo Runtime shutdown.")
self.runtime.shutdown()
os._exit(1)
class PrefillWorkerHandler(BaseWorkerHandler): class PrefillWorkerHandler(BaseWorkerHandler):
def __init__( def __init__(
...@@ -877,6 +1031,7 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -877,6 +1031,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
enable_multimodal: bool = False, enable_multimodal: bool = False,
generate_endpoint=None, generate_endpoint=None,
config=None, config=None,
use_vllm_tokenizer: bool = False,
): ):
super().__init__( super().__init__(
runtime, runtime,
...@@ -887,6 +1042,7 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -887,6 +1042,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
enable_multimodal, enable_multimodal,
generate_endpoint, generate_endpoint,
config, config,
use_vllm_tokenizer,
) )
async def generate(self, request, context): async def generate(self, request, context):
...@@ -894,6 +1050,17 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -894,6 +1050,17 @@ class PrefillWorkerHandler(BaseWorkerHandler):
request_id = context.id() request_id = context.id()
logger.debug(f"Prefill Request ID: {request_id}") logger.debug(f"Prefill Request ID: {request_id}")
if self.use_vllm_tokenizer:
# Text-in-text-out mode: use InputParamManager
async for chunk in self._generate_text_mode(request, context, request_id):
yield chunk
else:
# Token-in-token-out mode: internal protocol format
async for chunk in self._generate_token_mode(request, context, request_id):
yield chunk
async def _generate_token_mode(self, request, context, request_id):
"""Generate prefill using internal protocol format (token-in-token-out)."""
# Extract and decode multimodal data if present # Extract and decode multimodal data if present
multi_modal_data = await self._extract_multimodal_data(request) multi_modal_data = await self._extract_multimodal_data(request)
...@@ -997,3 +1164,77 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -997,3 +1164,77 @@ class PrefillWorkerHandler(BaseWorkerHandler):
raise GeneratorExit( raise GeneratorExit(
"Prefill engine was shut down during token generation" "Prefill engine was shut down during token generation"
) from None ) from None
async def _generate_text_mode(self, request, context, request_id):
"""Generate prefill using OpenAI-compatible format (text-in-text-out)."""
# Get text input using InputParamManager
input_text = self.input_param_manager.get_input_param(
request, use_tokenizer=True
)
# Build prompt for vLLM
prompt = TextPrompt(prompt=input_text)
# Build sampling params from OpenAI-style request
sampling_params = build_sampling_params_openai(
request, self.default_sampling_params
)
sampling_params.detokenize = False # Prefill doesn't need detokenization
# Configure for prefill-only mode with remote decode
if sampling_params.extra_args is None:
sampling_params.extra_args = {}
sampling_params.extra_args["kv_transfer_params"] = {
"do_remote_decode": True,
}
sampling_params_defaults = {
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None,
}
# Add only missing keys
for k, v in sampling_params_defaults.items():
sampling_params.extra_args["kv_transfer_params"].setdefault(k, v)
# Override for prefill: only generate 1 token
sampling_params.max_tokens = 1
sampling_params.min_tokens = 1
dp_rank = request.get("dp_rank", None)
async with self._abort_monitor(context, request_id, is_prefill=True):
try:
gen = self.engine_client.generate(
prompt, sampling_params, request_id, data_parallel_rank=dp_rank
)
except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}")
logger.warning("Initiating Dynamo Runtime shutdown.")
self.runtime.shutdown()
os._exit(1)
try:
async for res in gen:
logger.debug(f"kv transfer params: {res.kv_transfer_params}")
token_ids = res.outputs[0].token_ids if res.outputs else []
output: Dict[str, Any] = {
"token_ids": list(token_ids),
"disaggregated_params": (
{"kv_transfer_params": res.kv_transfer_params}
if res.kv_transfer_params
else None
),
"completion_usage": BaseWorkerHandler._build_completion_usage(
request_output=res
),
}
yield output
except asyncio.CancelledError:
# raise the error because we cannot migrate prefill requests
raise GeneratorExit(
"Prefill engine was shut down during token generation"
) from None
...@@ -8,11 +8,15 @@ This module defines the default health check payload for vLLM backends. ...@@ -8,11 +8,15 @@ This module defines the default health check payload for vLLM backends.
""" """
import logging import logging
from typing import TYPE_CHECKING, Optional
from dynamo.health_check import HealthCheckPayload from dynamo.health_check import HealthCheckPayload
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from vllm.v1.engine.async_llm import AsyncLLM
def _get_bos_token_id_from_engine(engine_client) -> int: def _get_bos_token_id_from_engine(engine_client) -> int:
""" """
...@@ -45,6 +49,36 @@ def _get_bos_token_id_from_engine(engine_client) -> int: ...@@ -45,6 +49,36 @@ def _get_bos_token_id_from_engine(engine_client) -> int:
return 1 return 1
def _make_default_payload(
engine_client: Optional["AsyncLLM"], use_text_input: bool
) -> dict:
sampling_options = {
"temperature": 0.0,
}
stop_conditions = {
"max_tokens": 1,
"stop": None,
"stop_token_ids": None,
"include_stop_str_in_output": False,
"ignore_eos": False,
}
if use_text_input:
return {
"prompt": "Test",
**sampling_options,
**stop_conditions,
}
else:
bos_token_id = _get_bos_token_id_from_engine(engine_client)
return {
"token_ids": [bos_token_id],
"sampling_options": sampling_options,
"stop_conditions": stop_conditions,
}
class VllmHealthCheckPayload(HealthCheckPayload): class VllmHealthCheckPayload(HealthCheckPayload):
""" """
vLLM-specific health check payload. vLLM-specific health check payload.
...@@ -52,31 +86,18 @@ class VllmHealthCheckPayload(HealthCheckPayload): ...@@ -52,31 +86,18 @@ class VllmHealthCheckPayload(HealthCheckPayload):
Provides vLLM defaults and inherits environment override support from base class. Provides vLLM defaults and inherits environment override support from base class.
""" """
def __init__(self, engine_client=None): def __init__(self, engine_client=None, use_text_input: bool = False):
""" """
Initialize vLLM health check payload with vLLM-specific defaults. Initialize vLLM health check payload with vLLM-specific defaults.
Args: Args:
engine_client: Optional vLLM AsyncLLM engine client to extract BOS token from. engine_client: Optional vLLM AsyncLLM engine client to extract BOS token from.
If provided, will attempt to use the model's actual BOS token. If provided, will attempt to use the model's actual BOS token.
use_text_input: If True, use text-based input (prompt field) instead of token_ids.
This should match the use_vllm_tokenizer config setting.
""" """
bos_token_id = _get_bos_token_id_from_engine(engine_client)
# Set vLLM default payload - minimal request that completes quickly self.default_payload = _make_default_payload(engine_client, use_text_input)
# The handler expects token_ids, sampling_options, and stop_conditions
self.default_payload = {
"token_ids": [bos_token_id],
"sampling_options": {
"temperature": 0.0,
},
"stop_conditions": {
"max_tokens": 1,
"stop": None,
"stop_token_ids": None,
"include_stop_str_in_output": False,
"ignore_eos": False,
},
}
super().__init__() super().__init__()
...@@ -87,7 +108,7 @@ class VllmPrefillHealthCheckPayload(HealthCheckPayload): ...@@ -87,7 +108,7 @@ class VllmPrefillHealthCheckPayload(HealthCheckPayload):
The prefill handler expects PreprocessedRequest format with sampling_options and stop_conditions. The prefill handler expects PreprocessedRequest format with sampling_options and stop_conditions.
""" """
def __init__(self, engine_client=None): def __init__(self, engine_client=None, use_text_input: bool = False):
""" """
Initialize vLLM prefill health check payload with proper PreprocessedRequest structure. Initialize vLLM prefill health check payload with proper PreprocessedRequest structure.
...@@ -95,23 +116,5 @@ class VllmPrefillHealthCheckPayload(HealthCheckPayload): ...@@ -95,23 +116,5 @@ class VllmPrefillHealthCheckPayload(HealthCheckPayload):
engine_client: Optional vLLM AsyncLLM engine client to extract BOS token from. engine_client: Optional vLLM AsyncLLM engine client to extract BOS token from.
If provided, will attempt to use the model's actual BOS token. If provided, will attempt to use the model's actual BOS token.
""" """
bos_token_id = _get_bos_token_id_from_engine(engine_client) self.default_payload = _make_default_payload(engine_client, use_text_input)
# Prefill handler expects PreprocessedRequest format: token_ids, sampling_options, stop_conditions
# The handler will override max_tokens/min_tokens to 1 and add do_remote_decode
self.default_payload = {
"token_ids": [bos_token_id],
"sampling_options": {
"temperature": 0.0,
"top_p": 1.0,
"top_k": -1,
},
"stop_conditions": {
"stop": None,
"stop_token_ids": None,
"include_stop_str_in_output": False,
"ignore_eos": False,
"min_tokens": 0,
},
}
super().__init__() super().__init__()
...@@ -384,6 +384,7 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -384,6 +384,7 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
enable_multimodal=config.enable_multimodal, enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint, generate_endpoint=generate_endpoint,
config=config, config=config,
use_vllm_tokenizer=config.use_vllm_tokenizer,
) )
handler.add_temp_dir(prometheus_temp_dir) handler.add_temp_dir(prometheus_temp_dir)
...@@ -418,8 +419,11 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -418,8 +419,11 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
# Register prefill model with ModelType.Prefill # Register prefill model with ModelType.Prefill
if not config.engine_args.data_parallel_rank: # if rank is 0 or None then register if not config.engine_args.data_parallel_rank: # if rank is 0 or None then register
model_input = (
ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
)
await register_vllm_model( await register_vllm_model(
ModelInput.Tokens, model_input,
ModelType.Prefill, ModelType.Prefill,
generate_endpoint, generate_endpoint,
config, config,
...@@ -428,7 +432,9 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -428,7 +432,9 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
migration_limit=0, # Prefill doesn't support migration migration_limit=0, # Prefill doesn't support migration
) )
health_check_payload = VllmPrefillHealthCheckPayload(engine_client).to_dict() health_check_payload = VllmPrefillHealthCheckPayload(
engine_client, use_text_input=config.use_vllm_tokenizer
).to_dict()
try: try:
logger.debug("Starting serve_endpoint for prefill worker") logger.debug("Starting serve_endpoint for prefill worker")
...@@ -497,6 +503,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -497,6 +503,7 @@ async def init(runtime: DistributedRuntime, config: Config):
enable_multimodal=config.enable_multimodal, enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint, generate_endpoint=generate_endpoint,
config=config, config=config,
use_vllm_tokenizer=config.use_vllm_tokenizer,
) )
handler.add_temp_dir(prometheus_temp_dir) handler.add_temp_dir(prometheus_temp_dir)
...@@ -536,6 +543,10 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -536,6 +543,10 @@ async def init(runtime: DistributedRuntime, config: Config):
f"Registering model with endpoint types: {config.dyn_endpoint_types}" f"Registering model with endpoint types: {config.dyn_endpoint_types}"
) )
model_input = (
ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
)
# Warn if custom template provided but chat endpoint not enabled # Warn if custom template provided but chat endpoint not enabled
if config.custom_jinja_template and "chat" not in config.dyn_endpoint_types: if config.custom_jinja_template and "chat" not in config.dyn_endpoint_types:
logger.warning( logger.warning(
...@@ -544,7 +555,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -544,7 +555,7 @@ async def init(runtime: DistributedRuntime, config: Config):
) )
await register_vllm_model( await register_vllm_model(
ModelInput.Tokens, model_input,
model_type, model_type,
generate_endpoint, generate_endpoint,
config, config,
...@@ -553,7 +564,9 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -553,7 +564,9 @@ async def init(runtime: DistributedRuntime, config: Config):
migration_limit=config.migration_limit, migration_limit=config.migration_limit,
) )
health_check_payload = VllmHealthCheckPayload(engine_client).to_dict() health_check_payload = VllmHealthCheckPayload(
engine_client, use_text_input=config.use_vllm_tokenizer
).to_dict()
try: try:
logger.debug("Starting serve_endpoint for decode worker") logger.debug("Starting serve_endpoint for decode worker")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment