Unverified Commit 2712426f authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

feat: enable mypy in pre-merge (#6732)

parent e5e118a1
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import asyncio import asyncio
import logging import logging
import time import time
from typing import Optional
from dynamo.planner import SubComponentType, TargetReplica from dynamo.planner import SubComponentType, TargetReplica
from dynamo.planner.utils.decode_planner import DecodePlanner from dynamo.planner.utils.decode_planner import DecodePlanner
...@@ -24,9 +23,7 @@ logger = logging.getLogger(__name__) ...@@ -24,9 +23,7 @@ logger = logging.getLogger(__name__)
class DisaggPlanner: class DisaggPlanner:
def __init__( def __init__(self, runtime: DistributedRuntime, config: PlannerConfig) -> None:
self, runtime: Optional[DistributedRuntime], config: PlannerConfig
) -> None:
self.config = config self.config = config
self.shared_state = PlannerSharedState() self.shared_state = PlannerSharedState()
prometheus_metrics = PlannerPrometheusMetrics() prometheus_metrics = PlannerPrometheusMetrics()
...@@ -89,13 +86,12 @@ class DisaggPlanner: ...@@ -89,13 +86,12 @@ class DisaggPlanner:
logger.info(f"Detected model name from deployment: {model_name}") logger.info(f"Detected model name from deployment: {model_name}")
model_name = model_name.lower() model_name = model_name.lower()
else: else:
model_name = getattr(self.config, "model_name", None) if not self.config.model_name:
if not model_name:
raise ValueError( raise ValueError(
"Model name is required in no-operation mode. " "Model name is required in no-operation mode. "
"Please set model_name in the config." "Please set model_name in the config."
) )
model_name = model_name.lower() model_name = self.config.model_name.lower()
self.prefill_planner.model_name = model_name self.prefill_planner.model_name = model_name
self.decode_planner.model_name = model_name self.decode_planner.model_name = model_name
......
...@@ -127,6 +127,13 @@ def run_sla_planner_dryrun( ...@@ -127,6 +127,13 @@ def run_sla_planner_dryrun(
time_series.append(time_series[-1] + interval) time_series.append(time_series[-1] + interval)
_est_rr, _est_isl, _est_osl = predictor_planner.predict_load() _est_rr, _est_isl, _est_osl = predictor_planner.predict_load()
# predict_load() returns Optional[float] values; in dryrun mode with
# pre-loaded data the predictors always return valid floats.
assert (
_est_rr is not None and _est_isl is not None and _est_osl is not None
), "predict_load() returned None in dryrun mode"
est_rr.append(_est_rr) est_rr.append(_est_rr)
est_isl.append(_est_isl) est_isl.append(_est_isl)
est_osl.append(_est_osl) est_osl.append(_est_osl)
...@@ -145,10 +152,12 @@ def run_sla_planner_dryrun( ...@@ -145,10 +152,12 @@ def run_sla_planner_dryrun(
if prefill_planner is not None and decode_planner is not None: if prefill_planner is not None and decode_planner is not None:
_num_p, _num_d = _apply_global_gpu_budget(_num_p, _num_d, config) _num_p, _num_d = _apply_global_gpu_budget(_num_p, _num_d, config)
elif prefill_planner is not None: elif prefill_planner is not None:
assert config.prefill_engine_num_gpu is not None
_num_p = _apply_component_gpu_budget( _num_p = _apply_component_gpu_budget(
_num_p, config.prefill_engine_num_gpu, config _num_p, config.prefill_engine_num_gpu, config
) )
elif decode_planner is not None: elif decode_planner is not None:
assert config.decode_engine_num_gpu is not None
_num_d = _apply_component_gpu_budget( _num_d = _apply_component_gpu_budget(
_num_d, config.decode_engine_num_gpu, config _num_d, config.decode_engine_num_gpu, config
) )
......
...@@ -19,7 +19,7 @@ import warnings ...@@ -19,7 +19,7 @@ import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime, timedelta from datetime import datetime, timedelta
from enum import Enum from enum import Enum
from typing import Any from typing import Any, Callable
import numpy as np import numpy as np
import pandas as pd import pandas as pd
...@@ -389,7 +389,7 @@ class KalmanPredictor(BasePredictor): ...@@ -389,7 +389,7 @@ class KalmanPredictor(BasePredictor):
) )
LOAD_PREDICTORS = { LOAD_PREDICTORS: dict[str, Callable[[PlannerConfig], BasePredictor]] = {
"constant": ConstantPredictor, "constant": ConstantPredictor,
"arima": ARIMAPredictor, "arima": ARIMAPredictor,
"kalman": KalmanPredictor, "kalman": KalmanPredictor,
......
...@@ -151,6 +151,8 @@ class DecodeInterpolator: ...@@ -151,6 +151,8 @@ class DecodeInterpolator:
self.resolution = resolution self.resolution = resolution
self.xi = np.linspace(0, 1, resolution) self.xi = np.linspace(0, 1, resolution)
self.yi = np.linspace(0, max(self.y_context_length), resolution) self.yi = np.linspace(0, max(self.y_context_length), resolution)
self.X: np.ndarray
self.Y: np.ndarray
self.X, self.Y = np.meshgrid(self.xi, self.yi) self.X, self.Y = np.meshgrid(self.xi, self.yi)
# Lazy import scipy only when interpolation is actually needed # Lazy import scipy only when interpolation is actually needed
......
...@@ -6,7 +6,7 @@ import logging ...@@ -6,7 +6,7 @@ import logging
import math import math
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional, Union
from prometheus_client import Gauge, start_http_server from prometheus_client import Gauge, start_http_server
...@@ -36,6 +36,9 @@ from dynamo.planner.utils.trace_data_extractor import extract_metrics_from_moonc ...@@ -36,6 +36,9 @@ from dynamo.planner.utils.trace_data_extractor import extract_metrics_from_moonc
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
# Union of all connector types used by the planner
ConnectorType = Union[GlobalPlannerConnector, KubernetesConnector, VirtualConnector]
configure_dynamo_logging() configure_dynamo_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -248,14 +251,14 @@ class BasePlanner: ...@@ -248,14 +251,14 @@ class BasePlanner:
def __init__( def __init__(
self, self,
runtime: DistributedRuntime, runtime: Optional[DistributedRuntime],
config: PlannerConfig, config: PlannerConfig,
dryrun: bool = False, dryrun: bool = False,
shared_state: Optional[PlannerSharedState] = None, shared_state: Optional[PlannerSharedState] = None,
prometheus_metrics: Optional[PlannerPrometheusMetrics] = None, prometheus_metrics: Optional[PlannerPrometheusMetrics] = None,
prometheus_traffic_client: Optional[PrometheusAPIClient] = None, prometheus_traffic_client: Optional[PrometheusAPIClient] = None,
prometheus_engine_client: Optional[DirectRouterMetricsClient] = None, prometheus_engine_client: Optional[DirectRouterMetricsClient] = None,
connector=None, connector: Optional[ConnectorType] = None,
start_prometheus_server: bool = True, start_prometheus_server: bool = True,
component_type: Optional[SubComponentType] = None, component_type: Optional[SubComponentType] = None,
): ):
...@@ -272,11 +275,13 @@ class BasePlanner: ...@@ -272,11 +275,13 @@ class BasePlanner:
if not self.dryrun: if not self.dryrun:
self.runtime = runtime self.runtime = runtime
self.namespace = config.namespace self.namespace = config.namespace
self.connector: ConnectorType
if not config.no_operation: if not config.no_operation:
# Initialize connector based on environment # Initialize connector based on environment
if config.environment == "global-planner": if config.environment == "global-planner":
assert config.global_planner_namespace is not None assert config.global_planner_namespace is not None
assert runtime is not None
self.connector = GlobalPlannerConnector( self.connector = GlobalPlannerConnector(
runtime, runtime,
self.namespace, self.namespace,
...@@ -289,6 +294,7 @@ class BasePlanner: ...@@ -289,6 +294,7 @@ class BasePlanner:
self.namespace, self.model_name self.namespace, self.model_name
) )
elif config.environment == "virtual": elif config.environment == "virtual":
assert runtime is not None
self.connector = VirtualConnector( self.connector = VirtualConnector(
runtime, runtime,
self.namespace, self.namespace,
...@@ -430,11 +436,12 @@ class BasePlanner: ...@@ -430,11 +436,12 @@ class BasePlanner:
self.prometheus_engine_client = prometheus_engine_client self.prometheus_engine_client = prometheus_engine_client
else: else:
# Auto-discover frontend metrics URL in Kubernetes mode # Auto-discover frontend metrics URL in Kubernetes mode
connector = getattr(self, "connector", None)
if not config.load_router_metrics_url and isinstance( if not config.load_router_metrics_url and isinstance(
getattr(self, "connector", None), KubernetesConnector connector, KubernetesConnector
): ):
config.load_router_metrics_url = ( config.load_router_metrics_url = (
self.connector.get_frontend_metrics_url() connector.get_frontend_metrics_url()
) )
if not config.load_router_metrics_url: if not config.load_router_metrics_url:
raise ValueError( raise ValueError(
...@@ -447,6 +454,9 @@ class BasePlanner: ...@@ -447,6 +454,9 @@ class BasePlanner:
f"Auto-discovered frontend metrics URL: {config.load_router_metrics_url}" f"Auto-discovered frontend metrics URL: {config.load_router_metrics_url}"
) )
assert (
config.load_router_metrics_url is not None
), "load_router_metrics_url must be set when load-based scaling is enabled"
self.prometheus_engine_client = DirectRouterMetricsClient( self.prometheus_engine_client = DirectRouterMetricsClient(
config.load_router_metrics_url, config.namespace config.load_router_metrics_url, config.namespace
) )
...@@ -494,6 +504,7 @@ class BasePlanner: ...@@ -494,6 +504,7 @@ class BasePlanner:
async def _get_or_create_client(self, component_name: str, endpoint_name: str): async def _get_or_create_client(self, component_name: str, endpoint_name: str):
"""Create a client for the given component and endpoint, with a brief sleep for state sync.""" """Create a client for the given component and endpoint, with a brief sleep for state sync."""
assert self.runtime is not None, "Runtime is not initialized"
client = await self.runtime.endpoint( client = await self.runtime.endpoint(
f"{self.namespace}.{component_name}.{endpoint_name}" f"{self.namespace}.{component_name}.{endpoint_name}"
).client() ).client()
...@@ -604,41 +615,46 @@ class BasePlanner: ...@@ -604,41 +615,46 @@ class BasePlanner:
) )
# Prometheus returns seconds, convert to milliseconds # Prometheus returns seconds, convert to milliseconds
assert (
self.model_name is not None
), "model_name must be set before observing traffic stats"
interval_str = f"{self.config.throughput_adjustment_interval}s"
self.last_metrics.ttft = ( self.last_metrics.ttft = (
self.prometheus_traffic_client.get_avg_time_to_first_token( self.prometheus_traffic_client.get_avg_time_to_first_token(
f"{self.config.throughput_adjustment_interval}s", interval_str,
self.model_name, self.model_name,
) )
* 1000 * 1000
) )
self.last_metrics.itl = ( self.last_metrics.itl = (
self.prometheus_traffic_client.get_avg_inter_token_latency( self.prometheus_traffic_client.get_avg_inter_token_latency(
f"{self.config.throughput_adjustment_interval}s", interval_str,
self.model_name, self.model_name,
) )
* 1000 * 1000
) )
self.last_metrics.num_req = ( self.last_metrics.num_req = (
self.prometheus_traffic_client.get_avg_request_count( self.prometheus_traffic_client.get_avg_request_count(
f"{self.config.throughput_adjustment_interval}s", interval_str,
self.model_name, self.model_name,
) )
) )
self.last_metrics.request_duration = ( self.last_metrics.request_duration = (
self.prometheus_traffic_client.get_avg_request_duration( self.prometheus_traffic_client.get_avg_request_duration(
f"{self.config.throughput_adjustment_interval}s", interval_str,
self.model_name, self.model_name,
) )
) )
self.last_metrics.isl = ( self.last_metrics.isl = (
self.prometheus_traffic_client.get_avg_input_sequence_tokens( self.prometheus_traffic_client.get_avg_input_sequence_tokens(
f"{self.config.throughput_adjustment_interval}s", interval_str,
self.model_name, self.model_name,
) )
) )
self.last_metrics.osl = ( self.last_metrics.osl = (
self.prometheus_traffic_client.get_avg_output_sequence_tokens( self.prometheus_traffic_client.get_avg_output_sequence_tokens(
f"{self.config.throughput_adjustment_interval}s", interval_str,
self.model_name, self.model_name,
) )
) )
...@@ -666,8 +682,11 @@ class BasePlanner: ...@@ -666,8 +682,11 @@ class BasePlanner:
self.update_predictors_from_metrics(self.last_metrics) self.update_predictors_from_metrics(self.last_metrics)
def update_predictors_from_metrics(self, metrics: Metrics) -> None: def update_predictors_from_metrics(self, metrics: Metrics) -> None:
if metrics.num_req is not None:
self.num_req_predictor.add_data_point(metrics.num_req) self.num_req_predictor.add_data_point(metrics.num_req)
if metrics.isl is not None:
self.isl_predictor.add_data_point(metrics.isl) self.isl_predictor.add_data_point(metrics.isl)
if metrics.osl is not None:
self.osl_predictor.add_data_point(metrics.osl) self.osl_predictor.add_data_point(metrics.osl)
def predict_load(self) -> tuple[Optional[float], Optional[float], Optional[float]]: def predict_load(self) -> tuple[Optional[float], Optional[float], Optional[float]]:
......
...@@ -94,6 +94,7 @@ class PrefillPlanner(BasePlanner): ...@@ -94,6 +94,7 @@ class PrefillPlanner(BasePlanner):
return None return None
def _update_correction_factor(self) -> bool: def _update_correction_factor(self) -> bool:
assert self.last_metrics.isl is not None and self.last_metrics.ttft is not None
expect_ttft = self.prefill_interpolator.interpolate_ttft(self.last_metrics.isl) expect_ttft = self.prefill_interpolator.interpolate_ttft(self.last_metrics.isl)
self.p_correction_factor = self.last_metrics.ttft / expect_ttft self.p_correction_factor = self.last_metrics.ttft / expect_ttft
logger.info(f"Correction factor (prefill TTFT): {self.p_correction_factor:.3f}") logger.info(f"Correction factor (prefill TTFT): {self.p_correction_factor:.3f}")
...@@ -117,6 +118,7 @@ class PrefillPlanner(BasePlanner): ...@@ -117,6 +118,7 @@ class PrefillPlanner(BasePlanner):
"(no throughput satisfies TTFT target), falling back to min_endpoint" "(no throughput satisfies TTFT target), falling back to min_endpoint"
) )
return self.config.min_endpoint return self.config.min_endpoint
assert self.config.prefill_engine_num_gpu is not None
next_num_p = math.ceil( next_num_p = math.ceil(
pred_prefill_throughput pred_prefill_throughput
/ p_thpt_per_gpu / p_thpt_per_gpu
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""Dynamo SGLang wrapper configuration ArgGroup.""" """Dynamo SGLang wrapper configuration ArgGroup."""
import argparse import argparse
from typing import Optional, Union from typing import Optional
from dynamo.common.configuration.arg_group import ArgGroup from dynamo.common.configuration.arg_group import ArgGroup
from dynamo.common.configuration.config_base import ConfigBase from dynamo.common.configuration.config_base import ConfigBase
...@@ -117,7 +117,7 @@ class DynamoSGLangConfig(ConfigBase): ...@@ -117,7 +117,7 @@ class DynamoSGLangConfig(ConfigBase):
multimodal_processor: bool multimodal_processor: bool
multimodal_encode_worker: bool multimodal_encode_worker: bool
multimodal_worker: bool multimodal_worker: bool
embedding_transfer_mode: Union[str, EmbeddingTransferMode] embedding_transfer_mode: EmbeddingTransferMode
embedding_worker: bool embedding_worker: bool
image_diffusion_worker: bool image_diffusion_worker: bool
...@@ -127,10 +127,11 @@ class DynamoSGLangConfig(ConfigBase): ...@@ -127,10 +127,11 @@ class DynamoSGLangConfig(ConfigBase):
video_generation_worker: bool video_generation_worker: bool
def validate(self) -> None: def validate(self) -> None:
if isinstance(self.embedding_transfer_mode, str): if not isinstance(self.embedding_transfer_mode, EmbeddingTransferMode):
self.embedding_transfer_mode = EmbeddingTransferMode( self.embedding_transfer_mode = EmbeddingTransferMode(
self.embedding_transfer_mode str(self.embedding_transfer_mode)
) )
if (self.disagg_config is not None) ^ (self.disagg_config_key is not None): if (self.disagg_config is not None) ^ (self.disagg_config_key is not None):
raise ValueError( raise ValueError(
"Both 'disagg_config' and 'disagg_config_key' must be provided together." "Both 'disagg_config' and 'disagg_config_key' must be provided together."
......
...@@ -8,13 +8,23 @@ import random ...@@ -8,13 +8,23 @@ import random
import socket import socket
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Optional, Tuple from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Dict,
Generic,
Optional,
Tuple,
TypeVar,
)
import sglang as sgl import sglang as sgl
from sglang.srt.utils import get_local_ip_auto from sglang.srt.utils import get_local_ip_auto
from dynamo._core import Context from dynamo._core import Context
from dynamo.common.utils.input_params import InputParamManager from dynamo.common.utils.input_params import InputParamManager
from dynamo.llm import KvEventPublisher, WorkerMetricsPublisher
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.publisher import DynamoSglangPublisher
...@@ -72,7 +82,11 @@ class SGLangEngineQuiesceController: ...@@ -72,7 +82,11 @@ class SGLangEngineQuiesceController:
self._is_quiesced = False self._is_quiesced = False
class BaseGenerativeHandler(ABC): RequestT = TypeVar("RequestT")
ResponseT = TypeVar("ResponseT")
class BaseGenerativeHandler(ABC, Generic[RequestT, ResponseT]):
"""Minimal base class for all generative handlers (LLM, diffusion, etc.). """Minimal base class for all generative handlers (LLM, diffusion, etc.).
Provides common infrastructure for: Provides common infrastructure for:
...@@ -95,27 +109,24 @@ class BaseGenerativeHandler(ABC): ...@@ -95,27 +109,24 @@ class BaseGenerativeHandler(ABC):
self.config = config self.config = config
# Set up metrics and KV publishers # Set up metrics and KV publishers
self.metrics_publisher: Optional[WorkerMetricsPublisher] = None
self.kv_publisher: Optional[KvEventPublisher] = None
if publisher is not None: if publisher is not None:
self.metrics_publisher = publisher.metrics_publisher self.metrics_publisher = publisher.metrics_publisher
self.kv_publisher = publisher.kv_publisher self.kv_publisher = publisher.kv_publisher
else:
self.metrics_publisher = None
self.kv_publisher = None
@abstractmethod @abstractmethod
async def generate( def generate(self, request: RequestT, context: Context) -> AsyncIterator[ResponseT]:
self, request: Dict[str, Any], context: Context
) -> AsyncGenerator[Dict[str, Any], None]:
"""Generate response from request. """Generate response from request.
Args: Args:
request: Request dict with input and parameters. request: Request with input and parameters.
context: Context object for cancellation handling. context: Context object for cancellation handling.
Yields: Yields:
Response data (format varies by handler implementation). Response data (format varies by handler implementation).
""" """
pass ...
def cleanup(self) -> None: def cleanup(self) -> None:
"""Cleanup resources. Override in subclasses as needed.""" """Cleanup resources. Override in subclasses as needed."""
...@@ -137,7 +148,7 @@ class BaseGenerativeHandler(ABC): ...@@ -137,7 +148,7 @@ class BaseGenerativeHandler(ABC):
return {"traceparent": f"00-{trace_id}-{span_id}-01"} return {"traceparent": f"00-{trace_id}-{span_id}-01"}
class BaseWorkerHandler(BaseGenerativeHandler): class BaseWorkerHandler(BaseGenerativeHandler[RequestT, ResponseT]):
"""Abstract base class for SGLang LLM worker handlers. """Abstract base class for SGLang LLM worker handlers.
Extends BaseGenerativeHandler with LLM-specific functionality: Extends BaseGenerativeHandler with LLM-specific functionality:
...@@ -175,9 +186,6 @@ class BaseWorkerHandler(BaseGenerativeHandler): ...@@ -175,9 +186,6 @@ class BaseWorkerHandler(BaseGenerativeHandler):
if publisher is not None: if publisher is not None:
self.metrics_publisher = publisher.metrics_publisher self.metrics_publisher = publisher.metrics_publisher
self.kv_publisher = publisher.kv_publisher self.kv_publisher = publisher.kv_publisher
else:
self.metrics_publisher = None
self.kv_publisher = None
self.serving_mode = config.serving_mode self.serving_mode = config.serving_mode
self.skip_tokenizer_init = config.server_args.skip_tokenizer_init self.skip_tokenizer_init = config.server_args.skip_tokenizer_init
self.enable_trace = config.server_args.enable_trace self.enable_trace = config.server_args.enable_trace
...@@ -454,17 +462,17 @@ class BaseWorkerHandler(BaseGenerativeHandler): ...@@ -454,17 +462,17 @@ class BaseWorkerHandler(BaseGenerativeHandler):
) )
@abstractmethod @abstractmethod
async def generate(self, request: Dict[str, Any], context: Context): def generate(self, request: RequestT, context: Context) -> AsyncIterator[ResponseT]:
"""Generate response from request. """Generate response from request.
Args: Args:
request: Request dict with input and parameters. request: Request with input and parameters.
context: Context object for cancellation handling. context: Context object for cancellation handling.
Yields: Yields:
Response data (format varies by handler implementation). Response data (format varies by handler implementation).
""" """
pass ...
def cleanup(self) -> None: def cleanup(self) -> None:
"""Cleanup resources. Override in subclasses as needed.""" """Cleanup resources. Override in subclasses as needed."""
......
...@@ -24,7 +24,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -24,7 +24,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
self, self,
engine: sgl.Engine, engine: sgl.Engine,
config: Config, config: Config,
publisher: DynamoSglangPublisher, publisher: Optional[DynamoSglangPublisher] = None,
generate_endpoint=None, generate_endpoint=None,
shutdown_event: Optional[asyncio.Event] = None, shutdown_event: Optional[asyncio.Event] = None,
) -> None: ) -> None:
...@@ -230,7 +230,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -230,7 +230,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
# This lets SGLang proceed to the second token generation, which will # This lets SGLang proceed to the second token generation, which will
# async context switch and allow the abort monitor to signal cancellation. # async context switch and allow the abort monitor to signal cancellation.
# The loop should exit by itself when context.is_stopped() returns True. # The loop should exit by itself when context.is_stopped() returns True.
out = {} out: dict[str, Any] = {}
finish_reason = res["meta_info"]["finish_reason"] finish_reason = res["meta_info"]["finish_reason"]
if finish_reason: if finish_reason:
out["finish_reason"] = normalize_finish_reason( out["finish_reason"] = normalize_finish_reason(
......
...@@ -21,7 +21,7 @@ class DiffusionWorkerHandler(DecodeWorkerHandler): ...@@ -21,7 +21,7 @@ class DiffusionWorkerHandler(DecodeWorkerHandler):
self, self,
engine: sgl.Engine, engine: sgl.Engine,
config: Config, config: Config,
publisher: DynamoSglangPublisher = None, publisher: Optional[DynamoSglangPublisher] = None,
generate_endpoint=None, generate_endpoint=None,
shutdown_event: Optional[asyncio.Event] = None, shutdown_event: Optional[asyncio.Event] = None,
) -> None: ) -> None:
......
...@@ -38,7 +38,7 @@ except ImportError as e: ...@@ -38,7 +38,7 @@ except ImportError as e:
DEVICE = "cpu" DEVICE = "cpu"
class MultimodalEncodeWorkerHandler(BaseWorkerHandler): class MultimodalEncodeWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]):
""" """
Handler for multimodal encode worker component that processes images/videos Handler for multimodal encode worker component that processes images/videos
and forwards them to the downstream worker. and forwards them to the downstream worker.
...@@ -84,12 +84,19 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler): ...@@ -84,12 +84,19 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
if image_token_str == "<|vision_start|><|image_pad|><|vision_end|>": if image_token_str == "<|vision_start|><|image_pad|><|vision_end|>":
# These are likely the individual special tokens for Qwen2.5-VL # These are likely the individual special tokens for Qwen2.5-VL
image_pad_id = self.tokenizer.convert_tokens_to_ids("<|image_pad|>") image_pad_id = self.tokenizer.convert_tokens_to_ids("<|image_pad|>")
assert isinstance(
image_pad_id, int
), f"Expected int token id, got {type(image_pad_id)}"
# Use the image_pad token as the main image token # Use the image_pad token as the main image token
self.image_token_id = image_pad_id self.image_token_id: int = image_pad_id
else: else:
# Fallback for other models # Fallback for other models
self.image_token_id = self.tokenizer.convert_tokens_to_ids(image_token_str) token_id = self.tokenizer.convert_tokens_to_ids(image_token_str)
assert isinstance(
token_id, int
), f"Expected int token id, got {type(token_id)}"
self.image_token_id = token_id
self.min_workers = 1 self.min_workers = 1
...@@ -230,10 +237,11 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler): ...@@ -230,10 +237,11 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
zip(multimodal_groups, image_grid_thw_list) zip(multimodal_groups, image_grid_thw_list)
): ):
mm_group.image_grid_thw = image_grid_thw mm_group.image_grid_thw = image_grid_thw
if mm_group.multimodal_input is not None:
mm_group.multimodal_input.image_url = None mm_group.multimodal_input.image_url = None
# Store shared tensor transfer metadata at request level. # Store shared tensor transfer metadata at request level.
request.embeddings_shape = tuple(precomputed_embeddings.shape) request.embeddings_shape = tuple(precomputed_embeddings.shape) # type: ignore[assignment]
request.transfer_payload = None request.transfer_payload = None
search_start = 0 search_start = 0
......
...@@ -6,7 +6,7 @@ import json ...@@ -6,7 +6,7 @@ import json
import logging import logging
import time import time
import uuid import uuid
from typing import Any, Dict, Optional from typing import Any, AsyncGenerator, Dict, Optional
from transformers import AutoTokenizer from transformers import AutoTokenizer
...@@ -20,6 +20,7 @@ from dynamo.sglang.protocol import ( ...@@ -20,6 +20,7 @@ from dynamo.sglang.protocol import (
MultiModalGroup, MultiModalGroup,
MultiModalInput, MultiModalInput,
MultiModalRequest, MultiModalRequest,
PreprocessedRequest,
SglangMultimodalRequest, SglangMultimodalRequest,
) )
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
...@@ -27,7 +28,7 @@ from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler ...@@ -27,7 +28,7 @@ from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MultimodalProcessorHandler(BaseWorkerHandler): class MultimodalProcessorHandler(BaseWorkerHandler[MultiModalRequest, Dict[str, Any]]):
""" """
Handler for multimodal processor component that processes multimodal requests Handler for multimodal processor component that processes multimodal requests
and forwards them to the encode worker. and forwards them to the encode worker.
...@@ -56,7 +57,9 @@ class MultimodalProcessorHandler(BaseWorkerHandler): ...@@ -56,7 +57,9 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
def cleanup(self): def cleanup(self):
pass pass
async def generate(self, raw_request: MultiModalRequest, context: Context): async def generate(
self, raw_request: MultiModalRequest, context: Context
) -> AsyncGenerator[Dict[str, Any], None]:
""" """
Process multimodal request and forward to encode worker. Process multimodal request and forward to encode worker.
...@@ -119,7 +122,7 @@ class MultimodalProcessorHandler(BaseWorkerHandler): ...@@ -119,7 +122,7 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
) )
worker_request = SglangMultimodalRequest( worker_request = SglangMultimodalRequest(
request=sglang_request, request=PreprocessedRequest(**sglang_request),
multimodal_inputs=multimodal_groups, multimodal_inputs=multimodal_groups,
) )
......
...@@ -256,7 +256,7 @@ async def _build_mm_items( ...@@ -256,7 +256,7 @@ async def _build_mm_items(
return mm_items, embeddings, tensor_id return mm_items, embeddings, tensor_id
class MultimodalWorkerHandler(BaseWorkerHandler): class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]):
""" """
Multimodal worker handler for LLM inference with multimodal data. Multimodal worker handler for LLM inference with multimodal data.
Handles both aggregated and disaggregated modes. Handles both aggregated and disaggregated modes.
...@@ -490,7 +490,9 @@ class MultimodalWorkerHandler(BaseWorkerHandler): ...@@ -490,7 +490,9 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
logger.info("Multimodal worker engine shutdown") logger.info("Multimodal worker engine shutdown")
class MultimodalPrefillWorkerHandler(BaseWorkerHandler): class MultimodalPrefillWorkerHandler(
BaseWorkerHandler[DisaggSglangMultimodalRequest, str]
):
""" """
Multimodal prefill worker handler for disaggregated inference Multimodal prefill worker handler for disaggregated inference
Processes multimodal inputs and coordinates with decode worker. Processes multimodal inputs and coordinates with decode worker.
......
...@@ -103,16 +103,22 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler): ...@@ -103,16 +103,22 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler):
) )
# Parse size # Parse size
assert req.size is not None, "Size is required"
width, height = self._parse_size(req.size) width, height = self._parse_size(req.size)
# Calculate num_frames if not explicitly provided # Calculate num_frames if not explicitly provided
num_frames = nvext.num_frames num_frames = nvext.num_frames
assert nvext.fps is not None, "FPS is required"
if num_frames is None: if num_frames is None:
assert req.seconds is not None, "Seconds is required"
num_frames = nvext.fps * req.seconds num_frames = nvext.fps * req.seconds
# Generate video # Generate video
context_id = context.id() context_id = context.id()
assert context_id is not None assert context_id is not None
assert (
nvext.num_inference_steps is not None
), "Num inference steps is required"
video_bytes = await self._generate_video( video_bytes = await self._generate_video(
prompt=req.prompt, prompt=req.prompt,
width=width, width=width,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging import logging
from typing import List, Optional from typing import List, Optional, Sequence
import torch import torch
from tensorrt_llm.sampling_params import LogitsProcessor from tensorrt_llm.sampling_params import LogitsProcessor
...@@ -70,7 +70,7 @@ class TrtllmDynamoLogitsAdapter(LogitsProcessor): ...@@ -70,7 +70,7 @@ class TrtllmDynamoLogitsAdapter(LogitsProcessor):
def create_trtllm_adapters( def create_trtllm_adapters(
processors: List[BaseLogitsProcessor], processors: Sequence[BaseLogitsProcessor],
) -> List[TrtllmDynamoLogitsAdapter]: ) -> List[TrtllmDynamoLogitsAdapter]:
""" """
Create TensorRT-LLM compatible adapters from Dynamo logits processors. Create TensorRT-LLM compatible adapters from Dynamo logits processors.
......
...@@ -5,7 +5,9 @@ ...@@ -5,7 +5,9 @@
import logging import logging
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import Optional from typing import Optional, Union
import torch
from dynamo._core import Context from dynamo._core import Context
from dynamo.common.memory.multimodal_embedding_cache_manager import ( from dynamo.common.memory.multimodal_embedding_cache_manager import (
...@@ -40,7 +42,7 @@ class AggregatedHandler(HandlerBase): ...@@ -40,7 +42,7 @@ class AggregatedHandler(HandlerBase):
"""Generate response, optionally using remote encoder for multimodal.""" """Generate response, optionally using remote encoder for multimodal."""
logging.debug(f"AggregatedHandler Request ID: {context.id()}") logging.debug(f"AggregatedHandler Request ID: {context.id()}")
embeddings = None embeddings: Optional[Union[torch.Tensor, dict]] = None
ep_disaggregated_params = None ep_disaggregated_params = None
if self.multimodal_processor and self.encode_client: if self.multimodal_processor and self.encode_client:
messages = request.get("extra_args", {}).get( messages = request.get("extra_args", {}).get(
...@@ -57,7 +59,7 @@ class AggregatedHandler(HandlerBase): ...@@ -57,7 +59,7 @@ class AggregatedHandler(HandlerBase):
self._encoder_cache, self._encoder_cache,
) )
if isinstance(result, list): if isinstance(result, list):
embeddings = result embeddings = result # type: ignore[assignment]
else: else:
ep_disaggregated_params = result ep_disaggregated_params = result
......
...@@ -31,7 +31,7 @@ from tensorrt_llm.llmapi.llm import SamplingParams ...@@ -31,7 +31,7 @@ from tensorrt_llm.llmapi.llm import SamplingParams
from tensorrt_llm.sampling_params import GuidedDecodingParams from tensorrt_llm.sampling_params import GuidedDecodingParams
from tensorrt_llm.scheduling_params import SchedulingParams from tensorrt_llm.scheduling_params import SchedulingParams
from dynamo._core import Context from dynamo._core import Client, Context
from dynamo.common.utils.otel_tracing import build_trace_headers from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.logits_processing.examples import HelloWorldLogitsProcessor from dynamo.logits_processing.examples import HelloWorldLogitsProcessor
from dynamo.nixl_connect import Connector from dynamo.nixl_connect import Connector
...@@ -65,9 +65,9 @@ class RequestHandlerConfig: ...@@ -65,9 +65,9 @@ class RequestHandlerConfig:
engine: TensorRTLLMEngine engine: TensorRTLLMEngine
default_sampling_params: SamplingParams default_sampling_params: SamplingParams
publisher: Publisher publisher: Optional[Publisher]
disaggregation_mode: DisaggregationMode disaggregation_mode: DisaggregationMode
encode_client: Optional[object] = None encode_client: Optional[Client] = None
multimodal_processor: Optional[ multimodal_processor: Optional[
MultimodalRequestProcessor MultimodalRequestProcessor
] = None # for multimodal support ] = None # for multimodal support
...@@ -558,11 +558,11 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -558,11 +558,11 @@ class HandlerBase(BaseGenerativeHandler):
# PREFILL/ENCODE/AGGREGATED: Process multimodal content if available # PREFILL/ENCODE/AGGREGATED: Process multimodal content if available
if self.multimodal_processor: if self.multimodal_processor:
processed_input = await self.multimodal_processor.process_openai_request( mm_result = await self.multimodal_processor.process_openai_request(
request, embeddings, ep_disaggregated_params request, embeddings, ep_disaggregated_params
) )
if processed_input: if mm_result:
return processed_input return mm_result
# If multimodal processing returned None but request has multimodal data, # If multimodal processing returned None but request has multimodal data,
# this is an error (not a text-only request). Raise instead of falling back. # this is an error (not a text-only request). Raise instead of falling back.
......
...@@ -111,6 +111,8 @@ class PrefillHandler(HandlerBase): ...@@ -111,6 +111,8 @@ class PrefillHandler(HandlerBase):
Encoder's embeddings tensor to be used by the prefill worker Encoder's embeddings tensor to be used by the prefill worker
""" """
# Get response with shape info and readable metadata # Get response with shape info and readable metadata
if self.encode_client is None:
raise RuntimeError("Encode client is not configured.")
encode_response = None encode_response = None
async for res in await self.encode_client.round_robin(request): async for res in await self.encode_client.round_robin(request):
encode_response = res.data() encode_response = res.data()
...@@ -119,6 +121,8 @@ class PrefillHandler(HandlerBase): ...@@ -119,6 +121,8 @@ class PrefillHandler(HandlerBase):
if not encode_response: if not encode_response:
raise RuntimeError("Did not receive a response from the encode worker.") raise RuntimeError("Did not receive a response from the encode worker.")
if self.connector is None:
raise RuntimeError("Connector is not configured.")
# Use utility function to handle NIXL reading and reconstruction # Use utility function to handle NIXL reading and reconstruction
return await EncodeHelper.read_embeddings_from_encode_response( return await EncodeHelper.read_embeddings_from_encode_response(
encode_response, self.connector encode_response, self.connector
......
...@@ -61,6 +61,9 @@ async def init_video_diffusion_worker( ...@@ -61,6 +61,9 @@ async def init_video_diffusion_worker(
else [] else []
) )
if not config.endpoint:
raise ValueError("endpoint must be configured for video diffusion worker")
# Build DiffusionConfig from the main Config # Build DiffusionConfig from the main Config
diffusion_config = DiffusionConfig( diffusion_config = DiffusionConfig(
namespace=config.namespace, namespace=config.namespace,
......
...@@ -13,7 +13,7 @@ import time ...@@ -13,7 +13,7 @@ import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Final from typing import Any, AsyncIterator, Dict, Final, Generic, TypeVar
import torch import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -23,6 +23,7 @@ from vllm.outputs import RequestOutput ...@@ -23,6 +23,7 @@ from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.engine.exceptions import EngineDeadError
from dynamo._core import Context
from dynamo.common.multimodal.image_loader import ImageLoader from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.common.utils.engine_response import normalize_finish_reason from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.common.utils.input_params import InputParamManager from dynamo.common.utils.input_params import InputParamManager
...@@ -325,7 +326,11 @@ def get_dp_range_for_worker(vllm_config: VllmConfig) -> tuple[int, int]: ...@@ -325,7 +326,11 @@ def get_dp_range_for_worker(vllm_config: VllmConfig) -> tuple[int, int]:
) )
class BaseWorkerHandler(ABC): RequestT = TypeVar("RequestT")
ResponseT = TypeVar("ResponseT")
class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
""" """
Request handler for the generate and clear_kv_blocks endpoints. Request handler for the generate and clear_kv_blocks endpoints.
""" """
...@@ -459,7 +464,7 @@ class BaseWorkerHandler(ABC): ...@@ -459,7 +464,7 @@ class BaseWorkerHandler(ABC):
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
@abstractmethod @abstractmethod
async def generate(self, request, context) -> AsyncGenerator[dict, None]: def generate(self, request: RequestT, context: Context) -> AsyncIterator[ResponseT]:
raise NotImplementedError raise NotImplementedError
async def _monitor_abort(self, context, request_id, is_prefill): async def _monitor_abort(self, context, request_id, is_prefill):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment