Unverified Commit 63e7b7da authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

chore: add mypy to planner and mocker (#6862)

parent 6216ae55
......@@ -90,7 +90,7 @@ def resolve_planner_profile_data(
)
def create_temp_engine_args_file(args) -> Path:
def create_temp_engine_args_file(args: argparse.Namespace) -> Path:
"""
Create a temporary JSON file with MockEngineArgs from CLI arguments.
Returns the path to the temporary file.
......@@ -146,7 +146,7 @@ def create_temp_engine_args_file(args) -> Path:
return temp_path
def validate_worker_type_args(args):
def validate_worker_type_args(args: argparse.Namespace) -> None:
"""
Resolve disaggregation mode from --disaggregation-mode or legacy boolean flags.
Raises ValueError if validation fails.
......@@ -199,7 +199,7 @@ def parse_bootstrap_ports(ports_str: str | None) -> list[int]:
return [int(p.strip()) for p in ports_str.split(",")]
def parse_args():
def parse_args() -> argparse.Namespace:
"""Parse command-line arguments for the Dynamo mocker engine.
Returns:
......
......@@ -4,6 +4,7 @@
# Usage: `python -m dynamo.mocker --model-path /data/models/Qwen3-0.6B`
# Now supports vLLM-style individual arguments for MockEngineArgs
import argparse
import asyncio
import json
import logging
......@@ -135,7 +136,7 @@ def compute_stagger_delay(num_workers: int, stagger_delay: float) -> float:
return 0.2
async def launch_workers(args, extra_engine_args_path):
async def launch_workers(args: argparse.Namespace, extra_engine_args_path: Path):
"""Launch mocker worker(s) with isolated DistributedRuntime instances.
Each worker gets its own DistributedRuntime, which means:
......@@ -185,7 +186,9 @@ async def launch_workers(args, extra_engine_args_path):
runtimes.append(runtime)
# Determine which engine args file to use
worker_engine_args_path: Path | str
if needs_per_worker_args:
assert base_engine_args is not None
worker_args = base_engine_args.copy()
if args.bootstrap_ports_list:
worker_args["bootstrap_port"] = args.bootstrap_ports_list[worker_id]
......@@ -195,9 +198,9 @@ async def launch_workers(args, extra_engine_args_path):
]
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as f:
json.dump(worker_args, f)
worker_engine_args_path = Path(f.name)
) as tmp:
json.dump(worker_args, tmp)
worker_engine_args_path = Path(tmp.name)
per_worker_temp_files.append(worker_engine_args_path)
logger.debug(f"Worker {worker_id}: per-worker args {worker_args}")
else:
......@@ -209,7 +212,7 @@ async def launch_workers(args, extra_engine_args_path):
model_path=args.model_path,
model_name=args.model_name,
endpoint_id=args.endpoint,
extra_engine_args=worker_engine_args_path,
extra_engine_args=str(worker_engine_args_path),
is_prefill=args.is_prefill_worker,
)
......
......@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Any
from transformers import AutoConfig
......@@ -43,7 +44,7 @@ def _normalize_dtype_str(dtype) -> str:
return s
def get_kv_cache_dtype_bytes(config, kv_cache_dtype: str = "auto") -> int:
def get_kv_cache_dtype_bytes(config: Any, kv_cache_dtype: str = "auto") -> int:
"""Get the byte size per element for KV cache based on dtype.
When kv_cache_dtype is "auto", uses the model's dtype from config.
......
......@@ -15,15 +15,21 @@
from abc import ABC, abstractmethod
from dynamo.planner.defaults import SubComponentType
# TODO: add ability to scale component to X replicas
class PlannerConnector(ABC):
@abstractmethod
async def add_component(self, component_name):
async def add_component(
self, sub_component_type: SubComponentType, blocking: bool = True
) -> None:
"""Add a component to the planner"""
pass
@abstractmethod
async def remove_component(self, component_name):
async def remove_component(
self, sub_component_type: SubComponentType, blocking: bool = True
) -> None:
"""Remove a component from the planner"""
pass
......@@ -6,6 +6,7 @@
import asyncio
import logging
from dynamo._core import Client
from dynamo.planner.defaults import SubComponentType
from dynamo.planner.scale_protocol import ScaleRequest, ScaleResponse
from dynamo.runtime import DistributedRuntime
......@@ -29,7 +30,7 @@ class RemotePlannerClient:
self.central_component = central_component
self.connection_timeout = connection_timeout
self.max_retries = max_retries
self._client = None
self._client: Client | None = None
async def _ensure_client(self):
"""Lazy initialization of endpoint client with retry mechanism"""
......@@ -39,7 +40,7 @@ class RemotePlannerClient:
)
# Retry logic with exponential backoff
last_error = None
last_error: Exception | None = None
for attempt in range(self.max_retries):
try:
logger.info(
......@@ -101,6 +102,7 @@ class RemotePlannerClient:
# Send request via the runtime client's generate method (the correct API for
# calling any dynamo endpoint, regardless of its registered name)
request_json = request.model_dump_json()
assert self._client is not None
stream = await self._client.generate(request_json)
response_data = None
......
......@@ -40,9 +40,7 @@ class AggPlanner:
# Engine metrics from agg workers are labeled "decode" by the router
ENGINE_WORKER_TYPE = "decode"
def __init__(
self, runtime: Optional[DistributedRuntime], config: PlannerConfig
) -> None:
def __init__(self, runtime: DistributedRuntime, config: PlannerConfig) -> None:
self.config = config
self.shared_state = PlannerSharedState()
......
......@@ -19,6 +19,7 @@ import warnings
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from enum import Enum
from typing import Any
import numpy as np
import pandas as pd
......@@ -55,19 +56,19 @@ for _name in (
class BasePredictor(ABC):
"""Base class for all load predictors"""
def __init__(self, minimum_data_points=5):
def __init__(self, minimum_data_points: int = 5) -> None:
self.minimum_data_points = minimum_data_points
self.data_buffer = []
self.data_buffer: list[Any] = []
# Even if we preload historical data, we still want to ignore the initial
# post-deployment idle period (a run of zeros) until we see the first
# non-zero datapoint from live traffic.
self._seen_nonzero_since_idle_reset = False
def reset_idle_skip(self):
def reset_idle_skip(self) -> None:
"""Reset idle-period skipping state (e.g., after warmup, before live)."""
self._seen_nonzero_since_idle_reset = False
def add_data_point(self, value):
def add_data_point(self, value: float) -> None:
"""Add new data point to the buffer"""
if math.isnan(value):
value = 0
......@@ -82,14 +83,14 @@ class BasePredictor(ABC):
self.data_buffer.append(value)
def get_last_value(self):
def get_last_value(self) -> float:
"""Get the last value from the buffer"""
if not self.data_buffer:
return 0
return self.data_buffer[-1]
@abstractmethod
def predict_next(self):
def predict_next(self) -> float:
"""Predict the next value"""
pass
......@@ -99,10 +100,10 @@ class ConstantPredictor(BasePredictor):
Assume load is constant and predict the next load to be the same as most recent load
"""
def __init__(self, _config: PlannerConfig):
def __init__(self, _config: PlannerConfig) -> None:
super().__init__(minimum_data_points=1)
def predict_next(self):
def predict_next(self) -> float:
return self.get_last_value()
......@@ -112,7 +113,7 @@ class ARIMAPredictor(BasePredictor):
RAW = "raw"
LOG1P = "log1p"
def __init__(self, config: PlannerConfig):
def __init__(self, config: PlannerConfig) -> None:
super().__init__(minimum_data_points=5)
self.model = None
# Keep raw values so we can fit in raw space first, then fallback to log1p space.
......@@ -125,7 +126,7 @@ class ARIMAPredictor(BasePredictor):
)
self._mode: ARIMAPredictor.Mode = self._requested_mode
def get_last_value(self):
def get_last_value(self) -> float:
"""Return last value in original scale."""
if self._raw_buffer:
return float(self._raw_buffer[-1])
......@@ -133,7 +134,7 @@ class ARIMAPredictor(BasePredictor):
return 0
return float(self.data_buffer[-1])
def add_data_point(self, value):
def add_data_point(self, value: float) -> None:
prev_len = len(self.data_buffer)
# Use raw value for idle skipping in BasePredictor. We may transform later.
super().add_data_point(value)
......@@ -145,7 +146,7 @@ class ARIMAPredictor(BasePredictor):
if self._mode == ARIMAPredictor.Mode.LOG1P:
self.data_buffer[-1] = math.log1p(raw)
def predict_next(self):
def predict_next(self) -> float:
"""Predict the next value(s)"""
if len(self._raw_buffer) < self.minimum_data_points:
return self.get_last_value()
......@@ -234,6 +235,7 @@ class ARIMAPredictor(BasePredictor):
self._pending_raw_updates = []
# Make prediction
assert self.model is not None
forecast = float(self.model.predict(n_periods=1)[0])
if self._mode == ARIMAPredictor.Mode.LOG1P:
return max(0.0, math.expm1(forecast))
......@@ -247,7 +249,7 @@ class ARIMAPredictor(BasePredictor):
# Time-series forecasting model from Meta
class ProphetPredictor(BasePredictor):
def __init__(self, config: PlannerConfig):
def __init__(self, config: PlannerConfig) -> None:
super().__init__(minimum_data_points=5)
self._use_log1p = config.load_predictor_log1p
self.window_size = config.prophet_window_size
......@@ -257,7 +259,7 @@ class ProphetPredictor(BasePredictor):
self.data_buffer = [] # Override to store dicts instead of values
self._seen_nonzero_since_idle_reset = False
def add_data_point(self, value):
def add_data_point(self, value: float) -> None:
"""Add new data point to the buffer"""
# Use proper datetime for Prophet
timestamp = self.start_date + timedelta(seconds=self.curr_step * self.step_size)
......@@ -279,14 +281,14 @@ class ProphetPredictor(BasePredictor):
if len(self.data_buffer) > self.window_size:
self.data_buffer = self.data_buffer[-self.window_size :]
def get_last_value(self):
def get_last_value(self) -> float:
"""Get the last value from the buffer"""
if not self.data_buffer:
return 0
y = float(self.data_buffer[-1]["y"])
return max(0.0, math.expm1(y)) if self._use_log1p else y
def predict_next(self):
def predict_next(self) -> float:
"""Predict the next value"""
if len(self.data_buffer) < self.minimum_data_points:
return self.get_last_value()
......@@ -322,7 +324,7 @@ class KalmanPredictor(BasePredictor):
forecasting in bursty systems.
"""
def __init__(self, config: PlannerConfig):
def __init__(self, config: PlannerConfig) -> None:
super().__init__(minimum_data_points=config.kalman_min_points)
self._use_log1p = config.load_predictor_log1p
q_level = config.kalman_q_level
......@@ -348,7 +350,7 @@ class KalmanPredictor(BasePredictor):
self._has_cached_pred = False
self._cached_pred: float = 0.0
def add_data_point(self, value):
def add_data_point(self, value: float) -> None:
prev_len = len(self.data_buffer)
super().add_data_point(value)
if len(self.data_buffer) == prev_len:
......@@ -367,7 +369,7 @@ class KalmanPredictor(BasePredictor):
# Consumed this step; clear cached forecast for next interval.
self._has_cached_pred = False
def predict_next(self):
def predict_next(self) -> float:
if not self._initialized:
return self.get_last_value()
if self._has_cached_pred:
......
......@@ -248,7 +248,7 @@ class BasePlanner:
def __init__(
self,
runtime: Optional[DistributedRuntime],
runtime: DistributedRuntime,
config: PlannerConfig,
dryrun: bool = False,
shared_state: Optional[PlannerSharedState] = None,
......@@ -389,6 +389,7 @@ class BasePlanner:
self.config.backend
].decode_worker_k8s_name
self.prometheus_metrics: PlannerPrometheusMetrics | None = None
if not self.dryrun:
self.prefill_client = None
self.workers_client = None
......@@ -665,7 +666,7 @@ class BasePlanner:
self.isl_predictor.add_data_point(metrics.isl)
self.osl_predictor.add_data_point(metrics.osl)
def predict_load(self):
def predict_load(self) -> tuple[Optional[float], Optional[float], Optional[float]]:
try:
# predict the next load
next_num_req = self.num_req_predictor.predict_next()
......@@ -948,7 +949,7 @@ class BasePlanner:
logger.info(f"Detected model name from deployment: {model_name}")
self.model_name = model_name.lower()
else:
model_name = getattr(self.config, "model_name", None)
model_name = getattr(self.config, "model_name", "")
if not model_name:
raise ValueError(
"Model name is required in no-operation mode. "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment