Unverified Commit 4f99451b authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(frontend): Use vllm for pre and post processing (#5544)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 93a27308
...@@ -21,6 +21,7 @@ import logging ...@@ -21,6 +21,7 @@ import logging
import os import os
import pathlib import pathlib
import signal import signal
import sys
import uvloop import uvloop
...@@ -30,8 +31,6 @@ from dynamo.llm import ( ...@@ -30,8 +31,6 @@ from dynamo.llm import (
EngineType, EngineType,
EntrypointArgs, EntrypointArgs,
KvRouterConfig, KvRouterConfig,
ModelDeploymentCard,
PythonAsyncEngine,
RouterConfig, RouterConfig,
RouterMode, RouterMode,
make_engine, make_engine,
...@@ -48,19 +47,18 @@ configure_dynamo_logging() ...@@ -48,19 +47,18 @@ configure_dynamo_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def _dummy_generator(request): def setup_engine_factory(
"""Minimal generator that yields nothing. Work in progress.""" runtime: DistributedRuntime,
return router_config: RouterConfig,
yield # Makes this an async generator flags: argparse.Namespace,
): # Returns EngineFactory:
async def engine_factory(mdc: ModelDeploymentCard) -> PythonAsyncEngine:
""" """
Called by Rust when a model is discovered. When using vllm pre and post processor, create the EngineFactory that
creates the engines that run requests.
""" """
loop = asyncio.get_running_loop() from .vllm_processor import EngineFactory
logger.info(f"Engine_factory called with MDC: {mdc.to_json_str()[:100]}...")
return PythonAsyncEngine(_dummy_generator, loop) return EngineFactory(runtime, router_config, flags)
def validate_model_name(value): def validate_model_name(value):
...@@ -87,10 +85,36 @@ def parse_args(): ...@@ -87,10 +85,36 @@ def parse_args():
Returns: Returns:
argparse.Namespace: Parsed command-line arguments. argparse.Namespace: Parsed command-line arguments.
""" """
parser = argparse.ArgumentParser(
description="Dynamo Frontend: HTTP+Pre-processor+Router", # We need to know before we parse the arguments
formatter_class=argparse.RawTextHelpFormatter, # To preserve multi-line help formatting full_args = " ".join(sys.argv)
is_vllm = (
"--chat-processor vllm" in full_args or "--chat-processor=vllm" in full_args
) )
if not is_vllm:
# Normal case, Dynamo processor
parser = argparse.ArgumentParser(
description="Dynamo Frontend: HTTP+Pre-processor+Router",
formatter_class=argparse.RawTextHelpFormatter, # To preserve multi-line help formatting
)
else:
# vllm processor
try:
from vllm.utils import FlexibleArgumentParser
except ImportError:
try:
from vllm.utils.argparse_utils import FlexibleArgumentParser
except ModuleNotFoundError:
logger.exception(
"Flag '--chat-processor vllm' requires vllm be installed."
)
sys.exit(1)
parser = FlexibleArgumentParser(
description="Dynamo Frontend: HTTP+Pre-processor+Router",
)
parser.add_argument( parser.add_argument(
"--version", action="version", version=f"Dynamo Frontend {__version__}" "--version", action="version", version=f"Dynamo Frontend {__version__}"
) )
...@@ -307,11 +331,23 @@ def parse_args(): ...@@ -307,11 +331,23 @@ def parse_args():
help="Determines how events are published [nats|zmq]", help="Determines how events are published [nats|zmq]",
) )
parser.add_argument( parser.add_argument(
"--exp-python-factory", "--chat-processor",
action="store_true", dest="chat_processor",
default=False, type=str,
help="[EXPERIMENTAL] Enable Python-based engine factory. When set, engines will be created via a Python callback instead of the default Rust pipeline.", choices=["dynamo", "vllm"],
default="dynamo",
help="[EXPERIMENTAL] When set to 'vllm', use local vllm for the pre and post processor.",
) )
if is_vllm:
try:
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.cli_args import FrontendArgs
parser = FrontendArgs.add_cli_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser)
except ModuleNotFoundError:
logger.exception("Flag '--chat-processor vllm' requires vllm be installed.")
sys.exit(1)
flags = parser.parse_args() flags = parser.parse_args()
...@@ -408,19 +444,19 @@ async def async_main(): ...@@ -408,19 +444,19 @@ async def async_main():
router_mode = RouterMode.RoundRobin router_mode = RouterMode.RoundRobin
kv_router_config = None kv_router_config = None
router_config = RouterConfig(
router_mode,
kv_router_config,
active_decode_blocks_threshold=flags.active_decode_blocks_threshold,
active_prefill_tokens_threshold=flags.active_prefill_tokens_threshold,
active_prefill_tokens_threshold_frac=flags.active_prefill_tokens_threshold_frac,
enforce_disagg=flags.enforce_disagg,
)
kwargs = { kwargs = {
"http_host": flags.http_host, "http_host": flags.http_host,
"http_port": flags.http_port, "http_port": flags.http_port,
"kv_cache_block_size": flags.kv_cache_block_size, "kv_cache_block_size": flags.kv_cache_block_size,
"router_config": RouterConfig( "router_config": router_config,
router_mode,
kv_router_config,
active_decode_blocks_threshold=flags.active_decode_blocks_threshold,
active_prefill_tokens_threshold=flags.active_prefill_tokens_threshold,
active_prefill_tokens_threshold_frac=flags.active_prefill_tokens_threshold_frac,
enforce_disagg=flags.enforce_disagg,
),
"migration_limit": flags.migration_limit,
} }
if flags.model_name: if flags.model_name:
...@@ -436,8 +472,11 @@ async def async_main(): ...@@ -436,8 +472,11 @@ async def async_main():
if flags.kserve_grpc_server and flags.grpc_metrics_port: if flags.kserve_grpc_server and flags.grpc_metrics_port:
kwargs["http_metrics_port"] = flags.grpc_metrics_port kwargs["http_metrics_port"] = flags.grpc_metrics_port
if flags.exp_python_factory: if flags.chat_processor == "vllm":
kwargs["engine_factory"] = engine_factory chat_engine_factory = setup_engine_factory(
runtime, router_config, flags
).chat_engine_factory
kwargs["chat_engine_factory"] = chat_engine_factory
e = EntrypointArgs(EngineType.Dynamic, **kwargs) e = EntrypointArgs(EngineType.Dynamic, **kwargs)
engine = await make_engine(runtime, e) engine = await make_engine(runtime, e)
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.engine.protocol import DeltaMessage, DeltaToolCall
from vllm.reasoning import ReasoningParser
from vllm.sampling_params import SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser
@dataclass
class PreprocessResult:
request_for_sampling: ChatCompletionRequest
tool_parser: ToolParser | None
chat_template_kwargs: dict[str, Any]
engine_prompt: dict[str, Any]
prompt_token_ids: list[int]
def _materialize_assistant_tool_calls(
messages: Sequence[Any],
) -> list[dict[str, Any] | Any]:
# Mistral chat templating expects assistant tool_calls to be materialized
# as a concrete list of dict-like values. Our validated message models may
# still carry non-list sequence-like containers here, which can break or
# mis-render when tokenize=True is used in-template. This helper converts
# model objects to dicts and normalizes assistant.tool_calls to list when
# possible, while preserving original values if they are not iterable.
normalized: list[dict[str, Any] | Any] = []
for message in messages:
if hasattr(message, "model_dump"):
msg: dict[str, Any] | Any = message.model_dump(exclude_none=False)
else:
msg = message
if isinstance(msg, dict) and msg.get("role") == "assistant":
tool_calls = msg.get("tool_calls")
if tool_calls is not None and not isinstance(tool_calls, list):
try:
msg["tool_calls"] = list(tool_calls)
except TypeError:
# Keep original object if it is not iterable.
pass
normalized.append(msg)
return normalized
async def preprocess_chat_request(
request: dict[str, Any],
*,
tokenizer: TokenizerLike,
renderer,
tool_parser_class: type[ToolParser] | None,
) -> PreprocessResult:
request_for_sampling = ChatCompletionRequest.model_validate(request)
tool_parser: ToolParser | None = None
if tool_parser_class and request_for_sampling.tools:
if request_for_sampling.tool_choice != "none":
tool_parser = tool_parser_class(tokenizer)
request_for_sampling = tool_parser.adjust_request(request_for_sampling)
tool_dicts = (
[tool.model_dump() for tool in request_for_sampling.tools]
if request_for_sampling.tools
else None
)
chat_template_kwargs = dict(request_for_sampling.chat_template_kwargs or {})
chat_template_kwargs["reasoning_effort"] = request_for_sampling.reasoning_effort
# Mistral warns that tokenize=False is unsafe for chat templates.
is_mistral_tokenizer = (
tokenizer.__class__.__name__ == "MistralTokenizer"
or "tokenizers.mistral" in tokenizer.__class__.__module__
)
tokenize_in_template = is_mistral_tokenizer
messages_for_render = (
_materialize_assistant_tool_calls(request_for_sampling.messages)
if is_mistral_tokenizer
else request_for_sampling.messages
)
_, engine_prompt = await renderer.render_messages_async(
messages_for_render,
chat_template=request_for_sampling.chat_template,
chat_template_content_format="auto",
add_generation_prompt=request_for_sampling.add_generation_prompt,
continue_final_message=request_for_sampling.continue_final_message,
tools=tool_dicts,
documents=request_for_sampling.documents,
tokenize=tokenize_in_template,
**chat_template_kwargs,
)
if "prompt_token_ids" in engine_prompt:
tokens = list(engine_prompt["prompt_token_ids"])
else:
tokens = tokenizer.encode(
engine_prompt["prompt"],
add_special_tokens=request_for_sampling.add_special_tokens,
)
return PreprocessResult(
request_for_sampling=request_for_sampling,
tool_parser=tool_parser,
chat_template_kwargs=chat_template_kwargs,
engine_prompt=engine_prompt,
prompt_token_ids=tokens,
)
class StreamingPostProcessor:
def __init__(
self,
*,
tokenizer: TokenizerLike,
request_for_sampling: ChatCompletionRequest,
sampling_params: SamplingParams,
prompt_token_ids: Sequence[int],
tool_parser: ToolParser | None,
reasoning_parser_class: type[ReasoningParser] | None,
chat_template_kwargs: dict[str, Any],
) -> None:
self.tokenizer = tokenizer
self.request_for_sampling = request_for_sampling
self.sampling_params = sampling_params
self.tool_parser = tool_parser
self.reasoning_parser = (
reasoning_parser_class(
tokenizer,
chat_template_kwargs=chat_template_kwargs,
)
if reasoning_parser_class
else None
)
self._control_markers = tuple(
t for t in getattr(tokenizer, "all_special_tokens", ()) if t
)
self.previous_text = ""
self.previous_token_ids: list[int] = []
self.reasoning_is_done = False
self.in_progress_tool_calls: dict[int, DeltaToolCall] = {}
@staticmethod
def _merge_tool_call(
existing: DeltaToolCall | None, incoming: DeltaToolCall
) -> DeltaToolCall:
if existing is None:
if incoming.function and incoming.function.arguments is None:
incoming.function.arguments = ""
return incoming
if incoming.id and not existing.id:
existing.id = incoming.id
if incoming.type and not existing.type:
existing.type = incoming.type
if incoming.function:
if existing.function is None:
existing.function = incoming.function
if existing.function.arguments is None:
existing.function.arguments = ""
else:
if incoming.function.name and not existing.function.name:
existing.function.name = incoming.function.name
if incoming.function.arguments:
if existing.function.arguments is None:
existing.function.arguments = ""
existing.function.arguments += incoming.function.arguments
return existing
def _is_control_only_content(self, content: str | None) -> bool:
if not content:
return True
stripped = content
for marker in self._control_markers:
stripped = stripped.replace(marker, "")
return stripped.strip() == ""
def process_output(self, output: Any) -> dict[str, Any] | None:
delta_token_ids = list(output.token_ids or [])
# vLLM output_processor already applies stop-token/stop-string trimming
# to text. Re-detokenizing from token_ids can reintroduce stop markers.
delta_text = output.text or ""
current_text = self.previous_text + delta_text
current_token_ids = self.previous_token_ids + delta_token_ids
delta_message: DeltaMessage | None = DeltaMessage(content=delta_text)
if not self.reasoning_is_done and self.reasoning_parser:
delta_message = self.reasoning_parser.extract_reasoning_streaming(
self.previous_text,
current_text,
delta_text,
self.previous_token_ids,
current_token_ids,
delta_token_ids,
)
should_parse_tools = (
self.tool_parser is not None
and self.request_for_sampling.tool_choice != "none"
)
if should_parse_tools:
no_prev_reasoning = (
delta_message and delta_message.content and not delta_message.reasoning
)
if self.reasoning_is_done or no_prev_reasoning:
delta_message = self.tool_parser.extract_tool_calls_streaming(
previous_text=self.previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=self.previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=self.request_for_sampling,
)
if (
not self.reasoning_is_done
and self.reasoning_parser
and self.reasoning_parser.is_reasoning_end_streaming(
current_token_ids, delta_token_ids
)
):
self.reasoning_is_done = True
self.previous_text = ""
self.previous_token_ids = []
current_text = ""
current_token_ids = []
choice = None
if delta_message is None:
if self.in_progress_tool_calls:
choice = {
"index": output.index,
"delta": {
"role": "assistant",
"tool_calls": [
tool_call.model_dump(exclude_none=True)
for _, tool_call in sorted(
self.in_progress_tool_calls.items()
)
],
},
"finish_reason": output.finish_reason,
"logprobs": output.logprobs,
}
self.in_progress_tool_calls.clear()
elif output.finish_reason:
choice = {
"index": output.index,
"delta": {},
"finish_reason": output.finish_reason,
"logprobs": output.logprobs,
}
elif delta_message.tool_calls:
for tool_delta in delta_message.tool_calls:
existing = self.in_progress_tool_calls.get(tool_delta.index)
merged = self._merge_tool_call(existing, tool_delta)
self.in_progress_tool_calls[tool_delta.index] = merged
elif delta_message.content or delta_message.reasoning:
delta: dict[str, Any] = {"role": "assistant"}
content = delta_message.content
if self.in_progress_tool_calls and self._is_control_only_content(content):
content = None
if content:
delta["content"] = content
if delta_message.reasoning:
delta["reasoning_content"] = delta_message.reasoning
if self.in_progress_tool_calls:
delta["tool_calls"] = [
tool_call.model_dump(exclude_none=True)
for _, tool_call in sorted(self.in_progress_tool_calls.items())
]
self.in_progress_tool_calls.clear()
if len(delta) > 1:
choice = {
"index": output.index,
"delta": delta,
"finish_reason": output.finish_reason,
"logprobs": output.logprobs,
}
elif self.in_progress_tool_calls:
choice = {
"index": output.index,
"delta": {
"role": "assistant",
"tool_calls": [
tool_call.model_dump(exclude_none=True)
for _, tool_call in sorted(self.in_progress_tool_calls.items())
],
},
"finish_reason": output.finish_reason,
"logprobs": output.logprobs,
}
self.in_progress_tool_calls.clear()
elif output.finish_reason:
choice = {
"index": output.index,
"delta": {},
"finish_reason": output.finish_reason,
"logprobs": output.logprobs,
}
self.previous_text = current_text
self.previous_token_ids = current_token_ids
return choice
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Use vllm for input and output processing
#
import asyncio
import logging
import os
import time
import uuid
from argparse import Namespace
from collections.abc import AsyncGenerator
from typing import Any
from vllm.config import CacheConfig, LoadConfig, ModelConfig, VllmConfig
from vllm.inputs.data import TokensPrompt
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser, ToolParserManager
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.input_processor import InputProcessor
from vllm.v1.engine.output_processor import OutputProcessor, OutputProcessorOutput
from dynamo.llm import (
KvPushRouter,
ModelCardInstanceId,
ModelDeploymentCard,
PythonAsyncEngine,
RouterConfig,
RouterMode,
fetch_llm,
)
from dynamo.runtime import DistributedRuntime
from .prepost import StreamingPostProcessor, preprocess_chat_request
logger = logging.getLogger(__name__)
_MASK_64_BITS = (1 << 64) - 1
_FINISH_REASON_MAP: dict[str, FinishReason] = {
"eos": FinishReason.STOP,
"stop": FinishReason.STOP,
"length": FinishReason.LENGTH,
"error": FinishReason.ERROR,
"cancelled": FinishReason.ABORT,
"content_filter": FinishReason.STOP,
}
def random_uuid() -> str:
return f"{uuid.uuid4().int & _MASK_64_BITS:016x}" # 16 hex chars
def map_finish_reason(raw_reason: str | None) -> FinishReason | None:
if raw_reason is None:
return None
if raw_reason.startswith("error"):
return FinishReason.ERROR
if raw_reason.startswith("abort"):
return FinishReason.ABORT
if raw_reason.startswith("content_filter"):
logger.info("Router finish_reason indicates content filtering: %s", raw_reason)
raw_reason = "content_filter"
mapped = _FINISH_REASON_MAP.get(raw_reason)
if mapped is None:
logger.warning("Unknown finish_reason from router: %s", raw_reason)
return mapped
class VllmProcessor:
def __init__(
self,
tokenizer: TokenizerLike,
input_processor: InputProcessor,
router, # Client or KvPushRouter
output_processor: OutputProcessor,
tool_parser_class: type[ToolParser] | None,
reasoning_parser_class: type[ReasoningParser] | None,
):
self.tokenizer = tokenizer
self.input_processor = input_processor
self.router = router
self.is_kv_router = isinstance(router, KvPushRouter)
self.output_processor = output_processor
self.tool_parser_class = tool_parser_class
self.reasoning_parser_class = reasoning_parser_class
# Ideally we would map NVCreateChatCompletionRequest into Python so it can be type checked, but
# it has a lot of fields.
# request: dynamo.NVCreateChatCompletionRequest
async def generator(
self, request: dict[str, Any]
) -> AsyncGenerator[dict[str, Any], None]:
"""
Run a single request through the engine. Does pre and post processing on this machine, delegates
model inference to a worker using the router.
"""
# ** VllmProcessor.generator called: {'messages': [{'role': 'user', 'content': 'What is the capital of Tuvalu?'}], 'model': '/home/grahamk/llms/Qwen3-0.6B', 'max_completion_tokens': 1000, 'stream': False}
pre = await preprocess_chat_request(
request,
tokenizer=self.tokenizer,
renderer=self.input_processor.renderer,
tool_parser_class=self.tool_parser_class,
)
request_for_sampling = pre.request_for_sampling
tool_parser = pre.tool_parser
chat_template_kwargs = pre.chat_template_kwargs
engine_prompt = pre.engine_prompt
tokens = pre.prompt_token_ids
if request_for_sampling.max_completion_tokens is not None:
max_tokens = request_for_sampling.max_completion_tokens
elif request_for_sampling.max_tokens is not None:
max_tokens = request_for_sampling.max_tokens
else:
# This should mean model max - prompt len.
max_tokens = None
sampling_params = SamplingParams(
output_kind=RequestOutputKind.DELTA,
max_tokens=max_tokens,
)
# generation_config.json
for k, v in self.input_processor.generation_config_fields.items():
if hasattr(sampling_params, k):
setattr(sampling_params, k, v)
# User request: copy fields supported by both request schema and
# SamplingParams, excluding fields handled separately below.
sampling_fields = (
set(getattr(SamplingParams, "__annotations__", ()))
& set(type(request_for_sampling).model_fields)
) - {"max_tokens", "logprobs", "output_kind"}
for k in sorted(sampling_fields):
v = getattr(request_for_sampling, k, None)
if v is not None:
setattr(sampling_params, k, v)
logprobs = request_for_sampling.logprobs
top_logprobs = request_for_sampling.top_logprobs
if logprobs is True:
sampling_params.logprobs = top_logprobs or 1
elif isinstance(logprobs, int) and not isinstance(logprobs, bool):
sampling_params.logprobs = logprobs
elif top_logprobs not in (None, 0):
sampling_params.logprobs = top_logprobs
if sampling_params.logprobs is not None and sampling_params.logprobs > 0:
logger.warning(
"Logprobs requested but not supported in distributed inference mode"
)
request_id = random_uuid()
# This calls update_from_generation_config and update_from_tokenizer on SamplingParams
prompt_inputs = TokensPrompt(prompt_token_ids=tokens)
if "multi_modal_data" in engine_prompt:
prompt_inputs["multi_modal_data"] = engine_prompt["multi_modal_data"]
if "multi_modal_uuids" in engine_prompt:
prompt_inputs["multi_modal_uuids"] = engine_prompt["multi_modal_uuids"]
if request_for_sampling.cache_salt is not None:
prompt_inputs["cache_salt"] = request_for_sampling.cache_salt
if request_for_sampling.mm_processor_kwargs is not None:
prompt_inputs[
"mm_processor_kwargs"
] = request_for_sampling.mm_processor_kwargs
vllm_preproc: EngineCoreRequest = self.input_processor.process_inputs(
request_id,
prompt_inputs,
sampling_params,
# arrival_time: float | None = None,
# lora_request: LoRARequest | None = None,
# tokenization_kwargs: dict[str, Any] | None = None,
# trace_headers: Mapping[str, str] | None = None,
# priority: int = 0,
# data_parallel_rank: int | None = None,
)
InputProcessor.assign_request_id(vllm_preproc)
# Processed: EngineCoreRequest(request_id='a2b76a85cd65e151', prompt_token_ids=[3838, 374, 279, 6722, 315, 28649, 25510, 30], mm_features=None, sampling_params=SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=1.0, top_p=1.0, top_k=0, min_p=0.0, seed=None, stop=[], stop_token_ids=[151643], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=16, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, structured_outputs=None, extra_args=None), pooling_params=None, eos_token_id=151645, arrival_time=1769036937.9417946, lora_request=None, cache_salt=None, data_parallel_rank=None, prompt_embeds=None, client_index=0, current_wave=0, priority=0, trace_headers=None)
# Convert to a Python object that has fields that match our PreprocessedRequest
sp = vllm_preproc.sampling_params
if sp.n != 1:
logger.error("Unsupported SamplingParams.n=%d, only n=1 is supported", sp.n)
yield {
"error": {
"message": (
f"Unsupported value: 'n={sp.n}'. "
"This endpoint currently supports only n=1."
),
"type": "invalid_request_error",
"param": "n",
"code": "unsupported_value",
}
}
return
prompt = None
self.output_processor.add_request(
vllm_preproc,
prompt,
# parent_req: ParentRequest | None = None,
# request_index: int = 0,
# queue: RequestOutputCollector | None = None,
)
dynamo_preproc = {
"model": request["model"],
"token_ids": tokens,
# protocols.common.StopConditions
"stop_conditions": {
"max_tokens": sp.max_tokens,
"stop": sp.stop,
"stop_token_ids": sp.stop_token_ids,
"min_tokens": sp.min_tokens,
"ignore_eos": sp.ignore_eos,
},
# protocols.common.SamplingOptions
# Is there a better way than typing it out like this?
"sampling_options": {
"n": sp.n,
"presence_penalty": sp.presence_penalty,
"frequency_penalty": sp.frequency_penalty,
"repetition_penalty": sp.repetition_penalty,
"temperature": sp.temperature,
"top_p": sp.top_p,
"top_k": sp.top_k,
"min_p": sp.min_p,
"seed": sp.seed,
},
# protocols.common.OutputOptions
"output_options": {
"logprobs": sp.logprobs,
"prompt_logprobs": sp.prompt_logprobs,
"skip_special_tokens": sp.skip_special_tokens,
},
"eos_token_ids": [vllm_preproc.eos_token_id]
if vllm_preproc.eos_token_id is not None
else [],
"annotations": [],
# "prompt_embeds": vllm_preproc.prompt_embeds,
}
post = StreamingPostProcessor(
tokenizer=self.tokenizer,
request_for_sampling=request_for_sampling,
sampling_params=sampling_params,
prompt_token_ids=tokens,
tool_parser=tool_parser,
reasoning_parser_class=self.reasoning_parser_class,
chat_template_kwargs=chat_template_kwargs,
)
# dynamo_response: Annotated
try:
# Dynamo Router. This goes to the backend, waits, gets the streaming response, returns it.
# Stream is AsyncResponseStream
if self.is_kv_router:
dynamo_stream = await self.router.generate(
token_ids=tokens,
model=dynamo_preproc["model"],
stop_conditions=dynamo_preproc["stop_conditions"],
sampling_options=dynamo_preproc["sampling_options"],
output_options=dynamo_preproc["output_options"],
)
else:
# Round robin or random, depending on cmd line flag
dynamo_stream = await self.router.generate(dynamo_preproc)
async for dynamo_response in dynamo_stream:
# dynamo_response looks like this for regular router:
# Stream got: Annotated(data={'token_ids': [7281]}, event=None, comment=[], id=None)
# For KV router is is only the inner map: {'token_ids': [7281]}
if self.is_kv_router:
engine_response = dynamo_response
else:
engine_response = dynamo_response.data()
# engine_response:
# Normal: {'token_ids': [151658]}
# Last: {'token_ids': [151645], 'finish_reason': 'stop', 'completion_usage': {'prompt_tokens': 190, 'completion_tokens': 168, 'total_tokens': 358, 'prompt_tokens_details': {'cached_tokens': 176}}}
if engine_response is None or "token_ids" not in engine_response:
logger.error("No outputs from engine for request %s", request_id)
yield {
"id": request_id,
"choices": [
{
"index": 0,
"delta": {},
"finish_reason": "error",
}
],
"created": int(time.time()),
"model": request["model"],
"object": "chat.completion.chunk",
}
break
raw_finish_reason = engine_response.get("finish_reason")
finish_reason = map_finish_reason(raw_finish_reason)
stop_reason = engine_response.get("stop_reason")
vllm_response = EngineCoreOutput(
request_id=vllm_preproc.request_id,
new_token_ids=engine_response["token_ids"],
finish_reason=finish_reason,
stop_reason=stop_reason,
# new_logprobs=new_logprobs,
# new_prompt_logprobs_tensors=prompt_logprobs_tensors,
# pooling_output=pooler_output,
# events=request.take_events(),
# kv_transfer_params=kv_transfer_params,
# trace_headers=request.trace_headers,
# num_cached_tokens=request.num_cached_tokens,
# num_nans_in_logits=request.num_nans_in_logits,
)
# Let vllm handle all post-processing
vllm_out: OutputProcessorOutput = self.output_processor.process_outputs(
[vllm_response]
)
if vllm_out.reqs_to_abort:
# Router has no abort API; we cannot propagate aborts.
pass
# vllm
# RequestOutput: OutputProcessorOutput(request_outputs=[RequestOutput(request_id=9dbe240d8de78db3, prompt='What is the capital of Tuvalu?', prompt_token_ids=[3838, 374, 279, 6722, 315, 28649, 25510, 30], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=' The', token_ids=[576], cumulative_logprob=None, logprobs=None, finish_reason=None, stop_reason=None)], finished=False, metrics=RequestStateStats(num_generation_tokens=0, arrival_time=1769118902.2172132, queued_ts=0.0, scheduled_ts=0.0, first_token_ts=0.0, last_token_ts=0.0, first_token_latency=0.0, is_corrupted=False), lora_request=None, num_cached_tokens=0, multi_modal_placeholders={})], reqs_to_abort=[])
# Vec<ChatChoiceStream>
choices = []
if not vllm_out.request_outputs:
continue
for output in vllm_out.request_outputs[0].outputs:
choice = post.process_output(output)
if choice:
choices.append(choice)
if choices:
# dynamo_out: NvCreateChatCompletionStreamResponse
dynamo_out = {
"id": request_id,
"choices": choices,
"created": int(time.time()),
"model": request["model"],
"object": "chat.completion.chunk",
}
if usage := engine_response.get("completion_usage"):
# The engine only includes this on the last response
dynamo_out["usage"] = usage
# Rust handles HTTP / Server Sent Events back to user
yield dynamo_out
finally:
if vllm_preproc.request_id in self.output_processor.request_states:
self.output_processor.abort_requests(
[vllm_preproc.request_id], internal=True
)
class EngineFactory:
def __init__(
self,
runtime: DistributedRuntime,
router_config: RouterConfig,
flags: Namespace,
):
self.runtime = runtime
self.router_config = router_config
self.flags = flags
async def chat_engine_factory(
self,
instance_id: ModelCardInstanceId,
mdc: ModelDeploymentCard,
) -> PythonAsyncEngine:
"""
Called by Rust when a model is discovered.
"""
model_type = mdc.model_type()
if not model_type.supports_chat():
raise RuntimeError(
f"model type {model_type} is not supported by this factory"
)
loop = asyncio.get_running_loop()
source_path = mdc.source_path()
if not os.path.exists(source_path):
await fetch_llm(source_path, ignore_weights=True)
tokenizer_mode = getattr(self.flags, "tokenizer_mode", None) or "auto"
config_format = getattr(self.flags, "config_format", None) or "auto"
load_format = getattr(self.flags, "load_format", None) or "dummy"
model_config = ModelConfig(
model=source_path,
tokenizer_mode=tokenizer_mode,
config_format=config_format,
)
vllm_config = VllmConfig(
model_config=model_config,
load_config=LoadConfig(load_format=load_format),
cache_config=CacheConfig(),
# scheduler_config=SchedulerConfig(),
)
input_processor = InputProcessor(vllm_config)
tokenizer = input_processor.get_tokenizer()
output_processor = OutputProcessor(
tokenizer,
log_stats=False,
stream_interval=1,
)
tool_parser_name = self.flags.tool_call_parser or mdc.runtime_config().get(
"tool_call_parser"
)
if tool_parser_name:
tool_parser_class = ToolParserManager.get_tool_parser(tool_parser_name)
else:
tool_parser_class = None
reasoning_parser_name = self.flags.reasoning_parser or mdc.runtime_config().get(
"reasoning_parser"
)
if reasoning_parser_name:
reasoning_parser_class = ReasoningParserManager.get_reasoning_parser(
reasoning_parser_name
)
else:
reasoning_parser_class = None
(namespace_name, component_name, endpoint_name) = instance_id.triple()
generate_endpoint = (
self.runtime.namespace(namespace_name)
.component(component_name)
.endpoint(endpoint_name)
)
if self.router_config.router_mode == RouterMode.KV:
router = KvPushRouter(
endpoint=generate_endpoint,
block_size=self.flags.kv_cache_block_size or 16,
kv_router_config=self.router_config.kv_router_config,
)
else:
router = await generate_endpoint.client2(self.router_config.router_mode)
gen = VllmProcessor(
tokenizer,
input_processor,
router,
output_processor,
tool_parser_class,
reasoning_parser_class,
)
return PythonAsyncEngine(gen.generator, loop)
...@@ -148,7 +148,7 @@ async fn engine_for( ...@@ -148,7 +148,7 @@ async fn engine_for(
// Auto-discover backends // Auto-discover backends
Ok(EngineConfig::Dynamic { Ok(EngineConfig::Dynamic {
model: Box::new(local_model), model: Box::new(local_model),
engine_factory: None, chat_engine_factory: None,
}) })
} }
Output::Echo => Ok(EngineConfig::InProcessText { Output::Echo => Ok(EngineConfig::InProcessText {
......
...@@ -149,6 +149,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -149,6 +149,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Namespace>()?; m.add_class::<Namespace>()?;
m.add_class::<Component>()?; m.add_class::<Component>()?;
m.add_class::<Endpoint>()?; m.add_class::<Endpoint>()?;
m.add_class::<ModelCardInstanceId>()?;
m.add_class::<Client>()?; m.add_class::<Client>()?;
m.add_class::<AsyncResponseStream>()?; m.add_class::<AsyncResponseStream>()?;
m.add_class::<llm::entrypoint::EntrypointArgs>()?; m.add_class::<llm::entrypoint::EntrypointArgs>()?;
...@@ -481,6 +482,12 @@ struct Endpoint { ...@@ -481,6 +482,12 @@ struct Endpoint {
event_loop: PyObject, event_loop: PyObject,
} }
#[pyclass]
#[derive(Clone)]
struct ModelCardInstanceId {
inner: rs::discovery::ModelCardInstanceId,
}
#[pyclass] #[pyclass]
#[derive(Clone)] #[derive(Clone)]
struct Client { struct Client {
...@@ -521,6 +528,10 @@ impl ModelType { ...@@ -521,6 +528,10 @@ impl ModelType {
inner: llm_rs::model_type::ModelType::Images, inner: llm_rs::model_type::ModelType::Images,
}; };
fn supports_chat(&self) -> bool {
self.inner.supports_chat()
}
fn __or__(&self, other: &Self) -> Self { fn __or__(&self, other: &Self) -> Self {
ModelType { ModelType {
inner: self.inner | other.inner, inner: self.inner | other.inner,
...@@ -825,13 +836,17 @@ impl Endpoint { ...@@ -825,13 +836,17 @@ impl Endpoint {
} }
fn client<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> { fn client<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
self.client2(py, RouterMode::RoundRobin)
}
fn client2<'p>(&self, py: Python<'p>, router_mode: RouterMode) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone(); let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let client = inner.client().await.map_err(to_pyerr)?; let client = inner.client().await.map_err(to_pyerr)?;
let push_router = rs::pipeline::PushRouter::< let push_router = rs::pipeline::PushRouter::<
serde_json::Value, serde_json::Value,
RsAnnotated<serde_json::Value>, RsAnnotated<serde_json::Value>,
>::from_client(client, Default::default()) >::from_client(client, router_mode.into())
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
Ok(Client { Ok(Client {
...@@ -892,6 +907,19 @@ impl Namespace { ...@@ -892,6 +907,19 @@ impl Namespace {
} }
} }
#[pymethods]
impl ModelCardInstanceId {
// (namespace, component, endpoint)
// TODO: Can these be borrowed as &str?
fn triple(&self) -> (String, String, String) {
(
self.inner.namespace.clone(),
self.inner.component.clone(),
self.inner.endpoint.clone(),
)
}
}
#[pymethods] #[pymethods]
impl Client { impl Client {
/// Get list of current instances. /// Get list of current instances.
......
...@@ -11,8 +11,8 @@ use pyo3::{exceptions::PyException, prelude::*}; ...@@ -11,8 +11,8 @@ use pyo3::{exceptions::PyException, prelude::*};
use pyo3_async_runtimes::TaskLocals; use pyo3_async_runtimes::TaskLocals;
use dynamo_llm::discovery::LoadThresholdConfig as RsLoadThresholdConfig; use dynamo_llm::discovery::LoadThresholdConfig as RsLoadThresholdConfig;
use dynamo_llm::entrypoint::ChatEngineFactoryCallback;
use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig; use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig;
use dynamo_llm::entrypoint::EngineFactoryCallback;
use dynamo_llm::entrypoint::RouterConfig as RsRouterConfig; use dynamo_llm::entrypoint::RouterConfig as RsRouterConfig;
use dynamo_llm::entrypoint::input::Input; use dynamo_llm::entrypoint::input::Input;
use dynamo_llm::kv_router::KvRouterConfig as RsKvRouterConfig; use dynamo_llm::kv_router::KvRouterConfig as RsKvRouterConfig;
...@@ -21,6 +21,7 @@ use dynamo_llm::local_model::{LocalModel, LocalModelBuilder}; ...@@ -21,6 +21,7 @@ use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
use dynamo_llm::mocker::protocols::MockEngineArgs; use dynamo_llm::mocker::protocols::MockEngineArgs;
use dynamo_llm::model_card::ModelDeploymentCard as RsModelDeploymentCard; use dynamo_llm::model_card::ModelDeploymentCard as RsModelDeploymentCard;
use dynamo_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine; use dynamo_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
use dynamo_runtime::discovery::ModelCardInstanceId as RsModelCardInstanceId;
use dynamo_runtime::protocols::EndpointId; use dynamo_runtime::protocols::EndpointId;
use super::model_card::ModelDeploymentCard; use super::model_card::ModelDeploymentCard;
...@@ -91,8 +92,12 @@ impl KvRouterConfig { ...@@ -91,8 +92,12 @@ impl KvRouterConfig {
#[pyclass] #[pyclass]
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct RouterConfig { pub struct RouterConfig {
router_mode: RouterMode, #[pyo3(get, set)]
kv_router_config: KvRouterConfig, pub router_mode: RouterMode,
#[pyo3(get, set)]
pub kv_router_config: KvRouterConfig,
/// Threshold for active decode blocks utilization (0.0-1.0) /// Threshold for active decode blocks utilization (0.0-1.0)
active_decode_blocks_threshold: Option<f64>, active_decode_blocks_threshold: Option<f64>,
/// Threshold for active prefill tokens utilization (literal token count) /// Threshold for active prefill tokens utilization (literal token count)
...@@ -175,14 +180,14 @@ pub(crate) struct EntrypointArgs { ...@@ -175,14 +180,14 @@ pub(crate) struct EntrypointArgs {
namespace: Option<String>, namespace: Option<String>,
is_prefill: bool, is_prefill: bool,
migration_limit: u32, migration_limit: u32,
engine_factory: Option<PyEngineFactory>, chat_engine_factory: Option<PyEngineFactory>,
} }
#[pymethods] #[pymethods]
impl EntrypointArgs { impl EntrypointArgs {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
#[new] #[new]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None, is_prefill=false, migration_limit=0, engine_factory=None))] #[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None, is_prefill=false, migration_limit=0, chat_engine_factory=None))]
pub fn new( pub fn new(
py: Python<'_>, py: Python<'_>,
engine_type: EngineType, engine_type: EngineType,
...@@ -202,7 +207,7 @@ impl EntrypointArgs { ...@@ -202,7 +207,7 @@ impl EntrypointArgs {
namespace: Option<String>, namespace: Option<String>,
is_prefill: bool, is_prefill: bool,
migration_limit: u32, migration_limit: u32,
engine_factory: Option<PyObject>, chat_engine_factory: Option<PyObject>,
) -> PyResult<Self> { ) -> PyResult<Self> {
let endpoint_id_obj: Option<EndpointId> = endpoint_id.as_deref().map(EndpointId::from); let endpoint_id_obj: Option<EndpointId> = endpoint_id.as_deref().map(EndpointId::from);
if (tls_cert_path.is_some() && tls_key_path.is_none()) if (tls_cert_path.is_some() && tls_key_path.is_none())
...@@ -213,12 +218,12 @@ impl EntrypointArgs { ...@@ -213,12 +218,12 @@ impl EntrypointArgs {
)); ));
} }
// Capture TaskLocals at registration time for the engine factory callback // Capture TaskLocals at registration time for the chat engine factory callback
let engine_factory = engine_factory let chat_engine_factory = chat_engine_factory
.map(|callback| { .map(|callback| {
let locals = pyo3_async_runtimes::tokio::get_current_locals(py).map_err(|e| { let locals = pyo3_async_runtimes::tokio::get_current_locals(py).map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!( pyo3::exceptions::PyRuntimeError::new_err(format!(
"Failed to get TaskLocals for engine_factory: {}", "Failed to get TaskLocals for chat_engine_factory: {}",
e e
)) ))
})?; })?;
...@@ -247,7 +252,7 @@ impl EntrypointArgs { ...@@ -247,7 +252,7 @@ impl EntrypointArgs {
namespace, namespace,
is_prefill, is_prefill,
migration_limit, migration_limit,
engine_factory, chat_engine_factory,
}) })
} }
} }
...@@ -310,13 +315,15 @@ pub fn make_engine<'p>( ...@@ -310,13 +315,15 @@ pub fn make_engine<'p>(
}) })
} }
/// Convert a PyEngineFactory to a Rust EngineFactoryCallback /// Convert a PyEngineFactory to a Rust ChatEngineFactoryCallback
fn py_engine_factory_to_callback(factory: PyEngineFactory) -> EngineFactoryCallback { fn py_engine_factory_to_callback(factory: PyEngineFactory) -> ChatEngineFactoryCallback {
let callback = factory.callback; let callback = factory.callback;
let locals = factory.locals; let locals = factory.locals;
Arc::new( Arc::new(
move |card: RsModelDeploymentCard| -> Pin< move |instance_id: RsModelCardInstanceId,
card: RsModelDeploymentCard|
-> Pin<
Box<dyn Future<Output = anyhow::Result<OpenAIChatCompletionsStreamingEngine>> + Send>, Box<dyn Future<Output = anyhow::Result<OpenAIChatCompletionsStreamingEngine>> + Send>,
> { > {
let callback = callback.clone(); let callback = callback.clone();
...@@ -325,27 +332,29 @@ fn py_engine_factory_to_callback(factory: PyEngineFactory) -> EngineFactoryCallb ...@@ -325,27 +332,29 @@ fn py_engine_factory_to_callback(factory: PyEngineFactory) -> EngineFactoryCallb
Box::pin(async move { Box::pin(async move {
// Acquire GIL to call Python callback and convert coroutine to future // Acquire GIL to call Python callback and convert coroutine to future
let py_future = Python::with_gil(|py| { let py_future = Python::with_gil(|py| {
let py_instance_id =
Py::new(py, crate::ModelCardInstanceId { inner: instance_id }).map_err(
|e| anyhow::anyhow!("Failed to create Python ModelCardInstanceId: {e}"),
)?;
// Create Python ModelDeploymentCard wrapper // Create Python ModelDeploymentCard wrapper
let py_card = ModelDeploymentCard { inner: card }; let py_card = ModelDeploymentCard { inner: card };
let py_card_obj = Py::new(py, py_card) let py_card_obj = Py::new(py, py_card)
.map_err(|e| anyhow::anyhow!("Failed to create Python MDC: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to create Python MDC: {e}"))?;
// Call Python async function to get a coroutine // Call Python async function to get a coroutine
let coroutine = callback let coroutine = callback
.call1(py, (py_card_obj,)) .call1(py, (py_instance_id, py_card_obj))
.map_err(|e| anyhow::anyhow!("Failed to call engine_factory: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to call chat_engine_factory: {e}"))?;
// Use the TaskLocals captured at registration time // Use the TaskLocals captured at registration time
pyo3_async_runtimes::into_future_with_locals(&locals, coroutine.into_bound(py)) pyo3_async_runtimes::into_future_with_locals(&locals, coroutine.into_bound(py))
.map_err(|e| { .map_err(|e| anyhow::anyhow!("Failed to convert coroutine to future: {e}"))
anyhow::anyhow!("Failed to convert coroutine to future: {}", e)
})
})?; })?;
// Await the Python coroutine (GIL is released during await) // Await the Python coroutine (GIL is released during await)
let py_result = py_future let py_result = py_future
.await .await
.map_err(|e| anyhow::anyhow!("engine_factory callback failed: {}", e))?; .map_err(|e| anyhow::anyhow!("chat_engine_factory callback failed: {}", e))?;
// Extract PythonAsyncEngine from the Python result and wrap in Arc // Extract PythonAsyncEngine from the Python result and wrap in Arc
let engine: OpenAIChatCompletionsStreamingEngine = Python::with_gil(|py| { let engine: OpenAIChatCompletionsStreamingEngine = Python::with_gil(|py| {
...@@ -375,11 +384,11 @@ async fn select_engine( ...@@ -375,11 +384,11 @@ async fn select_engine(
} }
} }
EngineType::Dynamic => { EngineType::Dynamic => {
// Convert Python engine factory to Rust callback // Convert Python chat engine factory to Rust callback
let engine_factory = args.engine_factory.map(py_engine_factory_to_callback); let chat_engine_factory = args.chat_engine_factory.map(py_engine_factory_to_callback);
RsEngineConfig::Dynamic { RsEngineConfig::Dynamic {
model: Box::new(local_model), model: Box::new(local_model),
engine_factory, chat_engine_factory,
} }
} }
EngineType::Mocker => { EngineType::Mocker => {
......
...@@ -10,8 +10,6 @@ pub(crate) struct ModelDeploymentCard { ...@@ -10,8 +10,6 @@ pub(crate) struct ModelDeploymentCard {
pub(crate) inner: RsModelDeploymentCard, pub(crate) inner: RsModelDeploymentCard,
} }
impl ModelDeploymentCard {}
#[pymethods] #[pymethods]
impl ModelDeploymentCard { impl ModelDeploymentCard {
// Previously called "from_local_path" // Previously called "from_local_path"
...@@ -32,4 +30,23 @@ impl ModelDeploymentCard { ...@@ -32,4 +30,23 @@ impl ModelDeploymentCard {
let json = self.inner.to_json().map_err(to_pyerr)?; let json = self.inner.to_json().map_err(to_pyerr)?;
Ok(json) Ok(json)
} }
fn source_path(&self) -> &str {
self.inner.source_path()
}
fn name(&self) -> &str {
self.inner.name()
}
fn model_type(&self) -> ModelType {
ModelType {
inner: self.inner.model_type,
}
}
fn runtime_config(&self, py: Python<'_>) -> PyResult<PyObject> {
let rc = pythonize::pythonize(py, &self.inner.runtime_config).map_err(to_pyerr)?;
Ok(rc.unbind())
}
} }
...@@ -166,7 +166,14 @@ class Endpoint: ...@@ -166,7 +166,14 @@ class Endpoint:
async def client(self) -> Client: async def client(self) -> Client:
""" """
Create a `Client` capable of calling served instances of this endpoint Create a `Client` capable of calling served instances of this endpoint using round-robin routing.
"""
...
async def client2(self, router_mode: RouterMode) -> Client:
"""
Create a `Client` capable of calling served instances of this endpoint, using a specific
router mode (random, round-robin, kv).
""" """
... ...
...@@ -251,6 +258,18 @@ class Client: ...@@ -251,6 +258,18 @@ class Client:
... ...
class ModelCardInstanceId:
"""
Unique identifier for a worker instance: namespace, component, endpoint and instance_id.
The instance_id is not currently exposed in the Python bindings.
"""
def triple(self) -> Tuple[str, str, str]:
"""
Triple of namespace, component and endpoint this worker is serving.
"""
...
def compute_block_hash_for_seq_py( def compute_block_hash_for_seq_py(
tokens: List[int], tokens: List[int],
kv_block_size: int, kv_block_size: int,
...@@ -1504,7 +1523,7 @@ class EntrypointArgs: ...@@ -1504,7 +1523,7 @@ class EntrypointArgs:
namespace: Optional[str] = None, namespace: Optional[str] = None,
is_prefill: bool = False, is_prefill: bool = False,
migration_limit: int = 0, migration_limit: int = 0,
engine_factory: Optional[Callable] = None, chat_engine_factory: Optional[Callable] = None,
) -> None: ) -> None:
""" """
Create EntrypointArgs. Create EntrypointArgs.
...@@ -1527,7 +1546,7 @@ class EntrypointArgs: ...@@ -1527,7 +1546,7 @@ class EntrypointArgs:
namespace: Dynamo namespace for model discovery scoping namespace: Dynamo namespace for model discovery scoping
is_prefill: Whether this is a prefill worker is_prefill: Whether this is a prefill worker
migration_limit: Maximum number of request migrations (0=disabled) migration_limit: Maximum number of request migrations (0=disabled)
engine_factory: Optional Python engine factory callback chat_engine_factory: Optional Python chat completions engine factory callback
""" """
... ...
...@@ -1583,4 +1602,5 @@ __all__ = [ ...@@ -1583,4 +1602,5 @@ __all__ = [
"ModelDeploymentCard", "ModelDeploymentCard",
"PythonAsyncEngine", "PythonAsyncEngine",
"prometheus_names", "prometheus_names",
"ModelCardInstanceId",
] ]
...@@ -18,6 +18,7 @@ from dynamo._core import KvRouterConfig as KvRouterConfig ...@@ -18,6 +18,7 @@ from dynamo._core import KvRouterConfig as KvRouterConfig
from dynamo._core import LoRADownloader as LoRADownloader from dynamo._core import LoRADownloader as LoRADownloader
from dynamo._core import MediaDecoder as MediaDecoder from dynamo._core import MediaDecoder as MediaDecoder
from dynamo._core import MediaFetcher as MediaFetcher from dynamo._core import MediaFetcher as MediaFetcher
from dynamo._core import ModelCardInstanceId as ModelCardInstanceId
from dynamo._core import ModelDeploymentCard as ModelDeploymentCard from dynamo._core import ModelDeploymentCard as ModelDeploymentCard
from dynamo._core import ModelInput as ModelInput from dynamo._core import ModelInput as ModelInput
from dynamo._core import ModelRuntimeConfig as ModelRuntimeConfig from dynamo._core import ModelRuntimeConfig as ModelRuntimeConfig
......
...@@ -25,7 +25,7 @@ use dynamo_runtime::{ ...@@ -25,7 +25,7 @@ use dynamo_runtime::{
use crate::{ use crate::{
backend::Backend, backend::Backend,
discovery::WORKER_TYPE_DECODE, discovery::WORKER_TYPE_DECODE,
entrypoint::{self, EngineFactoryCallback, RouterConfig}, entrypoint::{self, ChatEngineFactoryCallback, RouterConfig},
http::service::metrics::Metrics, http::service::metrics::Metrics,
kv_router::PrefillRouter, kv_router::PrefillRouter,
model_card::ModelDeploymentCard, model_card::ModelDeploymentCard,
...@@ -61,7 +61,7 @@ pub struct ModelWatcher { ...@@ -61,7 +61,7 @@ pub struct ModelWatcher {
migration_limit: u32, migration_limit: u32,
notify_on_model: Notify, notify_on_model: Notify,
model_update_tx: Option<Sender<ModelUpdate>>, model_update_tx: Option<Sender<ModelUpdate>>,
engine_factory: Option<EngineFactoryCallback>, chat_engine_factory: Option<ChatEngineFactoryCallback>,
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
registering_models: DashSet<String>, registering_models: DashSet<String>,
} }
...@@ -81,7 +81,7 @@ impl ModelWatcher { ...@@ -81,7 +81,7 @@ impl ModelWatcher {
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
router_config: RouterConfig, router_config: RouterConfig,
migration_limit: u32, migration_limit: u32,
engine_factory: Option<EngineFactoryCallback>, chat_engine_factory: Option<ChatEngineFactoryCallback>,
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
) -> ModelWatcher { ) -> ModelWatcher {
Self { Self {
...@@ -91,7 +91,7 @@ impl ModelWatcher { ...@@ -91,7 +91,7 @@ impl ModelWatcher {
migration_limit, migration_limit,
notify_on_model: Notify::new(), notify_on_model: Notify::new(),
model_update_tx: None, model_update_tx: None,
engine_factory, chat_engine_factory,
metrics, metrics,
registering_models: DashSet::new(), registering_models: DashSet::new(),
} }
...@@ -435,7 +435,15 @@ impl ModelWatcher { ...@@ -435,7 +435,15 @@ impl ModelWatcher {
// handle Chat or Completions requests, so handle whatever the model supports. // handle Chat or Completions requests, so handle whatever the model supports.
let endpoint = component.endpoint(&mcid.endpoint); let endpoint = component.endpoint(&mcid.endpoint);
let kv_chooser = if self.router_config.router_mode == RouterMode::KV { // Create the KV router whenever any local routed pipeline will be built.
// The chat factory builds its own router, but completions currently always
// uses the local routed pipeline and therefore still needs a chooser.
let needs_local_chat_pipeline =
card.model_type.supports_chat() && self.chat_engine_factory.is_none();
let needs_local_completions_pipeline = card.model_type.supports_completions();
let kv_chooser = if self.router_config.router_mode == RouterMode::KV
&& (needs_local_chat_pipeline || needs_local_completions_pipeline)
{
Some( Some(
self.manager self.manager
.kv_chooser_for( .kv_chooser_for(
...@@ -487,11 +495,17 @@ impl ModelWatcher { ...@@ -487,11 +495,17 @@ impl ModelWatcher {
// Add chat engine only if the model supports chat // Add chat engine only if the model supports chat
if card.model_type.supports_chat() { if card.model_type.supports_chat() {
// Work in progress. This will allow creating a chat_engine from Python. let factory_engine = if let Some(ref factory) = self.chat_engine_factory {
let chat_engine = if let Some(ref factory) = self.engine_factory { match factory(mcid.clone(), card.clone()).await {
factory(card.clone()) Ok(engine) => Some(engine),
.await Err(err) => return Err(err).context("python chat_engine_factory"),
.context("python engine_factory")? }
} else {
None
};
let chat_engine = if let Some(engine) = factory_engine {
engine
} else { } else {
entrypoint::build_routed_pipeline::< entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest, NvCreateChatCompletionRequest,
...@@ -518,7 +532,7 @@ impl ModelWatcher { ...@@ -518,7 +532,7 @@ impl ModelWatcher {
tracing::info!("Chat completions is ready"); tracing::info!("Chat completions is ready");
} }
// Add completions engine only if the model supports completions // Add completions engine only if the model supports completions.
if card.model_type.supports_completions() { if card.model_type.supports_completions() {
let formatter = PromptFormatter::no_op(); let formatter = PromptFormatter::no_op();
let PromptFormatter::OAI(formatter) = formatter; let PromptFormatter::OAI(formatter) = formatter;
......
...@@ -12,7 +12,7 @@ use std::future::Future; ...@@ -12,7 +12,7 @@ use std::future::Future;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use dynamo_runtime::pipeline::RouterMode; use dynamo_runtime::{discovery::ModelCardInstanceId, pipeline::RouterMode};
use crate::{ use crate::{
backend::ExecutionContext, discovery::LoadThresholdConfig, engines::StreamingEngine, backend::ExecutionContext, discovery::LoadThresholdConfig, engines::StreamingEngine,
...@@ -20,9 +20,10 @@ use crate::{ ...@@ -20,9 +20,10 @@ use crate::{
types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine, types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine,
}; };
/// Callback type for engine factory (async) /// Callback type for chat engine factory (async)
pub type EngineFactoryCallback = Arc< pub type ChatEngineFactoryCallback = Arc<
dyn Fn( dyn Fn(
ModelCardInstanceId,
ModelDeploymentCard, ModelDeploymentCard,
) -> Pin< ) -> Pin<
Box<dyn Future<Output = anyhow::Result<OpenAIChatCompletionsStreamingEngine>> + Send>, Box<dyn Future<Output = anyhow::Result<OpenAIChatCompletionsStreamingEngine>> + Send>,
...@@ -65,7 +66,7 @@ pub enum EngineConfig { ...@@ -65,7 +66,7 @@ pub enum EngineConfig {
/// Remote networked engines that we discover via etcd /// Remote networked engines that we discover via etcd
Dynamic { Dynamic {
model: Box<LocalModel>, model: Box<LocalModel>,
engine_factory: Option<EngineFactoryCallback>, chat_engine_factory: Option<ChatEngineFactoryCallback>,
}, },
/// A Text engine receives text, does it's own tokenization and prompt formatting. /// A Text engine receives text, does it's own tokenization and prompt formatting.
...@@ -92,9 +93,12 @@ impl EngineConfig { ...@@ -92,9 +93,12 @@ impl EngineConfig {
} }
} }
pub fn engine_factory(&self) -> Option<&EngineFactoryCallback> { pub fn chat_engine_factory(&self) -> Option<&ChatEngineFactoryCallback> {
match self { match self {
EngineConfig::Dynamic { engine_factory, .. } => engine_factory.as_ref(), EngineConfig::Dynamic {
chat_engine_factory,
..
} => chat_engine_factory.as_ref(),
_ => None, _ => None,
} }
} }
......
...@@ -7,7 +7,7 @@ use crate::{ ...@@ -7,7 +7,7 @@ use crate::{
discovery::{ModelManager, ModelUpdate, ModelWatcher}, discovery::{ModelManager, ModelUpdate, ModelWatcher},
endpoint_type::EndpointType, endpoint_type::EndpointType,
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
entrypoint::{EngineConfig, EngineFactoryCallback, RouterConfig, input::common}, entrypoint::{ChatEngineFactoryCallback, EngineConfig, RouterConfig, input::common},
http::service::service_v2::{self, HttpService}, http::service::service_v2::{self, HttpService},
namespace::is_global_namespace, namespace::is_global_namespace,
types::openai::{ types::openai::{
...@@ -54,7 +54,7 @@ pub async fn run( ...@@ -54,7 +54,7 @@ pub async fn run(
let http_service = match engine_config { let http_service = match engine_config {
EngineConfig::Dynamic { EngineConfig::Dynamic {
ref model, ref model,
ref engine_factory, ref chat_engine_factory,
} => { } => {
// This allows the /health endpoint to query store for active instances // This allows the /health endpoint to query store for active instances
http_service_builder = http_service_builder.store(distributed_runtime.store().clone()); http_service_builder = http_service_builder.store(distributed_runtime.store().clone());
...@@ -79,7 +79,7 @@ pub async fn run( ...@@ -79,7 +79,7 @@ pub async fn run(
target_namespace, target_namespace,
Arc::new(http_service.clone()), Arc::new(http_service.clone()),
http_service.state().metrics_clone(), http_service.state().metrics_clone(),
engine_factory.clone(), chat_engine_factory.clone(),
) )
.await?; .await?;
http_service http_service
...@@ -157,14 +157,14 @@ async fn run_watcher( ...@@ -157,14 +157,14 @@ async fn run_watcher(
target_namespace: Option<String>, target_namespace: Option<String>,
http_service: Arc<HttpService>, http_service: Arc<HttpService>,
metrics: Arc<crate::http::service::metrics::Metrics>, metrics: Arc<crate::http::service::metrics::Metrics>,
engine_factory: Option<EngineFactoryCallback>, chat_engine_factory: Option<ChatEngineFactoryCallback>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let mut watch_obj = ModelWatcher::new( let mut watch_obj = ModelWatcher::new(
runtime.clone(), runtime.clone(),
model_manager, model_manager,
router_config, router_config,
migration_limit, migration_limit,
engine_factory, chat_engine_factory,
metrics.clone(), metrics.clone(),
); );
tracing::debug!("Waiting for remote model"); tracing::debug!("Waiting for remote model");
......
...@@ -1051,6 +1051,7 @@ impl ...@@ -1051,6 +1051,7 @@ impl
let (mut common_request, annotations) = self let (mut common_request, annotations) = self
.preprocess_request(&request, tracker.as_deref()) .preprocess_request(&request, tracker.as_deref())
.await?; .await?;
tracing::trace!(request = ?common_request, "Pre-processed request");
// Attach the timing tracker to the request so downstream components can record metrics // Attach the timing tracker to the request so downstream components can record metrics
common_request.tracker = tracker; common_request.tracker = tracker;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment