Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
63e7b7da
Unverified
Commit
63e7b7da
authored
Mar 04, 2026
by
jh-nv
Committed by
GitHub
Mar 04, 2026
Browse files
chore: add mypy to planner and mocker (#6862)
parent
6216ae55
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
51 additions
and
38 deletions
+51
-38
components/src/dynamo/mocker/args.py
components/src/dynamo/mocker/args.py
+3
-3
components/src/dynamo/mocker/main.py
components/src/dynamo/mocker/main.py
+8
-5
components/src/dynamo/mocker/utils/kv_cache.py
components/src/dynamo/mocker/utils/kv_cache.py
+2
-1
components/src/dynamo/planner/planner_connector.py
components/src/dynamo/planner/planner_connector.py
+8
-2
components/src/dynamo/planner/remote_planner_client.py
components/src/dynamo/planner/remote_planner_client.py
+4
-2
components/src/dynamo/planner/utils/agg_planner.py
components/src/dynamo/planner/utils/agg_planner.py
+1
-3
components/src/dynamo/planner/utils/load_predictor.py
components/src/dynamo/planner/utils/load_predictor.py
+21
-19
components/src/dynamo/planner/utils/planner_core.py
components/src/dynamo/planner/utils/planner_core.py
+4
-3
No files found.
components/src/dynamo/mocker/args.py
View file @
63e7b7da
...
@@ -90,7 +90,7 @@ def resolve_planner_profile_data(
...
@@ -90,7 +90,7 @@ def resolve_planner_profile_data(
)
)
def
create_temp_engine_args_file
(
args
)
->
Path
:
def
create_temp_engine_args_file
(
args
:
argparse
.
Namespace
)
->
Path
:
"""
"""
Create a temporary JSON file with MockEngineArgs from CLI arguments.
Create a temporary JSON file with MockEngineArgs from CLI arguments.
Returns the path to the temporary file.
Returns the path to the temporary file.
...
@@ -146,7 +146,7 @@ def create_temp_engine_args_file(args) -> Path:
...
@@ -146,7 +146,7 @@ def create_temp_engine_args_file(args) -> Path:
return
temp_path
return
temp_path
def
validate_worker_type_args
(
args
)
:
def
validate_worker_type_args
(
args
:
argparse
.
Namespace
)
->
None
:
"""
"""
Resolve disaggregation mode from --disaggregation-mode or legacy boolean flags.
Resolve disaggregation mode from --disaggregation-mode or legacy boolean flags.
Raises ValueError if validation fails.
Raises ValueError if validation fails.
...
@@ -199,7 +199,7 @@ def parse_bootstrap_ports(ports_str: str | None) -> list[int]:
...
@@ -199,7 +199,7 @@ def parse_bootstrap_ports(ports_str: str | None) -> list[int]:
return
[
int
(
p
.
strip
())
for
p
in
ports_str
.
split
(
","
)]
return
[
int
(
p
.
strip
())
for
p
in
ports_str
.
split
(
","
)]
def
parse_args
():
def
parse_args
()
->
argparse
.
Namespace
:
"""Parse command-line arguments for the Dynamo mocker engine.
"""Parse command-line arguments for the Dynamo mocker engine.
Returns:
Returns:
...
...
components/src/dynamo/mocker/main.py
View file @
63e7b7da
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
# Usage: `python -m dynamo.mocker --model-path /data/models/Qwen3-0.6B`
# Usage: `python -m dynamo.mocker --model-path /data/models/Qwen3-0.6B`
# Now supports vLLM-style individual arguments for MockEngineArgs
# Now supports vLLM-style individual arguments for MockEngineArgs
import
argparse
import
asyncio
import
asyncio
import
json
import
json
import
logging
import
logging
...
@@ -135,7 +136,7 @@ def compute_stagger_delay(num_workers: int, stagger_delay: float) -> float:
...
@@ -135,7 +136,7 @@ def compute_stagger_delay(num_workers: int, stagger_delay: float) -> float:
return
0.2
return
0.2
async
def
launch_workers
(
args
,
extra_engine_args_path
):
async
def
launch_workers
(
args
:
argparse
.
Namespace
,
extra_engine_args_path
:
Path
):
"""Launch mocker worker(s) with isolated DistributedRuntime instances.
"""Launch mocker worker(s) with isolated DistributedRuntime instances.
Each worker gets its own DistributedRuntime, which means:
Each worker gets its own DistributedRuntime, which means:
...
@@ -185,7 +186,9 @@ async def launch_workers(args, extra_engine_args_path):
...
@@ -185,7 +186,9 @@ async def launch_workers(args, extra_engine_args_path):
runtimes
.
append
(
runtime
)
runtimes
.
append
(
runtime
)
# Determine which engine args file to use
# Determine which engine args file to use
worker_engine_args_path
:
Path
|
str
if
needs_per_worker_args
:
if
needs_per_worker_args
:
assert
base_engine_args
is
not
None
worker_args
=
base_engine_args
.
copy
()
worker_args
=
base_engine_args
.
copy
()
if
args
.
bootstrap_ports_list
:
if
args
.
bootstrap_ports_list
:
worker_args
[
"bootstrap_port"
]
=
args
.
bootstrap_ports_list
[
worker_id
]
worker_args
[
"bootstrap_port"
]
=
args
.
bootstrap_ports_list
[
worker_id
]
...
@@ -195,9 +198,9 @@ async def launch_workers(args, extra_engine_args_path):
...
@@ -195,9 +198,9 @@ async def launch_workers(args, extra_engine_args_path):
]
]
with
tempfile
.
NamedTemporaryFile
(
with
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".json"
,
delete
=
False
mode
=
"w"
,
suffix
=
".json"
,
delete
=
False
)
as
f
:
)
as
tmp
:
json
.
dump
(
worker_args
,
f
)
json
.
dump
(
worker_args
,
tmp
)
worker_engine_args_path
=
Path
(
f
.
name
)
worker_engine_args_path
=
Path
(
tmp
.
name
)
per_worker_temp_files
.
append
(
worker_engine_args_path
)
per_worker_temp_files
.
append
(
worker_engine_args_path
)
logger
.
debug
(
f
"Worker
{
worker_id
}
: per-worker args
{
worker_args
}
"
)
logger
.
debug
(
f
"Worker
{
worker_id
}
: per-worker args
{
worker_args
}
"
)
else
:
else
:
...
@@ -209,7 +212,7 @@ async def launch_workers(args, extra_engine_args_path):
...
@@ -209,7 +212,7 @@ async def launch_workers(args, extra_engine_args_path):
model_path
=
args
.
model_path
,
model_path
=
args
.
model_path
,
model_name
=
args
.
model_name
,
model_name
=
args
.
model_name
,
endpoint_id
=
args
.
endpoint
,
endpoint_id
=
args
.
endpoint
,
extra_engine_args
=
worker_engine_args_path
,
extra_engine_args
=
str
(
worker_engine_args_path
)
,
is_prefill
=
args
.
is_prefill_worker
,
is_prefill
=
args
.
is_prefill_worker
,
)
)
...
...
components/src/dynamo/mocker/utils/kv_cache.py
View file @
63e7b7da
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
logging
import
logging
from
typing
import
Any
from
transformers
import
AutoConfig
from
transformers
import
AutoConfig
...
@@ -43,7 +44,7 @@ def _normalize_dtype_str(dtype) -> str:
...
@@ -43,7 +44,7 @@ def _normalize_dtype_str(dtype) -> str:
return
s
return
s
def
get_kv_cache_dtype_bytes
(
config
,
kv_cache_dtype
:
str
=
"auto"
)
->
int
:
def
get_kv_cache_dtype_bytes
(
config
:
Any
,
kv_cache_dtype
:
str
=
"auto"
)
->
int
:
"""Get the byte size per element for KV cache based on dtype.
"""Get the byte size per element for KV cache based on dtype.
When kv_cache_dtype is "auto", uses the model's dtype from config.
When kv_cache_dtype is "auto", uses the model's dtype from config.
...
...
components/src/dynamo/planner/planner_connector.py
View file @
63e7b7da
...
@@ -15,15 +15,21 @@
...
@@ -15,15 +15,21 @@
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dynamo.planner.defaults
import
SubComponentType
# TODO: add ability to scale component to X replicas
# TODO: add ability to scale component to X replicas
class
PlannerConnector
(
ABC
):
class
PlannerConnector
(
ABC
):
@
abstractmethod
@
abstractmethod
async
def
add_component
(
self
,
component_name
):
async
def
add_component
(
self
,
sub_component_type
:
SubComponentType
,
blocking
:
bool
=
True
)
->
None
:
"""Add a component to the planner"""
"""Add a component to the planner"""
pass
pass
@
abstractmethod
@
abstractmethod
async
def
remove_component
(
self
,
component_name
):
async
def
remove_component
(
self
,
sub_component_type
:
SubComponentType
,
blocking
:
bool
=
True
)
->
None
:
"""Remove a component from the planner"""
"""Remove a component from the planner"""
pass
pass
components/src/dynamo/planner/remote_planner_client.py
View file @
63e7b7da
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
import
asyncio
import
asyncio
import
logging
import
logging
from
dynamo._core
import
Client
from
dynamo.planner.defaults
import
SubComponentType
from
dynamo.planner.defaults
import
SubComponentType
from
dynamo.planner.scale_protocol
import
ScaleRequest
,
ScaleResponse
from
dynamo.planner.scale_protocol
import
ScaleRequest
,
ScaleResponse
from
dynamo.runtime
import
DistributedRuntime
from
dynamo.runtime
import
DistributedRuntime
...
@@ -29,7 +30,7 @@ class RemotePlannerClient:
...
@@ -29,7 +30,7 @@ class RemotePlannerClient:
self
.
central_component
=
central_component
self
.
central_component
=
central_component
self
.
connection_timeout
=
connection_timeout
self
.
connection_timeout
=
connection_timeout
self
.
max_retries
=
max_retries
self
.
max_retries
=
max_retries
self
.
_client
=
None
self
.
_client
:
Client
|
None
=
None
async
def
_ensure_client
(
self
):
async
def
_ensure_client
(
self
):
"""Lazy initialization of endpoint client with retry mechanism"""
"""Lazy initialization of endpoint client with retry mechanism"""
...
@@ -39,7 +40,7 @@ class RemotePlannerClient:
...
@@ -39,7 +40,7 @@ class RemotePlannerClient:
)
)
# Retry logic with exponential backoff
# Retry logic with exponential backoff
last_error
=
None
last_error
:
Exception
|
None
=
None
for
attempt
in
range
(
self
.
max_retries
):
for
attempt
in
range
(
self
.
max_retries
):
try
:
try
:
logger
.
info
(
logger
.
info
(
...
@@ -101,6 +102,7 @@ class RemotePlannerClient:
...
@@ -101,6 +102,7 @@ class RemotePlannerClient:
# Send request via the runtime client's generate method (the correct API for
# Send request via the runtime client's generate method (the correct API for
# calling any dynamo endpoint, regardless of its registered name)
# calling any dynamo endpoint, regardless of its registered name)
request_json
=
request
.
model_dump_json
()
request_json
=
request
.
model_dump_json
()
assert
self
.
_client
is
not
None
stream
=
await
self
.
_client
.
generate
(
request_json
)
stream
=
await
self
.
_client
.
generate
(
request_json
)
response_data
=
None
response_data
=
None
...
...
components/src/dynamo/planner/utils/agg_planner.py
View file @
63e7b7da
...
@@ -40,9 +40,7 @@ class AggPlanner:
...
@@ -40,9 +40,7 @@ class AggPlanner:
# Engine metrics from agg workers are labeled "decode" by the router
# Engine metrics from agg workers are labeled "decode" by the router
ENGINE_WORKER_TYPE
=
"decode"
ENGINE_WORKER_TYPE
=
"decode"
def
__init__
(
def
__init__
(
self
,
runtime
:
DistributedRuntime
,
config
:
PlannerConfig
)
->
None
:
self
,
runtime
:
Optional
[
DistributedRuntime
],
config
:
PlannerConfig
)
->
None
:
self
.
config
=
config
self
.
config
=
config
self
.
shared_state
=
PlannerSharedState
()
self
.
shared_state
=
PlannerSharedState
()
...
...
components/src/dynamo/planner/utils/load_predictor.py
View file @
63e7b7da
...
@@ -19,6 +19,7 @@ import warnings
...
@@ -19,6 +19,7 @@ import warnings
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
datetime
import
datetime
,
timedelta
from
datetime
import
datetime
,
timedelta
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Any
import
numpy
as
np
import
numpy
as
np
import
pandas
as
pd
import
pandas
as
pd
...
@@ -55,19 +56,19 @@ for _name in (
...
@@ -55,19 +56,19 @@ for _name in (
class
BasePredictor
(
ABC
):
class
BasePredictor
(
ABC
):
"""Base class for all load predictors"""
"""Base class for all load predictors"""
def
__init__
(
self
,
minimum_data_points
=
5
)
:
def
__init__
(
self
,
minimum_data_points
:
int
=
5
)
->
None
:
self
.
minimum_data_points
=
minimum_data_points
self
.
minimum_data_points
=
minimum_data_points
self
.
data_buffer
=
[]
self
.
data_buffer
:
list
[
Any
]
=
[]
# Even if we preload historical data, we still want to ignore the initial
# Even if we preload historical data, we still want to ignore the initial
# post-deployment idle period (a run of zeros) until we see the first
# post-deployment idle period (a run of zeros) until we see the first
# non-zero datapoint from live traffic.
# non-zero datapoint from live traffic.
self
.
_seen_nonzero_since_idle_reset
=
False
self
.
_seen_nonzero_since_idle_reset
=
False
def
reset_idle_skip
(
self
):
def
reset_idle_skip
(
self
)
->
None
:
"""Reset idle-period skipping state (e.g., after warmup, before live)."""
"""Reset idle-period skipping state (e.g., after warmup, before live)."""
self
.
_seen_nonzero_since_idle_reset
=
False
self
.
_seen_nonzero_since_idle_reset
=
False
def
add_data_point
(
self
,
value
)
:
def
add_data_point
(
self
,
value
:
float
)
->
None
:
"""Add new data point to the buffer"""
"""Add new data point to the buffer"""
if
math
.
isnan
(
value
):
if
math
.
isnan
(
value
):
value
=
0
value
=
0
...
@@ -82,14 +83,14 @@ class BasePredictor(ABC):
...
@@ -82,14 +83,14 @@ class BasePredictor(ABC):
self
.
data_buffer
.
append
(
value
)
self
.
data_buffer
.
append
(
value
)
def
get_last_value
(
self
):
def
get_last_value
(
self
)
->
float
:
"""Get the last value from the buffer"""
"""Get the last value from the buffer"""
if
not
self
.
data_buffer
:
if
not
self
.
data_buffer
:
return
0
return
0
return
self
.
data_buffer
[
-
1
]
return
self
.
data_buffer
[
-
1
]
@
abstractmethod
@
abstractmethod
def
predict_next
(
self
):
def
predict_next
(
self
)
->
float
:
"""Predict the next value"""
"""Predict the next value"""
pass
pass
...
@@ -99,10 +100,10 @@ class ConstantPredictor(BasePredictor):
...
@@ -99,10 +100,10 @@ class ConstantPredictor(BasePredictor):
Assume load is constant and predict the next load to be the same as most recent load
Assume load is constant and predict the next load to be the same as most recent load
"""
"""
def
__init__
(
self
,
_config
:
PlannerConfig
):
def
__init__
(
self
,
_config
:
PlannerConfig
)
->
None
:
super
().
__init__
(
minimum_data_points
=
1
)
super
().
__init__
(
minimum_data_points
=
1
)
def
predict_next
(
self
):
def
predict_next
(
self
)
->
float
:
return
self
.
get_last_value
()
return
self
.
get_last_value
()
...
@@ -112,7 +113,7 @@ class ARIMAPredictor(BasePredictor):
...
@@ -112,7 +113,7 @@ class ARIMAPredictor(BasePredictor):
RAW
=
"raw"
RAW
=
"raw"
LOG1P
=
"log1p"
LOG1P
=
"log1p"
def
__init__
(
self
,
config
:
PlannerConfig
):
def
__init__
(
self
,
config
:
PlannerConfig
)
->
None
:
super
().
__init__
(
minimum_data_points
=
5
)
super
().
__init__
(
minimum_data_points
=
5
)
self
.
model
=
None
self
.
model
=
None
# Keep raw values so we can fit in raw space first, then fallback to log1p space.
# Keep raw values so we can fit in raw space first, then fallback to log1p space.
...
@@ -125,7 +126,7 @@ class ARIMAPredictor(BasePredictor):
...
@@ -125,7 +126,7 @@ class ARIMAPredictor(BasePredictor):
)
)
self
.
_mode
:
ARIMAPredictor
.
Mode
=
self
.
_requested_mode
self
.
_mode
:
ARIMAPredictor
.
Mode
=
self
.
_requested_mode
def
get_last_value
(
self
):
def
get_last_value
(
self
)
->
float
:
"""Return last value in original scale."""
"""Return last value in original scale."""
if
self
.
_raw_buffer
:
if
self
.
_raw_buffer
:
return
float
(
self
.
_raw_buffer
[
-
1
])
return
float
(
self
.
_raw_buffer
[
-
1
])
...
@@ -133,7 +134,7 @@ class ARIMAPredictor(BasePredictor):
...
@@ -133,7 +134,7 @@ class ARIMAPredictor(BasePredictor):
return
0
return
0
return
float
(
self
.
data_buffer
[
-
1
])
return
float
(
self
.
data_buffer
[
-
1
])
def
add_data_point
(
self
,
value
)
:
def
add_data_point
(
self
,
value
:
float
)
->
None
:
prev_len
=
len
(
self
.
data_buffer
)
prev_len
=
len
(
self
.
data_buffer
)
# Use raw value for idle skipping in BasePredictor. We may transform later.
# Use raw value for idle skipping in BasePredictor. We may transform later.
super
().
add_data_point
(
value
)
super
().
add_data_point
(
value
)
...
@@ -145,7 +146,7 @@ class ARIMAPredictor(BasePredictor):
...
@@ -145,7 +146,7 @@ class ARIMAPredictor(BasePredictor):
if
self
.
_mode
==
ARIMAPredictor
.
Mode
.
LOG1P
:
if
self
.
_mode
==
ARIMAPredictor
.
Mode
.
LOG1P
:
self
.
data_buffer
[
-
1
]
=
math
.
log1p
(
raw
)
self
.
data_buffer
[
-
1
]
=
math
.
log1p
(
raw
)
def
predict_next
(
self
):
def
predict_next
(
self
)
->
float
:
"""Predict the next value(s)"""
"""Predict the next value(s)"""
if
len
(
self
.
_raw_buffer
)
<
self
.
minimum_data_points
:
if
len
(
self
.
_raw_buffer
)
<
self
.
minimum_data_points
:
return
self
.
get_last_value
()
return
self
.
get_last_value
()
...
@@ -234,6 +235,7 @@ class ARIMAPredictor(BasePredictor):
...
@@ -234,6 +235,7 @@ class ARIMAPredictor(BasePredictor):
self
.
_pending_raw_updates
=
[]
self
.
_pending_raw_updates
=
[]
# Make prediction
# Make prediction
assert
self
.
model
is
not
None
forecast
=
float
(
self
.
model
.
predict
(
n_periods
=
1
)[
0
])
forecast
=
float
(
self
.
model
.
predict
(
n_periods
=
1
)[
0
])
if
self
.
_mode
==
ARIMAPredictor
.
Mode
.
LOG1P
:
if
self
.
_mode
==
ARIMAPredictor
.
Mode
.
LOG1P
:
return
max
(
0.0
,
math
.
expm1
(
forecast
))
return
max
(
0.0
,
math
.
expm1
(
forecast
))
...
@@ -247,7 +249,7 @@ class ARIMAPredictor(BasePredictor):
...
@@ -247,7 +249,7 @@ class ARIMAPredictor(BasePredictor):
# Time-series forecasting model from Meta
# Time-series forecasting model from Meta
class
ProphetPredictor
(
BasePredictor
):
class
ProphetPredictor
(
BasePredictor
):
def
__init__
(
self
,
config
:
PlannerConfig
):
def
__init__
(
self
,
config
:
PlannerConfig
)
->
None
:
super
().
__init__
(
minimum_data_points
=
5
)
super
().
__init__
(
minimum_data_points
=
5
)
self
.
_use_log1p
=
config
.
load_predictor_log1p
self
.
_use_log1p
=
config
.
load_predictor_log1p
self
.
window_size
=
config
.
prophet_window_size
self
.
window_size
=
config
.
prophet_window_size
...
@@ -257,7 +259,7 @@ class ProphetPredictor(BasePredictor):
...
@@ -257,7 +259,7 @@ class ProphetPredictor(BasePredictor):
self
.
data_buffer
=
[]
# Override to store dicts instead of values
self
.
data_buffer
=
[]
# Override to store dicts instead of values
self
.
_seen_nonzero_since_idle_reset
=
False
self
.
_seen_nonzero_since_idle_reset
=
False
def
add_data_point
(
self
,
value
)
:
def
add_data_point
(
self
,
value
:
float
)
->
None
:
"""Add new data point to the buffer"""
"""Add new data point to the buffer"""
# Use proper datetime for Prophet
# Use proper datetime for Prophet
timestamp
=
self
.
start_date
+
timedelta
(
seconds
=
self
.
curr_step
*
self
.
step_size
)
timestamp
=
self
.
start_date
+
timedelta
(
seconds
=
self
.
curr_step
*
self
.
step_size
)
...
@@ -279,14 +281,14 @@ class ProphetPredictor(BasePredictor):
...
@@ -279,14 +281,14 @@ class ProphetPredictor(BasePredictor):
if
len
(
self
.
data_buffer
)
>
self
.
window_size
:
if
len
(
self
.
data_buffer
)
>
self
.
window_size
:
self
.
data_buffer
=
self
.
data_buffer
[
-
self
.
window_size
:]
self
.
data_buffer
=
self
.
data_buffer
[
-
self
.
window_size
:]
def
get_last_value
(
self
):
def
get_last_value
(
self
)
->
float
:
"""Get the last value from the buffer"""
"""Get the last value from the buffer"""
if
not
self
.
data_buffer
:
if
not
self
.
data_buffer
:
return
0
return
0
y
=
float
(
self
.
data_buffer
[
-
1
][
"y"
])
y
=
float
(
self
.
data_buffer
[
-
1
][
"y"
])
return
max
(
0.0
,
math
.
expm1
(
y
))
if
self
.
_use_log1p
else
y
return
max
(
0.0
,
math
.
expm1
(
y
))
if
self
.
_use_log1p
else
y
def
predict_next
(
self
):
def
predict_next
(
self
)
->
float
:
"""Predict the next value"""
"""Predict the next value"""
if
len
(
self
.
data_buffer
)
<
self
.
minimum_data_points
:
if
len
(
self
.
data_buffer
)
<
self
.
minimum_data_points
:
return
self
.
get_last_value
()
return
self
.
get_last_value
()
...
@@ -322,7 +324,7 @@ class KalmanPredictor(BasePredictor):
...
@@ -322,7 +324,7 @@ class KalmanPredictor(BasePredictor):
forecasting in bursty systems.
forecasting in bursty systems.
"""
"""
def
__init__
(
self
,
config
:
PlannerConfig
):
def
__init__
(
self
,
config
:
PlannerConfig
)
->
None
:
super
().
__init__
(
minimum_data_points
=
config
.
kalman_min_points
)
super
().
__init__
(
minimum_data_points
=
config
.
kalman_min_points
)
self
.
_use_log1p
=
config
.
load_predictor_log1p
self
.
_use_log1p
=
config
.
load_predictor_log1p
q_level
=
config
.
kalman_q_level
q_level
=
config
.
kalman_q_level
...
@@ -348,7 +350,7 @@ class KalmanPredictor(BasePredictor):
...
@@ -348,7 +350,7 @@ class KalmanPredictor(BasePredictor):
self
.
_has_cached_pred
=
False
self
.
_has_cached_pred
=
False
self
.
_cached_pred
:
float
=
0.0
self
.
_cached_pred
:
float
=
0.0
def
add_data_point
(
self
,
value
)
:
def
add_data_point
(
self
,
value
:
float
)
->
None
:
prev_len
=
len
(
self
.
data_buffer
)
prev_len
=
len
(
self
.
data_buffer
)
super
().
add_data_point
(
value
)
super
().
add_data_point
(
value
)
if
len
(
self
.
data_buffer
)
==
prev_len
:
if
len
(
self
.
data_buffer
)
==
prev_len
:
...
@@ -367,7 +369,7 @@ class KalmanPredictor(BasePredictor):
...
@@ -367,7 +369,7 @@ class KalmanPredictor(BasePredictor):
# Consumed this step; clear cached forecast for next interval.
# Consumed this step; clear cached forecast for next interval.
self
.
_has_cached_pred
=
False
self
.
_has_cached_pred
=
False
def
predict_next
(
self
):
def
predict_next
(
self
)
->
float
:
if
not
self
.
_initialized
:
if
not
self
.
_initialized
:
return
self
.
get_last_value
()
return
self
.
get_last_value
()
if
self
.
_has_cached_pred
:
if
self
.
_has_cached_pred
:
...
...
components/src/dynamo/planner/utils/planner_core.py
View file @
63e7b7da
...
@@ -248,7 +248,7 @@ class BasePlanner:
...
@@ -248,7 +248,7 @@ class BasePlanner:
def
__init__
(
def
__init__
(
self
,
self
,
runtime
:
Optional
[
DistributedRuntime
]
,
runtime
:
DistributedRuntime
,
config
:
PlannerConfig
,
config
:
PlannerConfig
,
dryrun
:
bool
=
False
,
dryrun
:
bool
=
False
,
shared_state
:
Optional
[
PlannerSharedState
]
=
None
,
shared_state
:
Optional
[
PlannerSharedState
]
=
None
,
...
@@ -389,6 +389,7 @@ class BasePlanner:
...
@@ -389,6 +389,7 @@ class BasePlanner:
self
.
config
.
backend
self
.
config
.
backend
].
decode_worker_k8s_name
].
decode_worker_k8s_name
self
.
prometheus_metrics
:
PlannerPrometheusMetrics
|
None
=
None
if
not
self
.
dryrun
:
if
not
self
.
dryrun
:
self
.
prefill_client
=
None
self
.
prefill_client
=
None
self
.
workers_client
=
None
self
.
workers_client
=
None
...
@@ -665,7 +666,7 @@ class BasePlanner:
...
@@ -665,7 +666,7 @@ class BasePlanner:
self
.
isl_predictor
.
add_data_point
(
metrics
.
isl
)
self
.
isl_predictor
.
add_data_point
(
metrics
.
isl
)
self
.
osl_predictor
.
add_data_point
(
metrics
.
osl
)
self
.
osl_predictor
.
add_data_point
(
metrics
.
osl
)
def
predict_load
(
self
):
def
predict_load
(
self
)
->
tuple
[
Optional
[
float
],
Optional
[
float
],
Optional
[
float
]]
:
try
:
try
:
# predict the next load
# predict the next load
next_num_req
=
self
.
num_req_predictor
.
predict_next
()
next_num_req
=
self
.
num_req_predictor
.
predict_next
()
...
@@ -948,7 +949,7 @@ class BasePlanner:
...
@@ -948,7 +949,7 @@ class BasePlanner:
logger
.
info
(
f
"Detected model name from deployment:
{
model_name
}
"
)
logger
.
info
(
f
"Detected model name from deployment:
{
model_name
}
"
)
self
.
model_name
=
model_name
.
lower
()
self
.
model_name
=
model_name
.
lower
()
else
:
else
:
model_name
=
getattr
(
self
.
config
,
"model_name"
,
None
)
model_name
=
getattr
(
self
.
config
,
"model_name"
,
""
)
if
not
model_name
:
if
not
model_name
:
raise
ValueError
(
raise
ValueError
(
"Model name is required in no-operation mode. "
"Model name is required in no-operation mode. "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment