Unverified Commit 2128d48a authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

chore: add mypy typing to frontend (#6861)

parent 93529753
...@@ -22,7 +22,7 @@ import os ...@@ -22,7 +22,7 @@ import os
import signal import signal
import sys import sys
from argparse import Namespace from argparse import Namespace
from typing import Optional from typing import TYPE_CHECKING, Any, Optional
import uvloop import uvloop
...@@ -41,6 +41,9 @@ from dynamo.runtime.logging import configure_dynamo_logging ...@@ -41,6 +41,9 @@ from dynamo.runtime.logging import configure_dynamo_logging
from .frontend_args import FrontendArgGroup, FrontendConfig from .frontend_args import FrontendArgGroup, FrontendConfig
if TYPE_CHECKING:
from .vllm_processor import EngineFactory
configure_dynamo_logging() configure_dynamo_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -50,7 +53,7 @@ def setup_engine_factory( ...@@ -50,7 +53,7 @@ def setup_engine_factory(
router_config: RouterConfig, router_config: RouterConfig,
config: FrontendConfig, config: FrontendConfig,
vllm_flags: Namespace, vllm_flags: Namespace,
): ) -> "EngineFactory":
""" """
When using vllm pre and post processor, create the EngineFactory that When using vllm pre and post processor, create the EngineFactory that
creates the engines that run requests. creates the engines that run requests.
...@@ -196,7 +199,7 @@ async def async_main(): ...@@ -196,7 +199,7 @@ async def async_main():
active_prefill_tokens_threshold_frac=config.active_prefill_tokens_threshold_frac, active_prefill_tokens_threshold_frac=config.active_prefill_tokens_threshold_frac,
decode_fallback=config.decode_fallback, decode_fallback=config.decode_fallback,
) )
kwargs = { kwargs: dict[str, Any] = {
"http_host": config.http_host, "http_host": config.http_host,
"http_port": config.http_port, "http_port": config.http_port,
"kv_cache_block_size": config.kv_cache_block_size, "kv_cache_block_size": config.kv_cache_block_size,
...@@ -245,7 +248,7 @@ async def async_main(): ...@@ -245,7 +248,7 @@ async def async_main():
pass pass
async def graceful_shutdown(runtime): async def graceful_shutdown(runtime: DistributedRuntime) -> None:
"""Handle graceful shutdown of the distributed runtime. """Handle graceful shutdown of the distributed runtime.
Args: Args:
...@@ -254,7 +257,7 @@ async def graceful_shutdown(runtime): ...@@ -254,7 +257,7 @@ async def graceful_shutdown(runtime):
runtime.shutdown() runtime.shutdown()
def main(): def main() -> None:
"""Entry point for the Dynamo frontend CLI.""" """Entry point for the Dynamo frontend CLI."""
uvloop.run(async_main()) uvloop.run(async_main())
......
...@@ -401,10 +401,10 @@ class StreamingPostProcessor: ...@@ -401,10 +401,10 @@ class StreamingPostProcessor:
# vLLM output_processor already applies stop-token/stop-string trimming # vLLM output_processor already applies stop-token/stop-string trimming
# to text. Re-detokenizing from token_ids can reintroduce stop markers. # to text. Re-detokenizing from token_ids can reintroduce stop markers.
delta_text = output.text or "" delta_text = output.text or ""
delta: dict[str, Any] = {}
if self._fast_plain_text: if self._fast_plain_text:
if delta_text: if delta_text:
delta: dict[str, Any] = { delta = {
"role": "assistant", "role": "assistant",
"content": delta_text, "content": delta_text,
} }
...@@ -542,7 +542,7 @@ class StreamingPostProcessor: ...@@ -542,7 +542,7 @@ class StreamingPostProcessor:
# to drain the buffer. # to drain the buffer.
choice = self._emit_tool_calls_choice(output) choice = self._emit_tool_calls_choice(output)
elif delta_message.content or delta_message.reasoning: elif delta_message.content or delta_message.reasoning:
delta: dict[str, Any] = {"role": "assistant"} delta = {"role": "assistant"}
content = delta_message.content content = delta_message.content
if self.in_progress_tool_calls and self._is_control_only_content(content): if self.in_progress_tool_calls and self._is_control_only_content(content):
content = None content = None
......
...@@ -37,7 +37,7 @@ from dynamo.llm import ( ...@@ -37,7 +37,7 @@ from dynamo.llm import (
RouterMode, RouterMode,
fetch_model, fetch_model,
) )
from dynamo.runtime import DistributedRuntime from dynamo.runtime import Client, DistributedRuntime
from .prepost import ( from .prepost import (
StreamingPostProcessor, StreamingPostProcessor,
...@@ -147,6 +147,7 @@ def _preprocess_worker( ...@@ -147,6 +147,7 @@ def _preprocess_worker(
model_name: str, model_name: str,
) -> PreprocessWorkerResult: ) -> PreprocessWorkerResult:
"""Preprocess a request in a worker process and return a picklable result.""" """Preprocess a request in a worker process and return a picklable result."""
assert _w_input_processor is not None
pre = preprocess_chat_request_sync( pre = preprocess_chat_request_sync(
request, request,
tokenizer=_w_tokenizer, tokenizer=_w_tokenizer,
...@@ -271,7 +272,7 @@ class VllmProcessor: ...@@ -271,7 +272,7 @@ class VllmProcessor:
self, self,
tokenizer: TokenizerLike, tokenizer: TokenizerLike,
input_processor: InputProcessor, input_processor: InputProcessor,
router, # Client or KvRouter router: Any, # Client or KvRouter
output_processor: OutputProcessor, output_processor: OutputProcessor,
tool_parser_class: type[ToolParser] | None, tool_parser_class: type[ToolParser] | None,
reasoning_parser_class: type[ReasoningParser] | None, reasoning_parser_class: type[ReasoningParser] | None,
...@@ -644,7 +645,9 @@ class VllmProcessor: ...@@ -644,7 +645,9 @@ class VllmProcessor:
# --- Phase 1: Preprocess (semaphore held) --- # --- Phase 1: Preprocess (semaphore held) ---
try: try:
assert self._worker_semaphore is not None
async with self._worker_semaphore: async with self._worker_semaphore:
assert self.preprocess_pool is not None
future = self.preprocess_pool.submit( future = self.preprocess_pool.submit(
_preprocess_worker, request, request_id, request["model"] _preprocess_worker, request, request_id, request["model"]
) )
...@@ -793,7 +796,7 @@ class EngineFactory: ...@@ -793,7 +796,7 @@ class EngineFactory:
generate_endpoint = self.runtime.endpoint( generate_endpoint = self.runtime.endpoint(
f"{namespace_name}.{component_name}.{endpoint_name}" f"{namespace_name}.{component_name}.{endpoint_name}"
) )
router: Client | KvRouter
if self.router_config.router_mode == RouterMode.KV: if self.router_config.router_mode == RouterMode.KV:
router = KvRouter( router = KvRouter(
endpoint=generate_endpoint, endpoint=generate_endpoint,
......
...@@ -412,7 +412,17 @@ class ModelDeploymentCard: ...@@ -412,7 +412,17 @@ class ModelDeploymentCard:
"""Deserialize a model deployment card from a JSON string.""" """Deserialize a model deployment card from a JSON string."""
... ...
... def model_type(self) -> ModelType:
"""Return the model type of this deployment card."""
...
def source_path(self) -> str:
"""Return the source path of this deployment card."""
...
def runtime_config(self) -> Any:
"""Return the runtime configuration as a dict."""
...
class ModelRuntimeConfig: class ModelRuntimeConfig:
""" """
...@@ -927,7 +937,10 @@ class ModelType: ...@@ -927,7 +937,10 @@ class ModelType:
Images: ModelType Images: ModelType
Audios: ModelType Audios: ModelType
Videos: ModelType Videos: ModelType
...
def supports_chat(self) -> bool:
"""Return True if this model type supports chat."""
...
class RouterMode: class RouterMode:
"""Router mode for load balancing requests across workers""" """Router mode for load balancing requests across workers"""
...@@ -939,6 +952,8 @@ class RouterMode: ...@@ -939,6 +952,8 @@ class RouterMode:
class RouterConfig: class RouterConfig:
"""How to route the request""" """How to route the request"""
router_mode: RouterMode
kv_router_config: KvRouterConfig
def __init__( def __init__(
self, self,
...@@ -982,6 +997,7 @@ class KvRouterConfig: ...@@ -982,6 +997,7 @@ class KvRouterConfig:
router_prune_target_ratio: float = 0.8, router_prune_target_ratio: float = 0.8,
router_queue_threshold: Optional[float] = None, router_queue_threshold: Optional[float] = None,
router_event_threads: int = 4, router_event_threads: int = 4,
router_enable_cache_control: bool = False,
) -> None: ) -> None:
""" """
Create a KV router configuration. Create a KV router configuration.
...@@ -1012,6 +1028,8 @@ class KvRouterConfig: ...@@ -1012,6 +1028,8 @@ class KvRouterConfig:
If None, queueing is disabled and all requests go directly to the scheduler. If None, queueing is disabled and all requests go directly to the scheduler.
router_event_threads: Number of event processing threads (default: 4). router_event_threads: Number of event processing threads (default: 4).
When > 1, uses a concurrent radix tree with a thread pool. When > 1, uses a concurrent radix tree with a thread pool.
router_enable_cache_control: Enable cache control (PIN with TTL) via the worker's
cache_control service mesh endpoint (default: False).
""" """
... ...
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment