Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
63e7b7da
Unverified
Commit
63e7b7da
authored
Mar 04, 2026
by
jh-nv
Committed by
GitHub
Mar 04, 2026
Browse files
chore: add mypy to planner and mocker (#6862)
parent
6216ae55
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
51 additions
and
38 deletions
+51
-38
components/src/dynamo/mocker/args.py
components/src/dynamo/mocker/args.py
+3
-3
components/src/dynamo/mocker/main.py
components/src/dynamo/mocker/main.py
+8
-5
components/src/dynamo/mocker/utils/kv_cache.py
components/src/dynamo/mocker/utils/kv_cache.py
+2
-1
components/src/dynamo/planner/planner_connector.py
components/src/dynamo/planner/planner_connector.py
+8
-2
components/src/dynamo/planner/remote_planner_client.py
components/src/dynamo/planner/remote_planner_client.py
+4
-2
components/src/dynamo/planner/utils/agg_planner.py
components/src/dynamo/planner/utils/agg_planner.py
+1
-3
components/src/dynamo/planner/utils/load_predictor.py
components/src/dynamo/planner/utils/load_predictor.py
+21
-19
components/src/dynamo/planner/utils/planner_core.py
components/src/dynamo/planner/utils/planner_core.py
+4
-3
No files found.
components/src/dynamo/mocker/args.py
View file @
63e7b7da
...
...
@@ -90,7 +90,7 @@ def resolve_planner_profile_data(
)
def
create_temp_engine_args_file
(
args
)
->
Path
:
def
create_temp_engine_args_file
(
args
:
argparse
.
Namespace
)
->
Path
:
"""
Create a temporary JSON file with MockEngineArgs from CLI arguments.
Returns the path to the temporary file.
...
...
@@ -146,7 +146,7 @@ def create_temp_engine_args_file(args) -> Path:
return
temp_path
def
validate_worker_type_args
(
args
)
:
def
validate_worker_type_args
(
args
:
argparse
.
Namespace
)
->
None
:
"""
Resolve disaggregation mode from --disaggregation-mode or legacy boolean flags.
Raises ValueError if validation fails.
...
...
@@ -199,7 +199,7 @@ def parse_bootstrap_ports(ports_str: str | None) -> list[int]:
return
[
int
(
p
.
strip
())
for
p
in
ports_str
.
split
(
","
)]
def
parse_args
():
def
parse_args
()
->
argparse
.
Namespace
:
"""Parse command-line arguments for the Dynamo mocker engine.
Returns:
...
...
components/src/dynamo/mocker/main.py
View file @
63e7b7da
...
...
@@ -4,6 +4,7 @@
# Usage: `python -m dynamo.mocker --model-path /data/models/Qwen3-0.6B`
# Now supports vLLM-style individual arguments for MockEngineArgs
import
argparse
import
asyncio
import
json
import
logging
...
...
@@ -135,7 +136,7 @@ def compute_stagger_delay(num_workers: int, stagger_delay: float) -> float:
return
0.2
async
def
launch_workers
(
args
,
extra_engine_args_path
):
async
def
launch_workers
(
args
:
argparse
.
Namespace
,
extra_engine_args_path
:
Path
):
"""Launch mocker worker(s) with isolated DistributedRuntime instances.
Each worker gets its own DistributedRuntime, which means:
...
...
@@ -185,7 +186,9 @@ async def launch_workers(args, extra_engine_args_path):
runtimes
.
append
(
runtime
)
# Determine which engine args file to use
worker_engine_args_path
:
Path
|
str
if
needs_per_worker_args
:
assert
base_engine_args
is
not
None
worker_args
=
base_engine_args
.
copy
()
if
args
.
bootstrap_ports_list
:
worker_args
[
"bootstrap_port"
]
=
args
.
bootstrap_ports_list
[
worker_id
]
...
...
@@ -195,9 +198,9 @@ async def launch_workers(args, extra_engine_args_path):
]
with
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".json"
,
delete
=
False
)
as
f
:
json
.
dump
(
worker_args
,
f
)
worker_engine_args_path
=
Path
(
f
.
name
)
)
as
tmp
:
json
.
dump
(
worker_args
,
tmp
)
worker_engine_args_path
=
Path
(
tmp
.
name
)
per_worker_temp_files
.
append
(
worker_engine_args_path
)
logger
.
debug
(
f
"Worker
{
worker_id
}
: per-worker args
{
worker_args
}
"
)
else
:
...
...
@@ -209,7 +212,7 @@ async def launch_workers(args, extra_engine_args_path):
model_path
=
args
.
model_path
,
model_name
=
args
.
model_name
,
endpoint_id
=
args
.
endpoint
,
extra_engine_args
=
worker_engine_args_path
,
extra_engine_args
=
str
(
worker_engine_args_path
)
,
is_prefill
=
args
.
is_prefill_worker
,
)
...
...
components/src/dynamo/mocker/utils/kv_cache.py
View file @
63e7b7da
...
...
@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import
logging
from
typing
import
Any
from
transformers
import
AutoConfig
...
...
@@ -43,7 +44,7 @@ def _normalize_dtype_str(dtype) -> str:
return
s
def
get_kv_cache_dtype_bytes
(
config
,
kv_cache_dtype
:
str
=
"auto"
)
->
int
:
def
get_kv_cache_dtype_bytes
(
config
:
Any
,
kv_cache_dtype
:
str
=
"auto"
)
->
int
:
"""Get the byte size per element for KV cache based on dtype.
When kv_cache_dtype is "auto", uses the model's dtype from config.
...
...
components/src/dynamo/planner/planner_connector.py
View file @
63e7b7da
...
...
@@ -15,15 +15,21 @@
from
abc
import
ABC
,
abstractmethod
from
dynamo.planner.defaults
import
SubComponentType
# TODO: add ability to scale component to X replicas
class
PlannerConnector
(
ABC
):
@
abstractmethod
async
def
add_component
(
self
,
component_name
):
async
def
add_component
(
self
,
sub_component_type
:
SubComponentType
,
blocking
:
bool
=
True
)
->
None
:
"""Add a component to the planner"""
pass
@
abstractmethod
async
def
remove_component
(
self
,
component_name
):
async
def
remove_component
(
self
,
sub_component_type
:
SubComponentType
,
blocking
:
bool
=
True
)
->
None
:
"""Remove a component from the planner"""
pass
components/src/dynamo/planner/remote_planner_client.py
View file @
63e7b7da
...
...
@@ -6,6 +6,7 @@
import
asyncio
import
logging
from
dynamo._core
import
Client
from
dynamo.planner.defaults
import
SubComponentType
from
dynamo.planner.scale_protocol
import
ScaleRequest
,
ScaleResponse
from
dynamo.runtime
import
DistributedRuntime
...
...
@@ -29,7 +30,7 @@ class RemotePlannerClient:
self
.
central_component
=
central_component
self
.
connection_timeout
=
connection_timeout
self
.
max_retries
=
max_retries
self
.
_client
=
None
self
.
_client
:
Client
|
None
=
None
async
def
_ensure_client
(
self
):
"""Lazy initialization of endpoint client with retry mechanism"""
...
...
@@ -39,7 +40,7 @@ class RemotePlannerClient:
)
# Retry logic with exponential backoff
last_error
=
None
last_error
:
Exception
|
None
=
None
for
attempt
in
range
(
self
.
max_retries
):
try
:
logger
.
info
(
...
...
@@ -101,6 +102,7 @@ class RemotePlannerClient:
# Send request via the runtime client's generate method (the correct API for
# calling any dynamo endpoint, regardless of its registered name)
request_json
=
request
.
model_dump_json
()
assert
self
.
_client
is
not
None
stream
=
await
self
.
_client
.
generate
(
request_json
)
response_data
=
None
...
...
components/src/dynamo/planner/utils/agg_planner.py
View file @
63e7b7da
...
...
@@ -40,9 +40,7 @@ class AggPlanner:
# Engine metrics from agg workers are labeled "decode" by the router
ENGINE_WORKER_TYPE
=
"decode"
def
__init__
(
self
,
runtime
:
Optional
[
DistributedRuntime
],
config
:
PlannerConfig
)
->
None
:
def
__init__
(
self
,
runtime
:
DistributedRuntime
,
config
:
PlannerConfig
)
->
None
:
self
.
config
=
config
self
.
shared_state
=
PlannerSharedState
()
...
...
components/src/dynamo/planner/utils/load_predictor.py
View file @
63e7b7da
...
...
@@ -19,6 +19,7 @@ import warnings
from
abc
import
ABC
,
abstractmethod
from
datetime
import
datetime
,
timedelta
from
enum
import
Enum
from
typing
import
Any
import
numpy
as
np
import
pandas
as
pd
...
...
@@ -55,19 +56,19 @@ for _name in (
class
BasePredictor
(
ABC
):
"""Base class for all load predictors"""
def
__init__
(
self
,
minimum_data_points
=
5
)
:
def
__init__
(
self
,
minimum_data_points
:
int
=
5
)
->
None
:
self
.
minimum_data_points
=
minimum_data_points
self
.
data_buffer
=
[]
self
.
data_buffer
:
list
[
Any
]
=
[]
# Even if we preload historical data, we still want to ignore the initial
# post-deployment idle period (a run of zeros) until we see the first
# non-zero datapoint from live traffic.
self
.
_seen_nonzero_since_idle_reset
=
False
def
reset_idle_skip
(
self
):
def
reset_idle_skip
(
self
)
->
None
:
"""Reset idle-period skipping state (e.g., after warmup, before live)."""
self
.
_seen_nonzero_since_idle_reset
=
False
def
add_data_point
(
self
,
value
)
:
def
add_data_point
(
self
,
value
:
float
)
->
None
:
"""Add new data point to the buffer"""
if
math
.
isnan
(
value
):
value
=
0
...
...
@@ -82,14 +83,14 @@ class BasePredictor(ABC):
self
.
data_buffer
.
append
(
value
)
def
get_last_value
(
self
):
def
get_last_value
(
self
)
->
float
:
"""Get the last value from the buffer"""
if
not
self
.
data_buffer
:
return
0
return
self
.
data_buffer
[
-
1
]
@
abstractmethod
def
predict_next
(
self
):
def
predict_next
(
self
)
->
float
:
"""Predict the next value"""
pass
...
...
@@ -99,10 +100,10 @@ class ConstantPredictor(BasePredictor):
Assume load is constant and predict the next load to be the same as most recent load
"""
def
__init__
(
self
,
_config
:
PlannerConfig
):
def
__init__
(
self
,
_config
:
PlannerConfig
)
->
None
:
super
().
__init__
(
minimum_data_points
=
1
)
def
predict_next
(
self
):
def
predict_next
(
self
)
->
float
:
return
self
.
get_last_value
()
...
...
@@ -112,7 +113,7 @@ class ARIMAPredictor(BasePredictor):
RAW
=
"raw"
LOG1P
=
"log1p"
def
__init__
(
self
,
config
:
PlannerConfig
):
def
__init__
(
self
,
config
:
PlannerConfig
)
->
None
:
super
().
__init__
(
minimum_data_points
=
5
)
self
.
model
=
None
# Keep raw values so we can fit in raw space first, then fallback to log1p space.
...
...
@@ -125,7 +126,7 @@ class ARIMAPredictor(BasePredictor):
)
self
.
_mode
:
ARIMAPredictor
.
Mode
=
self
.
_requested_mode
def
get_last_value
(
self
):
def
get_last_value
(
self
)
->
float
:
"""Return last value in original scale."""
if
self
.
_raw_buffer
:
return
float
(
self
.
_raw_buffer
[
-
1
])
...
...
@@ -133,7 +134,7 @@ class ARIMAPredictor(BasePredictor):
return
0
return
float
(
self
.
data_buffer
[
-
1
])
def
add_data_point
(
self
,
value
)
:
def
add_data_point
(
self
,
value
:
float
)
->
None
:
prev_len
=
len
(
self
.
data_buffer
)
# Use raw value for idle skipping in BasePredictor. We may transform later.
super
().
add_data_point
(
value
)
...
...
@@ -145,7 +146,7 @@ class ARIMAPredictor(BasePredictor):
if
self
.
_mode
==
ARIMAPredictor
.
Mode
.
LOG1P
:
self
.
data_buffer
[
-
1
]
=
math
.
log1p
(
raw
)
def
predict_next
(
self
):
def
predict_next
(
self
)
->
float
:
"""Predict the next value(s)"""
if
len
(
self
.
_raw_buffer
)
<
self
.
minimum_data_points
:
return
self
.
get_last_value
()
...
...
@@ -234,6 +235,7 @@ class ARIMAPredictor(BasePredictor):
self
.
_pending_raw_updates
=
[]
# Make prediction
assert
self
.
model
is
not
None
forecast
=
float
(
self
.
model
.
predict
(
n_periods
=
1
)[
0
])
if
self
.
_mode
==
ARIMAPredictor
.
Mode
.
LOG1P
:
return
max
(
0.0
,
math
.
expm1
(
forecast
))
...
...
@@ -247,7 +249,7 @@ class ARIMAPredictor(BasePredictor):
# Time-series forecasting model from Meta
class
ProphetPredictor
(
BasePredictor
):
def
__init__
(
self
,
config
:
PlannerConfig
):
def
__init__
(
self
,
config
:
PlannerConfig
)
->
None
:
super
().
__init__
(
minimum_data_points
=
5
)
self
.
_use_log1p
=
config
.
load_predictor_log1p
self
.
window_size
=
config
.
prophet_window_size
...
...
@@ -257,7 +259,7 @@ class ProphetPredictor(BasePredictor):
self
.
data_buffer
=
[]
# Override to store dicts instead of values
self
.
_seen_nonzero_since_idle_reset
=
False
def
add_data_point
(
self
,
value
)
:
def
add_data_point
(
self
,
value
:
float
)
->
None
:
"""Add new data point to the buffer"""
# Use proper datetime for Prophet
timestamp
=
self
.
start_date
+
timedelta
(
seconds
=
self
.
curr_step
*
self
.
step_size
)
...
...
@@ -279,14 +281,14 @@ class ProphetPredictor(BasePredictor):
if
len
(
self
.
data_buffer
)
>
self
.
window_size
:
self
.
data_buffer
=
self
.
data_buffer
[
-
self
.
window_size
:]
def
get_last_value
(
self
):
def
get_last_value
(
self
)
->
float
:
"""Get the last value from the buffer"""
if
not
self
.
data_buffer
:
return
0
y
=
float
(
self
.
data_buffer
[
-
1
][
"y"
])
return
max
(
0.0
,
math
.
expm1
(
y
))
if
self
.
_use_log1p
else
y
def
predict_next
(
self
):
def
predict_next
(
self
)
->
float
:
"""Predict the next value"""
if
len
(
self
.
data_buffer
)
<
self
.
minimum_data_points
:
return
self
.
get_last_value
()
...
...
@@ -322,7 +324,7 @@ class KalmanPredictor(BasePredictor):
forecasting in bursty systems.
"""
def
__init__
(
self
,
config
:
PlannerConfig
):
def
__init__
(
self
,
config
:
PlannerConfig
)
->
None
:
super
().
__init__
(
minimum_data_points
=
config
.
kalman_min_points
)
self
.
_use_log1p
=
config
.
load_predictor_log1p
q_level
=
config
.
kalman_q_level
...
...
@@ -348,7 +350,7 @@ class KalmanPredictor(BasePredictor):
self
.
_has_cached_pred
=
False
self
.
_cached_pred
:
float
=
0.0
def
add_data_point
(
self
,
value
)
:
def
add_data_point
(
self
,
value
:
float
)
->
None
:
prev_len
=
len
(
self
.
data_buffer
)
super
().
add_data_point
(
value
)
if
len
(
self
.
data_buffer
)
==
prev_len
:
...
...
@@ -367,7 +369,7 @@ class KalmanPredictor(BasePredictor):
# Consumed this step; clear cached forecast for next interval.
self
.
_has_cached_pred
=
False
def
predict_next
(
self
):
def
predict_next
(
self
)
->
float
:
if
not
self
.
_initialized
:
return
self
.
get_last_value
()
if
self
.
_has_cached_pred
:
...
...
components/src/dynamo/planner/utils/planner_core.py
View file @
63e7b7da
...
...
@@ -248,7 +248,7 @@ class BasePlanner:
def
__init__
(
self
,
runtime
:
Optional
[
DistributedRuntime
]
,
runtime
:
DistributedRuntime
,
config
:
PlannerConfig
,
dryrun
:
bool
=
False
,
shared_state
:
Optional
[
PlannerSharedState
]
=
None
,
...
...
@@ -389,6 +389,7 @@ class BasePlanner:
self
.
config
.
backend
].
decode_worker_k8s_name
self
.
prometheus_metrics
:
PlannerPrometheusMetrics
|
None
=
None
if
not
self
.
dryrun
:
self
.
prefill_client
=
None
self
.
workers_client
=
None
...
...
@@ -665,7 +666,7 @@ class BasePlanner:
self
.
isl_predictor
.
add_data_point
(
metrics
.
isl
)
self
.
osl_predictor
.
add_data_point
(
metrics
.
osl
)
def
predict_load
(
self
):
def
predict_load
(
self
)
->
tuple
[
Optional
[
float
],
Optional
[
float
],
Optional
[
float
]]
:
try
:
# predict the next load
next_num_req
=
self
.
num_req_predictor
.
predict_next
()
...
...
@@ -948,7 +949,7 @@ class BasePlanner:
logger
.
info
(
f
"Detected model name from deployment:
{
model_name
}
"
)
self
.
model_name
=
model_name
.
lower
()
else
:
model_name
=
getattr
(
self
.
config
,
"model_name"
,
None
)
model_name
=
getattr
(
self
.
config
,
"model_name"
,
""
)
if
not
model_name
:
raise
ValueError
(
"Model name is required in no-operation mode. "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment