Unverified Commit 6398d453 authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

chore: add mypy typing to sglang (#6859)

parent 22f1ab15
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
"""Dynamo SGLang wrapper configuration ArgGroup.""" """Dynamo SGLang wrapper configuration ArgGroup."""
import argparse
from typing import Optional from typing import Optional
from dynamo.common.configuration.arg_group import ArgGroup from dynamo.common.configuration.arg_group import ArgGroup
...@@ -17,7 +18,7 @@ class DynamoSGLangArgGroup(ArgGroup): ...@@ -17,7 +18,7 @@ class DynamoSGLangArgGroup(ArgGroup):
name = "dynamo-sglang" name = "dynamo-sglang"
def add_arguments(self, parser) -> None: def add_arguments(self, parser: argparse.ArgumentParser) -> None:
"""Add Dynamo SGLang arguments to parser.""" """Add Dynamo SGLang arguments to parser."""
parser.add_argument( parser.add_argument(
......
...@@ -36,7 +36,7 @@ async def init_llm_diffusion( ...@@ -36,7 +36,7 @@ async def init_llm_diffusion(
shutdown_event: asyncio.Event, shutdown_event: asyncio.Event,
shutdown_endpoints: list, shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None, run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
): ) -> None:
"""Initialize diffusion language model worker component""" """Initialize diffusion language model worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
...@@ -120,7 +120,7 @@ async def init_image_diffusion( ...@@ -120,7 +120,7 @@ async def init_image_diffusion(
config: Config, config: Config,
shutdown_endpoints: list, shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None, run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
): ) -> None:
"""Initialize image diffusion worker component""" """Initialize image diffusion worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
...@@ -206,7 +206,7 @@ async def init_video_diffusion( ...@@ -206,7 +206,7 @@ async def init_video_diffusion(
config: Config, config: Config,
shutdown_endpoints: list, shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None, run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
): ) -> None:
"""Initialize video generation worker component""" """Initialize video generation worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
......
...@@ -22,7 +22,7 @@ async def init_embedding( ...@@ -22,7 +22,7 @@ async def init_embedding(
shutdown_event: asyncio.Event, shutdown_event: asyncio.Event,
shutdown_endpoints: list, shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None, run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
): ) -> None:
"""Initialize embedding worker component""" """Initialize embedding worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
......
...@@ -62,7 +62,7 @@ async def init_decode( ...@@ -62,7 +62,7 @@ async def init_decode(
shutdown_endpoints: list, shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None, run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
checkpoint_restore_engine: Optional[sgl.Engine] = None, checkpoint_restore_engine: Optional[sgl.Engine] = None,
): ) -> None:
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
if server_args.node_rank >= 1: if server_args.node_rank >= 1:
...@@ -152,7 +152,7 @@ async def init_prefill( ...@@ -152,7 +152,7 @@ async def init_prefill(
shutdown_endpoints: list, shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None, run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
checkpoint_restore_engine: Optional[sgl.Engine] = None, checkpoint_restore_engine: Optional[sgl.Engine] = None,
): ) -> None:
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
if server_args.node_rank >= 1: if server_args.node_rank >= 1:
......
...@@ -31,7 +31,7 @@ async def init_multimodal_processor( ...@@ -31,7 +31,7 @@ async def init_multimodal_processor(
shutdown_event: asyncio.Event, shutdown_event: asyncio.Event,
shutdown_endpoints: list, shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None, run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
): ) -> None:
"""Initialize multimodal processor component""" """Initialize multimodal processor component"""
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
generate_endpoint = runtime.endpoint( generate_endpoint = runtime.endpoint(
...@@ -86,7 +86,7 @@ async def init_multimodal_encode_worker( ...@@ -86,7 +86,7 @@ async def init_multimodal_encode_worker(
shutdown_event: asyncio.Event, shutdown_event: asyncio.Event,
shutdown_endpoints: list, shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None, run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
): ) -> None:
"""Initialize multimodal encode worker component""" """Initialize multimodal encode worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
...@@ -130,7 +130,7 @@ async def init_multimodal_worker( ...@@ -130,7 +130,7 @@ async def init_multimodal_worker(
shutdown_event: asyncio.Event, shutdown_event: asyncio.Event,
shutdown_endpoints: list, shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None, run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
): ) -> None:
"""Initialize multimodal worker component. """Initialize multimodal worker component.
This worker is always an internal component that should not register with This worker is always an internal component that should not register with
...@@ -185,7 +185,7 @@ async def init_multimodal_prefill_worker( ...@@ -185,7 +185,7 @@ async def init_multimodal_prefill_worker(
shutdown_event: asyncio.Event, shutdown_event: asyncio.Event,
shutdown_endpoints: list, shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None, run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
): ) -> None:
"""Initialize multimodal prefill worker component""" """Initialize multimodal prefill worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
......
...@@ -2,13 +2,16 @@ ...@@ -2,13 +2,16 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging import logging
from typing import Any, Dict, Tuple
from sglang.srt.parser.conversation import chat_templates from sglang.srt.parser.conversation import chat_templates
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def multimodal_request_to_sglang(raw_request, tokenizer, chat_template): def multimodal_request_to_sglang(
raw_request: Any, tokenizer: Any, chat_template: str
) -> Dict[str, Any]:
conv = chat_templates[chat_template].copy() conv = chat_templates[chat_template].copy()
conv.messages = [] conv.messages = []
...@@ -48,7 +51,7 @@ def multimodal_request_to_sglang(raw_request, tokenizer, chat_template): ...@@ -48,7 +51,7 @@ def multimodal_request_to_sglang(raw_request, tokenizer, chat_template):
return sglang_request return sglang_request
def detokenize_sglang_response(response_data, tokenizer): def detokenize_sglang_response(response_data: Any, tokenizer: Any) -> str:
""" """
Detokenize SGLang response token IDs to text. Detokenize SGLang response token IDs to text.
...@@ -106,7 +109,9 @@ def detokenize_sglang_response(response_data, tokenizer): ...@@ -106,7 +109,9 @@ def detokenize_sglang_response(response_data, tokenizer):
return f"[Detokenization error: {e}]" return f"[Detokenization error: {e}]"
def process_sglang_stream_response(response_data, tokenizer, accumulated_text=""): def process_sglang_stream_response(
response_data: Any, tokenizer: Any, accumulated_text: str = ""
) -> Tuple[str, str, bool]:
""" """
Process a single SGLang streaming response with efficient detokenization. Process a single SGLang streaming response with efficient detokenization.
......
...@@ -94,13 +94,14 @@ class DynamoSglangPublisher: ...@@ -94,13 +94,14 @@ class DynamoSglangPublisher:
# Non-leader nodes don't receive scheduler metrics via this socket - they only # Non-leader nodes don't receive scheduler metrics via this socket - they only
# need KV event publishing which is set up separately in init_kv_event_publish() # need KV event publishing which is set up separately in init_kv_event_publish()
node_rank = getattr(self.server_args, "node_rank", 0) or 0 node_rank = getattr(self.server_args, "node_rank", 0) or 0
self._ctx: zmq.asyncio.Context | None = None
if node_rank == 0: if node_rank == 0:
self._ctx = zmq.asyncio.Context() # type: ignore self._ctx = zmq.asyncio.Context()
self._sock = get_zmq_socket( self._sock = get_zmq_socket(
self._ctx, self._ctx,
zmq.PULL, zmq.PULL,
self.engine.port_args.metrics_ipc_name, self.engine.port_args.metrics_ipc_name,
True, # type: ignore True,
) )
else: else:
self._ctx = None self._ctx = None
......
...@@ -21,8 +21,8 @@ async def _register_model_with_runtime_config( ...@@ -21,8 +21,8 @@ async def _register_model_with_runtime_config(
endpoint: Endpoint, endpoint: Endpoint,
server_args: ServerArgs, server_args: ServerArgs,
dynamo_args: DynamoConfig, dynamo_args: DynamoConfig,
input_type: Optional[ModelInput] = ModelInput.Tokens, input_type: ModelInput = ModelInput.Tokens,
output_type: Optional[ModelType] = ModelType.Chat | ModelType.Completions, output_type: ModelType = ModelType.Chat | ModelType.Completions,
) -> bool: ) -> bool:
"""Register LLM with the Dynamo runtime. """Register LLM with the Dynamo runtime.
...@@ -38,7 +38,6 @@ async def _register_model_with_runtime_config( ...@@ -38,7 +38,6 @@ async def _register_model_with_runtime_config(
True if registration succeeded, False otherwise. True if registration succeeded, False otherwise.
""" """
runtime_config = await _get_runtime_config(engine, server_args, dynamo_args) runtime_config = await _get_runtime_config(engine, server_args, dynamo_args)
input_type = input_type
if not server_args.skip_tokenizer_init: if not server_args.skip_tokenizer_init:
logging.warning( logging.warning(
...@@ -134,6 +133,7 @@ def _get_bootstrap_info_for_config( ...@@ -134,6 +133,7 @@ def _get_bootstrap_info_for_config(
) )
# Wrap IPv6 literal with brackets so f"{host}:{port}" stays valid. # Wrap IPv6 literal with brackets so f"{host}:{port}" stays valid.
assert isinstance(bootstrap_host, str)
if ":" in bootstrap_host and not bootstrap_host.startswith("["): if ":" in bootstrap_host and not bootstrap_host.startswith("["):
bootstrap_host = f"[{bootstrap_host}]" bootstrap_host = f"[{bootstrap_host}]"
logging.info(f"Wrapped IPv6 address with brackets: {bootstrap_host}") logging.info(f"Wrapped IPv6 address with brackets: {bootstrap_host}")
...@@ -237,8 +237,8 @@ async def register_model_with_readiness_gate( ...@@ -237,8 +237,8 @@ async def register_model_with_readiness_gate(
generate_endpoint: Endpoint, generate_endpoint: Endpoint,
server_args: ServerArgs, server_args: ServerArgs,
dynamo_args: DynamoConfig, dynamo_args: DynamoConfig,
input_type: Optional[ModelInput] = ModelInput.Tokens, input_type: ModelInput = ModelInput.Tokens,
output_type: Optional[ModelType] = ModelType.Chat | ModelType.Completions, output_type: ModelType = ModelType.Chat | ModelType.Completions,
readiness_gate: Optional[asyncio.Event] = None, readiness_gate: Optional[asyncio.Event] = None,
) -> None: ) -> None:
"""Wrapper function to register LLM with the Dynamo runtime and use optional readiness gate to signal success. """Wrapper function to register LLM with the Dynamo runtime and use optional readiness gate to signal success.
......
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
import asyncio import asyncio
import logging import logging
from typing import Optional from collections.abc import AsyncGenerator
from typing import Any, Dict, Optional
import sglang as sgl import sglang as sgl
...@@ -25,12 +26,14 @@ class EmbeddingWorkerHandler(BaseWorkerHandler): ...@@ -25,12 +26,14 @@ class EmbeddingWorkerHandler(BaseWorkerHandler):
super().__init__(engine, config, publisher, None, shutdown_event) super().__init__(engine, config, publisher, None, shutdown_event)
logging.info("Embedding worker handler initialized") logging.info("Embedding worker handler initialized")
def cleanup(self): def cleanup(self) -> None:
super().cleanup() super().cleanup()
self.engine.shutdown() self.engine.shutdown()
logging.info("Engine shutdown") logging.info("Engine shutdown")
async def generate(self, request: dict, context: Context): async def generate(
self, request: dict, context: Context
) -> AsyncGenerator[Dict[str, Any], None]:
""" """
Generate embeddings for the given input. Generate embeddings for the given input.
...@@ -44,6 +47,7 @@ class EmbeddingWorkerHandler(BaseWorkerHandler): ...@@ -44,6 +47,7 @@ class EmbeddingWorkerHandler(BaseWorkerHandler):
embedding_request = EmbeddingRequest(**request) embedding_request = EmbeddingRequest(**request)
# Handle different input types # Handle different input types
prompt: str | list[Any]
if isinstance(embedding_request.input, str): if isinstance(embedding_request.input, str):
prompt = embedding_request.input prompt = embedding_request.input
elif isinstance(embedding_request.input, list): elif isinstance(embedding_request.input, list):
......
...@@ -15,6 +15,7 @@ from sglang.srt.utils import get_local_ip_auto ...@@ -15,6 +15,7 @@ from sglang.srt.utils import get_local_ip_auto
from dynamo._core import Context from dynamo._core import Context
from dynamo.common.utils.input_params import InputParamManager from dynamo.common.utils.input_params import InputParamManager
from dynamo.runtime import DistributedRuntime
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.publisher import DynamoSglangPublisher
...@@ -371,7 +372,7 @@ class BaseWorkerHandler(BaseGenerativeHandler): ...@@ -371,7 +372,7 @@ class BaseWorkerHandler(BaseGenerativeHandler):
result = {"status": "error", "message": f"Unknown action: {action}"} result = {"status": "error", "message": f"Unknown action: {action}"}
yield result yield result
def register_engine_routes(self, runtime) -> None: def register_engine_routes(self, runtime: DistributedRuntime) -> None:
"""Register all engine routes for this handler. """Register all engine routes for this handler.
Args: Args:
...@@ -483,6 +484,7 @@ class BaseWorkerHandler(BaseGenerativeHandler): ...@@ -483,6 +484,7 @@ class BaseWorkerHandler(BaseGenerativeHandler):
bootstrap_host = get_local_ip_auto() bootstrap_host = get_local_ip_auto()
# Wrap IPv6 literal with brackets so f"{host}:{port}" stays valid. # Wrap IPv6 literal with brackets so f"{host}:{port}" stays valid.
assert isinstance(bootstrap_host, str)
if ":" in bootstrap_host and not bootstrap_host.startswith("["): if ":" in bootstrap_host and not bootstrap_host.startswith("["):
bootstrap_host = f"[{bootstrap_host}]" bootstrap_host = f"[{bootstrap_host}]"
...@@ -515,7 +517,7 @@ class BaseWorkerHandler(BaseGenerativeHandler): ...@@ -515,7 +517,7 @@ class BaseWorkerHandler(BaseGenerativeHandler):
cancellation_future = context.async_killed_or_stopped() cancellation_future = context.async_killed_or_stopped()
# Build list of futures/tasks to wait for # Build list of futures/tasks to wait for
wait_for = [cancellation_future] wait_for: list[asyncio.Future[Any]] = [cancellation_future]
shutdown_task = None shutdown_task = None
if self.shutdown_event: if self.shutdown_event:
......
...@@ -109,13 +109,14 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler): ...@@ -109,13 +109,14 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
seed=nvext.seed, seed=nvext.seed,
) )
user_id = req.user if req.user else context.id() context_id = context.id()
assert context_id is not None
user_id = req.user or context_id
image_data = [] image_data = []
for img in images: for img in images:
# uploading or encoding the image # uploading or encoding the image
if req.response_format == "url": if req.response_format == "url":
url = await self._upload_to_fs(img, user_id, context.id()) url = await self._upload_to_fs(img, user_id, context_id)
image_data.append(ImageData(url=url)) image_data.append(ImageData(url=url))
else: else:
b64 = self._encode_base64(img) b64 = self._encode_base64(img)
......
...@@ -160,7 +160,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -160,7 +160,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
else: else:
# Extract image URLs for multimodal requests. SGLang's mm_data_processor # Extract image URLs for multimodal requests. SGLang's mm_data_processor
# handles loading/preprocessing, and the scheduler does vision encoding. # handles loading/preprocessing, and the scheduler does vision encoding.
image_data = None image_data: list[str] | None = None
image_items = request.get("multi_modal_data", {}).get("image_url") image_items = request.get("multi_modal_data", {}).get("image_url")
if image_items: if image_items:
image_data = [] image_data = []
...@@ -215,7 +215,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -215,7 +215,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
Dict with token_ids and optional finish_reason. Dict with token_ids and optional finish_reason.
""" """
# Use Future pattern for request ID - will be set when first response arrives # Use Future pattern for request ID - will be set when first response arrives
request_id_future = asyncio.Future() request_id_future: asyncio.Future[str] = asyncio.Future()
async with self._cancellation_monitor(request_id_future, context): async with self._cancellation_monitor(request_id_future, context):
async for res in stream_source: async for res in stream_source:
# Extract SGLang request ID from the first response and set the future # Extract SGLang request ID from the first response and set the future
...@@ -289,7 +289,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -289,7 +289,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
count = 0 count = 0
# Use Future pattern for request ID - will be set when first response arrives # Use Future pattern for request ID - will be set when first response arrives
request_id_future = asyncio.Future() request_id_future: asyncio.Future[str] = asyncio.Future()
async with self._cancellation_monitor(request_id_future, context): async with self._cancellation_monitor(request_id_future, context):
async for res in stream_source: async for res in stream_source:
# Extract SGLang request ID from the first response and set the future # Extract SGLang request ID from the first response and set the future
......
...@@ -36,7 +36,7 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -36,7 +36,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
self.engine = engine self.engine = engine
self.bootstrap_host, self.bootstrap_port = self._get_bootstrap_info(self.engine) self.bootstrap_host, self.bootstrap_port = self._get_bootstrap_info(self.engine)
super().__init__(engine, config, publisher, generate_endpoint, shutdown_event) super().__init__(engine, config, publisher, generate_endpoint, shutdown_event)
self._consume_tasks = set() self._consume_tasks: set[asyncio.Task[Any]] = set()
logging.info( logging.info(
f"Prefill worker handler initialized - bootstrap host: {self.bootstrap_host}, bootstrap port: {self.bootstrap_port}" f"Prefill worker handler initialized - bootstrap host: {self.bootstrap_host}, bootstrap port: {self.bootstrap_port}"
) )
...@@ -146,7 +146,7 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -146,7 +146,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
context: Context object for cancellation handling. context: Context object for cancellation handling.
""" """
# Use Future pattern for request ID - will be set when first response arrives # Use Future pattern for request ID - will be set when first response arrives
request_id_future = asyncio.Future() request_id_future: asyncio.Future[str] = asyncio.Future()
async with self._cancellation_monitor(request_id_future, context): async with self._cancellation_monitor(request_id_future, context):
async for res in results: async for res in results:
# Extract SGLang request ID from the first response and set the future # Extract SGLang request ID from the first response and set the future
......
...@@ -93,7 +93,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler): ...@@ -93,7 +93,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
self.min_workers = 1 self.min_workers = 1
def cleanup(self): def cleanup(self) -> None:
pass pass
async def generate( async def generate(
...@@ -262,7 +262,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler): ...@@ -262,7 +262,7 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
logger.error(f"Error processing request: {e}") logger.error(f"Error processing request: {e}")
raise raise
async def async_init(self, runtime: DistributedRuntime): async def async_init(self, runtime: DistributedRuntime) -> None:
logger.info("Startup started.") logger.info("Startup started.")
# Create and initialize a dynamo connector for this worker. # Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently. # We'll needs this to move data between this worker and remote workers efficiently.
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import asyncio import asyncio
import json import json
import logging import logging
from typing import AsyncIterator, Optional from typing import Any, AsyncIterator, Optional
import sglang as sgl import sglang as sgl
import torch import torch
...@@ -127,7 +127,7 @@ class EmbeddingsProcessor: ...@@ -127,7 +127,7 @@ class EmbeddingsProcessor:
""" """
precomputed = embeddings.to(MultimodalConfig.EMBEDDINGS_DTYPE) precomputed = embeddings.to(MultimodalConfig.EMBEDDINGS_DTYPE)
mm_item = {"image_grid_thw": torch.tensor(image_grid_thw)} mm_item: dict[str, Any] = {"image_grid_thw": torch.tensor(image_grid_thw)}
mm_item.update( mm_item.update(
{ {
"format": "processor_output", "format": "processor_output",
...@@ -271,7 +271,7 @@ class MultimodalWorkerHandler(BaseWorkerHandler): ...@@ -271,7 +271,7 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
self, self,
engine: sgl.Engine, engine: sgl.Engine,
config: Config, config: Config,
prefill_client: Client = None, prefill_client: Client | None = None,
shutdown_event: Optional[asyncio.Event] = None, shutdown_event: Optional[asyncio.Event] = None,
): ):
super().__init__(engine, config, None, None, shutdown_event) super().__init__(engine, config, None, None, shutdown_event)
...@@ -413,6 +413,7 @@ class MultimodalWorkerHandler(BaseWorkerHandler): ...@@ -413,6 +413,7 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
self, request: SglangMultimodalRequest, sampling_params: dict self, request: SglangMultimodalRequest, sampling_params: dict
) -> dict: ) -> dict:
"""Get bootstrap info from prefill worker""" """Get bootstrap info from prefill worker"""
assert self.prefill_client is not None
prefill_stream = await self.prefill_client.generate( prefill_stream = await self.prefill_client.generate(
DisaggSglangMultimodalRequest( DisaggSglangMultimodalRequest(
request=request, request=request,
...@@ -491,7 +492,7 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler): ...@@ -491,7 +492,7 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
"bootstrap_room": bootstrap_room, "bootstrap_room": bootstrap_room,
} }
yield bootstrap_info yield json.dumps(bootstrap_info)
# Process prefill generation # Process prefill generation
await self._process_prefill_generation(disagg_request, bootstrap_room) await self._process_prefill_generation(disagg_request, bootstrap_room)
......
...@@ -111,6 +111,8 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler): ...@@ -111,6 +111,8 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler):
num_frames = nvext.fps * req.seconds num_frames = nvext.fps * req.seconds
# Generate video # Generate video
context_id = context.id()
assert context_id is not None
video_bytes = await self._generate_video( video_bytes = await self._generate_video(
prompt=req.prompt, prompt=req.prompt,
width=width, width=width,
...@@ -120,14 +122,14 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler): ...@@ -120,14 +122,14 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler):
num_inference_steps=nvext.num_inference_steps, num_inference_steps=nvext.num_inference_steps,
guidance_scale=nvext.guidance_scale, guidance_scale=nvext.guidance_scale,
seed=nvext.seed, seed=nvext.seed,
request_id=context.id(), request_id=context_id,
negative_prompt=nvext.negative_prompt, negative_prompt=nvext.negative_prompt,
input_reference=req.input_reference, input_reference=req.input_reference,
) )
video_data = [] video_data = []
if req.response_format == "url": if req.response_format == "url":
url = await self._upload_to_fs(video_bytes, context.id()) url = await self._upload_to_fs(video_bytes, context_id)
video_data.append(VideoData(url=url)) video_data.append(VideoData(url=url))
else: # b64_json else: # b64_json
b64 = self._encode_base64(video_bytes) b64 = self._encode_base64(video_bytes)
...@@ -269,13 +271,13 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler): ...@@ -269,13 +271,13 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler):
output_buffer = io.BytesIO() output_buffer = io.BytesIO()
with imageio.get_writer( with imageio.get_writer(
output_buffer, output_buffer,
format="mp4", format="mp4", # type: ignore
fps=fps, fps=fps,
codec=codec, codec=codec,
output_params=["-pix_fmt", "yuv420p"], output_params=["-pix_fmt", "yuv420p"],
) as writer: ) as writer:
for frame in np_frames: for frame in np_frames:
writer.append_data(frame) writer.append_data(frame) # type: ignore
output_buffer.seek(0) output_buffer.seek(0)
return output_buffer.read() return output_buffer.read()
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio
from typing import ( from typing import (
Any, Any,
AsyncGenerator, AsyncGenerator,
...@@ -226,6 +227,17 @@ class Client: ...@@ -226,6 +227,17 @@ class Client:
""" """
... ...
async def generate(
self,
request: JsonLike,
annotated: bool | None = True,
context: Context | None = None,
) -> AsyncIterator[JsonLike]:
"""
Generate a response from the endpoint
"""
...
class ModelCardInstanceId: class ModelCardInstanceId:
""" """
...@@ -328,7 +340,7 @@ class Context: ...@@ -328,7 +340,7 @@ class Context:
""" """
... ...
async def async_killed_or_stopped(self) -> bool: def async_killed_or_stopped(self) -> asyncio.Future[bool]:
""" """
Asynchronously wait until the context is killed or stopped. Asynchronously wait until the context is killed or stopped.
...@@ -443,6 +455,8 @@ class ModelRuntimeConfig: ...@@ -443,6 +455,8 @@ class ModelRuntimeConfig:
enable_local_indexer: bool enable_local_indexer: bool
runtime_data: dict[str, Any] runtime_data: dict[str, Any]
tensor_model_config: Any | None tensor_model_config: Any | None
data_parallel_size: int
data_parallel_start_rank: int
def __init__(self) -> None: ... def __init__(self) -> None: ...
...@@ -454,6 +468,14 @@ class ModelRuntimeConfig: ...@@ -454,6 +468,14 @@ class ModelRuntimeConfig:
"""Get an engine-specific runtime configuration value""" """Get an engine-specific runtime configuration value"""
... ...
def set_disaggregated_endpoint(
self,
bootstrap_host: str | None = None,
bootstrap_port: int | None = None,
) -> None:
"""Set the disaggregated endpoint for the model"""
...
class OverlapScores: class OverlapScores:
""" """
A collection of prefix matching scores of workers for a given token ids. A collection of prefix matching scores of workers for a given token ids.
...@@ -931,7 +953,10 @@ class KserveGrpcService: ...@@ -931,7 +953,10 @@ class KserveGrpcService:
class ModelInput: class ModelInput:
"""What type of request this model needs: Text, Tokens or Tensor""" """What type of request this model needs: Text, Tokens or Tensor"""
... Text: ModelInput
Tokens: ModelInput
Tensor: ModelInput
class ModelType: class ModelType:
"""What type of request this model needs: Chat, Completions, Embedding, Tensor, Images, Videos or Prefill""" """What type of request this model needs: Chat, Completions, Embedding, Tensor, Images, Videos or Prefill"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment