Unverified Commit 63e7b7da authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

chore: add mypy to planner and mocker (#6862)

parent 6216ae55
...@@ -90,7 +90,7 @@ def resolve_planner_profile_data( ...@@ -90,7 +90,7 @@ def resolve_planner_profile_data(
) )
def create_temp_engine_args_file(args) -> Path: def create_temp_engine_args_file(args: argparse.Namespace) -> Path:
""" """
Create a temporary JSON file with MockEngineArgs from CLI arguments. Create a temporary JSON file with MockEngineArgs from CLI arguments.
Returns the path to the temporary file. Returns the path to the temporary file.
...@@ -146,7 +146,7 @@ def create_temp_engine_args_file(args) -> Path: ...@@ -146,7 +146,7 @@ def create_temp_engine_args_file(args) -> Path:
return temp_path return temp_path
def validate_worker_type_args(args): def validate_worker_type_args(args: argparse.Namespace) -> None:
""" """
Resolve disaggregation mode from --disaggregation-mode or legacy boolean flags. Resolve disaggregation mode from --disaggregation-mode or legacy boolean flags.
Raises ValueError if validation fails. Raises ValueError if validation fails.
...@@ -199,7 +199,7 @@ def parse_bootstrap_ports(ports_str: str | None) -> list[int]: ...@@ -199,7 +199,7 @@ def parse_bootstrap_ports(ports_str: str | None) -> list[int]:
return [int(p.strip()) for p in ports_str.split(",")] return [int(p.strip()) for p in ports_str.split(",")]
def parse_args(): def parse_args() -> argparse.Namespace:
"""Parse command-line arguments for the Dynamo mocker engine. """Parse command-line arguments for the Dynamo mocker engine.
Returns: Returns:
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Usage: `python -m dynamo.mocker --model-path /data/models/Qwen3-0.6B` # Usage: `python -m dynamo.mocker --model-path /data/models/Qwen3-0.6B`
# Now supports vLLM-style individual arguments for MockEngineArgs # Now supports vLLM-style individual arguments for MockEngineArgs
import argparse
import asyncio import asyncio
import json import json
import logging import logging
...@@ -135,7 +136,7 @@ def compute_stagger_delay(num_workers: int, stagger_delay: float) -> float: ...@@ -135,7 +136,7 @@ def compute_stagger_delay(num_workers: int, stagger_delay: float) -> float:
return 0.2 return 0.2
async def launch_workers(args, extra_engine_args_path): async def launch_workers(args: argparse.Namespace, extra_engine_args_path: Path):
"""Launch mocker worker(s) with isolated DistributedRuntime instances. """Launch mocker worker(s) with isolated DistributedRuntime instances.
Each worker gets its own DistributedRuntime, which means: Each worker gets its own DistributedRuntime, which means:
...@@ -185,7 +186,9 @@ async def launch_workers(args, extra_engine_args_path): ...@@ -185,7 +186,9 @@ async def launch_workers(args, extra_engine_args_path):
runtimes.append(runtime) runtimes.append(runtime)
# Determine which engine args file to use # Determine which engine args file to use
worker_engine_args_path: Path | str
if needs_per_worker_args: if needs_per_worker_args:
assert base_engine_args is not None
worker_args = base_engine_args.copy() worker_args = base_engine_args.copy()
if args.bootstrap_ports_list: if args.bootstrap_ports_list:
worker_args["bootstrap_port"] = args.bootstrap_ports_list[worker_id] worker_args["bootstrap_port"] = args.bootstrap_ports_list[worker_id]
...@@ -195,9 +198,9 @@ async def launch_workers(args, extra_engine_args_path): ...@@ -195,9 +198,9 @@ async def launch_workers(args, extra_engine_args_path):
] ]
with tempfile.NamedTemporaryFile( with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False mode="w", suffix=".json", delete=False
) as f: ) as tmp:
json.dump(worker_args, f) json.dump(worker_args, tmp)
worker_engine_args_path = Path(f.name) worker_engine_args_path = Path(tmp.name)
per_worker_temp_files.append(worker_engine_args_path) per_worker_temp_files.append(worker_engine_args_path)
logger.debug(f"Worker {worker_id}: per-worker args {worker_args}") logger.debug(f"Worker {worker_id}: per-worker args {worker_args}")
else: else:
...@@ -209,7 +212,7 @@ async def launch_workers(args, extra_engine_args_path): ...@@ -209,7 +212,7 @@ async def launch_workers(args, extra_engine_args_path):
model_path=args.model_path, model_path=args.model_path,
model_name=args.model_name, model_name=args.model_name,
endpoint_id=args.endpoint, endpoint_id=args.endpoint,
extra_engine_args=worker_engine_args_path, extra_engine_args=str(worker_engine_args_path),
is_prefill=args.is_prefill_worker, is_prefill=args.is_prefill_worker,
) )
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging import logging
from typing import Any
from transformers import AutoConfig from transformers import AutoConfig
...@@ -43,7 +44,7 @@ def _normalize_dtype_str(dtype) -> str: ...@@ -43,7 +44,7 @@ def _normalize_dtype_str(dtype) -> str:
return s return s
def get_kv_cache_dtype_bytes(config, kv_cache_dtype: str = "auto") -> int: def get_kv_cache_dtype_bytes(config: Any, kv_cache_dtype: str = "auto") -> int:
"""Get the byte size per element for KV cache based on dtype. """Get the byte size per element for KV cache based on dtype.
When kv_cache_dtype is "auto", uses the model's dtype from config. When kv_cache_dtype is "auto", uses the model's dtype from config.
......
...@@ -15,15 +15,21 @@ ...@@ -15,15 +15,21 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dynamo.planner.defaults import SubComponentType
# TODO: add ability to scale component to X replicas # TODO: add ability to scale component to X replicas
class PlannerConnector(ABC): class PlannerConnector(ABC):
@abstractmethod @abstractmethod
async def add_component(self, component_name): async def add_component(
self, sub_component_type: SubComponentType, blocking: bool = True
) -> None:
"""Add a component to the planner""" """Add a component to the planner"""
pass pass
@abstractmethod @abstractmethod
async def remove_component(self, component_name): async def remove_component(
self, sub_component_type: SubComponentType, blocking: bool = True
) -> None:
"""Remove a component from the planner""" """Remove a component from the planner"""
pass pass
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import asyncio import asyncio
import logging import logging
from dynamo._core import Client
from dynamo.planner.defaults import SubComponentType from dynamo.planner.defaults import SubComponentType
from dynamo.planner.scale_protocol import ScaleRequest, ScaleResponse from dynamo.planner.scale_protocol import ScaleRequest, ScaleResponse
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
...@@ -29,7 +30,7 @@ class RemotePlannerClient: ...@@ -29,7 +30,7 @@ class RemotePlannerClient:
self.central_component = central_component self.central_component = central_component
self.connection_timeout = connection_timeout self.connection_timeout = connection_timeout
self.max_retries = max_retries self.max_retries = max_retries
self._client = None self._client: Client | None = None
async def _ensure_client(self): async def _ensure_client(self):
"""Lazy initialization of endpoint client with retry mechanism""" """Lazy initialization of endpoint client with retry mechanism"""
...@@ -39,7 +40,7 @@ class RemotePlannerClient: ...@@ -39,7 +40,7 @@ class RemotePlannerClient:
) )
# Retry logic with exponential backoff # Retry logic with exponential backoff
last_error = None last_error: Exception | None = None
for attempt in range(self.max_retries): for attempt in range(self.max_retries):
try: try:
logger.info( logger.info(
...@@ -101,6 +102,7 @@ class RemotePlannerClient: ...@@ -101,6 +102,7 @@ class RemotePlannerClient:
# Send request via the runtime client's generate method (the correct API for # Send request via the runtime client's generate method (the correct API for
# calling any dynamo endpoint, regardless of its registered name) # calling any dynamo endpoint, regardless of its registered name)
request_json = request.model_dump_json() request_json = request.model_dump_json()
assert self._client is not None
stream = await self._client.generate(request_json) stream = await self._client.generate(request_json)
response_data = None response_data = None
......
...@@ -40,9 +40,7 @@ class AggPlanner: ...@@ -40,9 +40,7 @@ class AggPlanner:
# Engine metrics from agg workers are labeled "decode" by the router # Engine metrics from agg workers are labeled "decode" by the router
ENGINE_WORKER_TYPE = "decode" ENGINE_WORKER_TYPE = "decode"
def __init__( def __init__(self, runtime: DistributedRuntime, config: PlannerConfig) -> None:
self, runtime: Optional[DistributedRuntime], config: PlannerConfig
) -> None:
self.config = config self.config = config
self.shared_state = PlannerSharedState() self.shared_state = PlannerSharedState()
......
...@@ -19,6 +19,7 @@ import warnings ...@@ -19,6 +19,7 @@ import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime, timedelta from datetime import datetime, timedelta
from enum import Enum from enum import Enum
from typing import Any
import numpy as np import numpy as np
import pandas as pd import pandas as pd
...@@ -55,19 +56,19 @@ for _name in ( ...@@ -55,19 +56,19 @@ for _name in (
class BasePredictor(ABC): class BasePredictor(ABC):
"""Base class for all load predictors""" """Base class for all load predictors"""
def __init__(self, minimum_data_points=5): def __init__(self, minimum_data_points: int = 5) -> None:
self.minimum_data_points = minimum_data_points self.minimum_data_points = minimum_data_points
self.data_buffer = [] self.data_buffer: list[Any] = []
# Even if we preload historical data, we still want to ignore the initial # Even if we preload historical data, we still want to ignore the initial
# post-deployment idle period (a run of zeros) until we see the first # post-deployment idle period (a run of zeros) until we see the first
# non-zero datapoint from live traffic. # non-zero datapoint from live traffic.
self._seen_nonzero_since_idle_reset = False self._seen_nonzero_since_idle_reset = False
def reset_idle_skip(self): def reset_idle_skip(self) -> None:
"""Reset idle-period skipping state (e.g., after warmup, before live).""" """Reset idle-period skipping state (e.g., after warmup, before live)."""
self._seen_nonzero_since_idle_reset = False self._seen_nonzero_since_idle_reset = False
def add_data_point(self, value): def add_data_point(self, value: float) -> None:
"""Add new data point to the buffer""" """Add new data point to the buffer"""
if math.isnan(value): if math.isnan(value):
value = 0 value = 0
...@@ -82,14 +83,14 @@ class BasePredictor(ABC): ...@@ -82,14 +83,14 @@ class BasePredictor(ABC):
self.data_buffer.append(value) self.data_buffer.append(value)
def get_last_value(self): def get_last_value(self) -> float:
"""Get the last value from the buffer""" """Get the last value from the buffer"""
if not self.data_buffer: if not self.data_buffer:
return 0 return 0
return self.data_buffer[-1] return self.data_buffer[-1]
@abstractmethod @abstractmethod
def predict_next(self): def predict_next(self) -> float:
"""Predict the next value""" """Predict the next value"""
pass pass
...@@ -99,10 +100,10 @@ class ConstantPredictor(BasePredictor): ...@@ -99,10 +100,10 @@ class ConstantPredictor(BasePredictor):
Assume load is constant and predict the next load to be the same as most recent load Assume load is constant and predict the next load to be the same as most recent load
""" """
def __init__(self, _config: PlannerConfig): def __init__(self, _config: PlannerConfig) -> None:
super().__init__(minimum_data_points=1) super().__init__(minimum_data_points=1)
def predict_next(self): def predict_next(self) -> float:
return self.get_last_value() return self.get_last_value()
...@@ -112,7 +113,7 @@ class ARIMAPredictor(BasePredictor): ...@@ -112,7 +113,7 @@ class ARIMAPredictor(BasePredictor):
RAW = "raw" RAW = "raw"
LOG1P = "log1p" LOG1P = "log1p"
def __init__(self, config: PlannerConfig): def __init__(self, config: PlannerConfig) -> None:
super().__init__(minimum_data_points=5) super().__init__(minimum_data_points=5)
self.model = None self.model = None
# Keep raw values so we can fit in raw space first, then fallback to log1p space. # Keep raw values so we can fit in raw space first, then fallback to log1p space.
...@@ -125,7 +126,7 @@ class ARIMAPredictor(BasePredictor): ...@@ -125,7 +126,7 @@ class ARIMAPredictor(BasePredictor):
) )
self._mode: ARIMAPredictor.Mode = self._requested_mode self._mode: ARIMAPredictor.Mode = self._requested_mode
def get_last_value(self): def get_last_value(self) -> float:
"""Return last value in original scale.""" """Return last value in original scale."""
if self._raw_buffer: if self._raw_buffer:
return float(self._raw_buffer[-1]) return float(self._raw_buffer[-1])
...@@ -133,7 +134,7 @@ class ARIMAPredictor(BasePredictor): ...@@ -133,7 +134,7 @@ class ARIMAPredictor(BasePredictor):
return 0 return 0
return float(self.data_buffer[-1]) return float(self.data_buffer[-1])
def add_data_point(self, value): def add_data_point(self, value: float) -> None:
prev_len = len(self.data_buffer) prev_len = len(self.data_buffer)
# Use raw value for idle skipping in BasePredictor. We may transform later. # Use raw value for idle skipping in BasePredictor. We may transform later.
super().add_data_point(value) super().add_data_point(value)
...@@ -145,7 +146,7 @@ class ARIMAPredictor(BasePredictor): ...@@ -145,7 +146,7 @@ class ARIMAPredictor(BasePredictor):
if self._mode == ARIMAPredictor.Mode.LOG1P: if self._mode == ARIMAPredictor.Mode.LOG1P:
self.data_buffer[-1] = math.log1p(raw) self.data_buffer[-1] = math.log1p(raw)
def predict_next(self): def predict_next(self) -> float:
"""Predict the next value(s)""" """Predict the next value(s)"""
if len(self._raw_buffer) < self.minimum_data_points: if len(self._raw_buffer) < self.minimum_data_points:
return self.get_last_value() return self.get_last_value()
...@@ -234,6 +235,7 @@ class ARIMAPredictor(BasePredictor): ...@@ -234,6 +235,7 @@ class ARIMAPredictor(BasePredictor):
self._pending_raw_updates = [] self._pending_raw_updates = []
# Make prediction # Make prediction
assert self.model is not None
forecast = float(self.model.predict(n_periods=1)[0]) forecast = float(self.model.predict(n_periods=1)[0])
if self._mode == ARIMAPredictor.Mode.LOG1P: if self._mode == ARIMAPredictor.Mode.LOG1P:
return max(0.0, math.expm1(forecast)) return max(0.0, math.expm1(forecast))
...@@ -247,7 +249,7 @@ class ARIMAPredictor(BasePredictor): ...@@ -247,7 +249,7 @@ class ARIMAPredictor(BasePredictor):
# Time-series forecasting model from Meta # Time-series forecasting model from Meta
class ProphetPredictor(BasePredictor): class ProphetPredictor(BasePredictor):
def __init__(self, config: PlannerConfig): def __init__(self, config: PlannerConfig) -> None:
super().__init__(minimum_data_points=5) super().__init__(minimum_data_points=5)
self._use_log1p = config.load_predictor_log1p self._use_log1p = config.load_predictor_log1p
self.window_size = config.prophet_window_size self.window_size = config.prophet_window_size
...@@ -257,7 +259,7 @@ class ProphetPredictor(BasePredictor): ...@@ -257,7 +259,7 @@ class ProphetPredictor(BasePredictor):
self.data_buffer = [] # Override to store dicts instead of values self.data_buffer = [] # Override to store dicts instead of values
self._seen_nonzero_since_idle_reset = False self._seen_nonzero_since_idle_reset = False
def add_data_point(self, value): def add_data_point(self, value: float) -> None:
"""Add new data point to the buffer""" """Add new data point to the buffer"""
# Use proper datetime for Prophet # Use proper datetime for Prophet
timestamp = self.start_date + timedelta(seconds=self.curr_step * self.step_size) timestamp = self.start_date + timedelta(seconds=self.curr_step * self.step_size)
...@@ -279,14 +281,14 @@ class ProphetPredictor(BasePredictor): ...@@ -279,14 +281,14 @@ class ProphetPredictor(BasePredictor):
if len(self.data_buffer) > self.window_size: if len(self.data_buffer) > self.window_size:
self.data_buffer = self.data_buffer[-self.window_size :] self.data_buffer = self.data_buffer[-self.window_size :]
def get_last_value(self): def get_last_value(self) -> float:
"""Get the last value from the buffer""" """Get the last value from the buffer"""
if not self.data_buffer: if not self.data_buffer:
return 0 return 0
y = float(self.data_buffer[-1]["y"]) y = float(self.data_buffer[-1]["y"])
return max(0.0, math.expm1(y)) if self._use_log1p else y return max(0.0, math.expm1(y)) if self._use_log1p else y
def predict_next(self): def predict_next(self) -> float:
"""Predict the next value""" """Predict the next value"""
if len(self.data_buffer) < self.minimum_data_points: if len(self.data_buffer) < self.minimum_data_points:
return self.get_last_value() return self.get_last_value()
...@@ -322,7 +324,7 @@ class KalmanPredictor(BasePredictor): ...@@ -322,7 +324,7 @@ class KalmanPredictor(BasePredictor):
forecasting in bursty systems. forecasting in bursty systems.
""" """
def __init__(self, config: PlannerConfig): def __init__(self, config: PlannerConfig) -> None:
super().__init__(minimum_data_points=config.kalman_min_points) super().__init__(minimum_data_points=config.kalman_min_points)
self._use_log1p = config.load_predictor_log1p self._use_log1p = config.load_predictor_log1p
q_level = config.kalman_q_level q_level = config.kalman_q_level
...@@ -348,7 +350,7 @@ class KalmanPredictor(BasePredictor): ...@@ -348,7 +350,7 @@ class KalmanPredictor(BasePredictor):
self._has_cached_pred = False self._has_cached_pred = False
self._cached_pred: float = 0.0 self._cached_pred: float = 0.0
def add_data_point(self, value): def add_data_point(self, value: float) -> None:
prev_len = len(self.data_buffer) prev_len = len(self.data_buffer)
super().add_data_point(value) super().add_data_point(value)
if len(self.data_buffer) == prev_len: if len(self.data_buffer) == prev_len:
...@@ -367,7 +369,7 @@ class KalmanPredictor(BasePredictor): ...@@ -367,7 +369,7 @@ class KalmanPredictor(BasePredictor):
# Consumed this step; clear cached forecast for next interval. # Consumed this step; clear cached forecast for next interval.
self._has_cached_pred = False self._has_cached_pred = False
def predict_next(self): def predict_next(self) -> float:
if not self._initialized: if not self._initialized:
return self.get_last_value() return self.get_last_value()
if self._has_cached_pred: if self._has_cached_pred:
......
...@@ -248,7 +248,7 @@ class BasePlanner: ...@@ -248,7 +248,7 @@ class BasePlanner:
def __init__( def __init__(
self, self,
runtime: Optional[DistributedRuntime], runtime: DistributedRuntime,
config: PlannerConfig, config: PlannerConfig,
dryrun: bool = False, dryrun: bool = False,
shared_state: Optional[PlannerSharedState] = None, shared_state: Optional[PlannerSharedState] = None,
...@@ -389,6 +389,7 @@ class BasePlanner: ...@@ -389,6 +389,7 @@ class BasePlanner:
self.config.backend self.config.backend
].decode_worker_k8s_name ].decode_worker_k8s_name
self.prometheus_metrics: PlannerPrometheusMetrics | None = None
if not self.dryrun: if not self.dryrun:
self.prefill_client = None self.prefill_client = None
self.workers_client = None self.workers_client = None
...@@ -665,7 +666,7 @@ class BasePlanner: ...@@ -665,7 +666,7 @@ class BasePlanner:
self.isl_predictor.add_data_point(metrics.isl) self.isl_predictor.add_data_point(metrics.isl)
self.osl_predictor.add_data_point(metrics.osl) self.osl_predictor.add_data_point(metrics.osl)
def predict_load(self): def predict_load(self) -> tuple[Optional[float], Optional[float], Optional[float]]:
try: try:
# predict the next load # predict the next load
next_num_req = self.num_req_predictor.predict_next() next_num_req = self.num_req_predictor.predict_next()
...@@ -948,7 +949,7 @@ class BasePlanner: ...@@ -948,7 +949,7 @@ class BasePlanner:
logger.info(f"Detected model name from deployment: {model_name}") logger.info(f"Detected model name from deployment: {model_name}")
self.model_name = model_name.lower() self.model_name = model_name.lower()
else: else:
model_name = getattr(self.config, "model_name", None) model_name = getattr(self.config, "model_name", "")
if not model_name: if not model_name:
raise ValueError( raise ValueError(
"Model name is required in no-operation mode. " "Model name is required in no-operation mode. "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment