Unverified Commit 6398d453 authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

chore: add mypy typing to sglang (#6859)

parent 22f1ab15
......@@ -3,6 +3,7 @@
"""Dynamo SGLang wrapper configuration ArgGroup."""
import argparse
from typing import Optional
from dynamo.common.configuration.arg_group import ArgGroup
......@@ -17,7 +18,7 @@ class DynamoSGLangArgGroup(ArgGroup):
name = "dynamo-sglang"
def add_arguments(self, parser) -> None:
def add_arguments(self, parser: argparse.ArgumentParser) -> None:
"""Add Dynamo SGLang arguments to parser."""
parser.add_argument(
......
......@@ -36,7 +36,7 @@ async def init_llm_diffusion(
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
):
) -> None:
"""Initialize diffusion language model worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
......@@ -120,7 +120,7 @@ async def init_image_diffusion(
config: Config,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
):
) -> None:
"""Initialize image diffusion worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
......@@ -206,7 +206,7 @@ async def init_video_diffusion(
config: Config,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
):
) -> None:
"""Initialize video generation worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
......
......@@ -22,7 +22,7 @@ async def init_embedding(
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
):
) -> None:
"""Initialize embedding worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
......
......@@ -62,7 +62,7 @@ async def init_decode(
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
checkpoint_restore_engine: Optional[sgl.Engine] = None,
):
) -> None:
server_args, dynamo_args = config.server_args, config.dynamo_args
if server_args.node_rank >= 1:
......@@ -152,7 +152,7 @@ async def init_prefill(
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
checkpoint_restore_engine: Optional[sgl.Engine] = None,
):
) -> None:
server_args, dynamo_args = config.server_args, config.dynamo_args
if server_args.node_rank >= 1:
......
......@@ -31,7 +31,7 @@ async def init_multimodal_processor(
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
):
) -> None:
"""Initialize multimodal processor component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
generate_endpoint = runtime.endpoint(
......@@ -86,7 +86,7 @@ async def init_multimodal_encode_worker(
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
):
) -> None:
"""Initialize multimodal encode worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
......@@ -130,7 +130,7 @@ async def init_multimodal_worker(
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
):
) -> None:
"""Initialize multimodal worker component.
This worker is always an internal component that should not register with
......@@ -185,7 +185,7 @@ async def init_multimodal_prefill_worker(
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
):
) -> None:
"""Initialize multimodal prefill worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
......
......@@ -2,13 +2,16 @@
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Any, Dict, Tuple
from sglang.srt.parser.conversation import chat_templates
logger = logging.getLogger(__name__)
def multimodal_request_to_sglang(raw_request, tokenizer, chat_template):
def multimodal_request_to_sglang(
raw_request: Any, tokenizer: Any, chat_template: str
) -> Dict[str, Any]:
conv = chat_templates[chat_template].copy()
conv.messages = []
......@@ -48,7 +51,7 @@ def multimodal_request_to_sglang(raw_request, tokenizer, chat_template):
return sglang_request
def detokenize_sglang_response(response_data, tokenizer):
def detokenize_sglang_response(response_data: Any, tokenizer: Any) -> str:
"""
Detokenize SGLang response token IDs to text.
......@@ -106,7 +109,9 @@ def detokenize_sglang_response(response_data, tokenizer):
return f"[Detokenization error: {e}]"
def process_sglang_stream_response(response_data, tokenizer, accumulated_text=""):
def process_sglang_stream_response(
response_data: Any, tokenizer: Any, accumulated_text: str = ""
) -> Tuple[str, str, bool]:
"""
Process a single SGLang streaming response with efficient detokenization.
......
......@@ -94,13 +94,14 @@ class DynamoSglangPublisher:
# Non-leader nodes don't receive scheduler metrics via this socket - they only
# need KV event publishing which is set up separately in init_kv_event_publish()
node_rank = getattr(self.server_args, "node_rank", 0) or 0
self._ctx: zmq.asyncio.Context | None = None
if node_rank == 0:
self._ctx = zmq.asyncio.Context() # type: ignore
self._ctx = zmq.asyncio.Context()
self._sock = get_zmq_socket(
self._ctx,
zmq.PULL,
self.engine.port_args.metrics_ipc_name,
True, # type: ignore
True,
)
else:
self._ctx = None
......
......@@ -21,8 +21,8 @@ async def _register_model_with_runtime_config(
endpoint: Endpoint,
server_args: ServerArgs,
dynamo_args: DynamoConfig,
input_type: Optional[ModelInput] = ModelInput.Tokens,
output_type: Optional[ModelType] = ModelType.Chat | ModelType.Completions,
input_type: ModelInput = ModelInput.Tokens,
output_type: ModelType = ModelType.Chat | ModelType.Completions,
) -> bool:
"""Register LLM with the Dynamo runtime.
......@@ -38,7 +38,6 @@ async def _register_model_with_runtime_config(
True if registration succeeded, False otherwise.
"""
runtime_config = await _get_runtime_config(engine, server_args, dynamo_args)
input_type = input_type
if not server_args.skip_tokenizer_init:
logging.warning(
......@@ -134,6 +133,7 @@ def _get_bootstrap_info_for_config(
)
# Wrap IPv6 literal with brackets so f"{host}:{port}" stays valid.
assert isinstance(bootstrap_host, str)
if ":" in bootstrap_host and not bootstrap_host.startswith("["):
bootstrap_host = f"[{bootstrap_host}]"
logging.info(f"Wrapped IPv6 address with brackets: {bootstrap_host}")
......@@ -237,8 +237,8 @@ async def register_model_with_readiness_gate(
generate_endpoint: Endpoint,
server_args: ServerArgs,
dynamo_args: DynamoConfig,
input_type: Optional[ModelInput] = ModelInput.Tokens,
output_type: Optional[ModelType] = ModelType.Chat | ModelType.Completions,
input_type: ModelInput = ModelInput.Tokens,
output_type: ModelType = ModelType.Chat | ModelType.Completions,
readiness_gate: Optional[asyncio.Event] = None,
) -> None:
"""Wrapper function to register LLM with the Dynamo runtime and use optional readiness gate to signal success.
......
......@@ -3,7 +3,8 @@
import asyncio
import logging
from typing import Optional
from collections.abc import AsyncGenerator
from typing import Any, Dict, Optional
import sglang as sgl
......@@ -25,12 +26,14 @@ class EmbeddingWorkerHandler(BaseWorkerHandler):
super().__init__(engine, config, publisher, None, shutdown_event)
logging.info("Embedding worker handler initialized")
def cleanup(self):
def cleanup(self) -> None:
super().cleanup()
self.engine.shutdown()
logging.info("Engine shutdown")
async def generate(self, request: dict, context: Context):
async def generate(
self, request: dict, context: Context
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Generate embeddings for the given input.
......@@ -44,6 +47,7 @@ class EmbeddingWorkerHandler(BaseWorkerHandler):
embedding_request = EmbeddingRequest(**request)
# Handle different input types
prompt: str | list[Any]
if isinstance(embedding_request.input, str):
prompt = embedding_request.input
elif isinstance(embedding_request.input, list):
......
......@@ -15,6 +15,7 @@ from sglang.srt.utils import get_local_ip_auto
from dynamo._core import Context
from dynamo.common.utils.input_params import InputParamManager
from dynamo.runtime import DistributedRuntime
from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher
......@@ -371,7 +372,7 @@ class BaseWorkerHandler(BaseGenerativeHandler):
result = {"status": "error", "message": f"Unknown action: {action}"}
yield result
def register_engine_routes(self, runtime) -> None:
def register_engine_routes(self, runtime: DistributedRuntime) -> None:
"""Register all engine routes for this handler.
Args:
......@@ -483,6 +484,7 @@ class BaseWorkerHandler(BaseGenerativeHandler):
bootstrap_host = get_local_ip_auto()
# Wrap IPv6 literal with brackets so f"{host}:{port}" stays valid.
assert isinstance(bootstrap_host, str)
if ":" in bootstrap_host and not bootstrap_host.startswith("["):
bootstrap_host = f"[{bootstrap_host}]"
......@@ -515,7 +517,7 @@ class BaseWorkerHandler(BaseGenerativeHandler):
cancellation_future = context.async_killed_or_stopped()
# Build list of futures/tasks to wait for
wait_for = [cancellation_future]
wait_for: list[asyncio.Future[Any]] = [cancellation_future]
shutdown_task = None
if self.shutdown_event:
......
......@@ -109,13 +109,14 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
seed=nvext.seed,
)
user_id = req.user if req.user else context.id()
context_id = context.id()
assert context_id is not None
user_id = req.user or context_id
image_data = []
for img in images:
# uploading or encoding the image
if req.response_format == "url":
url = await self._upload_to_fs(img, user_id, context.id())
url = await self._upload_to_fs(img, user_id, context_id)
image_data.append(ImageData(url=url))
else:
b64 = self._encode_base64(img)
......
......@@ -160,7 +160,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
else:
# Extract image URLs for multimodal requests. SGLang's mm_data_processor
# handles loading/preprocessing, and the scheduler does vision encoding.
image_data = None
image_data: list[str] | None = None
image_items = request.get("multi_modal_data", {}).get("image_url")
if image_items:
image_data = []
......@@ -215,7 +215,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
Dict with token_ids and optional finish_reason.
"""
# Use Future pattern for request ID - will be set when first response arrives
request_id_future = asyncio.Future()
request_id_future: asyncio.Future[str] = asyncio.Future()
async with self._cancellation_monitor(request_id_future, context):
async for res in stream_source:
# Extract SGLang request ID from the first response and set the future
......@@ -289,7 +289,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
count = 0
# Use Future pattern for request ID - will be set when first response arrives
request_id_future = asyncio.Future()
request_id_future: asyncio.Future[str] = asyncio.Future()
async with self._cancellation_monitor(request_id_future, context):
async for res in stream_source:
# Extract SGLang request ID from the first response and set the future
......
......@@ -36,7 +36,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
self.engine = engine
self.bootstrap_host, self.bootstrap_port = self._get_bootstrap_info(self.engine)
super().__init__(engine, config, publisher, generate_endpoint, shutdown_event)
self._consume_tasks = set()
self._consume_tasks: set[asyncio.Task[Any]] = set()
logging.info(
f"Prefill worker handler initialized - bootstrap host: {self.bootstrap_host}, bootstrap port: {self.bootstrap_port}"
)
......@@ -146,7 +146,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
context: Context object for cancellation handling.
"""
# Use Future pattern for request ID - will be set when first response arrives
request_id_future = asyncio.Future()
request_id_future: asyncio.Future[str] = asyncio.Future()
async with self._cancellation_monitor(request_id_future, context):
async for res in results:
# Extract SGLang request ID from the first response and set the future
......
......@@ -93,7 +93,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
self.min_workers = 1
def cleanup(self):
def cleanup(self) -> None:
pass
async def generate(
......@@ -262,7 +262,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
logger.error(f"Error processing request: {e}")
raise
async def async_init(self, runtime: DistributedRuntime):
async def async_init(self, runtime: DistributedRuntime) -> None:
logger.info("Startup started.")
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
......
......@@ -4,7 +4,7 @@
import asyncio
import json
import logging
from typing import AsyncIterator, Optional
from typing import Any, AsyncIterator, Optional
import sglang as sgl
import torch
......@@ -127,7 +127,7 @@ class EmbeddingsProcessor:
"""
precomputed = embeddings.to(MultimodalConfig.EMBEDDINGS_DTYPE)
mm_item = {"image_grid_thw": torch.tensor(image_grid_thw)}
mm_item: dict[str, Any] = {"image_grid_thw": torch.tensor(image_grid_thw)}
mm_item.update(
{
"format": "processor_output",
......@@ -271,7 +271,7 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
self,
engine: sgl.Engine,
config: Config,
prefill_client: Client = None,
prefill_client: Client | None = None,
shutdown_event: Optional[asyncio.Event] = None,
):
super().__init__(engine, config, None, None, shutdown_event)
......@@ -413,6 +413,7 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
self, request: SglangMultimodalRequest, sampling_params: dict
) -> dict:
"""Get bootstrap info from prefill worker"""
assert self.prefill_client is not None
prefill_stream = await self.prefill_client.generate(
DisaggSglangMultimodalRequest(
request=request,
......@@ -491,7 +492,7 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
"bootstrap_room": bootstrap_room,
}
yield bootstrap_info
yield json.dumps(bootstrap_info)
# Process prefill generation
await self._process_prefill_generation(disagg_request, bootstrap_room)
......
......@@ -111,6 +111,8 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler):
num_frames = nvext.fps * req.seconds
# Generate video
context_id = context.id()
assert context_id is not None
video_bytes = await self._generate_video(
prompt=req.prompt,
width=width,
......@@ -120,14 +122,14 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler):
num_inference_steps=nvext.num_inference_steps,
guidance_scale=nvext.guidance_scale,
seed=nvext.seed,
request_id=context.id(),
request_id=context_id,
negative_prompt=nvext.negative_prompt,
input_reference=req.input_reference,
)
video_data = []
if req.response_format == "url":
url = await self._upload_to_fs(video_bytes, context.id())
url = await self._upload_to_fs(video_bytes, context_id)
video_data.append(VideoData(url=url))
else: # b64_json
b64 = self._encode_base64(video_bytes)
......@@ -269,13 +271,13 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler):
output_buffer = io.BytesIO()
with imageio.get_writer(
output_buffer,
format="mp4",
format="mp4", # type: ignore
fps=fps,
codec=codec,
output_params=["-pix_fmt", "yuv420p"],
) as writer:
for frame in np_frames:
writer.append_data(frame)
writer.append_data(frame) # type: ignore
output_buffer.seek(0)
return output_buffer.read()
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
from typing import (
Any,
AsyncGenerator,
......@@ -226,6 +227,17 @@ class Client:
"""
...
async def generate(
self,
request: JsonLike,
annotated: bool | None = True,
context: Context | None = None,
) -> AsyncIterator[JsonLike]:
"""
Generate a response from the endpoint
"""
...
class ModelCardInstanceId:
"""
......@@ -328,7 +340,7 @@ class Context:
"""
...
async def async_killed_or_stopped(self) -> bool:
def async_killed_or_stopped(self) -> asyncio.Future[bool]:
"""
Asynchronously wait until the context is killed or stopped.
......@@ -443,6 +455,8 @@ class ModelRuntimeConfig:
enable_local_indexer: bool
runtime_data: dict[str, Any]
tensor_model_config: Any | None
data_parallel_size: int
data_parallel_start_rank: int
def __init__(self) -> None: ...
......@@ -454,6 +468,14 @@ class ModelRuntimeConfig:
"""Get an engine-specific runtime configuration value"""
...
def set_disaggregated_endpoint(
self,
bootstrap_host: str | None = None,
bootstrap_port: int | None = None,
) -> None:
"""Set the disaggregated endpoint for the model"""
...
class OverlapScores:
"""
A collection of prefix matching scores of workers for a given token ids.
......@@ -931,7 +953,10 @@ class KserveGrpcService:
class ModelInput:
"""What type of request this model needs: Text, Tokens or Tensor"""
...
Text: ModelInput
Tokens: ModelInput
Tensor: ModelInput
class ModelType:
"""What type of request this model needs: Chat, Completions, Embedding, Tensor, Images, Videos or Prefill"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment