Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
f7e0b3fd
Unverified
Commit
f7e0b3fd
authored
Apr 09, 2026
by
Hongkuan Zhou
Committed by
GitHub
Apr 09, 2026
Browse files
refactor(planner): extract discrete-event state machine with explicit inputs (#8046)
parent
39a6a240
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
2086 additions
and
2807 deletions
+2086
-2807
components/src/dynamo/planner/__main__.py
components/src/dynamo/planner/__main__.py
+6
-4
components/src/dynamo/planner/core/__init__.py
components/src/dynamo/planner/core/__init__.py
+28
-0
components/src/dynamo/planner/core/adapters.py
components/src/dynamo/planner/core/adapters.py
+192
-0
components/src/dynamo/planner/core/agg.py
components/src/dynamo/planner/core/agg.py
+0
-406
components/src/dynamo/planner/core/base.py
components/src/dynamo/planner/core/base.py
+348
-644
components/src/dynamo/planner/core/decode.py
components/src/dynamo/planner/core/decode.py
+0
-97
components/src/dynamo/planner/core/disagg.py
components/src/dynamo/planner/core/disagg.py
+0
-259
components/src/dynamo/planner/core/load_scaling.py
components/src/dynamo/planner/core/load_scaling.py
+313
-0
components/src/dynamo/planner/core/prefill.py
components/src/dynamo/planner/core/prefill.py
+0
-104
components/src/dynamo/planner/core/state_machine.py
components/src/dynamo/planner/core/state_machine.py
+357
-0
components/src/dynamo/planner/core/throughput_scaling.py
components/src/dynamo/planner/core/throughput_scaling.py
+184
-0
components/src/dynamo/planner/core/types.py
components/src/dynamo/planner/core/types.py
+125
-0
components/src/dynamo/planner/tests/unit/test_load_based_scaling.py
.../src/dynamo/planner/tests/unit/test_load_based_scaling.py
+22
-299
components/src/dynamo/planner/tests/unit/test_replica_calculation.py
...src/dynamo/planner/tests/unit/test_replica_calculation.py
+0
-627
components/src/dynamo/planner/tests/unit/test_sla_planner_scaling.py
...src/dynamo/planner/tests/unit/test_sla_planner_scaling.py
+0
-367
components/src/dynamo/planner/tests/unit/test_state_machine.py
...nents/src/dynamo/planner/tests/unit/test_state_machine.py
+511
-0
No files found.
components/src/dynamo/planner/__main__.py
View file @
f7e0b3fd
...
...
@@ -21,10 +21,12 @@ from typing import Union
from
pydantic
import
BaseModel
from
dynamo.planner.config.planner_config
import
PlannerConfig
from
dynamo.planner.core.agg
import
AggPlanner
from
dynamo.planner.core.decode
import
DecodePlanner
from
dynamo.planner.core.disagg
import
DisaggPlanner
from
dynamo.planner.core.prefill
import
PrefillPlanner
from
dynamo.planner.core.adapters
import
(
AggPlanner
,
DecodePlanner
,
DisaggPlanner
,
PrefillPlanner
,
)
from
dynamo.runtime
import
DistributedRuntime
,
dynamo_worker
logger
=
logging
.
getLogger
(
__name__
)
...
...
components/src/dynamo/planner/core/__init__.py
View file @
f7e0b3fd
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
dynamo.planner.core.state_machine
import
PlannerStateMachine
from
dynamo.planner.core.types
import
(
EngineCapabilities
,
FpmObservations
,
PlannerEffects
,
ScalingDecision
,
ScheduledTick
,
TickInput
,
TrafficObservation
,
WorkerCapabilities
,
WorkerCounts
,
)
__all__
=
[
"EngineCapabilities"
,
"FpmObservations"
,
"PlannerEffects"
,
"PlannerStateMachine"
,
"ScalingDecision"
,
"ScheduledTick"
,
"TickInput"
,
"TrafficObservation"
,
"WorkerCapabilities"
,
"WorkerCounts"
,
]
components/src/dynamo/planner/core/adapters.py
0 → 100644
View file @
f7e0b3fd
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Native planner adapter subclasses (one per mode).
Each subclass sets ``require_prefill`` / ``require_decode`` and overrides
``_bootstrap_regression()`` and ``_apply_effects()``. Everything else
(connector, Prometheus, FPM subscribers, tick loop) is in ``NativePlannerBase``.
"""
import
logging
from
dynamo.planner.config.defaults
import
SubComponentType
,
TargetReplica
from
dynamo.planner.core.base
import
NativePlannerBase
from
dynamo.planner.core.types
import
PlannerEffects
from
dynamo.planner.monitoring.perf_metrics
import
fetch_pre_deployment_metrics
logger
=
logging
.
getLogger
(
__name__
)
class
PrefillPlanner
(
NativePlannerBase
):
"""Prefill-only mode."""
require_prefill
=
True
require_decode
=
False
async
def
_bootstrap_regression
(
self
)
->
None
:
try
:
fpms
=
await
fetch_pre_deployment_metrics
(
runtime
=
self
.
runtime
,
namespace
=
self
.
namespace
,
worker_info
=
self
.
prefill_worker_info
,
profile_results_dir
=
self
.
config
.
profile_results_dir
,
component_type
=
SubComponentType
.
PREFILL
,
)
self
.
state_machine
.
load_benchmark_fpms
(
prefill_fpms
=
fpms
)
except
Exception
as
e
:
if
self
.
config
.
enable_throughput_scaling
:
raise
logger
.
warning
(
f
"No pre-deployment data for prefill:
{
e
}
"
)
async
def
_apply_effects
(
self
,
effects
:
PlannerEffects
)
->
None
:
if
effects
.
scale_to
is
None
or
effects
.
scale_to
.
num_prefill
is
None
:
return
desired
=
effects
.
scale_to
.
num_prefill
if
self
.
prometheus_port
!=
0
:
self
.
prometheus_metrics
.
predicted_num_p
.
set
(
desired
)
await
self
.
_apply_scaling_targets
(
[
TargetReplica
(
sub_component_type
=
SubComponentType
.
PREFILL
,
component_name
=
self
.
prefill_worker_info
.
k8s_name
,
desired_replicas
=
desired
,
)
]
)
class
DecodePlanner
(
NativePlannerBase
):
"""Decode-only mode."""
require_prefill
=
False
require_decode
=
True
async
def
_bootstrap_regression
(
self
)
->
None
:
try
:
fpms
=
await
fetch_pre_deployment_metrics
(
runtime
=
self
.
runtime
,
namespace
=
self
.
namespace
,
worker_info
=
self
.
decode_worker_info
,
profile_results_dir
=
self
.
config
.
profile_results_dir
,
component_type
=
SubComponentType
.
DECODE
,
)
self
.
state_machine
.
load_benchmark_fpms
(
decode_fpms
=
fpms
)
except
Exception
as
e
:
if
self
.
config
.
enable_throughput_scaling
:
raise
logger
.
warning
(
f
"No pre-deployment data for decode:
{
e
}
"
)
async
def
_apply_effects
(
self
,
effects
:
PlannerEffects
)
->
None
:
if
effects
.
scale_to
is
None
or
effects
.
scale_to
.
num_decode
is
None
:
return
desired
=
effects
.
scale_to
.
num_decode
if
self
.
prometheus_port
!=
0
:
self
.
prometheus_metrics
.
predicted_num_d
.
set
(
desired
)
await
self
.
_apply_scaling_targets
(
[
TargetReplica
(
sub_component_type
=
SubComponentType
.
DECODE
,
component_name
=
self
.
decode_worker_info
.
k8s_name
,
desired_replicas
=
desired
,
)
]
)
class
AggPlanner
(
NativePlannerBase
):
"""Aggregated mode (single engine type handles both prefill and decode)."""
require_prefill
=
False
require_decode
=
True
async
def
_bootstrap_regression
(
self
)
->
None
:
try
:
fpms
=
await
fetch_pre_deployment_metrics
(
runtime
=
self
.
runtime
,
namespace
=
self
.
namespace
,
worker_info
=
self
.
decode_worker_info
,
profile_results_dir
=
self
.
config
.
profile_results_dir
,
component_type
=
SubComponentType
.
DECODE
,
)
self
.
state_machine
.
load_benchmark_fpms
(
agg_fpms
=
fpms
)
except
Exception
as
e
:
if
self
.
config
.
enable_throughput_scaling
:
raise
logger
.
warning
(
f
"No pre-deployment data for agg:
{
e
}
"
)
async
def
_apply_effects
(
self
,
effects
:
PlannerEffects
)
->
None
:
if
effects
.
scale_to
is
None
or
effects
.
scale_to
.
num_decode
is
None
:
return
desired
=
effects
.
scale_to
.
num_decode
if
self
.
prometheus_port
!=
0
:
self
.
prometheus_metrics
.
predicted_num_d
.
set
(
desired
)
await
self
.
_apply_scaling_targets
(
[
TargetReplica
(
sub_component_type
=
SubComponentType
.
DECODE
,
component_name
=
self
.
decode_worker_info
.
k8s_name
,
desired_replicas
=
desired
,
)
]
)
class
DisaggPlanner
(
NativePlannerBase
):
"""Disaggregated mode (separate prefill and decode engines)."""
require_prefill
=
True
require_decode
=
True
async
def
_bootstrap_regression
(
self
)
->
None
:
for
component
,
kwarg
in
[
(
SubComponentType
.
PREFILL
,
"prefill_fpms"
),
(
SubComponentType
.
DECODE
,
"decode_fpms"
),
]:
worker_info
=
(
self
.
prefill_worker_info
if
component
==
SubComponentType
.
PREFILL
else
self
.
decode_worker_info
)
try
:
fpms
=
await
fetch_pre_deployment_metrics
(
runtime
=
self
.
runtime
,
namespace
=
self
.
namespace
,
worker_info
=
worker_info
,
profile_results_dir
=
self
.
config
.
profile_results_dir
,
component_type
=
component
,
)
self
.
state_machine
.
load_benchmark_fpms
(
**
{
kwarg
:
fpms
})
except
Exception
as
e
:
if
self
.
config
.
enable_throughput_scaling
:
raise
logger
.
warning
(
f
"No pre-deployment data for
{
component
.
value
}
:
{
e
}
"
)
async
def
_apply_effects
(
self
,
effects
:
PlannerEffects
)
->
None
:
if
effects
.
scale_to
is
None
:
return
decision
=
effects
.
scale_to
if
decision
.
num_prefill
is
not
None
and
self
.
prometheus_port
!=
0
:
self
.
prometheus_metrics
.
predicted_num_p
.
set
(
decision
.
num_prefill
)
if
decision
.
num_decode
is
not
None
and
self
.
prometheus_port
!=
0
:
self
.
prometheus_metrics
.
predicted_num_d
.
set
(
decision
.
num_decode
)
targets
=
[]
if
decision
.
num_prefill
is
not
None
:
targets
.
append
(
TargetReplica
(
sub_component_type
=
SubComponentType
.
PREFILL
,
component_name
=
self
.
prefill_worker_info
.
k8s_name
,
desired_replicas
=
decision
.
num_prefill
,
)
)
if
decision
.
num_decode
is
not
None
:
targets
.
append
(
TargetReplica
(
sub_component_type
=
SubComponentType
.
DECODE
,
component_name
=
self
.
decode_worker_info
.
k8s_name
,
desired_replicas
=
decision
.
num_decode
,
)
)
await
self
.
_apply_scaling_targets
(
targets
)
components/src/dynamo/planner/core/agg.py
deleted
100644 → 0
View file @
39a6a240
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
logging
import
math
import
time
from
typing
import
TYPE_CHECKING
,
Optional
from
dynamo.planner.config.backend_components
import
WORKER_COMPONENT_NAMES
from
dynamo.planner.config.defaults
import
SubComponentType
,
TargetReplica
from
dynamo.planner.config.planner_config
import
PlannerConfig
from
dynamo.planner.core.base
import
BasePlanner
from
dynamo.planner.core.budget
import
(
_apply_component_gpu_budget
,
_initialize_gpu_counts
,
)
from
dynamo.planner.core.perf_model
import
AggRegressionModel
from
dynamo.planner.core.state
import
PlannerSharedState
from
dynamo.planner.monitoring.perf_metrics
import
fetch_pre_deployment_metrics
from
dynamo.planner.monitoring.planner_metrics
import
PlannerPrometheusMetrics
from
dynamo.runtime
import
DistributedRuntime
if
TYPE_CHECKING
:
from
dynamo.common.forward_pass_metrics
import
ForwardPassMetrics
from
dynamo.runtime.logging
import
configure_dynamo_logging
configure_dynamo_logging
()
logger
=
logging
.
getLogger
(
__name__
)
class
AggPlanner
:
"""Aggregated planner: FPM-driven scaling for single engine type.
In aggregated mode, engines handle both prefill and decode (chunked prefill).
A single AggRegressionModel maps (sum_prefill_tokens, sum_decode_kv_tokens)
to wall_time using 2D linear regression.
Supports load-only, throughput-only, or both scaling modes.
Scaling logic (load-based):
- Estimate next TTFT per engine by simulating prefill chunking with
piggybacked decode (steady-state decode load).
- Estimate next ITL per engine by predicting decode iteration time with
average piggybacked prefill load.
- Scale up if (ALL TTFT > SLA) OR (ALL ITL > SLA).
- Scale down if (ALL TTFT < SLA * sensitivity) AND (ALL ITL < SLA * sensitivity).
Scaling logic (throughput-based):
- Use compute_agg_replicas() to find minimum replicas where both SLAs
are met under predicted traffic load.
"""
def
__init__
(
self
,
runtime
:
DistributedRuntime
,
config
:
PlannerConfig
)
->
None
:
self
.
config
=
config
self
.
runtime
=
runtime
self
.
shared_state
=
PlannerSharedState
()
self
.
enable_throughput
=
config
.
enable_throughput_scaling
self
.
enable_load
=
config
.
enable_load_scaling
if
not
self
.
enable_throughput
and
not
self
.
enable_load
:
raise
ValueError
(
"Aggregated planner requires at least one scaling mode enabled."
)
prometheus_metrics
=
PlannerPrometheusMetrics
()
self
.
planner
=
BasePlanner
(
runtime
,
config
,
shared_state
=
self
.
shared_state
,
prometheus_metrics
=
prometheus_metrics
,
start_prometheus_server
=
True
,
component_type
=
SubComponentType
.
DECODE
,
)
self
.
regression
=
AggRegressionModel
(
max_num_fpm_samples
=
config
.
max_num_fpm_samples
,
min_observations
=
config
.
load_min_observations
,
bucket_count
=
config
.
fpm_sample_bucket_size
,
)
async
def
_async_init
(
self
):
defaults
=
WORKER_COMPONENT_NAMES
.
get
(
self
.
config
.
backend
)
if
not
self
.
config
.
no_operation
:
connector
=
getattr
(
self
.
planner
,
"connector"
,
None
)
if
connector
and
hasattr
(
connector
,
"_async_init"
):
await
connector
.
_async_init
()
logger
.
info
(
"Validating deployment..."
)
await
self
.
planner
.
connector
.
validate_deployment
(
prefill_component_name
=
None
,
decode_component_name
=
(
defaults
.
decode_worker_k8s_name
if
defaults
else
None
),
require_prefill
=
False
,
require_decode
=
True
,
)
logger
.
info
(
"Successfully validated the deployment"
)
_initialize_gpu_counts
(
self
.
config
,
self
.
planner
.
connector
,
require_prefill
=
False
,
require_decode
=
True
,
)
await
self
.
planner
.
connector
.
wait_for_deployment_ready
(
include_planner
=
False
)
await
self
.
planner
.
_init_worker_info
(
require_prefill
=
False
,
require_decode
=
True
)
if
self
.
runtime
is
not
None
:
await
self
.
planner
.
_init_fpm_subscriber
()
await
self
.
_bootstrap_regression
()
async
def
_bootstrap_regression
(
self
)
->
None
:
"""Bootstrap agg regression from pre-deployment benchmark data."""
worker_info
=
self
.
planner
.
decode_worker_info
try
:
fpms
=
await
fetch_pre_deployment_metrics
(
runtime
=
self
.
runtime
,
namespace
=
self
.
config
.
namespace
,
worker_info
=
worker_info
,
profile_results_dir
=
self
.
config
.
profile_results_dir
,
component_type
=
SubComponentType
.
DECODE
,
)
self
.
regression
.
load_benchmark_fpms
(
fpms
)
logger
.
info
(
f
"Bootstrapped agg regression with
{
len
(
fpms
)
}
pre-deployment FPMs"
)
except
Exception
as
e
:
if
self
.
enable_throughput
:
raise
logger
.
warning
(
f
"No pre-deployment data for agg regression:
{
e
}
. "
"Load-based scaling will learn from live FPM only."
)
async
def
run
(
self
):
"""Main scaling loop. Call _async_init() before this."""
self
.
shared_state
.
last_adjustment_time
=
time
.
time
()
loops
=
[]
if
self
.
enable_throughput
:
loops
.
append
(
self
.
_throughput_loop
())
loops
.
append
(
self
.
_load_and_fpm_update_loop
())
await
asyncio
.
gather
(
*
loops
)
async
def
_throughput_loop
(
self
)
->
None
:
"""Throughput-based scaling loop for agg mode."""
while
True
:
current_time
=
time
.
time
()
if
(
current_time
-
self
.
shared_state
.
last_adjustment_time
>=
self
.
config
.
throughput_adjustment_interval
):
self
.
shared_state
.
last_adjustment_time
=
time
.
time
()
logger
.
info
(
"New agg throughput adjustment interval started!"
)
await
self
.
planner
.
observe_traffic_stats
(
require_prefill
=
False
,
require_decode
=
True
)
metrics
=
self
.
shared_state
.
last_metrics
if
not
metrics
.
is_valid
():
logger
.
info
(
"Metrics invalid, skipping agg throughput adjustment"
)
await
asyncio
.
sleep
(
self
.
config
.
throughput_adjustment_interval
/
10
)
continue
next_num_req
=
self
.
planner
.
num_req_predictor
.
predict_next
()
next_isl
=
self
.
planner
.
isl_predictor
.
predict_next
()
next_osl
=
self
.
planner
.
osl_predictor
.
predict_next
()
max_num_batched_tokens
=
getattr
(
self
.
planner
.
decode_worker_info
,
"max_num_batched_tokens"
,
None
)
if
not
max_num_batched_tokens
or
max_num_batched_tokens
<=
0
:
logger
.
warning
(
"max_num_batched_tokens not available, skipping agg throughput"
)
await
asyncio
.
sleep
(
self
.
config
.
throughput_adjustment_interval
/
10
)
continue
(
engine_rps
,
actual_ttft
,
actual_itl
,
)
=
self
.
regression
.
find_best_engine_agg_rps
(
isl
=
next_isl
,
osl
=
next_osl
,
max_num_batched_tokens
=
max_num_batched_tokens
,
ttft_sla
=
self
.
config
.
ttft
,
itl_sla
=
self
.
config
.
itl
,
)
if
engine_rps
<=
0
:
logger
.
warning
(
"Agg perf model not ready, skipping throughput scaling"
)
await
asyncio
.
sleep
(
self
.
config
.
throughput_adjustment_interval
/
10
)
continue
if
actual_ttft
>
self
.
config
.
ttft
or
actual_itl
>
self
.
config
.
itl
:
logger
.
warning
(
f
"Agg SLA not fully met: TTFT=
{
actual_ttft
:.
1
f
}
ms "
f
"(target
{
self
.
config
.
ttft
:.
1
f
}
ms), "
f
"ITL=
{
actual_itl
:.
1
f
}
ms (target
{
self
.
config
.
itl
:.
1
f
}
ms), "
"scaling with best achievable rate"
)
demand_rps
=
next_num_req
/
self
.
config
.
throughput_adjustment_interval
desired
=
math
.
ceil
(
demand_rps
/
engine_rps
)
desired
=
max
(
desired
,
self
.
config
.
min_endpoint
)
logger
.
info
(
f
"Agg:
{
demand_rps
:.
2
f
}
(demand rps) / "
f
"
{
engine_rps
:.
2
f
}
(engine rps) =
{
desired
}
(replicas), "
f
"est_ttft=
{
actual_ttft
:.
1
f
}
ms, est_itl=
{
actual_itl
:.
1
f
}
ms"
)
if
self
.
enable_load
:
self
.
shared_state
.
throughput_lower_bound_d
=
desired
logger
.
info
(
f
"Agg throughput lower bound set to
{
desired
}
"
)
else
:
assert
self
.
config
.
decode_engine_num_gpu
is
not
None
desired
=
_apply_component_gpu_budget
(
desired
,
self
.
config
.
decode_engine_num_gpu
,
self
.
config
)
if
(
self
.
planner
.
prometheus_port
!=
0
and
self
.
planner
.
prometheus_metrics
is
not
None
):
self
.
planner
.
prometheus_metrics
.
predicted_num_d
.
set
(
desired
)
if
not
self
.
config
.
no_operation
:
target_replicas
=
[
TargetReplica
(
sub_component_type
=
SubComponentType
.
DECODE
,
component_name
=
self
.
planner
.
decode_worker_info
.
k8s_name
,
desired_replicas
=
desired
,
)
]
await
self
.
planner
.
connector
.
set_component_replicas
(
target_replicas
,
blocking
=
False
)
await
asyncio
.
sleep
(
self
.
config
.
throughput_adjustment_interval
/
10
)
async
def
_load_and_fpm_update_loop
(
self
)
->
None
:
"""FPM observation and (optionally) load-based scaling for agg mode.
Always updates regression with live FPM. When load-based scaling
is enabled, makes scaling decisions immediately after.
"""
pending_desired
:
Optional
[
int
]
=
None
while
True
:
await
asyncio
.
sleep
(
self
.
config
.
load_adjustment_interval
)
logger
.
info
(
"New agg load/FPM update interval started!"
)
_
,
num_d
,
_
=
await
self
.
planner
.
get_workers_info
(
require_prefill
=
False
,
require_decode
=
True
)
self
.
shared_state
.
num_d_workers
=
num_d
num_workers
=
num_d
fpm_stats
=
self
.
planner
.
_get_fpm_stats
()
if
not
fpm_stats
:
continue
for
(
wid
,
dp
),
fpm
in
fpm_stats
.
items
():
BasePlanner
.
_log_fpm
(
wid
,
dp
,
fpm
,
"agg"
)
self
.
regression
.
add_observation
(
fpm
)
if
not
self
.
enable_load
:
continue
if
pending_desired
is
not
None
:
if
num_workers
==
pending_desired
:
logger
.
info
(
f
"Scaling to
{
pending_desired
}
complete, resuming decisions"
)
pending_desired
=
None
else
:
logger
.
info
(
f
"Scaling in progress (
{
num_workers
}
->
{
pending_desired
}
), "
"observing only"
)
continue
if
not
BasePlanner
.
_reconcile_fpm_worker_count
(
fpm_stats
,
num_workers
,
"agg"
):
continue
if
not
self
.
regression
.
has_sufficient_data
():
logger
.
info
(
f
"Agg regression: insufficient data "
f
"(
{
self
.
regression
.
num_observations
}
/
{
self
.
regression
.
min_observations
}
)"
)
continue
max_num_batched_tokens
=
getattr
(
self
.
planner
.
decode_worker_info
,
"max_num_batched_tokens"
,
None
)
if
not
max_num_batched_tokens
or
max_num_batched_tokens
<=
0
:
logger
.
warning
(
"max_num_batched_tokens not available from WorkerInfo, "
"skipping agg scaling"
)
continue
p_desired
=
self
.
_prefill_scaling_decision
(
fpm_stats
,
num_workers
,
max_num_batched_tokens
)
d_desired
=
self
.
_decode_scaling_decision
(
fpm_stats
,
num_workers
)
logger
.
info
(
f
"Agg scaling decisions: prefill=
{
p_desired
}
, decode=
{
d_desired
}
"
f
"(current=
{
num_workers
}
)"
)
if
p_desired
is
not
None
and
p_desired
>
num_workers
:
desired
=
p_desired
elif
d_desired
is
not
None
and
d_desired
>
num_workers
:
desired
=
d_desired
elif
(
p_desired
is
not
None
and
p_desired
<
num_workers
and
d_desired
is
not
None
and
d_desired
<
num_workers
):
desired
=
max
(
p_desired
,
d_desired
)
else
:
logger
.
info
(
"Agg scaling: no scaling needed"
)
continue
desired
=
max
(
desired
,
self
.
config
.
min_endpoint
)
if
self
.
enable_throughput
:
desired
=
max
(
desired
,
self
.
shared_state
.
throughput_lower_bound_d
)
assert
self
.
config
.
decode_engine_num_gpu
is
not
None
desired
=
_apply_component_gpu_budget
(
desired
,
self
.
config
.
decode_engine_num_gpu
,
self
.
config
)
logger
.
info
(
f
"Agg load-based scaling:
{
num_workers
}
->
{
desired
}
"
)
if
(
self
.
planner
.
prometheus_port
!=
0
and
self
.
planner
.
prometheus_metrics
is
not
None
):
self
.
planner
.
prometheus_metrics
.
predicted_num_d
.
set
(
desired
)
if
not
self
.
config
.
no_operation
:
pending_desired
=
desired
target_replicas
=
[
TargetReplica
(
sub_component_type
=
SubComponentType
.
DECODE
,
component_name
=
self
.
planner
.
decode_worker_info
.
k8s_name
,
desired_replicas
=
desired
,
)
]
await
self
.
planner
.
connector
.
set_component_replicas
(
target_replicas
,
blocking
=
False
)
def
_prefill_scaling_decision
(
self
,
fpm_stats
:
"dict[tuple[str, int], ForwardPassMetrics]"
,
num_workers
:
int
,
max_num_batched_tokens
:
int
,
)
->
Optional
[
int
]:
estimated_ttfts
:
list
[
float
]
=
[]
for
(
wid
,
dp
),
fpm
in
fpm_stats
.
items
():
est
=
self
.
regression
.
estimate_next_ttft
(
queued_prefill_tokens
=
fpm
.
queued_requests
.
sum_prefill_tokens
,
max_num_batched_tokens
=
max_num_batched_tokens
,
current_decode_kv
=
fpm
.
scheduled_requests
.
sum_decode_kv_tokens
,
)
if
est
is
not
None
:
estimated_ttfts
.
append
(
est
*
1000
)
return
self
.
planner
.
_load_based_scaling_decision_from_estimates
(
estimated_ttfts
,
self
.
config
.
ttft
,
num_workers
,
"agg TTFT"
)
def
_decode_scaling_decision
(
self
,
fpm_stats
:
"dict[tuple[str, int], ForwardPassMetrics]"
,
num_workers
:
int
,
)
->
Optional
[
int
]:
estimated_itls
:
list
[
float
]
=
[]
for
(
wid
,
dp
),
fpm
in
fpm_stats
.
items
():
est
=
self
.
regression
.
estimate_next_itl
(
scheduled_decode_kv
=
fpm
.
scheduled_requests
.
sum_decode_kv_tokens
,
queued_decode_kv
=
fpm
.
queued_requests
.
sum_decode_kv_tokens
,
)
if
est
is
not
None
:
estimated_itls
.
append
(
est
*
1000
)
return
self
.
planner
.
_load_based_scaling_decision_from_estimates
(
estimated_itls
,
self
.
config
.
itl
,
num_workers
,
"agg ITL"
)
components/src/dynamo/planner/core/base.py
View file @
f7e0b3fd
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Runtime I/O plumbing for the native planner.
This module contains **zero decision logic**. It only gathers data from the
outside world (Prometheus, FPM subscribers, K8s connectors) and applies
scaling decisions back. All scaling logic lives in
:class:`~dynamo.planner.core.state_machine.PlannerStateMachine`.
Subclasses (PrefillPlanner, DecodePlanner, AggPlanner, DisaggPlanner) set
mode-specific flags and override ``_bootstrap_regression`` and
``_apply_effects``.
"""
from
__future__
import
annotations
import
asyncio
import
logging
import
time
...
...
@@ -9,19 +23,23 @@ from typing import TYPE_CHECKING, Optional, Union
from
prometheus_client
import
start_http_server
from
dynamo.planner.config.backend_components
import
WORKER_COMPONENT_NAMES
from
dynamo.planner.config.defaults
import
SubComponentType
,
TargetReplica
from
dynamo.planner.config.defaults
import
TargetReplica
from
dynamo.planner.config.planner_config
import
PlannerConfig
from
dynamo.planner.connectors.global_planner
import
GlobalPlannerConnector
from
dynamo.planner.connectors.kubernetes
import
KubernetesConnector
from
dynamo.planner.connectors.virtual
import
VirtualConnector
from
dynamo.planner.core.budget
import
(
_apply_component_gpu_budget
,
_initialize_gpu_counts
,
from
dynamo.planner.core.budget
import
_initialize_gpu_counts
from
dynamo.planner.core.state_machine
import
PlannerStateMachine
from
dynamo.planner.core.types
import
(
EngineCapabilities
,
FpmObservations
,
PlannerEffects
,
ScheduledTick
,
TickInput
,
TrafficObservation
,
WorkerCapabilities
,
WorkerCounts
,
)
from
dynamo.planner.core.load.predictors
import
LOAD_PREDICTORS
from
dynamo.planner.core.perf_model
import
DecodeRegressionModel
,
PrefillRegressionModel
from
dynamo.planner.core.state
import
PlannerSharedState
from
dynamo.planner.monitoring.perf_metrics
import
fetch_pre_deployment_metrics
from
dynamo.planner.monitoring.planner_metrics
import
PlannerPrometheusMetrics
from
dynamo.planner.monitoring.traffic_metrics
import
Metrics
,
PrometheusAPIClient
from
dynamo.planner.monitoring.worker_info
import
WorkerInfo
,
resolve_worker_info
...
...
@@ -31,7 +49,6 @@ if TYPE_CHECKING:
from
dynamo.common.forward_pass_metrics
import
ForwardPassMetrics
from
dynamo.llm
import
FpmEventSubscriber
from
dynamo.runtime
import
DistributedRuntime
from
dynamo.runtime.logging
import
configure_dynamo_logging
...
...
@@ -41,34 +58,63 @@ configure_dynamo_logging()
logger
=
logging
.
getLogger
(
__name__
)
class
BasePlanner
:
component_type
:
SubComponentType
# ------------------------------------------------------------------
# Helpers for building WorkerCapabilities from resolved WorkerInfo
# ------------------------------------------------------------------
def
__init__
(
self
,
runtime
:
Optional
[
DistributedRuntime
],
config
:
PlannerConfig
,
shared_state
:
Optional
[
PlannerSharedState
]
=
None
,
prometheus_metrics
:
Optional
[
PlannerPrometheusMetrics
]
=
None
,
prometheus_traffic_client
:
Optional
[
PrometheusAPIClient
]
=
None
,
connector
:
Optional
[
ConnectorType
]
=
None
,
start_prometheus_server
:
bool
=
True
,
component_type
:
Optional
[
SubComponentType
]
=
None
,
):
if
component_type
is
not
None
:
self
.
component_type
=
component_type
self
.
config
=
config
self
.
shared_state
=
shared_state
or
PlannerSharedState
()
def
_engine_caps
(
worker_info
:
Optional
[
WorkerInfo
],
num_gpu
:
Optional
[
int
]
)
->
Optional
[
EngineCapabilities
]:
if
worker_info
is
None
and
num_gpu
is
None
:
return
None
return
EngineCapabilities
(
num_gpu
=
num_gpu
,
max_num_batched_tokens
=
worker_info
.
max_num_batched_tokens
if
worker_info
else
None
,
max_num_seqs
=
worker_info
.
max_num_seqs
if
worker_info
else
None
,
context_length
=
worker_info
.
context_length
if
worker_info
else
None
,
)
def
build_worker_capabilities
(
config
:
PlannerConfig
,
prefill_worker_info
:
Optional
[
WorkerInfo
]
=
None
,
decode_worker_info
:
Optional
[
WorkerInfo
]
=
None
,
)
->
WorkerCapabilities
:
return
WorkerCapabilities
(
prefill
=
_engine_caps
(
prefill_worker_info
,
config
.
prefill_engine_num_gpu
),
decode
=
_engine_caps
(
decode_worker_info
,
config
.
decode_engine_num_gpu
),
)
# ------------------------------------------------------------------
# Base adapter
# ------------------------------------------------------------------
class
NativePlannerBase
:
"""Base adapter: runtime I/O plumbing shared by all planner modes.
Subclasses set ``require_prefill`` / ``require_decode`` and override
``_bootstrap_regression()`` and ``_apply_effects()``.
"""
require_prefill
:
bool
=
False
require_decode
:
bool
=
False
def
__init__
(
self
,
runtime
:
Optional
[
DistributedRuntime
],
config
:
PlannerConfig
)
->
None
:
self
.
config
=
config
self
.
runtime
=
runtime
self
.
namespace
=
config
.
namespace
self
.
model_name
:
Optional
[
str
]
=
None
self
.
connector
:
ConnectorType
if
c
onnector
is
not
None
:
self
.
connector
=
c
onnector
el
if
not
config
.
no_operation
:
# C
onnector
self
.
connector
:
C
onnector
Type
if
not
config
.
no_operation
:
if
config
.
environment
==
"global-planner"
:
assert
config
.
global_planner_namespace
is
not
None
assert
runtime
is
not
None
...
...
@@ -84,217 +130,161 @@ class BasePlanner:
elif
config
.
environment
==
"virtual"
:
assert
runtime
is
not
None
self
.
connector
=
VirtualConnector
(
runtime
,
self
.
namespace
,
config
.
model_name
,
runtime
,
self
.
namespace
,
config
.
model_name
)
else
:
raise
ValueError
(
f
"Invalid environment:
{
config
.
environment
}
"
)
self
.
prometheus_traffic_client
=
(
prometheus_traffic_client
or
PrometheusAPIClient
(
config
.
metric_pulling_prometheus_endpoint
,
config
.
namespace
,
metrics_source
=
config
.
throughput_metrics_source
,
)
# Prometheus
self
.
prometheus_traffic_client
=
PrometheusAPIClient
(
config
.
metric_pulling_prometheus_endpoint
,
config
.
namespace
,
metrics_source
=
config
.
throughput_metrics_source
,
)
if
config
.
throughput_metrics_source
==
"router"
:
self
.
prometheus_traffic_client
.
warn_if_router_not_scraped
()
predictor_cls
=
LOAD_PREDICTORS
[
config
.
load_predictor
]
self
.
num_req_predictor
=
predictor_cls
(
config
)
self
.
isl_predictor
=
predictor_cls
(
config
)
self
.
osl_predictor
=
predictor_cls
(
config
)
# Optional warmup: preload predictors with historical observations from a
# mooncake-style JSONL trace (request_count/avg_isl/avg_osl per interval).
if
config
.
load_predictor_warmup_trace
is
not
None
:
warmup_trace
=
config
.
load_predictor_warmup_trace
self
.
prometheus_port
=
config
.
metric_reporting_prometheus_port
self
.
prometheus_metrics
=
PlannerPrometheusMetrics
()
if
self
.
prometheus_port
!=
0
:
try
:
metrics
=
extract_metrics_from_mooncake
(
warmup_trace
,
config
.
throughput_adjustment_interval
)
for
m
in
metrics
:
self
.
num_req_predictor
.
add_data_point
(
float
(
m
[
"request_count"
]))
self
.
isl_predictor
.
add_data_point
(
float
(
m
[
"avg_isl"
]))
self
.
osl_predictor
.
add_data_point
(
float
(
m
[
"avg_osl"
]))
start_http_server
(
self
.
prometheus_port
)
logger
.
info
(
f
"
W
ar
m
ed
load predictors with
{
len
(
metrics
)
}
intervals from
{
warmup_trace
}
"
f
"
St
ar
t
ed
Prometheus metrics server on port
{
self
.
prometheus_port
}
"
)
except
Exception
as
e
:
logger
.
warning
(
f
"Failed to warm load predictors from
{
warmup_trace
}
:
{
e
}
"
)
finally
:
# Even with warmup data, ignore the initial post-deploy idle
# period (leading zeros) when live metrics start coming in.
for
p
in
(
self
.
num_req_predictor
,
self
.
isl_predictor
,
self
.
osl_predictor
,
):
if
hasattr
(
p
,
"reset_idle_skip"
):
p
.
reset_idle_skip
()
self
.
enable_load
=
config
.
enable_load_scaling
self
.
enable_throughput
=
config
.
enable_throughput_scaling
logger
.
error
(
f
"Failed to start Prometheus metrics server:
{
e
}
"
)
# Worker info (resolved during _async_init)
self
.
prefill_worker_info
=
WorkerInfo
()
self
.
decode_worker_info
=
WorkerInfo
()
self
.
prefill_client
=
None
self
.
workers_client
=
None
# FPM subscribers (one per component type, populated during _async_init)
self
.
_prefill_fpm_sub
:
Optional
[
FpmEventSubscriber
]
=
None
self
.
_decode_fpm_sub
:
Optional
[
FpmEventSubscriber
]
=
None
self
.
prometheus_port
=
config
.
metric_reporting_prometheus_port
self
.
prometheus_metrics
:
PlannerPrometheusMetrics
|
None
=
None
# Runtime client caches
self
.
_prefill_client
=
None
self
.
_decode_client
=
None
if
prometheus_metrics
is
None
:
self
.
prometheus_metrics
=
PlannerPrometheusMetrics
()
else
:
self
.
prometheus_metrics
=
prometheus_metrics
# Shared metrics state
self
.
_last_metrics
=
Metrics
()
self
.
_cumulative_gpu_hours
:
float
=
0.0
if
start_prometheus_server
and
self
.
prometheus_port
!=
0
:
try
:
start_http_server
(
self
.
prometheus_port
)
logger
.
info
(
f
"Started Prometheus metrics server on port
{
self
.
prometheus_port
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start Prometheus metrics server:
{
e
}
"
)
# State machine (created after WorkerInfo is resolved)
self
.
_state_machine
:
Optional
[
PlannerStateMachine
]
=
None
self
.
fpm_subscriber
:
"Optional[FpmEventSubscriber]"
=
None
# ------------------------------------------------------------------
# State machine access
# ------------------------------------------------------------------
if
self
.
component_type
==
SubComponentType
.
PREFILL
:
self
.
ttft_regression
=
PrefillRegressionModel
(
max_num_fpm_samples
=
self
.
config
.
max_num_fpm_samples
,
min_observations
=
self
.
config
.
load_min_observations
,
bucket_count
=
self
.
config
.
fpm_sample_bucket_size
,
)
elif
self
.
component_type
==
SubComponentType
.
DECODE
:
self
.
itl_regression
=
DecodeRegressionModel
(
max_num_fpm_samples
=
self
.
config
.
max_num_fpm_samples
,
min_observations
=
self
.
config
.
load_min_observations
,
bucket_count
=
self
.
config
.
fpm_sample_bucket_size
,
def
_ensure_state_machine
(
self
)
->
PlannerStateMachine
:
if
self
.
_state_machine
is
None
:
caps
=
build_worker_capabilities
(
self
.
config
,
self
.
prefill_worker_info
,
self
.
decode_worker_info
,
)
self
.
_state_machine
=
PlannerStateMachine
(
self
.
config
,
caps
)
self
.
_warm_predictors
()
return
self
.
_state_machine
@
property
def
last_metrics
(
self
)
->
Metrics
:
return
self
.
sha
re
d
_state
.
last_metrics
def
state_machine
(
self
)
->
PlannerStateMachine
:
return
self
.
_ensu
re_state
_machine
()
@
last_metrics
.
setter
def
last_metrics
(
self
,
value
:
Metrics
)
->
None
:
self
.
shared_state
.
last_metrics
=
value
def
_warm_predictors
(
self
)
->
None
:
if
self
.
config
.
load_predictor_warmup_trace
is
None
:
return
assert
self
.
_state_machine
is
not
None
try
:
metrics
=
extract_metrics_from_mooncake
(
self
.
config
.
load_predictor_warmup_trace
,
self
.
config
.
throughput_adjustment_interval
,
)
self
.
_state_machine
.
warm_load_predictors
(
[
TrafficObservation
(
duration_s
=
self
.
config
.
throughput_adjustment_interval
,
num_req
=
float
(
m
[
"request_count"
]),
isl
=
float
(
m
[
"avg_isl"
]),
osl
=
float
(
m
[
"avg_osl"
]),
)
for
m
in
metrics
]
)
except
Exception
as
e
:
logger
.
warning
(
f
"Failed to warm load predictors:
{
e
}
"
)
async
def
_init_worker_info
(
self
,
require_prefill
:
bool
,
require_decode
:
bool
)
->
None
:
"""Initialize WorkerInfo and model name in a single step."""
connector
=
getattr
(
self
,
"connector"
,
None
)
self
.
prefill_worker_info
,
self
.
decode_worker_info
=
resolve_worker_info
(
backend
=
self
.
config
.
backend
,
require_prefill
=
require_prefill
,
require_decode
=
require_decode
,
connector
=
connector
,
config_model_name
=
getattr
(
self
.
config
,
"model_name"
,
""
),
no_operation
=
self
.
config
.
no_operation
,
)
# model_name is resolved and written into both WorkerInfo objects
self
.
model_name
=
(
self
.
decode_worker_info
.
model_name
or
self
.
prefill_worker_info
.
model_name
)
# ------------------------------------------------------------------
# Async init
# ------------------------------------------------------------------
async
def
_async_init
(
self
):
"""Async initialization: connector init, deployment validation, WorkerInfo."""
async
def
_async_init
(
self
)
->
None
:
if
hasattr
(
self
,
"connector"
)
and
hasattr
(
self
.
connector
,
"_async_init"
):
await
self
.
connector
.
_async_init
()
require_prefill
=
self
.
component_type
==
SubComponentType
.
PREFILL
require_decode
=
self
.
component_type
==
SubComponentType
.
DECODE
if
not
self
.
config
.
no_operation
:
defaults
=
WORKER_COMPONENT_NAMES
.
get
(
self
.
config
.
backend
)
logger
.
info
(
"Validating deployment..."
)
await
self
.
connector
.
validate_deployment
(
prefill_component_name
=
(
defaults
.
prefill_worker_k8s_name
if
require_prefill
and
defaults
if
self
.
require_prefill
and
defaults
else
None
),
decode_component_name
=
(
defaults
.
decode_worker_k8s_name
if
require_decode
and
defaults
if
self
.
require_decode
and
defaults
else
None
),
require_prefill
=
require_prefill
,
require_decode
=
require_decode
,
require_prefill
=
self
.
require_prefill
,
require_decode
=
self
.
require_decode
,
)
logger
.
info
(
"Successfully validated the deployment"
)
_initialize_gpu_counts
(
self
.
config
,
self
.
connector
,
require_prefill
=
require_prefill
,
require_decode
=
require_decode
,
require_prefill
=
self
.
require_prefill
,
require_decode
=
self
.
require_decode
,
)
await
self
.
connector
.
wait_for_deployment_ready
(
include_planner
=
False
)
await
self
.
_init_worker_info
(
require_prefill
=
require_prefill
,
require_decode
=
require_decode
,
)
await
self
.
_init_worker_info
()
if
self
.
runtime
is
not
None
:
await
self
.
_init_fpm_subscriber
()
if
self
.
require_prefill
:
await
self
.
_init_fpm_subscriber
(
"prefill"
)
if
self
.
require_decode
:
await
self
.
_init_fpm_subscriber
(
"decode"
)
await
self
.
_bootstrap_regression
()
async
def
_bootstrap_regression
(
self
)
->
None
:
"""Fetch pre-deployment FPM data and bootstrap the regression model."""
worker_info
=
(
self
.
prefill_worker_info
if
self
.
component_type
==
SubComponentType
.
PREFILL
else
self
.
decode_worker_info
async
def
_init_worker_info
(
self
)
->
None
:
connector
=
getattr
(
self
,
"connector"
,
None
)
self
.
prefill_worker_info
,
self
.
decode_worker_info
=
resolve_worker_info
(
backend
=
self
.
config
.
backend
,
require_prefill
=
self
.
require_prefill
,
require_decode
=
self
.
require_decode
,
connector
=
connector
,
config_model_name
=
getattr
(
self
.
config
,
"model_name"
,
""
),
no_operation
=
self
.
config
.
no_operation
,
)
self
.
model_name
=
(
self
.
decode_worker_info
.
model_name
or
self
.
prefill_worker_info
.
model_name
)
try
:
fpms
=
await
fetch_pre_deployment_metrics
(
runtime
=
self
.
runtime
,
namespace
=
self
.
namespace
,
worker_info
=
worker_info
,
profile_results_dir
=
self
.
config
.
profile_results_dir
,
component_type
=
self
.
component_type
,
)
if
self
.
component_type
==
SubComponentType
.
PREFILL
:
self
.
ttft_regression
.
load_benchmark_fpms
(
fpms
)
elif
self
.
component_type
==
SubComponentType
.
DECODE
:
self
.
itl_regression
.
load_benchmark_fpms
(
fpms
)
logger
.
info
(
f
"Bootstrapped
{
self
.
component_type
.
value
}
regression with "
f
"
{
len
(
fpms
)
}
pre-deployment FPMs"
)
except
Exception
as
e
:
if
self
.
enable_throughput
:
raise
logger
.
warning
(
f
"No pre-deployment data for
{
self
.
component_type
.
value
}
regression:
{
e
}
. "
"Load-based scaling will learn from live FPM only."
)
async
def
_init_fpm_subscriber
(
self
)
->
None
:
"""Create and start the FPM subscriber for load-based scaling."""
async
def
_init_fpm_subscriber
(
self
,
component
:
str
)
->
None
:
from
dynamo.llm
import
FpmEventSubscriber
worker_info
=
(
self
.
prefill_worker_info
if
self
.
component
_type
==
SubComponentType
.
PREFILL
if
component
==
"prefill"
else
self
.
decode_worker_info
)
if
not
worker_info
.
component_name
or
not
worker_info
.
endpoint
:
logger
.
warning
(
"WorkerInfo missing component_name or endpoint, "
"cannot create FPM subscriber"
f
"WorkerInfo missing for
{
component
}
, cannot create FPM subscriber"
)
return
...
...
@@ -302,50 +292,49 @@ class BasePlanner:
endpoint
=
self
.
runtime
.
endpoint
(
f
"
{
self
.
namespace
}
.
{
worker_info
.
component_name
}
.
{
worker_info
.
endpoint
}
"
)
s
elf
.
fpm_subscriber
=
FpmEventSubscriber
(
endpoint
)
s
elf
.
fpm_subscriber
.
start_tracking
()
s
ub
=
FpmEventSubscriber
(
endpoint
)
s
ub
.
start_tracking
()
logger
.
info
(
f
"FPM tracker started for
{
worker_info
.
component_name
}
.
{
worker_info
.
endpoint
}
"
)
def
_get_fpm_stats
(
self
)
->
"dict[tuple[str, int], ForwardPassMetrics]"
:
"""Get decoded FPM stats from the subscriber, keyed by (worker_id, dp_rank)."""
if
component
==
"prefill"
:
self
.
_prefill_fpm_sub
=
sub
else
:
self
.
_decode_fpm_sub
=
sub
async
def
_bootstrap_regression
(
self
)
->
None
:
"""Override in subclasses to bootstrap regression models."""
pass
# ------------------------------------------------------------------
# Data collection (runtime I/O)
# ------------------------------------------------------------------
def
_decode_fpm_bytes
(
self
,
subscriber
:
Optional
[
FpmEventSubscriber
]
)
->
dict
[
tuple
[
str
,
int
],
ForwardPassMetrics
]:
from
dynamo.common.forward_pass_metrics
import
decode
as
decode_fpm
if
self
.
fpm_
subscriber
is
None
:
if
subscriber
is
None
:
return
{}
raw_stats
=
self
.
fpm_subscriber
.
get_recent_stats
()
result
=
{}
for
key
,
raw_bytes
in
raw
_stats
.
items
():
for
key
,
raw_bytes
in
subscriber
.
get_recent
_stats
()
.
items
():
fpm
=
decode_fpm
(
raw_bytes
)
if
fpm
is
not
None
:
result
[
key
]
=
fpm
return
result
async
def
_get_or_create_client
(
self
,
component_name
:
str
,
endpoint_name
:
str
):
"""Create a client for the given component and endpoint, with a brief sleep for state sync."""
assert
self
.
runtime
is
not
None
,
"Runtime is not initialized"
assert
self
.
runtime
is
not
None
client
=
await
self
.
runtime
.
endpoint
(
f
"
{
self
.
namespace
}
.
{
component_name
}
.
{
endpoint_name
}
"
).
client
()
# TODO: remove this sleep after rust client() is blocking until watching state
await
asyncio
.
sleep
(
0.1
)
return
client
async
def
get_workers_info
(
self
,
require_prefill
:
bool
=
True
,
require_decode
:
bool
=
True
)
->
tuple
[
int
,
int
,
bool
]:
"""
Get worker counts for prefill and decode components.
Returns:
tuple[int, int, bool]: (num_p_workers, num_d_workers, is_stable)
- is_stable: False if rollout in progress (scaling should be skipped)
"""
num_p_workers
=
0
num_d_workers
=
0
# For Kubernetes, use DGD status instead of runtime client
async
def
_get_worker_counts_raw
(
self
)
->
tuple
[
int
,
int
,
bool
]:
"""Returns (num_prefill, num_decode, is_stable) from connector or runtime."""
if
hasattr
(
self
,
"connector"
)
and
isinstance
(
self
.
connector
,
KubernetesConnector
):
...
...
@@ -355,515 +344,230 @@ class BasePlanner:
is_stable
,
)
=
self
.
connector
.
get_actual_worker_counts
(
prefill_component_name
=
(
self
.
prefill_worker_info
.
k8s_name
if
require_prefill
else
None
self
.
prefill_worker_info
.
k8s_name
if
self
.
require_prefill
else
None
),
decode_component_name
=
(
self
.
decode_worker_info
.
k8s_name
if
require_decode
else
None
self
.
decode_worker_info
.
k8s_name
if
self
.
require_decode
else
None
),
)
num_p_workers
=
prefill_count
if
require_prefill
else
0
num_d_workers
=
decode_count
if
require_decode
else
0
return
num_p_workers
,
num_d_workers
,
is_stable
return
(
prefill_count
if
self
.
require_prefill
else
0
,
decode_count
if
self
.
require_decode
else
0
,
is_stable
,
)
# Fall back to runtime client for non-Kubernetes environments
if
self
.
runtime
is
None
:
raise
RuntimeError
(
"Runtime is not initialized"
)
if
require_prefill
:
num_p
,
num_d
=
0
,
0
if
self
.
require_prefill
:
try
:
if
self
.
prefill_client
is
None
:
if
self
.
_
prefill_client
is
None
:
assert
self
.
prefill_worker_info
.
component_name
is
not
None
assert
self
.
prefill_worker_info
.
endpoint
is
not
None
self
.
prefill_client
=
await
self
.
_get_or_create_client
(
self
.
_
prefill_client
=
await
self
.
_get_or_create_client
(
self
.
prefill_worker_info
.
component_name
,
self
.
prefill_worker_info
.
endpoint
,
)
num_p
_workers
=
len
(
self
.
prefill_client
.
instance_ids
())
# type: ignore
num_p
=
len
(
self
.
_
prefill_client
.
instance_ids
())
# type: ignore
except
Exception
:
num_p_workers
=
0
logger
.
warning
(
"No prefill workers found, aggregated mode is not supported yet"
)
logger
.
warning
(
"No prefill workers found"
)
if
require_decode
:
if
self
.
require_decode
:
try
:
if
self
.
workers
_client
is
None
:
if
self
.
_decode
_client
is
None
:
assert
self
.
decode_worker_info
.
component_name
is
not
None
assert
self
.
decode_worker_info
.
endpoint
is
not
None
self
.
workers
_client
=
await
self
.
_get_or_create_client
(
self
.
_decode
_client
=
await
self
.
_get_or_create_client
(
self
.
decode_worker_info
.
component_name
,
self
.
decode_worker_info
.
endpoint
,
)
num_d
_workers
=
len
(
self
.
workers
_client
.
instance_ids
())
# type: ignore
num_d
=
len
(
self
.
_decode
_client
.
instance_ids
())
# type: ignore
except
Exception
as
e
:
raise
RuntimeError
(
f
"Failed to get decode worker endpoints:
{
e
}
"
)
return
num_p_workers
,
num_d_workers
,
True
# Always stable for non-K8s
async
def
observe_traffic_stats
(
self
,
require_prefill
:
bool
=
True
,
require_decode
:
bool
=
True
)
->
None
:
"""
Observe metrics from Prometheus and update shared state.
"""
num_p_workers
,
num_d_workers
,
_
=
await
self
.
get_workers_info
(
require_prefill
=
require_prefill
,
require_decode
=
require_decode
)
self
.
shared_state
.
num_p_workers
=
num_p_workers
self
.
shared_state
.
num_d_workers
=
num_d_workers
logger
.
debug
(
f
"Number of prefill workers:
{
num_p_workers
}
, number of decode workers:
{
num_d_workers
}
"
)
return
num_p
,
num_d
,
True
# Update Prometheus metrics if server is running
if
self
.
prometheus_port
!=
0
and
self
.
prometheus_metrics
is
not
None
:
self
.
prometheus_metrics
.
num_p_workers
.
set
(
num_p_workers
)
self
.
prometheus_metrics
.
num_d_workers
.
set
(
num_d_workers
)
async
def
_collect_traffic
(
self
)
->
Optional
[
TrafficObservation
]:
"""Pull traffic metrics from Prometheus."""
num_p
,
num_d
,
_
=
await
self
.
_get_worker_counts_raw
()
# Calculate and accumulate GPU hours for this interval
# TODO: track startup and shutdown times to get more accurate GPU hours
interval_gpu_hours
=
(
if
self
.
prometheus_port
!=
0
:
self
.
prometheus_metrics
.
num_p_workers
.
set
(
num_p
)
self
.
prometheus_metrics
.
num_d_workers
.
set
(
num_d
)
gpu_hours
=
(
(
num_p
_workers
*
(
self
.
config
.
prefill_engine_num_gpu
or
0
)
+
num_d
_workers
*
(
self
.
config
.
decode_engine_num_gpu
or
0
)
num_p
*
(
self
.
config
.
prefill_engine_num_gpu
or
0
)
+
num_d
*
(
self
.
config
.
decode_engine_num_gpu
or
0
)
)
*
self
.
config
.
throughput_adjustment_interval
/
3600
)
self
.
shared_state
.
cumulative_gpu_hours
+=
interval_gpu_hours
self
.
prometheus_metrics
.
gpu_hours
.
set
(
self
.
shared_state
.
cumulative_gpu_hours
)
# Prometheus returns seconds, convert to milliseconds
assert
(
self
.
model_name
is
not
None
),
"model_name must be set before observing traffic stats"
self
.
_cumulative_gpu_hours
+=
gpu_hours
self
.
prometheus_metrics
.
gpu_hours
.
set
(
self
.
_cumulative_gpu_hours
)
assert
self
.
model_name
is
not
None
interval_str
=
f
"
{
self
.
config
.
throughput_adjustment_interval
}
s"
self
.
last_metrics
.
ttft
=
(
m
=
self
.
_last_metrics
m
.
ttft
=
(
self
.
prometheus_traffic_client
.
get_avg_time_to_first_token
(
interval_str
,
self
.
model_name
,
interval_str
,
self
.
model_name
)
*
1000
)
self
.
last_metrics
.
itl
=
(
m
.
itl
=
(
self
.
prometheus_traffic_client
.
get_avg_inter_token_latency
(
interval_str
,
self
.
model_name
,
interval_str
,
self
.
model_name
)
*
1000
)
self
.
last_metrics
.
num_req
=
(
self
.
prometheus_traffic_client
.
get_avg_request_count
(
interval_str
,
self
.
model_name
,
)
m
.
num_req
=
self
.
prometheus_traffic_client
.
get_avg_request_count
(
interval_str
,
self
.
model_name
)
self
.
last_metrics
.
request_duration
=
(
self
.
prometheus_traffic_client
.
get_avg_request_duration
(
interval_str
,
self
.
model_name
,
)
m
.
request_duration
=
self
.
prometheus_traffic_client
.
get_avg_request_duration
(
interval_str
,
self
.
model_name
)
self
.
last_metrics
.
isl
=
(
self
.
prometheus_traffic_client
.
get_avg_input_sequence_tokens
(
interval_str
,
self
.
model_name
,
)
m
.
isl
=
self
.
prometheus_traffic_client
.
get_avg_input_sequence_tokens
(
interval_str
,
self
.
model_name
)
self
.
last_metrics
.
osl
=
(
self
.
prometheus_traffic_client
.
get_avg_output_sequence_tokens
(
interval_str
,
self
.
model_name
,
)
m
.
osl
=
self
.
prometheus_traffic_client
.
get_avg_output_sequence_tokens
(
interval_str
,
self
.
model_name
)
logger
.
info
(
f
"Observed num_req:
{
self
.
last_metrics
.
num_req
:.
2
f
}
isl:
{
self
.
last_metrics
.
isl
:.
2
f
}
osl:
{
self
.
last_metrics
.
osl
:.
2
f
}
"
)
logger
.
info
(
f
"Observed ttft:
{
self
.
last_metrics
.
ttft
:.
2
f
}
ms itl:
{
self
.
last_metrics
.
itl
:.
2
f
}
ms"
f
"Observed num_req:
{
m
.
num_req
:.
2
f
}
isl:
{
m
.
isl
:.
2
f
}
osl:
{
m
.
osl
:.
2
f
}
"
)
# Update observed metrics in Prometheus
if
self
.
prometheus_port
!=
0
and
self
.
prometheus_metrics
is
not
None
:
self
.
prometheus_metrics
.
observed_ttft
.
set
(
self
.
last_metrics
.
ttft
)
self
.
prometheus_metrics
.
observed_itl
.
set
(
self
.
last_metrics
.
itl
)
if
self
.
prometheus_port
!=
0
:
self
.
prometheus_metrics
.
observed_ttft
.
set
(
m
.
ttft
)
self
.
prometheus_metrics
.
observed_itl
.
set
(
m
.
itl
)
self
.
prometheus_metrics
.
observed_request_rate
.
set
(
self
.
last_metrics
.
num_req
/
self
.
config
.
throughput_adjustment_interval
)
self
.
prometheus_metrics
.
observed_request_duration
.
set
(
self
.
last_metrics
.
request_duration
m
.
num_req
/
self
.
config
.
throughput_adjustment_interval
)
self
.
prometheus_metrics
.
observed_isl
.
set
(
self
.
last_metrics
.
isl
)
self
.
prometheus_metrics
.
observed_osl
.
set
(
self
.
last_metrics
.
osl
)
self
.
prometheus_metrics
.
observed_request_duration
.
set
(
m
.
request_duration
)
self
.
prometheus_metrics
.
observed_isl
.
set
(
m
.
isl
)
self
.
prometheus_metrics
.
observed_osl
.
set
(
m
.
osl
)
self
.
update_predictors_from_metrics
(
self
.
last_metrics
)
def
update_predictors_from_metrics
(
self
,
metrics
:
Metrics
)
->
None
:
if
metrics
.
num_req
is
not
None
:
self
.
num_req_predictor
.
add_data_point
(
metrics
.
num_req
)
if
metrics
.
isl
is
not
None
:
self
.
isl_predictor
.
add_data_point
(
metrics
.
isl
)
if
metrics
.
osl
is
not
None
:
self
.
osl_predictor
.
add_data_point
(
metrics
.
osl
)
def
predict_load
(
self
)
->
tuple
[
Optional
[
float
],
Optional
[
float
],
Optional
[
float
]]:
try
:
next_num_req
=
self
.
num_req_predictor
.
predict_next
()
next_isl
=
self
.
isl_predictor
.
predict_next
()
next_osl
=
self
.
osl_predictor
.
predict_next
()
logger
.
info
(
f
"Predicted load: num_req=
{
next_num_req
:.
2
f
}
, isl=
{
next_isl
:.
2
f
}
, osl=
{
next_osl
:.
2
f
}
"
)
return
next_num_req
,
next_isl
,
next_osl
except
Exception
as
e
:
logger
.
error
(
f
"Failed to predict load:
{
e
}
"
)
return
None
,
None
,
None
def
plan_adjustment
(
self
)
->
Optional
[
int
]:
if
not
self
.
last_metrics
.
is_valid
():
logger
.
info
(
"Metrics contain None or NaN values (no active requests), skipping adjustment"
)
return
None
next_num_req
,
next_isl
,
next_osl
=
self
.
predict_load
()
if
next_num_req
is
None
or
next_isl
is
None
or
next_osl
is
None
:
if
not
m
.
is_valid
():
logger
.
info
(
"Metrics contain None or NaN values, skipping"
)
return
None
# Update predicted load metrics in Prometheus
if
self
.
prometheus_port
!=
0
and
self
.
prometheus_metrics
is
not
None
:
self
.
prometheus_metrics
.
predicted_request_rate
.
set
(
next_num_req
/
self
.
config
.
throughput_adjustment_interval
)
self
.
prometheus_metrics
.
predicted_isl
.
set
(
next_isl
)
self
.
prometheus_metrics
.
predicted_osl
.
set
(
next_osl
)
try
:
return
self
.
_compute_replica_requirements
(
next_num_req
,
next_isl
,
next_osl
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to compute number of replicas:
{
e
}
"
)
return
None
def
update_predicted_replicas_metric
(
self
,
desired_replicas
:
int
)
->
None
:
raise
NotImplementedError
def
_compute_replica_requirements
(
self
,
next_num_req
:
float
,
next_isl
:
float
,
next_osl
:
float
)
->
Optional
[
int
]:
raise
NotImplementedError
def
_component_name
(
self
)
->
str
:
if
self
.
component_type
==
SubComponentType
.
PREFILL
:
assert
self
.
prefill_worker_info
.
k8s_name
is
not
None
return
self
.
prefill_worker_info
.
k8s_name
assert
self
.
decode_worker_info
.
k8s_name
is
not
None
return
self
.
decode_worker_info
.
k8s_name
def
_engine_num_gpu
(
self
)
->
int
:
if
self
.
component_type
==
SubComponentType
.
PREFILL
:
assert
self
.
config
.
prefill_engine_num_gpu
is
not
None
return
self
.
config
.
prefill_engine_num_gpu
assert
self
.
config
.
decode_engine_num_gpu
is
not
None
return
self
.
config
.
decode_engine_num_gpu
def
apply_component_budget
(
self
,
desired_replicas
:
int
)
->
int
:
return
_apply_component_gpu_budget
(
max
(
desired_replicas
,
self
.
config
.
min_endpoint
),
self
.
_engine_num_gpu
(),
self
.
config
,
)
async
def
_apply_scaling
(
self
,
desired_replicas
:
int
)
->
None
:
if
self
.
config
.
no_operation
:
return
target_replicas
=
[
TargetReplica
(
sub_component_type
=
self
.
component_type
,
component_name
=
self
.
_component_name
(),
desired_replicas
=
desired_replicas
,
)
]
await
self
.
connector
.
set_component_replicas
(
target_replicas
,
blocking
=
False
)
_apply_scaling_blocking
=
_apply_scaling
@
staticmethod
def
_reconcile_fpm_worker_count
(
fpm_stats
:
"dict[tuple[str, int], ForwardPassMetrics]"
,
dgd_count
:
int
,
label
:
str
,
)
->
bool
:
"""Validate that FPM coverage matches DGD worker count, accounting for DP.
With attention DP, each worker emits FPM per dp_rank. We check that
the number of unique worker IDs matches DGD, and that all workers
have the same number of dp_ranks (complete coverage).
Returns True if counts match, False otherwise.
"""
workers_to_dp
:
dict
[
str
,
set
[
int
]]
=
{}
for
wid
,
dp
in
fpm_stats
:
workers_to_dp
.
setdefault
(
wid
,
set
()).
add
(
dp
)
fpm_worker_count
=
len
(
workers_to_dp
)
if
fpm_worker_count
!=
dgd_count
:
logger
.
warning
(
f
"Worker count mismatch: DGD reports
{
dgd_count
}
, "
f
"FPM reports
{
fpm_worker_count
}
workers for
{
label
}
. "
"Skipping scaling."
)
return
False
dp_sizes
=
{
len
(
dps
)
for
dps
in
workers_to_dp
.
values
()}
if
len
(
dp_sizes
)
>
1
:
logger
.
warning
(
f
"Inconsistent DP ranks across workers for
{
label
}
: "
f
"
{
dict
(
workers_to_dp
)
}
. Skipping scaling."
)
return
False
dp_size
=
dp_sizes
.
pop
()
if
dp_sizes
else
1
expected_total
=
dgd_count
*
dp_size
actual_total
=
len
(
fpm_stats
)
if
actual_total
!=
expected_total
:
logger
.
warning
(
f
"Incomplete FPM coverage for
{
label
}
: expected "
f
"
{
dgd_count
}
workers ×
{
dp_size
}
dp_ranks =
{
expected_total
}
, "
f
"got
{
actual_total
}
. Skipping scaling."
)
return
False
if
dp_size
>
1
:
logger
.
info
(
f
"FPM
{
label
}
:
{
fpm_worker_count
}
workers ×
{
dp_size
}
dp_ranks "
f
"=
{
actual_total
}
engines"
)
return
True
@
staticmethod
def
_log_fpm
(
wid
:
str
,
dp
:
int
,
fpm
:
"ForwardPassMetrics"
,
label
:
str
)
->
None
:
sched
=
fpm
.
scheduled_requests
queued
=
fpm
.
queued_requests
logger
.
info
(
f
"FPM
{
label
}
engine
{
wid
}
:dp
{
dp
}
: "
f
"wall_time=
{
fpm
.
wall_time
:.
4
f
}
s, "
f
"sched(prefill_tok=
{
sched
.
sum_prefill_tokens
}
, "
f
"prefill_req=
{
sched
.
num_prefill_requests
}
, "
f
"decode_kv=
{
sched
.
sum_decode_kv_tokens
}
, "
f
"decode_req=
{
sched
.
num_decode_requests
}
), "
f
"queued(prefill_tok=
{
queued
.
sum_prefill_tokens
}
, "
f
"decode_kv=
{
queued
.
sum_decode_kv_tokens
}
)"
return
TrafficObservation
(
duration_s
=
self
.
config
.
throughput_adjustment_interval
,
num_req
=
m
.
num_req
,
isl
=
m
.
isl
,
osl
=
m
.
osl
,
)
def
observe_fpm_load_stats
(
self
,
)
->
"dict[tuple[str, int], ForwardPassMetrics]"
:
"""Get latest FPM stats and feed observations into the regression model.
Returns:
The decoded FPM stats dict for use by load_plan_adjustment().
"""
fpm_stats
=
self
.
_get_fpm_stats
()
if
not
fpm_stats
:
logger
.
warning
(
f
"No FPM data available for
{
self
.
component_type
.
value
}
(tracker empty)"
)
return
{}
for
(
wid
,
dp
),
fpm
in
fpm_stats
.
items
():
self
.
_log_fpm
(
wid
,
dp
,
fpm
,
self
.
component_type
.
value
)
if
self
.
component_type
==
SubComponentType
.
PREFILL
:
self
.
ttft_regression
.
add_observation
(
fpm
)
elif
self
.
component_type
==
SubComponentType
.
DECODE
:
self
.
itl_regression
.
add_observation
(
fpm
)
logger
.
info
(
f
"FPM load stats:
{
len
(
fpm_stats
)
}
engines observed for "
f
"
{
self
.
component_type
.
value
}
"
def
_collect_fpm
(
self
)
->
FpmObservations
:
"""Collect FPM from active subscribers."""
prefill_stats
=
None
decode_stats
=
None
if
self
.
_prefill_fpm_sub
is
not
None
:
stats
=
self
.
_decode_fpm_bytes
(
self
.
_prefill_fpm_sub
)
if
stats
:
for
(
wid
,
dp
),
fpm
in
stats
.
items
():
_log_fpm
(
wid
,
dp
,
fpm
,
"prefill"
)
prefill_stats
=
stats
if
self
.
_decode_fpm_sub
is
not
None
:
stats
=
self
.
_decode_fpm_bytes
(
self
.
_decode_fpm_sub
)
if
stats
:
for
(
wid
,
dp
),
fpm
in
stats
.
items
():
_log_fpm
(
wid
,
dp
,
fpm
,
"decode"
)
decode_stats
=
stats
return
FpmObservations
(
prefill
=
prefill_stats
,
decode
=
decode_stats
)
async
def
_collect_worker_counts
(
self
)
->
WorkerCounts
:
num_p
,
num_d
,
is_stable
=
await
self
.
_get_worker_counts_raw
()
return
WorkerCounts
(
ready_num_prefill
=
num_p
if
self
.
require_prefill
else
None
,
ready_num_decode
=
num_d
if
self
.
require_decode
else
None
,
expected_num_prefill
=
(
num_p
if
is_stable
else
None
)
if
self
.
require_prefill
else
None
,
expected_num_decode
=
(
num_d
if
is_stable
else
None
)
if
self
.
require_decode
else
None
,
)
return
fpm_stats
def
_load_based_scaling_decision_from_estimates
(
self
,
estimates
:
list
[
float
],
sla
:
float
,
num_workers
:
int
,
label
:
str
,
)
->
Optional
[
int
]:
"""Shared scale-up/down logic from per-engine latency estimates (ms).
Args:
estimates: per-engine estimated latencies in ms.
sla: target SLA in ms (e.g. config.ttft or config.itl).
num_workers: current worker count for this component.
label: human-readable label for log messages (e.g. "prefill TTFT").
Returns:
Desired replica count, or None if no scaling action needed.
"""
if
not
estimates
:
return
None
sensitivity
=
self
.
config
.
load_scaling_down_sensitivity
/
100.0
logger
.
info
(
f
"Load-based
{
label
}
: workers=
{
num_workers
}
, sla=
{
sla
:.
1
f
}
ms, "
f
"estimates=
{
[
f
'
{
t
:.
1
f
}
' for t in estimates]
}
"
# ------------------------------------------------------------------
# Gather tick input
# ------------------------------------------------------------------
async
def
_gather_tick_input
(
self
,
tick
:
ScheduledTick
)
->
TickInput
:
now
=
time
.
time
()
traffic
=
None
worker_counts
=
None
fpm_obs
=
None
if
tick
.
need_traffic_metrics
:
traffic
=
await
self
.
_collect_traffic
()
if
tick
.
need_worker_states
:
worker_counts
=
await
self
.
_collect_worker_counts
()
if
tick
.
need_worker_fpm
:
fpm_obs
=
self
.
_collect_fpm
()
return
TickInput
(
now_s
=
now
,
traffic
=
traffic
,
worker_counts
=
worker_counts
,
fpm_observations
=
fpm_obs
,
)
if
all
(
t
>
sla
for
t
in
estimates
):
logger
.
info
(
f
"Load-based
{
label
}
: ALL engines above SLA (
{
sla
:.
1
f
}
ms), "
f
"scaling up to
{
num_workers
+
1
}
"
)
return
num_workers
+
1
if
num_workers
>
1
:
threshold
=
sla
*
sensitivity
if
all
(
t
<
threshold
for
t
in
estimates
):
desired
=
max
(
num_workers
-
1
,
self
.
config
.
min_endpoint
)
if
desired
==
num_workers
:
logger
.
info
(
f
"Load-based
{
label
}
: ALL engines below threshold "
f
"(
{
threshold
:.
1
f
}
ms), but at min_endpoint (
{
self
.
config
.
min_endpoint
}
)"
)
else
:
logger
.
info
(
f
"Load-based
{
label
}
: ALL engines below threshold "
f
"(
{
threshold
:.
1
f
}
ms), scaling down to
{
desired
}
"
)
return
desired
return
None
# ------------------------------------------------------------------
# Apply effects (override in subclasses for mode-specific metrics)
# ------------------------------------------------------------------
def
load_plan_adjustment
(
self
)
->
Optional
[
int
]
:
"""
Load-based scaling decision. Override in subclasses
."""
raise
NotImplementedError
async
def
_apply_effects
(
self
,
effects
:
PlannerEffects
)
->
None
:
"""
Override in subclasses to report metrics and apply scaling
."""
pass
async
def
_
throughput_loop
(
self
,
require_prefill
:
bool
,
require_decode
:
bool
async
def
_
apply_scaling_targets
(
self
,
targets
:
list
[
TargetReplica
],
blocking
:
bool
=
False
)
->
None
:
"""Throughput-based scaling loop (existing behavior, extracted from run())."""
while
True
:
current_time
=
time
.
time
()
if
(
current_time
-
self
.
shared_state
.
last_adjustment_time
>=
self
.
config
.
throughput_adjustment_interval
):
self
.
shared_state
.
last_adjustment_time
=
time
.
time
()
logger
.
info
(
"New throughput adjustment interval started!"
)
await
self
.
observe_traffic_stats
(
require_prefill
=
require_prefill
,
require_decode
=
require_decode
)
desired_replicas
=
self
.
plan_adjustment
()
if
desired_replicas
is
not
None
:
if
self
.
enable_load
:
# When load-based is also enabled: just set lower bound
if
self
.
component_type
==
SubComponentType
.
PREFILL
:
self
.
shared_state
.
throughput_lower_bound_p
=
(
desired_replicas
)
else
:
self
.
shared_state
.
throughput_lower_bound_d
=
(
desired_replicas
)
logger
.
info
(
f
"Throughput lower bound set to
{
desired_replicas
}
for
{
self
.
component_type
.
value
}
"
)
else
:
# Throughput-only: apply scaling directly
desired_replicas
=
self
.
apply_component_budget
(
desired_replicas
)
self
.
update_predicted_replicas_metric
(
desired_replicas
)
# Throughput planner does not needs blocking scaling because it monitors
# and predicts the load, not relying on the current status of the engine.
await
self
.
_apply_scaling
(
desired_replicas
)
await
asyncio
.
sleep
(
self
.
config
.
throughput_adjustment_interval
/
10
)
async
def
_load_and_fpm_update_loop
(
self
,
require_prefill
:
bool
,
require_decode
:
bool
)
->
None
:
"""FPM observation and (optionally) load-based scaling loop.
Runs every load_adjustment_interval. Always updates the FPM
regression model with live observations. When load-based scaling
is enabled, also makes scaling decisions immediately after the
FPM update.
"""
pending_desired
:
Optional
[
int
]
=
None
while
True
:
await
asyncio
.
sleep
(
self
.
config
.
load_adjustment_interval
)
logger
.
info
(
"New load/FPM update interval started!"
)
num_p
,
num_d
,
is_stable
=
await
self
.
get_workers_info
(
require_prefill
=
require_prefill
,
require_decode
=
require_decode
)
self
.
shared_state
.
num_p_workers
=
num_p
self
.
shared_state
.
num_d_workers
=
num_d
fpm_stats
=
self
.
observe_fpm_load_stats
()
if
not
fpm_stats
:
continue
"""Shared helper: send scaling targets to connector."""
if
self
.
config
.
no_operation
or
not
targets
:
return
await
self
.
connector
.
set_component_replicas
(
targets
,
blocking
=
blocking
)
if
not
self
.
enable_load
:
continue
# ------------------------------------------------------------------
# Main loop
# ------------------------------------------------------------------
if
pending_desired
is
not
None
:
dgd_count
=
(
num_p
if
self
.
component_type
==
SubComponentType
.
PREFILL
else
num_d
)
if
dgd_count
==
pending_desired
:
logger
.
info
(
f
"Scaling to
{
pending_desired
}
complete, resuming decisions"
)
pending_desired
=
None
else
:
logger
.
info
(
f
"Scaling in progress (
{
dgd_count
}
->
{
pending_desired
}
), "
"observing only"
)
continue
async
def
run
(
self
)
->
None
:
next_tick
=
self
.
state_machine
.
initial_tick
(
time
.
time
())
poll_interval
=
self
.
config
.
load_adjustment_interval
/
10
dgd_count
=
(
num_p
if
self
.
component_type
==
SubComponentType
.
PREFILL
else
num_d
)
if
not
self
.
_reconcile_fpm_worker_count
(
fpm_stats
,
dgd_count
,
self
.
component_type
.
value
):
while
True
:
now
=
time
.
time
()
if
now
<
next_tick
.
at_s
:
await
asyncio
.
sleep
(
min
(
next_tick
.
at_s
-
now
,
poll_interval
))
continue
desired_replicas
=
self
.
load_plan_adjustment
()
if
desired_replicas
is
not
None
:
if
self
.
enable_throughput
:
if
self
.
component_type
==
SubComponentType
.
PREFILL
:
lower_bound
=
self
.
shared_state
.
throughput_lower_bound_p
else
:
lower_bound
=
self
.
shared_state
.
throughput_lower_bound_d
desired_replicas
=
max
(
desired_replicas
,
lower_bound
)
desired_replicas
=
self
.
apply_component_budget
(
desired_replicas
)
self
.
update_predicted_replicas_metric
(
desired_replicas
)
pending_desired
=
desired_replicas
await
self
.
_apply_scaling_blocking
(
desired_replicas
)
async
def
run
(
self
):
"""Main scaling loop. Call _async_init() before this."""
require_prefill
=
self
.
component_type
==
SubComponentType
.
PREFILL
require_decode
=
self
.
component_type
==
SubComponentType
.
DECODE
self
.
shared_state
.
last_adjustment_time
=
time
.
time
()
self
.
shared_state
.
last_load_adjustment_time
=
time
.
time
()
loops
=
[]
if
self
.
enable_throughput
:
loops
.
append
(
self
.
_throughput_loop
(
require_prefill
,
require_decode
))
loops
.
append
(
self
.
_load_and_fpm_update_loop
(
require_prefill
,
require_decode
))
await
asyncio
.
gather
(
*
loops
)
tick_input
=
await
self
.
_gather_tick_input
(
next_tick
)
effects
=
self
.
state_machine
.
on_tick
(
next_tick
,
tick_input
)
await
self
.
_apply_effects
(
effects
)
assert
effects
.
next_tick
is
not
None
next_tick
=
effects
.
next_tick
# ------------------------------------------------------------------
# Shared utility
# ------------------------------------------------------------------
def
_log_fpm
(
wid
:
str
,
dp
:
int
,
fpm
:
ForwardPassMetrics
,
label
:
str
)
->
None
:
sched
=
fpm
.
scheduled_requests
queued
=
fpm
.
queued_requests
logger
.
info
(
f
"FPM
{
label
}
engine
{
wid
}
:dp
{
dp
}
: "
f
"wall_time=
{
fpm
.
wall_time
:.
4
f
}
s, "
f
"sched(prefill_tok=
{
sched
.
sum_prefill_tokens
}
, "
f
"prefill_req=
{
sched
.
num_prefill_requests
}
, "
f
"decode_kv=
{
sched
.
sum_decode_kv_tokens
}
, "
f
"decode_req=
{
sched
.
num_decode_requests
}
), "
f
"queued(prefill_tok=
{
queued
.
sum_prefill_tokens
}
, "
f
"decode_kv=
{
queued
.
sum_decode_kv_tokens
}
)"
)
components/src/dynamo/planner/core/decode.py
deleted
100644 → 0
View file @
39a6a240
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
logging
import
math
from
typing
import
Optional
from
dynamo.planner.config.defaults
import
SubComponentType
from
dynamo.planner.core.base
import
BasePlanner
from
dynamo.runtime.logging
import
configure_dynamo_logging
configure_dynamo_logging
()
logger
=
logging
.
getLogger
(
__name__
)
class
DecodePlanner
(
BasePlanner
):
component_type
=
SubComponentType
.
DECODE
def
load_plan_adjustment
(
self
)
->
Optional
[
int
]:
"""Load-based scaling decision for decode using FPM data.
For each engine, estimates next decode ITL:
- Uses scheduled + queued decode KV tokens + avg decode length
- Predicts wall time via regression
Scale up if ALL engines' estimated ITL > SLA.
Scale down if ALL engines' estimated ITL < SLA * sensitivity.
"""
if
not
self
.
itl_regression
.
has_sufficient_data
():
logger
.
info
(
f
"ITL regression: insufficient data (
{
self
.
itl_regression
.
num_observations
}
"
f
"/
{
self
.
itl_regression
.
min_observations
}
), skipping load-based scaling"
)
return
None
fpm_stats
=
self
.
_get_fpm_stats
()
if
not
fpm_stats
:
return
None
num_workers
=
self
.
shared_state
.
num_d_workers
if
num_workers
==
0
:
return
None
estimated_itls
:
list
[
float
]
=
[]
for
(
wid
,
dp
),
fpm
in
fpm_stats
.
items
():
scheduled_kv
=
fpm
.
scheduled_requests
.
sum_decode_kv_tokens
queued_kv
=
fpm
.
queued_requests
.
sum_decode_kv_tokens
est
=
self
.
itl_regression
.
estimate_next_itl
(
scheduled_decode_kv
=
scheduled_kv
,
queued_decode_kv
=
queued_kv
,
)
if
est
is
None
:
continue
est_ms
=
est
*
1000
estimated_itls
.
append
(
est_ms
)
logger
.
info
(
f
"Decode engine
{
wid
}
:dp
{
dp
}
: estimated ITL
{
est_ms
:.
2
f
}
ms "
f
"(sched_kv=
{
scheduled_kv
}
, queued_kv=
{
queued_kv
}
, "
f
"avg_decode_len=
{
self
.
itl_regression
.
avg_decode_length
:.
1
f
}
)"
)
return
self
.
_load_based_scaling_decision_from_estimates
(
estimates
=
estimated_itls
,
sla
=
self
.
config
.
itl
,
num_workers
=
num_workers
,
label
=
"decode ITL"
,
)
def
_compute_replica_requirements
(
self
,
next_num_req
:
float
,
next_isl
:
float
,
next_osl
:
float
)
->
Optional
[
int
]:
demand_rps
=
next_num_req
/
self
.
config
.
throughput_adjustment_interval
engine_rps
,
actual_itl_ms
=
self
.
itl_regression
.
find_best_engine_decode_rps
(
itl
=
self
.
config
.
itl
,
context_length
=
next_isl
+
next_osl
/
2
,
osl
=
next_osl
,
)
if
engine_rps
<=
0
:
logger
.
warning
(
"Decode perf model not ready, skipping throughput scaling"
)
return
None
if
actual_itl_ms
>
self
.
config
.
itl
:
logger
.
warning
(
f
"Decode ITL SLA not met:
{
actual_itl_ms
:.
1
f
}
ms > "
f
"
{
self
.
config
.
itl
:.
1
f
}
ms, scaling with best achievable rate"
)
next_num_d
=
math
.
ceil
(
demand_rps
/
engine_rps
)
next_num_d
=
max
(
next_num_d
,
self
.
config
.
min_endpoint
)
logger
.
info
(
f
"Decode:
{
demand_rps
:.
2
f
}
(demand rps) / "
f
"
{
engine_rps
:.
2
f
}
(engine rps) =
{
next_num_d
}
(num_d), "
f
"est_itl=
{
actual_itl_ms
:.
1
f
}
ms"
)
return
next_num_d
def
update_predicted_replicas_metric
(
self
,
desired_replicas
:
int
)
->
None
:
if
self
.
prometheus_port
!=
0
and
self
.
prometheus_metrics
is
not
None
:
self
.
prometheus_metrics
.
predicted_num_d
.
set
(
desired_replicas
)
components/src/dynamo/planner/core/disagg.py
deleted
100644 → 0
View file @
39a6a240
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
logging
import
time
from
dynamo.planner.config.backend_components
import
WORKER_COMPONENT_NAMES
from
dynamo.planner.config.defaults
import
SubComponentType
,
TargetReplica
from
dynamo.planner.config.planner_config
import
PlannerConfig
from
dynamo.planner.core.base
import
BasePlanner
from
dynamo.planner.core.budget
import
_apply_global_gpu_budget
,
_initialize_gpu_counts
from
dynamo.planner.core.decode
import
DecodePlanner
from
dynamo.planner.core.prefill
import
PrefillPlanner
from
dynamo.planner.core.state
import
PlannerSharedState
from
dynamo.planner.monitoring.planner_metrics
import
PlannerPrometheusMetrics
from
dynamo.runtime
import
DistributedRuntime
from
dynamo.runtime.logging
import
configure_dynamo_logging
configure_dynamo_logging
()
logger
=
logging
.
getLogger
(
__name__
)
class
DisaggPlanner
:
def
__init__
(
self
,
runtime
:
DistributedRuntime
,
config
:
PlannerConfig
)
->
None
:
self
.
config
=
config
self
.
shared_state
=
PlannerSharedState
()
prometheus_metrics
=
PlannerPrometheusMetrics
()
self
.
enable_throughput
=
config
.
enable_throughput_scaling
self
.
enable_load
=
config
.
enable_load_scaling
self
.
prefill_planner
=
PrefillPlanner
(
runtime
,
config
,
shared_state
=
self
.
shared_state
,
prometheus_metrics
=
prometheus_metrics
,
start_prometheus_server
=
True
,
)
self
.
decode_planner
=
DecodePlanner
(
runtime
,
config
,
shared_state
=
self
.
shared_state
,
prometheus_metrics
=
prometheus_metrics
,
prometheus_traffic_client
=
getattr
(
self
.
prefill_planner
,
"prometheus_traffic_client"
,
None
),
connector
=
getattr
(
self
.
prefill_planner
,
"connector"
,
None
),
start_prometheus_server
=
False
,
)
async
def
_async_init
(
self
):
# DisaggPlanner overrides _async_init to handle both prefill+decode
# and share WorkerInfo between the two sub-planners.
defaults
=
WORKER_COMPONENT_NAMES
.
get
(
self
.
config
.
backend
)
if
not
self
.
config
.
no_operation
:
# Connector init (prefill/decode share the same connector)
connector
=
getattr
(
self
.
prefill_planner
,
"connector"
,
None
)
if
connector
and
hasattr
(
connector
,
"_async_init"
):
await
connector
.
_async_init
()
logger
.
info
(
"Validating deployment..."
)
await
self
.
prefill_planner
.
connector
.
validate_deployment
(
prefill_component_name
=
(
defaults
.
prefill_worker_k8s_name
if
defaults
else
None
),
decode_component_name
=
(
defaults
.
decode_worker_k8s_name
if
defaults
else
None
),
require_prefill
=
True
,
require_decode
=
True
,
)
logger
.
info
(
"Successfully validated the deployment"
)
_initialize_gpu_counts
(
self
.
config
,
self
.
prefill_planner
.
connector
,
require_prefill
=
True
,
require_decode
=
True
,
)
await
self
.
prefill_planner
.
connector
.
wait_for_deployment_ready
(
include_planner
=
False
)
await
self
.
prefill_planner
.
_init_worker_info
(
require_prefill
=
True
,
require_decode
=
True
)
# Share WorkerInfo and model name with decode planner
self
.
decode_planner
.
prefill_worker_info
=
(
self
.
prefill_planner
.
prefill_worker_info
)
self
.
decode_planner
.
decode_worker_info
=
self
.
prefill_planner
.
decode_worker_info
self
.
decode_planner
.
model_name
=
self
.
prefill_planner
.
model_name
if
self
.
prefill_planner
.
runtime
is
not
None
:
await
self
.
prefill_planner
.
_init_fpm_subscriber
()
if
self
.
decode_planner
.
runtime
is
not
None
:
await
self
.
decode_planner
.
_init_fpm_subscriber
()
await
self
.
prefill_planner
.
_bootstrap_regression
()
await
self
.
decode_planner
.
_bootstrap_regression
()
async
def
run
(
self
):
"""Main scaling loop. Call _async_init() before this."""
self
.
shared_state
.
last_adjustment_time
=
time
.
time
()
self
.
shared_state
.
last_load_adjustment_time
=
time
.
time
()
loops
=
[]
if
self
.
enable_throughput
:
loops
.
append
(
self
.
_throughput_loop
())
loops
.
append
(
self
.
_load_and_fpm_update_loop
())
await
asyncio
.
gather
(
*
loops
)
async
def
_throughput_loop
(
self
)
->
None
:
"""Throughput-based scaling loop for disagg mode."""
while
True
:
current_time
=
time
.
time
()
if
(
current_time
-
self
.
shared_state
.
last_adjustment_time
>=
self
.
config
.
throughput_adjustment_interval
):
self
.
shared_state
.
last_adjustment_time
=
time
.
time
()
logger
.
info
(
"New throughput adjustment interval started!"
)
await
self
.
prefill_planner
.
observe_traffic_stats
(
require_prefill
=
True
,
require_decode
=
True
)
self
.
decode_planner
.
update_predictors_from_metrics
(
self
.
shared_state
.
last_metrics
)
next_num_p
=
self
.
prefill_planner
.
plan_adjustment
()
next_num_d
=
self
.
decode_planner
.
plan_adjustment
()
if
next_num_p
is
None
or
next_num_d
is
None
:
await
asyncio
.
sleep
(
self
.
config
.
throughput_adjustment_interval
/
10
)
continue
if
self
.
enable_load
:
# When load-based is also enabled: just set lower bounds
self
.
shared_state
.
throughput_lower_bound_p
=
next_num_p
self
.
shared_state
.
throughput_lower_bound_d
=
next_num_d
logger
.
info
(
f
"Throughput lower bounds set: prefill=
{
next_num_p
}
, decode=
{
next_num_d
}
"
)
else
:
# Throughput-only: apply scaling directly
next_num_p
,
next_num_d
=
_apply_global_gpu_budget
(
next_num_p
,
next_num_d
,
self
.
config
)
self
.
prefill_planner
.
update_predicted_replicas_metric
(
next_num_p
)
self
.
decode_planner
.
update_predicted_replicas_metric
(
next_num_d
)
if
not
self
.
config
.
no_operation
:
target_replicas
=
[
TargetReplica
(
sub_component_type
=
SubComponentType
.
PREFILL
,
component_name
=
self
.
prefill_planner
.
prefill_worker_info
.
k8s_name
,
desired_replicas
=
next_num_p
,
),
TargetReplica
(
sub_component_type
=
SubComponentType
.
DECODE
,
component_name
=
self
.
prefill_planner
.
decode_worker_info
.
k8s_name
,
desired_replicas
=
next_num_d
,
),
]
await
self
.
prefill_planner
.
connector
.
set_component_replicas
(
target_replicas
,
blocking
=
False
)
await
asyncio
.
sleep
(
self
.
config
.
throughput_adjustment_interval
/
10
)
async
def
_load_and_fpm_update_loop
(
self
)
->
None
:
"""FPM observation and (optionally) load-based scaling for disagg mode.
Always updates regression models with live FPM. When load-based
scaling is enabled, makes scaling decisions immediately after.
"""
while
True
:
await
asyncio
.
sleep
(
self
.
config
.
load_adjustment_interval
)
logger
.
info
(
"New load/FPM update interval started!"
)
num_p
,
num_d
,
_
=
await
self
.
prefill_planner
.
get_workers_info
(
require_prefill
=
True
,
require_decode
=
True
)
self
.
shared_state
.
num_p_workers
=
num_p
self
.
shared_state
.
num_d_workers
=
num_d
p_stats
=
self
.
prefill_planner
.
observe_fpm_load_stats
()
d_stats
=
self
.
decode_planner
.
observe_fpm_load_stats
()
if
not
self
.
enable_load
:
continue
if
not
p_stats
and
not
d_stats
:
logger
.
warning
(
"No FPM data for either prefill or decode, skipping"
)
continue
if
p_stats
and
not
BasePlanner
.
_reconcile_fpm_worker_count
(
p_stats
,
num_p
,
"prefill"
):
continue
if
d_stats
and
not
BasePlanner
.
_reconcile_fpm_worker_count
(
d_stats
,
num_d
,
"decode"
):
continue
p_desired
=
self
.
prefill_planner
.
load_plan_adjustment
()
d_desired
=
self
.
decode_planner
.
load_plan_adjustment
()
final_p
=
(
p_desired
if
p_desired
is
not
None
else
self
.
shared_state
.
num_p_workers
)
final_d
=
(
d_desired
if
d_desired
is
not
None
else
self
.
shared_state
.
num_d_workers
)
if
(
final_p
==
self
.
shared_state
.
num_p_workers
and
final_d
==
self
.
shared_state
.
num_d_workers
):
logger
.
info
(
"Load-based scaling: no scaling needed"
)
continue
if
self
.
enable_throughput
:
final_p
=
max
(
final_p
,
self
.
shared_state
.
throughput_lower_bound_p
)
final_d
=
max
(
final_d
,
self
.
shared_state
.
throughput_lower_bound_d
)
final_p
=
max
(
final_p
,
self
.
config
.
min_endpoint
)
final_d
=
max
(
final_d
,
self
.
config
.
min_endpoint
)
final_p
,
final_d
=
_apply_global_gpu_budget
(
final_p
,
final_d
,
self
.
config
)
logger
.
info
(
f
"Load-based disagg scaling: prefill
{
self
.
shared_state
.
num_p_workers
}
->
{
final_p
}
, "
f
"decode
{
self
.
shared_state
.
num_d_workers
}
->
{
final_d
}
"
)
self
.
prefill_planner
.
update_predicted_replicas_metric
(
final_p
)
self
.
decode_planner
.
update_predicted_replicas_metric
(
final_d
)
if
not
self
.
config
.
no_operation
:
target_replicas
=
[
TargetReplica
(
sub_component_type
=
SubComponentType
.
PREFILL
,
component_name
=
self
.
prefill_planner
.
prefill_worker_info
.
k8s_name
,
desired_replicas
=
final_p
,
),
TargetReplica
(
sub_component_type
=
SubComponentType
.
DECODE
,
component_name
=
self
.
prefill_planner
.
decode_worker_info
.
k8s_name
,
desired_replicas
=
final_d
,
),
]
await
self
.
prefill_planner
.
connector
.
set_component_replicas
(
target_replicas
,
blocking
=
True
)
components/src/dynamo/planner/core/load_scaling.py
0 → 100644
View file @
f7e0b3fd
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# mypy: disable-error-code="attr-defined"
"""Load-based scaling logic (FPM-driven, reactive).
Mixin consumed by ``PlannerStateMachine``. All methods access state
via ``self._config``, ``self._capabilities``, and regression models.
"""
from
__future__
import
annotations
import
logging
from
typing
import
TYPE_CHECKING
,
Optional
from
dynamo.planner.core.types
import
FpmObservations
,
ScalingDecision
if
TYPE_CHECKING
:
from
dynamo.common.forward_pass_metrics
import
ForwardPassMetrics
logger
=
logging
.
getLogger
(
__name__
)
class
LoadScalingMixin
:
"""FPM-driven load-based scaling decisions."""
def
_advance_load
(
self
,
obs
:
FpmObservations
)
->
Optional
[
ScalingDecision
]:
if
not
self
.
_config
.
enable_load_scaling
:
return
None
mode
=
self
.
_config
.
mode
if
mode
==
"agg"
:
return
self
.
_advance_load_agg
(
obs
)
if
mode
==
"disagg"
:
return
self
.
_advance_load_disagg
(
obs
)
return
self
.
_advance_load_single
(
obs
,
mode
)
def
_advance_load_single
(
self
,
obs
:
FpmObservations
,
component
:
str
)
->
Optional
[
ScalingDecision
]:
if
self
.
_scaling_in_progress
(
component
):
logger
.
info
(
f
"Scaling in progress for
{
component
}
, observing only"
)
return
None
fpm_stats
=
obs
.
prefill
if
component
==
"prefill"
else
obs
.
decode
num_workers
=
(
self
.
_num_p_workers
if
component
==
"prefill"
else
self
.
_num_d_workers
)
if
not
fpm_stats
:
return
None
if
not
self
.
_reconcile_fpm_worker_count
(
fpm_stats
,
num_workers
,
component
):
return
None
desired
=
(
self
.
_prefill_load_decision
(
fpm_stats
,
num_workers
)
if
component
==
"prefill"
else
self
.
_decode_load_decision
(
fpm_stats
,
num_workers
)
)
if
desired
is
None
:
return
None
if
self
.
_config
.
enable_throughput_scaling
:
bound
=
(
self
.
_throughput_lower_bound_p
if
component
==
"prefill"
else
self
.
_throughput_lower_bound_d
)
desired
=
max
(
desired
,
bound
)
desired
=
self
.
_apply_single_budget
(
desired
,
component
)
return
(
ScalingDecision
(
num_prefill
=
desired
)
if
component
==
"prefill"
else
ScalingDecision
(
num_decode
=
desired
)
)
def
_advance_load_disagg
(
self
,
obs
:
FpmObservations
)
->
Optional
[
ScalingDecision
]:
p_stats
,
d_stats
=
obs
.
prefill
,
obs
.
decode
if
not
p_stats
and
not
d_stats
:
logger
.
warning
(
"No FPM data for either prefill or decode, skipping"
)
return
None
if
p_stats
and
not
self
.
_reconcile_fpm_worker_count
(
p_stats
,
self
.
_num_p_workers
,
"prefill"
):
return
None
if
d_stats
and
not
self
.
_reconcile_fpm_worker_count
(
d_stats
,
self
.
_num_d_workers
,
"decode"
):
return
None
p_desired
=
(
self
.
_prefill_load_decision
(
p_stats
,
self
.
_num_p_workers
)
if
p_stats
else
None
)
d_desired
=
(
self
.
_decode_load_decision
(
d_stats
,
self
.
_num_d_workers
)
if
d_stats
else
None
)
final_p
=
p_desired
if
p_desired
is
not
None
else
self
.
_num_p_workers
final_d
=
d_desired
if
d_desired
is
not
None
else
self
.
_num_d_workers
if
final_p
==
self
.
_num_p_workers
and
final_d
==
self
.
_num_d_workers
:
logger
.
info
(
"Load-based scaling: no scaling needed"
)
return
None
if
self
.
_config
.
enable_throughput_scaling
:
final_p
=
max
(
final_p
,
self
.
_throughput_lower_bound_p
)
final_d
=
max
(
final_d
,
self
.
_throughput_lower_bound_d
)
final_p
=
max
(
final_p
,
self
.
_config
.
min_endpoint
)
final_d
=
max
(
final_d
,
self
.
_config
.
min_endpoint
)
final_p
,
final_d
=
self
.
_apply_global_budget
(
final_p
,
final_d
)
logger
.
info
(
f
"Load-based disagg scaling: prefill
{
self
.
_num_p_workers
}
->
{
final_p
}
, "
f
"decode
{
self
.
_num_d_workers
}
->
{
final_d
}
"
)
return
ScalingDecision
(
num_prefill
=
final_p
,
num_decode
=
final_d
)
def
_advance_load_agg
(
self
,
obs
:
FpmObservations
)
->
Optional
[
ScalingDecision
]:
fpm_stats
=
obs
.
decode
if
not
fpm_stats
:
return
None
num_workers
=
self
.
_num_d_workers
if
self
.
_scaling_in_progress
(
"decode"
):
logger
.
info
(
f
"Scaling in progress (
{
num_workers
}
->
{
self
.
_expected_num_d
}
), observing only"
)
return
None
if
not
self
.
_reconcile_fpm_worker_count
(
fpm_stats
,
num_workers
,
"agg"
):
return
None
if
not
self
.
_agg_regression
.
has_sufficient_data
():
logger
.
info
(
f
"Agg regression: insufficient data "
f
"(
{
self
.
_agg_regression
.
num_observations
}
/
{
self
.
_agg_regression
.
min_observations
}
)"
)
return
None
d_caps
=
self
.
_capabilities
.
decode
max_tokens
=
d_caps
.
max_num_batched_tokens
if
d_caps
else
None
if
not
max_tokens
or
max_tokens
<=
0
:
logger
.
warning
(
"max_num_batched_tokens not available, skipping agg scaling"
)
return
None
p_desired
=
self
.
_agg_prefill_scaling
(
fpm_stats
,
num_workers
,
max_tokens
)
d_desired
=
self
.
_agg_decode_scaling
(
fpm_stats
,
num_workers
)
logger
.
info
(
f
"Agg scaling decisions: prefill=
{
p_desired
}
, decode=
{
d_desired
}
(current=
{
num_workers
}
)"
)
if
p_desired
is
not
None
and
p_desired
>
num_workers
:
desired
=
p_desired
elif
d_desired
is
not
None
and
d_desired
>
num_workers
:
desired
=
d_desired
elif
(
p_desired
is
not
None
and
p_desired
<
num_workers
and
d_desired
is
not
None
and
d_desired
<
num_workers
):
desired
=
max
(
p_desired
,
d_desired
)
else
:
logger
.
info
(
"Agg scaling: no scaling needed"
)
return
None
desired
=
max
(
desired
,
self
.
_config
.
min_endpoint
)
if
self
.
_config
.
enable_throughput_scaling
:
desired
=
max
(
desired
,
self
.
_throughput_lower_bound_d
)
desired
=
self
.
_apply_single_budget
(
desired
,
"decode"
)
logger
.
info
(
f
"Agg load-based scaling:
{
num_workers
}
->
{
desired
}
"
)
return
ScalingDecision
(
num_decode
=
desired
)
# ------------------------------------------------------------------
# Per-engine latency estimation
# ------------------------------------------------------------------
def
_prefill_load_decision
(
self
,
fpm_stats
:
dict
[
tuple
[
str
,
int
],
ForwardPassMetrics
],
num_workers
:
int
)
->
Optional
[
int
]:
if
not
self
.
_prefill_regression
.
has_sufficient_data
():
logger
.
info
(
f
"TTFT regression: insufficient data "
f
"(
{
self
.
_prefill_regression
.
num_observations
}
/
{
self
.
_prefill_regression
.
min_observations
}
)"
)
return
None
if
num_workers
==
0
:
return
None
p_caps
=
self
.
_capabilities
.
prefill
max_tokens
=
p_caps
.
max_num_batched_tokens
if
p_caps
else
None
if
not
max_tokens
or
max_tokens
<=
0
:
logger
.
warning
(
"max_num_batched_tokens not available, skipping prefill load scaling"
)
return
None
estimates
:
list
[
float
]
=
[]
for
(
wid
,
dp
),
fpm
in
fpm_stats
.
items
():
est
=
self
.
_prefill_regression
.
estimate_next_ttft
(
queued_prefill_tokens
=
fpm
.
queued_requests
.
sum_prefill_tokens
,
max_num_batched_tokens
=
max_tokens
,
)
if
est
is
not
None
:
est_ms
=
est
*
1000
estimates
.
append
(
est_ms
)
logger
.
info
(
f
"Prefill engine
{
wid
}
:dp
{
dp
}
: estimated TTFT
{
est_ms
:.
2
f
}
ms "
f
"(queued=
{
fpm
.
queued_requests
.
sum_prefill_tokens
}
, "
f
"avg_isl=
{
self
.
_prefill_regression
.
avg_isl
:.
1
f
}
)"
)
return
self
.
_scale_decision
(
estimates
,
self
.
_config
.
ttft
,
num_workers
,
"prefill TTFT"
)
def
_decode_load_decision
(
self
,
fpm_stats
:
dict
[
tuple
[
str
,
int
],
ForwardPassMetrics
],
num_workers
:
int
)
->
Optional
[
int
]:
if
not
self
.
_decode_regression
.
has_sufficient_data
():
logger
.
info
(
f
"ITL regression: insufficient data "
f
"(
{
self
.
_decode_regression
.
num_observations
}
/
{
self
.
_decode_regression
.
min_observations
}
)"
)
return
None
if
num_workers
==
0
:
return
None
estimates
:
list
[
float
]
=
[]
for
(
wid
,
dp
),
fpm
in
fpm_stats
.
items
():
est
=
self
.
_decode_regression
.
estimate_next_itl
(
scheduled_decode_kv
=
fpm
.
scheduled_requests
.
sum_decode_kv_tokens
,
queued_decode_kv
=
fpm
.
queued_requests
.
sum_decode_kv_tokens
,
)
if
est
is
not
None
:
est_ms
=
est
*
1000
estimates
.
append
(
est_ms
)
logger
.
info
(
f
"Decode engine
{
wid
}
:dp
{
dp
}
: estimated ITL
{
est_ms
:.
2
f
}
ms "
f
"(sched_kv=
{
fpm
.
scheduled_requests
.
sum_decode_kv_tokens
}
, "
f
"queued_kv=
{
fpm
.
queued_requests
.
sum_decode_kv_tokens
}
)"
)
return
self
.
_scale_decision
(
estimates
,
self
.
_config
.
itl
,
num_workers
,
"decode ITL"
)
def
_agg_prefill_scaling
(
self
,
fpm_stats
:
dict
[
tuple
[
str
,
int
],
ForwardPassMetrics
],
num_workers
:
int
,
max_tokens
:
int
,
)
->
Optional
[
int
]:
estimates
:
list
[
float
]
=
[]
for
fpm
in
fpm_stats
.
values
():
est
=
self
.
_agg_regression
.
estimate_next_ttft
(
queued_prefill_tokens
=
fpm
.
queued_requests
.
sum_prefill_tokens
,
max_num_batched_tokens
=
max_tokens
,
current_decode_kv
=
fpm
.
scheduled_requests
.
sum_decode_kv_tokens
,
)
if
est
is
not
None
:
estimates
.
append
(
est
*
1000
)
return
self
.
_scale_decision
(
estimates
,
self
.
_config
.
ttft
,
num_workers
,
"agg TTFT"
)
def
_agg_decode_scaling
(
self
,
fpm_stats
:
dict
[
tuple
[
str
,
int
],
ForwardPassMetrics
],
num_workers
:
int
,
)
->
Optional
[
int
]:
estimates
:
list
[
float
]
=
[]
for
fpm
in
fpm_stats
.
values
():
est
=
self
.
_agg_regression
.
estimate_next_itl
(
scheduled_decode_kv
=
fpm
.
scheduled_requests
.
sum_decode_kv_tokens
,
queued_decode_kv
=
fpm
.
queued_requests
.
sum_decode_kv_tokens
,
)
if
est
is
not
None
:
estimates
.
append
(
est
*
1000
)
return
self
.
_scale_decision
(
estimates
,
self
.
_config
.
itl
,
num_workers
,
"agg ITL"
)
def
_scale_decision
(
self
,
estimates
:
list
[
float
],
sla
:
float
,
num_workers
:
int
,
label
:
str
)
->
Optional
[
int
]:
if
not
estimates
:
return
None
sensitivity
=
self
.
_config
.
load_scaling_down_sensitivity
/
100.0
logger
.
info
(
f
"Load-based
{
label
}
: workers=
{
num_workers
}
, sla=
{
sla
:.
1
f
}
ms, "
f
"estimates=
{
[
f
'
{
t
:.
1
f
}
' for t in estimates]
}
"
)
if
all
(
t
>
sla
for
t
in
estimates
):
logger
.
info
(
f
"Load-based
{
label
}
: ALL above SLA, scaling up to
{
num_workers
+
1
}
"
)
return
num_workers
+
1
if
num_workers
>
1
:
threshold
=
sla
*
sensitivity
if
all
(
t
<
threshold
for
t
in
estimates
):
desired
=
max
(
num_workers
-
1
,
self
.
_config
.
min_endpoint
)
logger
.
info
(
f
"Load-based
{
label
}
: ALL below threshold (
{
threshold
:.
1
f
}
ms), ->
{
desired
}
"
)
return
desired
return
None
components/src/dynamo/planner/core/prefill.py
deleted
100644 → 0
View file @
39a6a240
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
logging
import
math
from
typing
import
Optional
from
dynamo.planner.config.defaults
import
SubComponentType
from
dynamo.planner.core.base
import
BasePlanner
from
dynamo.runtime.logging
import
configure_dynamo_logging
configure_dynamo_logging
()
logger
=
logging
.
getLogger
(
__name__
)
class
PrefillPlanner
(
BasePlanner
):
component_type
=
SubComponentType
.
PREFILL
def
load_plan_adjustment
(
self
)
->
Optional
[
int
]:
"""Load-based scaling decision for prefill using FPM data.
For each engine, simulates prefill scheduling to estimate next TTFT:
- Uses queued prefill tokens + avg ISL as total tokens to process
- Chunks into max_num_batched_tokens-sized iterations
- Sums regression-predicted wall time per chunk
Scale up if ALL engines' estimated TTFT > SLA.
Scale down if ALL engines' estimated TTFT < SLA * sensitivity.
"""
if
not
self
.
ttft_regression
.
has_sufficient_data
():
logger
.
info
(
f
"TTFT regression: insufficient data (
{
self
.
ttft_regression
.
num_observations
}
"
f
"/
{
self
.
ttft_regression
.
min_observations
}
), skipping load-based scaling"
)
return
None
fpm_stats
=
self
.
_get_fpm_stats
()
if
not
fpm_stats
:
return
None
num_workers
=
self
.
shared_state
.
num_p_workers
if
num_workers
==
0
:
return
None
max_num_batched_tokens
=
getattr
(
self
.
prefill_worker_info
,
"max_num_batched_tokens"
,
None
)
if
not
max_num_batched_tokens
or
max_num_batched_tokens
<=
0
:
logger
.
warning
(
"max_num_batched_tokens not available from WorkerInfo, "
"skipping prefill load-based scaling"
)
return
None
estimated_ttfts
:
list
[
float
]
=
[]
for
(
wid
,
dp
),
fpm
in
fpm_stats
.
items
():
queued_prefill
=
fpm
.
queued_requests
.
sum_prefill_tokens
est
=
self
.
ttft_regression
.
estimate_next_ttft
(
queued_prefill_tokens
=
queued_prefill
,
max_num_batched_tokens
=
max_num_batched_tokens
,
)
if
est
is
None
:
continue
est_ms
=
est
*
1000
estimated_ttfts
.
append
(
est_ms
)
logger
.
info
(
f
"Prefill engine
{
wid
}
:dp
{
dp
}
: estimated TTFT
{
est_ms
:.
2
f
}
ms "
f
"(queued_prefill=
{
queued_prefill
}
, avg_isl=
{
self
.
ttft_regression
.
avg_isl
:.
1
f
}
)"
)
return
self
.
_load_based_scaling_decision_from_estimates
(
estimates
=
estimated_ttfts
,
sla
=
self
.
config
.
ttft
,
num_workers
=
num_workers
,
label
=
"prefill TTFT"
,
)
def
_compute_replica_requirements
(
self
,
next_num_req
:
float
,
next_isl
:
float
,
next_osl
:
float
)
->
Optional
[
int
]:
demand_rps
=
next_num_req
/
self
.
config
.
throughput_adjustment_interval
engine_rps
,
actual_ttft_ms
=
self
.
ttft_regression
.
find_best_engine_prefill_rps
(
ttft_sla
=
self
.
config
.
ttft
,
isl
=
next_isl
)
if
engine_rps
<=
0
:
logger
.
warning
(
"Prefill perf model not ready, skipping throughput scaling"
)
return
None
if
actual_ttft_ms
>
self
.
config
.
ttft
:
logger
.
warning
(
f
"Prefill TTFT SLA not met:
{
actual_ttft_ms
:.
1
f
}
ms > "
f
"
{
self
.
config
.
ttft
:.
1
f
}
ms, scaling with best achievable rate"
)
next_num_p
=
math
.
ceil
(
demand_rps
/
engine_rps
)
next_num_p
=
max
(
next_num_p
,
self
.
config
.
min_endpoint
)
logger
.
info
(
f
"Prefill:
{
demand_rps
:.
2
f
}
(demand rps) / "
f
"
{
engine_rps
:.
2
f
}
(engine rps) =
{
next_num_p
}
(num_p), "
f
"est_ttft=
{
actual_ttft_ms
:.
1
f
}
ms"
)
return
next_num_p
def
update_predicted_replicas_metric
(
self
,
desired_replicas
:
int
)
->
None
:
if
self
.
prometheus_port
!=
0
and
self
.
prometheus_metrics
is
not
None
:
self
.
prometheus_metrics
.
predicted_num_p
.
set
(
desired_replicas
)
components/src/dynamo/planner/core/state_machine.py
0 → 100644
View file @
f7e0b3fd
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Pure discrete-event state machine for planner scaling decisions.
``PlannerStateMachine`` receives events (``ScheduledTick`` + ``TickInput``),
updates internal state (regression models, load predictors, worker inventory),
and returns effects (``PlannerEffects``: optional scaling decision + next tick).
This module contains **zero I/O** -- no runtime, connector, subscriber, asyncio,
or Prometheus dependencies. All external interaction is done by the adapter
layer (``NativePlannerBase`` and its subclasses) which feeds data in and
applies decisions out.
Load-based scaling logic lives in ``load_scaling.py``.
Throughput-based scaling logic lives in ``throughput_scaling.py``.
"""
from
__future__
import
annotations
import
logging
import
math
from
typing
import
TYPE_CHECKING
,
Optional
from
dynamo.planner.config.planner_config
import
PlannerConfig
from
dynamo.planner.core.load.predictors
import
LOAD_PREDICTORS
from
dynamo.planner.core.load_scaling
import
LoadScalingMixin
from
dynamo.planner.core.perf_model
import
(
AggRegressionModel
,
DecodeRegressionModel
,
PrefillRegressionModel
,
)
from
dynamo.planner.core.throughput_scaling
import
ThroughputScalingMixin
from
dynamo.planner.core.types
import
(
FpmObservations
,
PlannerEffects
,
ScheduledTick
,
TickInput
,
TrafficObservation
,
WorkerCapabilities
,
WorkerCounts
,
)
if
TYPE_CHECKING
:
from
dynamo.common.forward_pass_metrics
import
ForwardPassMetrics
logger
=
logging
.
getLogger
(
__name__
)
class
PlannerStateMachine
(
LoadScalingMixin
,
ThroughputScalingMixin
):
"""Discrete-event state machine for all planner modes.
Owns regression models, load predictors, throughput lower bounds,
and all scaling decision logic. Receives events, returns effects.
Has no runtime dependencies.
"""
def
__init__
(
self
,
config
:
PlannerConfig
,
capabilities
:
Optional
[
WorkerCapabilities
]
=
None
,
)
->
None
:
self
.
_config
=
config
self
.
_capabilities
=
capabilities
or
WorkerCapabilities
()
self
.
_is_agg
=
config
.
mode
==
"agg"
self
.
_has_prefill
=
config
.
mode
in
(
"disagg"
,
"prefill"
)
self
.
_has_decode
=
config
.
mode
in
(
"disagg"
,
"decode"
,
"agg"
)
if
self
.
_is_agg
:
self
.
_agg_regression
=
AggRegressionModel
(
max_num_fpm_samples
=
config
.
max_num_fpm_samples
,
min_observations
=
config
.
load_min_observations
,
bucket_count
=
config
.
fpm_sample_bucket_size
,
)
else
:
if
self
.
_has_prefill
:
self
.
_prefill_regression
=
PrefillRegressionModel
(
max_num_fpm_samples
=
config
.
max_num_fpm_samples
,
min_observations
=
config
.
load_min_observations
,
bucket_count
=
config
.
fpm_sample_bucket_size
,
)
if
self
.
_has_decode
:
self
.
_decode_regression
=
DecodeRegressionModel
(
max_num_fpm_samples
=
config
.
max_num_fpm_samples
,
min_observations
=
config
.
load_min_observations
,
bucket_count
=
config
.
fpm_sample_bucket_size
,
)
predictor_cls
=
LOAD_PREDICTORS
[
config
.
load_predictor
]
self
.
_num_req_predictor
=
predictor_cls
(
config
)
self
.
_isl_predictor
=
predictor_cls
(
config
)
self
.
_osl_predictor
=
predictor_cls
(
config
)
self
.
_num_p_workers
:
int
=
0
self
.
_num_d_workers
:
int
=
0
self
.
_expected_num_p
:
Optional
[
int
]
=
None
self
.
_expected_num_d
:
Optional
[
int
]
=
None
self
.
_throughput_lower_bound_p
:
int
=
1
self
.
_throughput_lower_bound_d
:
int
=
1
self
.
_next_load_s
:
float
=
float
(
"inf"
)
self
.
_next_throughput_s
:
float
=
float
(
"inf"
)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def
initial_tick
(
self
,
start_s
:
float
)
->
ScheduledTick
:
self
.
_next_load_s
=
start_s
+
self
.
_config
.
load_adjustment_interval
if
self
.
_config
.
enable_throughput_scaling
:
self
.
_next_throughput_s
=
(
start_s
+
self
.
_config
.
throughput_adjustment_interval
)
return
self
.
_next_scheduled_tick
()
def
load_benchmark_fpms
(
self
,
prefill_fpms
:
Optional
[
list
[
ForwardPassMetrics
]]
=
None
,
decode_fpms
:
Optional
[
list
[
ForwardPassMetrics
]]
=
None
,
agg_fpms
:
Optional
[
list
[
ForwardPassMetrics
]]
=
None
,
)
->
None
:
if
agg_fpms
and
self
.
_is_agg
:
self
.
_agg_regression
.
load_benchmark_fpms
(
agg_fpms
)
logger
.
info
(
f
"Bootstrapped agg regression with
{
len
(
agg_fpms
)
}
FPMs"
)
if
prefill_fpms
and
self
.
_has_prefill
and
not
self
.
_is_agg
:
self
.
_prefill_regression
.
load_benchmark_fpms
(
prefill_fpms
)
logger
.
info
(
f
"Bootstrapped prefill regression with
{
len
(
prefill_fpms
)
}
FPMs"
)
if
decode_fpms
and
self
.
_has_decode
and
not
self
.
_is_agg
:
self
.
_decode_regression
.
load_benchmark_fpms
(
decode_fpms
)
logger
.
info
(
f
"Bootstrapped decode regression with
{
len
(
decode_fpms
)
}
FPMs"
)
def
warm_load_predictors
(
self
,
observations
:
list
[
TrafficObservation
])
->
None
:
for
obs
in
observations
:
self
.
_num_req_predictor
.
add_data_point
(
obs
.
num_req
)
self
.
_isl_predictor
.
add_data_point
(
obs
.
isl
)
self
.
_osl_predictor
.
add_data_point
(
obs
.
osl
)
logger
.
info
(
f
"Warmed load predictors with
{
len
(
observations
)
}
intervals"
)
for
p
in
(
self
.
_num_req_predictor
,
self
.
_isl_predictor
,
self
.
_osl_predictor
):
if
hasattr
(
p
,
"reset_idle_skip"
):
p
.
reset_idle_skip
()
def
on_tick
(
self
,
tick
:
ScheduledTick
,
tick_input
:
TickInput
)
->
PlannerEffects
:
effects
=
PlannerEffects
()
if
tick_input
.
worker_counts
is
not
None
:
self
.
_update_inventory
(
tick_input
.
worker_counts
)
if
tick
.
run_load_scaling
:
if
tick_input
.
fpm_observations
is
not
None
:
self
.
_observe_fpm
(
tick_input
.
fpm_observations
)
load_decision
=
self
.
_advance_load
(
tick_input
.
fpm_observations
)
if
load_decision
is
not
None
:
effects
.
scale_to
=
load_decision
self
.
_next_load_s
=
tick_input
.
now_s
+
self
.
_config
.
load_adjustment_interval
if
tick
.
run_throughput_scaling
:
if
tick_input
.
traffic
is
not
None
:
self
.
_observe_traffic
(
tick_input
.
traffic
)
throughput_decision
=
self
.
_advance_throughput
(
tick_input
.
traffic
)
if
throughput_decision
is
not
None
:
if
effects
.
scale_to
is
None
:
effects
.
scale_to
=
throughput_decision
self
.
_next_throughput_s
=
(
tick_input
.
now_s
+
self
.
_config
.
throughput_adjustment_interval
)
effects
.
next_tick
=
self
.
_next_scheduled_tick
()
return
effects
# ------------------------------------------------------------------
# Tick scheduling
# ------------------------------------------------------------------
_MERGE_TOLERANCE_S
=
0.5
def
_next_scheduled_tick
(
self
)
->
ScheduledTick
:
"""Build the single next tick, merging cadences if they coincide."""
at_s
=
min
(
self
.
_next_load_s
,
self
.
_next_throughput_s
)
is_load
=
self
.
_next_load_s
<=
at_s
+
self
.
_MERGE_TOLERANCE_S
is_throughput
=
self
.
_next_throughput_s
<=
at_s
+
self
.
_MERGE_TOLERANCE_S
return
ScheduledTick
(
at_s
=
at_s
,
run_load_scaling
=
is_load
,
run_throughput_scaling
=
is_throughput
,
need_worker_states
=
True
,
need_worker_fpm
=
is_load
,
need_traffic_metrics
=
is_throughput
,
traffic_metrics_duration_s
=
(
self
.
_config
.
throughput_adjustment_interval
if
is_throughput
else
0.0
),
)
# ------------------------------------------------------------------
# Inventory
# ------------------------------------------------------------------
def
_update_inventory
(
self
,
counts
:
WorkerCounts
)
->
None
:
if
counts
.
ready_num_prefill
is
not
None
:
self
.
_num_p_workers
=
counts
.
ready_num_prefill
if
counts
.
ready_num_decode
is
not
None
:
self
.
_num_d_workers
=
counts
.
ready_num_decode
self
.
_expected_num_p
=
counts
.
expected_num_prefill
self
.
_expected_num_d
=
counts
.
expected_num_decode
def
_scaling_in_progress
(
self
,
component
:
str
)
->
bool
:
if
component
==
"prefill"
:
return
(
self
.
_expected_num_p
is
not
None
and
self
.
_expected_num_p
!=
self
.
_num_p_workers
)
return
(
self
.
_expected_num_d
is
not
None
and
self
.
_expected_num_d
!=
self
.
_num_d_workers
)
# ------------------------------------------------------------------
# FPM / traffic observation
# ------------------------------------------------------------------
def
_observe_fpm
(
self
,
obs
:
FpmObservations
)
->
None
:
if
self
.
_is_agg
:
if
obs
.
decode
:
for
fpm
in
obs
.
decode
.
values
():
self
.
_agg_regression
.
add_observation
(
fpm
)
logger
.
info
(
f
"FPM load stats:
{
len
(
obs
.
decode
)
}
agg engines observed"
)
return
if
obs
.
prefill
and
self
.
_has_prefill
:
for
fpm
in
obs
.
prefill
.
values
():
self
.
_prefill_regression
.
add_observation
(
fpm
)
logger
.
info
(
f
"FPM load stats:
{
len
(
obs
.
prefill
)
}
prefill engines observed"
)
if
obs
.
decode
and
self
.
_has_decode
:
for
fpm
in
obs
.
decode
.
values
():
self
.
_decode_regression
.
add_observation
(
fpm
)
logger
.
info
(
f
"FPM load stats:
{
len
(
obs
.
decode
)
}
decode engines observed"
)
def
_observe_traffic
(
self
,
traffic
:
TrafficObservation
)
->
None
:
self
.
_num_req_predictor
.
add_data_point
(
traffic
.
num_req
)
self
.
_isl_predictor
.
add_data_point
(
traffic
.
isl
)
self
.
_osl_predictor
.
add_data_point
(
traffic
.
osl
)
# ------------------------------------------------------------------
# Budget
# ------------------------------------------------------------------
def
_apply_single_budget
(
self
,
desired
:
int
,
component
:
str
)
->
int
:
caps
=
(
self
.
_capabilities
.
prefill
if
component
==
"prefill"
else
self
.
_capabilities
.
decode
)
gpu
=
caps
.
num_gpu
if
caps
else
None
if
gpu
is
None
:
return
desired
return
self
.
_budget_clamp
(
max
(
desired
,
self
.
_config
.
min_endpoint
),
gpu
)
def
_apply_global_budget
(
self
,
num_p
:
int
,
num_d
:
int
)
->
tuple
[
int
,
int
]:
budget
=
self
.
_config
.
max_gpu_budget
p_gpu
=
(
self
.
_capabilities
.
prefill
.
num_gpu
if
self
.
_capabilities
.
prefill
else
None
)
d_gpu
=
self
.
_capabilities
.
decode
.
num_gpu
if
self
.
_capabilities
.
decode
else
None
if
budget
<
0
or
p_gpu
is
None
or
d_gpu
is
None
:
return
num_p
,
num_d
total
=
num_p
*
p_gpu
+
num_d
*
d_gpu
if
total
<=
budget
:
return
num_p
,
num_d
min_req
=
self
.
_config
.
min_endpoint
*
p_gpu
+
self
.
_config
.
min_endpoint
*
d_gpu
if
budget
<
min_req
:
logger
.
warning
(
f
"max_gpu_budget (
{
budget
}
) below min (
{
min_req
}
); zero replicas"
)
return
0
,
0
scale
=
budget
/
total
max_p
=
math
.
floor
((
budget
-
self
.
_config
.
min_endpoint
*
d_gpu
)
/
p_gpu
)
num_p
=
max
(
self
.
_config
.
min_endpoint
,
min
(
max_p
,
math
.
floor
(
num_p
*
scale
)))
remaining
=
budget
-
num_p
*
p_gpu
num_d
=
max
(
self
.
_config
.
min_endpoint
,
math
.
floor
(
remaining
/
d_gpu
))
logger
.
warning
(
f
"GPUs (
{
total
}
) > budget (
{
budget
}
), ->
{
num_p
}
P +
{
num_d
}
D"
)
return
num_p
,
num_d
def
_budget_clamp
(
self
,
desired
:
int
,
engine_gpu
:
int
)
->
int
:
budget
=
self
.
_config
.
max_gpu_budget
if
budget
<
0
:
return
desired
total
=
desired
*
engine_gpu
if
total
<=
budget
:
return
desired
min_req
=
self
.
_config
.
min_endpoint
*
engine_gpu
if
budget
<
min_req
:
logger
.
warning
(
f
"max_gpu_budget (
{
budget
}
) below min (
{
min_req
}
); zero replicas"
)
return
0
result
=
max
(
self
.
_config
.
min_endpoint
,
math
.
floor
(
budget
/
engine_gpu
))
logger
.
warning
(
f
"GPUs (
{
total
}
) > budget (
{
budget
}
), ->
{
result
}
replicas"
)
return
result
# ------------------------------------------------------------------
# FPM / worker count reconciliation
# ------------------------------------------------------------------
@
staticmethod
def
_reconcile_fpm_worker_count
(
fpm_stats
:
dict
[
tuple
[
str
,
int
],
ForwardPassMetrics
],
dgd_count
:
int
,
label
:
str
)
->
bool
:
workers_to_dp
:
dict
[
str
,
set
[
int
]]
=
{}
for
wid
,
dp
in
fpm_stats
:
workers_to_dp
.
setdefault
(
wid
,
set
()).
add
(
dp
)
if
len
(
workers_to_dp
)
!=
dgd_count
:
logger
.
warning
(
f
"Worker count mismatch: DGD=
{
dgd_count
}
, FPM=
{
len
(
workers_to_dp
)
}
for
{
label
}
"
)
return
False
dp_sizes
=
{
len
(
dps
)
for
dps
in
workers_to_dp
.
values
()}
if
len
(
dp_sizes
)
>
1
:
logger
.
warning
(
f
"Inconsistent DP ranks for
{
label
}
:
{
dict
(
workers_to_dp
)
}
"
)
return
False
dp_size
=
dp_sizes
.
pop
()
if
dp_sizes
else
1
if
len
(
fpm_stats
)
!=
dgd_count
*
dp_size
:
logger
.
warning
(
f
"Incomplete FPM coverage for
{
label
}
: expected
{
dgd_count
}
x
{
dp_size
}
, got
{
len
(
fpm_stats
)
}
"
)
return
False
return
True
# ------------------------------------------------------------------
# Accessors
# ------------------------------------------------------------------
@
property
def
prefill_regression
(
self
)
->
PrefillRegressionModel
:
if
not
self
.
_has_prefill
:
raise
AttributeError
(
f
"No prefill regression in mode=
{
self
.
_config
.
mode
}
"
)
return
self
.
_prefill_regression
@
property
def
decode_regression
(
self
)
->
DecodeRegressionModel
:
if
not
self
.
_has_decode
or
self
.
_is_agg
:
raise
AttributeError
(
f
"No decode regression in mode=
{
self
.
_config
.
mode
}
"
)
return
self
.
_decode_regression
@
property
def
agg_regression
(
self
)
->
AggRegressionModel
:
if
not
self
.
_is_agg
:
raise
AttributeError
(
f
"No agg regression in mode=
{
self
.
_config
.
mode
}
"
)
return
self
.
_agg_regression
@
property
def
regression
(
self
)
->
AggRegressionModel
:
return
self
.
agg_regression
components/src/dynamo/planner/core/throughput_scaling.py
0 → 100644
View file @
f7e0b3fd
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# mypy: disable-error-code="attr-defined"
"""Throughput-based scaling logic (Prometheus traffic-driven, predictive).
Mixin consumed by ``PlannerStateMachine``. All methods access state
via ``self._config``, ``self._capabilities``, and regression models.
"""
from
__future__
import
annotations
import
logging
import
math
from
typing
import
Optional
from
dynamo.planner.core.types
import
ScalingDecision
,
TrafficObservation
logger
=
logging
.
getLogger
(
__name__
)
class
ThroughputScalingMixin
:
"""Traffic-driven throughput-based scaling decisions."""
def
_advance_throughput
(
self
,
traffic
:
TrafficObservation
)
->
Optional
[
ScalingDecision
]:
if
not
self
.
_config
.
enable_throughput_scaling
:
return
None
next_num_req
,
next_isl
,
next_osl
=
self
.
_predict_load
()
if
next_num_req
is
None
or
next_isl
is
None
or
next_osl
is
None
:
return
None
if
traffic
.
duration_s
<=
0
:
logger
.
warning
(
"Traffic observation has non-positive duration, skipping"
)
return
None
demand_rps
=
next_num_req
/
traffic
.
duration_s
mode
=
self
.
_config
.
mode
if
mode
==
"agg"
:
return
self
.
_throughput_agg
(
demand_rps
,
next_isl
,
next_osl
)
if
mode
==
"disagg"
:
return
self
.
_throughput_disagg
(
demand_rps
,
next_isl
,
next_osl
)
return
self
.
_throughput_single
(
demand_rps
,
next_isl
,
next_osl
,
mode
)
def
_predict_load
(
self
)
->
tuple
[
Optional
[
float
],
Optional
[
float
],
Optional
[
float
]]:
try
:
nr
=
self
.
_num_req_predictor
.
predict_next
()
isl
=
self
.
_isl_predictor
.
predict_next
()
osl
=
self
.
_osl_predictor
.
predict_next
()
logger
.
info
(
f
"Predicted load: num_req=
{
nr
:.
2
f
}
, isl=
{
isl
:.
2
f
}
, osl=
{
osl
:.
2
f
}
"
)
return
nr
,
isl
,
osl
except
Exception
as
e
:
logger
.
error
(
f
"Failed to predict load:
{
e
}
"
)
return
None
,
None
,
None
def
_throughput_single
(
self
,
demand_rps
:
float
,
isl
:
float
,
osl
:
float
,
component
:
str
)
->
Optional
[
ScalingDecision
]:
desired
=
(
self
.
_compute_prefill_replicas
(
demand_rps
,
isl
,
osl
)
if
component
==
"prefill"
else
self
.
_compute_decode_replicas
(
demand_rps
,
isl
,
osl
)
)
if
desired
is
None
:
return
None
if
self
.
_config
.
enable_load_scaling
:
if
component
==
"prefill"
:
self
.
_throughput_lower_bound_p
=
desired
else
:
self
.
_throughput_lower_bound_d
=
desired
logger
.
info
(
f
"Throughput lower bound set to
{
desired
}
for
{
component
}
"
)
return
None
desired
=
self
.
_apply_single_budget
(
desired
,
component
)
return
(
ScalingDecision
(
num_prefill
=
desired
)
if
component
==
"prefill"
else
ScalingDecision
(
num_decode
=
desired
)
)
def
_throughput_disagg
(
self
,
demand_rps
:
float
,
isl
:
float
,
osl
:
float
)
->
Optional
[
ScalingDecision
]:
num_p
=
self
.
_compute_prefill_replicas
(
demand_rps
,
isl
,
osl
)
num_d
=
self
.
_compute_decode_replicas
(
demand_rps
,
isl
,
osl
)
if
num_p
is
None
or
num_d
is
None
:
return
None
if
self
.
_config
.
enable_load_scaling
:
self
.
_throughput_lower_bound_p
=
num_p
self
.
_throughput_lower_bound_d
=
num_d
logger
.
info
(
f
"Throughput lower bounds set: prefill=
{
num_p
}
, decode=
{
num_d
}
"
)
return
None
num_p
,
num_d
=
self
.
_apply_global_budget
(
num_p
,
num_d
)
return
ScalingDecision
(
num_prefill
=
num_p
,
num_decode
=
num_d
)
def
_throughput_agg
(
self
,
demand_rps
:
float
,
isl
:
float
,
osl
:
float
)
->
Optional
[
ScalingDecision
]:
d_caps
=
self
.
_capabilities
.
decode
max_tokens
=
d_caps
.
max_num_batched_tokens
if
d_caps
else
None
if
not
max_tokens
or
max_tokens
<=
0
:
logger
.
warning
(
"max_num_batched_tokens not available, skipping agg throughput"
)
return
None
(
engine_rps
,
actual_ttft
,
actual_itl
,
)
=
self
.
_agg_regression
.
find_best_engine_agg_rps
(
isl
=
isl
,
osl
=
osl
,
max_num_batched_tokens
=
max_tokens
,
ttft_sla
=
self
.
_config
.
ttft
,
itl_sla
=
self
.
_config
.
itl
,
)
if
engine_rps
<=
0
:
logger
.
warning
(
"Agg perf model not ready, skipping throughput scaling"
)
return
None
if
actual_ttft
>
self
.
_config
.
ttft
or
actual_itl
>
self
.
_config
.
itl
:
logger
.
warning
(
f
"Agg SLA not fully met: TTFT=
{
actual_ttft
:.
1
f
}
ms, ITL=
{
actual_itl
:.
1
f
}
ms"
)
desired
=
max
(
math
.
ceil
(
demand_rps
/
engine_rps
),
self
.
_config
.
min_endpoint
)
logger
.
info
(
f
"Agg:
{
demand_rps
:.
2
f
}
rps /
{
engine_rps
:.
2
f
}
engine_rps =
{
desired
}
replicas"
)
if
self
.
_config
.
enable_load_scaling
:
self
.
_throughput_lower_bound_d
=
desired
logger
.
info
(
f
"Agg throughput lower bound set to
{
desired
}
"
)
return
None
desired
=
self
.
_apply_single_budget
(
desired
,
"decode"
)
return
ScalingDecision
(
num_decode
=
desired
)
def
_compute_prefill_replicas
(
self
,
demand_rps
:
float
,
isl
:
float
,
osl
:
float
)
->
Optional
[
int
]:
engine_rps
,
ttft_ms
=
self
.
_prefill_regression
.
find_best_engine_prefill_rps
(
ttft_sla
=
self
.
_config
.
ttft
,
isl
=
isl
)
if
engine_rps
<=
0
:
logger
.
warning
(
"Prefill perf model not ready, skipping throughput scaling"
)
return
None
if
ttft_ms
>
self
.
_config
.
ttft
:
logger
.
warning
(
f
"Prefill TTFT SLA not met:
{
ttft_ms
:.
1
f
}
ms >
{
self
.
_config
.
ttft
:.
1
f
}
ms"
)
result
=
max
(
math
.
ceil
(
demand_rps
/
engine_rps
),
self
.
_config
.
min_endpoint
)
logger
.
info
(
f
"Prefill:
{
demand_rps
:.
2
f
}
rps /
{
engine_rps
:.
2
f
}
=
{
result
}
, est_ttft=
{
ttft_ms
:.
1
f
}
ms"
)
return
result
def
_compute_decode_replicas
(
self
,
demand_rps
:
float
,
isl
:
float
,
osl
:
float
)
->
Optional
[
int
]:
engine_rps
,
itl_ms
=
self
.
_decode_regression
.
find_best_engine_decode_rps
(
itl
=
self
.
_config
.
itl
,
context_length
=
isl
+
osl
/
2
,
osl
=
osl
,
)
if
engine_rps
<=
0
:
logger
.
warning
(
"Decode perf model not ready, skipping throughput scaling"
)
return
None
if
itl_ms
>
self
.
_config
.
itl
:
logger
.
warning
(
f
"Decode ITL SLA not met:
{
itl_ms
:.
1
f
}
ms >
{
self
.
_config
.
itl
:.
1
f
}
ms"
)
result
=
max
(
math
.
ceil
(
demand_rps
/
engine_rps
),
self
.
_config
.
min_endpoint
)
logger
.
info
(
f
"Decode:
{
demand_rps
:.
2
f
}
rps /
{
engine_rps
:.
2
f
}
=
{
result
}
, est_itl=
{
itl_ms
:.
1
f
}
ms"
)
return
result
components/src/dynamo/planner/core/types.py
0 → 100644
View file @
f7e0b3fd
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Explicit-input types for the planner core.
These types form the boundary between the planner core (pure decision logic)
and any adapter (native runtime, replay harness, tests). The core receives
``TickInput`` and returns ``PlannerEffects``; the adapter fills the input
based on the previous tick's ``ScheduledTick`` requirements.
"""
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
if
TYPE_CHECKING
:
from
dynamo.common.forward_pass_metrics
import
ForwardPassMetrics
@
dataclass
class
ScheduledTick
:
"""Declares when the core next needs to be called, what data it needs,
and what decisions to make.
All times are absolute seconds (wall clock for native adapter,
simulated clock for replay).
"""
at_s
:
float
# What decisions the core will make on this tick
run_load_scaling
:
bool
=
False
run_throughput_scaling
:
bool
=
False
# What data the adapter should collect before calling on_tick
need_traffic_metrics
:
bool
=
False
traffic_metrics_duration_s
:
float
=
0.0
need_worker_states
:
bool
=
False
need_worker_fpm
:
bool
=
False
@
dataclass
class
TrafficObservation
:
"""Aggregated traffic metrics over an observation window."""
duration_s
:
float
num_req
:
float
isl
:
float
osl
:
float
@
dataclass
class
WorkerCounts
:
"""Current worker inventory as reported by the adapter."""
ready_num_prefill
:
Optional
[
int
]
=
None
ready_num_decode
:
Optional
[
int
]
=
None
expected_num_prefill
:
Optional
[
int
]
=
None
expected_num_decode
:
Optional
[
int
]
=
None
@
dataclass
class
FpmObservations
:
"""Per-engine ForwardPassMetrics keyed by (worker_id, dp_rank)."""
prefill
:
Optional
[
dict
[
tuple
[
str
,
int
],
ForwardPassMetrics
]]
=
None
decode
:
Optional
[
dict
[
tuple
[
str
,
int
],
ForwardPassMetrics
]]
=
None
@
dataclass
class
TickInput
:
"""What the adapter provides to the core on each tick.
Fields are filled according to the previous ``ScheduledTick``'s
declared requirements.
"""
now_s
:
float
traffic
:
Optional
[
TrafficObservation
]
=
None
worker_counts
:
Optional
[
WorkerCounts
]
=
None
fpm_observations
:
Optional
[
FpmObservations
]
=
None
@
dataclass
class
ScalingDecision
:
"""Desired replica counts. ``None`` means the core has no opinion
on that component (e.g. prefill-only planner leaves decode as None).
"""
num_prefill
:
Optional
[
int
]
=
None
num_decode
:
Optional
[
int
]
=
None
@
dataclass
class
PlannerEffects
:
"""What the core returns after processing a tick."""
scale_to
:
Optional
[
ScalingDecision
]
=
None
next_tick
:
Optional
[
ScheduledTick
]
=
None
@
dataclass
class
EngineCapabilities
:
"""Static capabilities for a single engine stage (prefill or decode)."""
num_gpu
:
Optional
[
int
]
=
None
max_num_batched_tokens
:
Optional
[
int
]
=
None
max_num_seqs
:
Optional
[
int
]
=
None
context_length
:
Optional
[
int
]
=
None
@
dataclass
class
WorkerCapabilities
:
"""Static per-engine capabilities discovered at startup from MDC.
Provided once when constructing the planner core. In native mode
these come from ``WorkerInfo`` (resolved via MDC / DGD); in replay
they come from the simulated engine args.
For agg mode, only ``decode`` is populated (single engine type).
"""
prefill
:
Optional
[
EngineCapabilities
]
=
None
decode
:
Optional
[
EngineCapabilities
]
=
None
components/src/dynamo/planner/tests/unit/test_load_based_scaling.py
View file @
f7e0b3fd
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
os
from
unittest.mock
import
Mock
,
patch
"""Regression model unit tests.
These test the perf_model classes directly (PrefillRegressionModel,
DecodeRegressionModel, AggRegressionModel) without any planner adapter.
FPM-driven scaling integration tests live in test_state_machine.py.
"""
import
pytest
...
...
@@ -15,18 +20,12 @@ from dynamo.common.forward_pass_metrics import (
ForwardPassMetrics
,
QueuedRequestMetrics
,
ScheduledRequestMetrics
,
encode
,
)
from
dynamo.planner.config.planner_config
import
PlannerConfig
from
dynamo.planner.core.decode
import
DecodePlanner
from
dynamo.planner.core.perf_model
import
(
AggRegressionModel
,
DecodeRegressionModel
,
PrefillRegressionModel
,
)
from
dynamo.planner.core.prefill
import
PrefillPlanner
from
dynamo.planner.core.state
import
PlannerSharedState
from
dynamo.planner.monitoring.worker_info
import
WorkerInfo
pytestmark
=
[
pytest
.
mark
.
gpu_0
,
...
...
@@ -149,7 +148,6 @@ class TestPrefillRegressionModel:
ttft_sla
=
2000.0
,
isl
=
1000.0
)
assert
rps
>
0
# wall_time ~1.002s for 1000 tokens -> rps ~ 1/1.002 ~ 0.998
assert
0.5
<
rps
<
2.0
assert
actual_ttft_ms
>
0
assert
1000
<
actual_ttft_ms
<
2000
...
...
@@ -190,7 +188,6 @@ class TestPrefillRegressionModel:
class
TestBucketedRetirement
:
def
test_total_capped_at_max
(
self
):
"""Total observations never exceed max_num_fpm_samples."""
model
=
PrefillRegressionModel
(
max_num_fpm_samples
=
10
,
min_observations
=
3
,
bucket_count
=
4
)
...
...
@@ -204,12 +201,9 @@ class TestBucketedRetirement:
assert
model
.
num_observations
==
10
def
test_most_populated_bucket_loses_oldest
(
self
):
"""When evicting, the oldest entry from the most-populated bucket is removed."""
model
=
PrefillRegressionModel
(
max_num_fpm_samples
=
6
,
min_observations
=
1
,
bucket_count
=
4
)
# 3 observations at low tokens (bucket 0 area)
for
i
in
range
(
3
):
fpm
=
_make_fpm
(
sum_prefill_tokens
=
10
+
i
,
...
...
@@ -217,8 +211,6 @@ class TestBucketedRetirement:
wall_time
=
0.001
*
(
10
+
i
),
)
model
.
add_observation
(
fpm
)
# 3 observations at high tokens (different bucket)
for
i
in
range
(
3
):
fpm
=
_make_fpm
(
sum_prefill_tokens
=
1000
+
i
*
100
,
...
...
@@ -226,51 +218,31 @@ class TestBucketedRetirement:
wall_time
=
0.001
*
(
1000
+
i
*
100
),
)
model
.
add_observation
(
fpm
)
assert
model
.
num_observations
==
6
# One more at low tokens; total would exceed 6 so most-populated
# bucket loses its oldest entry.
fpm
=
_make_fpm
(
sum_prefill_tokens
=
15
,
num_prefill_requests
=
1
,
wall_time
=
0.015
,
)
fpm
=
_make_fpm
(
sum_prefill_tokens
=
15
,
num_prefill_requests
=
1
,
wall_time
=
0.015
)
model
.
add_observation
(
fpm
)
assert
model
.
num_observations
==
6
def
test_uniform_distribution_preserved
(
self
):
"""Bucketed eviction keeps observations across operating points."""
model
=
DecodeRegressionModel
(
max_num_fpm_samples
=
10
,
min_observations
=
3
,
bucket_count
=
16
)
# Many observations at a single operating point
for
_
in
range
(
15
):
fpm
=
_make_fpm
(
num_decode_requests
=
32
,
sum_decode_kv_tokens
=
32000
,
wall_time
=
0.01
,
num_decode_requests
=
32
,
sum_decode_kv_tokens
=
32000
,
wall_time
=
0.01
)
model
.
add_observation
(
fpm
)
assert
model
.
num_observations
==
10
# Add a different operating point; the concentrated bucket loses one
fpm
=
_make_fpm
(
num_decode_requests
=
4
,
sum_decode_kv_tokens
=
4000
,
wall_time
=
0.005
,
num_decode_requests
=
4
,
sum_decode_kv_tokens
=
4000
,
wall_time
=
0.005
)
model
.
add_observation
(
fpm
)
assert
model
.
num_observations
==
10
def
test_2d_bucketed_retirement
(
self
):
"""2D models retire from the most-populated grid cell."""
model
=
AggRegressionModel
(
max_num_fpm_samples
=
8
,
min_observations
=
1
,
bucket_count
=
16
)
# Fill with varied data
for
p
,
d
in
[(
100
,
500
),
(
200
,
1000
),
(
300
,
1500
),
(
400
,
2000
)]:
fpm
=
_make_fpm
(
sum_prefill_tokens
=
p
,
...
...
@@ -280,8 +252,6 @@ class TestBucketedRetirement:
wall_time
=
0.001
*
p
+
0.0001
*
d
,
)
model
.
add_observation
(
fpm
)
# Concentrate 4 more in one region
for
_
in
range
(
4
):
fpm
=
_make_fpm
(
sum_prefill_tokens
=
100
,
...
...
@@ -291,10 +261,7 @@ class TestBucketedRetirement:
wall_time
=
0.15
,
)
model
.
add_observation
(
fpm
)
assert
model
.
num_observations
==
8
# Overflow triggers retirement from the concentrated cell
fpm
=
_make_fpm
(
sum_prefill_tokens
=
350
,
num_prefill_requests
=
1
,
...
...
@@ -310,15 +277,8 @@ class TestBucketedRetirement:
class
TestDecodeRegressionModel
:
def
_train_2d
(
self
,
model
:
DecodeRegressionModel
)
->
None
:
"""Populate with 2D data: wall_time = f(num_decode_requests, sum_decode_kv_tokens)."""
for
n_req
,
kv
in
[
(
5
,
1000
),
(
10
,
2000
),
(
15
,
3000
),
(
20
,
4000
),
(
25
,
5000
),
]:
def
_train_2d
(
self
,
model
):
for
n_req
,
kv
in
[(
5
,
1000
),
(
10
,
2000
),
(
15
,
3000
),
(
20
,
4000
),
(
25
,
5000
)]:
fpm
=
_make_fpm
(
sum_decode_kv_tokens
=
kv
,
num_decode_requests
=
n_req
,
...
...
@@ -346,11 +306,9 @@ class TestDecodeRegressionModel:
max_num_fpm_samples
=
50
,
min_observations
=
3
,
bucket_count
=
16
)
self
.
_train_2d
(
model
)
assert
model
.
has_sufficient_data
()
est
=
model
.
estimate_next_itl
(
scheduled_decode_kv
=
3000
,
queued_decode_kv
=
0
)
assert
est
is
not
None
assert
est
>
0
assert
est
is
not
None
and
est
>
0
def
test_avg_decode_length_tracking
(
self
):
model
=
DecodeRegressionModel
(
...
...
@@ -365,8 +323,7 @@ class TestDecodeRegressionModel:
model
.
add_observation
(
fpm
)
assert
abs
(
model
.
avg_decode_length
-
200.0
)
<
1.0
def
_train_thpt_model
(
self
,
model
:
DecodeRegressionModel
)
->
None
:
"""Populate with 2D data at decode-realistic wall-time scale."""
def
_train_thpt_model
(
self
,
model
):
for
n_req
,
kv
in
[
(
5
,
5000
),
(
10
,
10000
),
...
...
@@ -386,13 +343,10 @@ class TestDecodeRegressionModel:
max_num_fpm_samples
=
50
,
min_observations
=
3
,
bucket_count
=
16
)
self
.
_train_thpt_model
(
model
)
rps
,
actual_itl
=
model
.
find_best_engine_decode_rps
(
itl
=
50.0
,
context_length
=
1000.0
,
osl
=
150.0
)
assert
rps
>
0
assert
actual_itl
>
0
assert
actual_itl
<=
50.0
assert
rps
>
0
and
actual_itl
>
0
and
actual_itl
<=
50.0
def
test_find_best_engine_decode_rps_zero_context
(
self
):
model
=
DecodeRegressionModel
(
...
...
@@ -402,8 +356,7 @@ class TestDecodeRegressionModel:
rps
,
itl_ms
=
model
.
find_best_engine_decode_rps
(
itl
=
50.0
,
context_length
=
0.0
,
osl
=
150.0
)
assert
rps
==
0.0
assert
itl_ms
==
0.0
assert
rps
==
0.0
and
itl_ms
==
0.0
def
test_load_benchmark_fpms
(
self
):
model
=
DecodeRegressionModel
(
...
...
@@ -418,15 +371,14 @@ class TestDecodeRegressionModel:
for
n
in
[
5
,
10
,
15
,
20
,
25
]
]
model
.
load_benchmark_fpms
(
fpms
)
assert
model
.
num_observations
==
5
assert
model
.
has_sufficient_data
()
assert
model
.
num_observations
==
5
and
model
.
has_sufficient_data
()
# ── AggRegressionModel tests ─────────────────────────────────────────
class
TestAggRegressionModel
:
def
_train_agg
(
self
,
model
:
AggRegressionModel
)
->
None
:
def
_train_agg
(
self
,
model
)
:
for
p
,
d
in
[(
100
,
1000
),
(
200
,
2000
),
(
300
,
3000
),
(
400
,
4000
),
(
500
,
5000
)]:
fpm
=
_make_fpm
(
sum_prefill_tokens
=
p
,
...
...
@@ -458,27 +410,19 @@ class TestAggRegressionModel:
max_num_fpm_samples
=
50
,
min_observations
=
3
,
bucket_count
=
16
)
self
.
_train_agg
(
model
)
assert
model
.
has_sufficient_data
()
ttft
=
model
.
estimate_next_ttft
(
queued_prefill_tokens
=
0
,
max_num_batched_tokens
=
2048
,
current_decode_kv
=
3000
,
queued_prefill_tokens
=
0
,
max_num_batched_tokens
=
2048
,
current_decode_kv
=
3000
)
assert
ttft
is
not
None
assert
ttft
>
0
assert
ttft
is
not
None
and
ttft
>
0
itl
=
model
.
estimate_next_itl
(
scheduled_decode_kv
=
3000
,
queued_decode_kv
=
0
)
assert
itl
is
not
None
assert
itl
>
0
assert
itl
is
not
None
and
itl
>
0
def
test_find_best_engine_agg_rps
(
self
):
model
=
AggRegressionModel
(
max_num_fpm_samples
=
50
,
min_observations
=
3
,
bucket_count
=
16
)
self
.
_train_agg
(
model
)
thpt
,
actual_ttft
,
actual_itl
=
model
.
find_best_engine_agg_rps
(
isl
=
2048.0
,
osl
=
150.0
,
...
...
@@ -486,10 +430,7 @@ class TestAggRegressionModel:
ttft_sla
=
500.0
,
itl_sla
=
50.0
,
)
assert
isinstance
(
thpt
,
float
)
assert
thpt
>
0
assert
actual_ttft
>=
0
assert
actual_itl
>=
0
assert
thpt
>
0
and
actual_ttft
>=
0
and
actual_itl
>=
0
def
test_find_best_engine_agg_rps_insufficient_data
(
self
):
model
=
AggRegressionModel
(
...
...
@@ -503,221 +444,3 @@ class TestAggRegressionModel:
itl_sla
=
50.0
,
)
assert
thpt
==
0.0
# ── Planner integration tests (with mocked FPM subscriber) ──────────
@
pytest
.
fixture
(
autouse
=
True
)
def
mock_prometheus_metrics
():
with
patch
(
"dynamo.planner.monitoring.planner_metrics.Gauge"
)
as
mock_gauge
:
mock_gauge
.
return_value
=
Mock
()
yield
def
_build_load_config
(
**
overrides
)
->
PlannerConfig
:
defaults
=
dict
(
throughput_adjustment_interval
=
60
,
prefill_engine_num_gpu
=
1
,
decode_engine_num_gpu
=
1
,
min_endpoint
=
1
,
max_gpu_budget
=-
1
,
ttft
=
500.0
,
itl
=
50.0
,
backend
=
"vllm"
,
no_operation
=
True
,
metric_pulling_prometheus_endpoint
=
"http://localhost:9090"
,
metric_reporting_prometheus_port
=
0
,
load_predictor
=
"constant"
,
profile_results_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"data"
,
"profiling_results"
,
"H200_TP1P_TP1D"
,
),
environment
=
"kubernetes"
,
namespace
=
"test-namespace"
,
mode
=
"disagg"
,
enable_load_scaling
=
True
,
enable_throughput_scaling
=
True
,
load_adjustment_interval
=
5
,
max_num_fpm_samples
=
50
,
fpm_sample_bucket_size
=
16
,
load_scaling_down_sensitivity
=
80
,
load_metric_samples
=
10
,
load_min_observations
=
5
,
)
defaults
.
update
(
overrides
)
return
PlannerConfig
.
model_construct
(
**
defaults
)
def
_mock_fpm_subscriber
(
fpm_stats
:
dict
[
tuple
[
str
,
int
],
ForwardPassMetrics
]):
"""Create a mock FPM subscriber that returns encoded FPM stats."""
mock
=
Mock
()
encoded
=
{
k
:
encode
(
v
)
for
k
,
v
in
fpm_stats
.
items
()}
mock
.
get_recent_stats
.
return_value
=
encoded
return
mock
class
TestPrefillFpmScaling
:
def
test_scale_up_all_engines_above_sla
(
self
):
"""All engines have high queued prefill -> estimated TTFT > SLA -> scale up."""
config
=
_build_load_config
(
ttft
=
5.0
)
# 5ms SLA (easy to exceed)
shared_state
=
PlannerSharedState
()
shared_state
.
num_p_workers
=
2
planner
=
PrefillPlanner
(
None
,
config
,
shared_state
=
shared_state
)
planner
.
model_name
=
"test-model"
planner
.
prefill_worker_info
=
WorkerInfo
(
max_num_batched_tokens
=
2048
)
for
tokens
in
range
(
200
,
1200
,
100
):
fpm
=
_make_fpm
(
sum_prefill_tokens
=
tokens
,
num_prefill_requests
=
1
,
wall_time
=
0.001
*
tokens
,
)
planner
.
ttft_regression
.
add_observation
(
fpm
)
stats
=
{
(
"w1"
,
0
):
_make_fpm
(
worker_id
=
"w1"
,
queued_prefill_tokens
=
10000
,
sum_prefill_tokens
=
500
,
num_prefill_requests
=
1
,
wall_time
=
0.5
,
),
(
"w2"
,
0
):
_make_fpm
(
worker_id
=
"w2"
,
queued_prefill_tokens
=
8000
,
sum_prefill_tokens
=
600
,
num_prefill_requests
=
1
,
wall_time
=
0.6
,
),
}
planner
.
fpm_subscriber
=
_mock_fpm_subscriber
(
stats
)
result
=
planner
.
load_plan_adjustment
()
assert
result
==
3
def
test_scale_down_all_engines_below_sla
(
self
):
"""All engines have low queued prefill -> estimated TTFT < SLA * sensitivity."""
config
=
_build_load_config
(
ttft
=
500.0
,
load_scaling_down_sensitivity
=
100
)
shared_state
=
PlannerSharedState
()
shared_state
.
num_p_workers
=
3
planner
=
PrefillPlanner
(
None
,
config
,
shared_state
=
shared_state
)
planner
.
model_name
=
"test-model"
planner
.
prefill_worker_info
=
WorkerInfo
(
max_num_batched_tokens
=
2048
)
for
tokens
in
range
(
100
,
600
,
50
):
fpm
=
_make_fpm
(
sum_prefill_tokens
=
tokens
,
num_prefill_requests
=
1
,
wall_time
=
0.001
*
tokens
,
)
planner
.
ttft_regression
.
add_observation
(
fpm
)
stats
=
{
(
f
"w
{
i
}
"
,
0
):
_make_fpm
(
worker_id
=
f
"w
{
i
}
"
,
queued_prefill_tokens
=
0
,
sum_prefill_tokens
=
100
,
num_prefill_requests
=
1
,
wall_time
=
0.1
,
)
for
i
in
range
(
3
)
}
planner
.
fpm_subscriber
=
_mock_fpm_subscriber
(
stats
)
result
=
planner
.
load_plan_adjustment
()
assert
result
==
2
def
test_cold_start_returns_none
(
self
):
config
=
_build_load_config
()
shared_state
=
PlannerSharedState
()
shared_state
.
num_p_workers
=
2
planner
=
PrefillPlanner
(
None
,
config
,
shared_state
=
shared_state
)
planner
.
model_name
=
"test-model"
planner
.
prefill_worker_info
=
WorkerInfo
(
max_num_batched_tokens
=
2048
)
for
tokens
in
[
100
,
200
]:
fpm
=
_make_fpm
(
sum_prefill_tokens
=
tokens
,
wall_time
=
0.01
)
planner
.
ttft_regression
.
add_observation
(
fpm
)
stats
=
{(
"w1"
,
0
):
_make_fpm
(
queued_prefill_tokens
=
5000
,
wall_time
=
0.5
)}
planner
.
fpm_subscriber
=
_mock_fpm_subscriber
(
stats
)
result
=
planner
.
load_plan_adjustment
()
assert
result
is
None
class
TestDecodeFpmScaling
:
def
test_scale_up_all_engines_above_sla
(
self
):
"""All engines have high decode load -> estimated ITL > SLA -> scale up."""
config
=
_build_load_config
(
itl
=
5.0
)
# 5ms SLA
shared_state
=
PlannerSharedState
()
shared_state
.
num_d_workers
=
2
planner
=
DecodePlanner
(
None
,
config
,
shared_state
=
shared_state
)
planner
.
model_name
=
"test-model"
# 2D regression: vary both num_decode_requests and sum_decode_kv_tokens
for
n_req
,
kv
in
[
(
5
,
1000
),
(
10
,
2000
),
(
15
,
3000
),
(
20
,
4000
),
(
25
,
5000
),
]:
fpm
=
_make_fpm
(
sum_decode_kv_tokens
=
kv
,
num_decode_requests
=
n_req
,
wall_time
=
0.0001
*
kv
+
0.0005
*
n_req
+
0.001
,
)
planner
.
itl_regression
.
add_observation
(
fpm
)
stats
=
{
(
"w1"
,
0
):
_make_fpm
(
worker_id
=
"w1"
,
sum_decode_kv_tokens
=
5000
,
queued_decode_kv_tokens
=
3000
,
num_decode_requests
=
20
,
wall_time
=
0.6
,
),
(
"w2"
,
0
):
_make_fpm
(
worker_id
=
"w2"
,
sum_decode_kv_tokens
=
4500
,
queued_decode_kv_tokens
=
2500
,
num_decode_requests
=
18
,
wall_time
=
0.55
,
),
}
planner
.
fpm_subscriber
=
_mock_fpm_subscriber
(
stats
)
result
=
planner
.
load_plan_adjustment
()
assert
result
==
3
def
test_cold_start_returns_none
(
self
):
config
=
_build_load_config
()
shared_state
=
PlannerSharedState
()
shared_state
.
num_d_workers
=
2
planner
=
DecodePlanner
(
None
,
config
,
shared_state
=
shared_state
)
planner
.
model_name
=
"test-model"
fpm
=
_make_fpm
(
sum_decode_kv_tokens
=
1000
,
num_decode_requests
=
5
,
wall_time
=
0.01
)
planner
.
itl_regression
.
add_observation
(
fpm
)
stats
=
{
(
"w1"
,
0
):
_make_fpm
(
sum_decode_kv_tokens
=
5000
,
num_decode_requests
=
10
,
wall_time
=
0.5
)
}
planner
.
fpm_subscriber
=
_mock_fpm_subscriber
(
stats
)
result
=
planner
.
load_plan_adjustment
()
assert
result
is
None
components/src/dynamo/planner/tests/unit/test_replica_calculation.py
deleted
100644 → 0
View file @
39a6a240
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Unit tests for SLA planner replica calculation logic.
These tests focus specifically on the replica calculation formulas without
testing load prediction or regression internals.
"""
import
asyncio
import
math
import
os
from
unittest.mock
import
Mock
,
patch
import
pytest
from
dynamo.planner.config.planner_config
import
PlannerConfig
from
dynamo.planner.core.budget
import
_apply_global_gpu_budget
from
dynamo.planner.core.decode
import
DecodePlanner
from
dynamo.planner.core.prefill
import
PrefillPlanner
from
dynamo.planner.core.state
import
PlannerSharedState
from
dynamo.planner.monitoring.traffic_metrics
import
Metrics
from
dynamo.planner.monitoring.worker_info
import
WorkerInfo
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
gpu_0
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
planner
,
]
class
PlannerHarness
:
def
__init__
(
self
,
prefill_planner
,
decode_planner
,
shared_state
):
self
.
prefill_planner
=
prefill_planner
self
.
decode_planner
=
decode_planner
self
.
shared_state
=
shared_state
self
.
last_target_replicas
=
[]
async
def
make_adjustments
(
self
):
if
not
self
.
shared_state
.
last_metrics
.
is_valid
():
return
num_p
,
num_d
,
is_stable
=
await
self
.
prefill_planner
.
get_workers_info
()
self
.
shared_state
.
num_p_workers
=
num_p
self
.
shared_state
.
num_d_workers
=
num_d
next_num_p
=
self
.
prefill_planner
.
plan_adjustment
()
next_num_d
=
self
.
decode_planner
.
plan_adjustment
()
if
next_num_p
is
None
or
next_num_d
is
None
:
return
next_num_p
,
next_num_d
=
_apply_global_gpu_budget
(
next_num_p
,
next_num_d
,
self
.
prefill_planner
.
config
)
self
.
prefill_planner
.
update_predicted_replicas_metric
(
next_num_p
)
self
.
decode_planner
.
update_predicted_replicas_metric
(
next_num_d
)
target_replicas
=
[
{
"sub_component_type"
:
"prefill"
,
"component_name"
:
self
.
prefill_planner
.
prefill_worker_info
.
k8s_name
,
"desired_replicas"
:
next_num_p
,
},
{
"sub_component_type"
:
"decode"
,
"component_name"
:
self
.
prefill_planner
.
decode_worker_info
.
k8s_name
,
"desired_replicas"
:
next_num_d
,
},
]
self
.
last_target_replicas
=
target_replicas
if
not
self
.
prefill_planner
.
config
.
no_operation
:
await
self
.
prefill_planner
.
connector
.
set_component_replicas
(
target_replicas
,
blocking
=
False
)
def
__getattr__
(
self
,
name
):
shared_attrs
=
{
"num_req_predictor"
,
"isl_predictor"
,
"osl_predictor"
,
"connector"
,
"prometheus_traffic_client"
,
"config"
,
}
prefill_attrs
=
{
"ttft_regression"
,
"prefill_worker_info"
,
}
decode_attrs
=
{
"itl_regression"
,
"decode_worker_info"
,
}
if
name
==
"last_metrics"
:
return
self
.
shared_state
.
last_metrics
if
name
==
"get_workers_info"
:
return
self
.
prefill_planner
.
get_workers_info
if
name
in
shared_attrs
:
return
getattr
(
self
.
prefill_planner
,
name
)
if
name
in
prefill_attrs
:
return
getattr
(
self
.
prefill_planner
,
name
)
if
name
in
decode_attrs
:
return
getattr
(
self
.
decode_planner
,
name
)
raise
AttributeError
(
name
)
def
__setattr__
(
self
,
name
,
value
):
if
name
in
{
"prefill_planner"
,
"decode_planner"
,
"shared_state"
}:
return
super
().
__setattr__
(
name
,
value
)
shared_attrs
=
{
"num_req_predictor"
,
"isl_predictor"
,
"osl_predictor"
,
"connector"
,
"prometheus_traffic_client"
,
"config"
,
"get_workers_info"
,
}
prefill_attrs
=
{
"ttft_regression"
}
decode_attrs
=
{
"itl_regression"
}
if
name
==
"last_metrics"
:
self
.
shared_state
.
last_metrics
=
value
return
None
if
name
in
shared_attrs
:
# Store locally to support patch.object lifecycle (set/del).
object
.
__setattr__
(
self
,
name
,
value
)
setattr
(
self
.
prefill_planner
,
name
,
value
)
setattr
(
self
.
decode_planner
,
name
,
value
)
return
None
if
name
in
prefill_attrs
:
setattr
(
self
.
prefill_planner
,
name
,
value
)
return
None
if
name
in
decode_attrs
:
setattr
(
self
.
decode_planner
,
name
,
value
)
return
None
return
super
().
__setattr__
(
name
,
value
)
def
_replica_count
(
target_replicas
,
component_name
,
default
=
1
):
for
replica
in
target_replicas
:
if
replica
.
get
(
"component_name"
)
==
component_name
:
return
replica
.
get
(
"desired_replicas"
,
default
)
return
default
@
pytest
.
fixture
def
planner
():
"""Set up test environment with mocked dependencies."""
config
=
PlannerConfig
.
model_construct
(
throughput_adjustment_interval
=
60
,
prefill_engine_num_gpu
=
1
,
decode_engine_num_gpu
=
1
,
min_endpoint
=
1
,
max_gpu_budget
=
10
,
ttft
=
80.0
,
itl
=
10.0
,
backend
=
"vllm"
,
no_operation
=
True
,
metric_pulling_prometheus_endpoint
=
"http://localhost:9090"
,
metric_reporting_prometheus_port
=
0
,
load_predictor
=
"constant"
,
profile_results_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"data"
,
"profiling_results"
,
"H200_TP1P_TP1D"
,
),
environment
=
"kubernetes"
,
namespace
=
"test-namespace"
,
enable_throughput_scaling
=
True
,
enable_load_scaling
=
False
,
load_predictor_warmup_trace
=
None
,
load_predictor_log1p
=
False
,
max_num_fpm_samples
=
50
,
fpm_sample_bucket_size
=
16
,
load_min_observations
=
5
,
)
mock_runtime
=
Mock
()
with
patch
(
"dynamo.planner.monitoring.planner_metrics.Gauge"
)
as
mock_gauge
:
mock_gauge
.
return_value
=
Mock
()
shared_state
=
PlannerSharedState
()
prefill_planner
=
PrefillPlanner
(
mock_runtime
,
config
,
shared_state
=
shared_state
)
decode_planner
=
DecodePlanner
(
mock_runtime
,
config
,
shared_state
=
shared_state
)
planner
=
PlannerHarness
(
prefill_planner
,
decode_planner
,
shared_state
)
# Set up WorkerInfo for both planners
prefill_planner
.
prefill_worker_info
=
WorkerInfo
(
k8s_name
=
"VllmPrefillWorker"
,
component_name
=
"prefill"
,
endpoint
=
"generate"
,
)
prefill_planner
.
decode_worker_info
=
WorkerInfo
(
k8s_name
=
"VllmDecodeWorker"
,
component_name
=
"backend"
,
endpoint
=
"generate"
,
)
decode_planner
.
prefill_worker_info
=
prefill_planner
.
prefill_worker_info
decode_planner
.
decode_worker_info
=
prefill_planner
.
decode_worker_info
planner
.
ttft_regression
=
Mock
()
# Default: 40000 tokens/s at isl=3000 → 40000/3000 rps
planner
.
ttft_regression
.
find_best_engine_prefill_rps
.
return_value
=
(
40000.0
/
3000.0
,
75.0
,
)
planner
.
ttft_regression
.
has_sufficient_data
.
return_value
=
True
planner
.
itl_regression
=
Mock
()
# Default: 10000 tokens/s at osl=150 → 10000/150 rps
planner
.
itl_regression
.
find_best_engine_decode_rps
.
return_value
=
(
10000.0
/
150.0
,
9.5
,
)
planner
.
itl_regression
.
has_sufficient_data
.
return_value
=
True
# Mock the predictors to return fixed values
planner
.
num_req_predictor
=
Mock
()
planner
.
isl_predictor
=
Mock
()
planner
.
osl_predictor
=
Mock
()
# Mock the connector since we're not testing actual scaling
planner
.
connector
=
Mock
()
# Mock prometheus client
planner
.
prometheus_traffic_client
=
Mock
()
planner
.
config
=
config
yield
planner
class
TestReplicaCalculation
:
"""Test replica calculation formulas in isolation."""
@
pytest
.
mark
.
nightly
@
pytest
.
mark
.
gpu_2
@
pytest
.
mark
.
performance
def
test_prefill_replica_calculation_basic
(
self
,
planner
):
"""Test basic prefill replica calculation."""
next_num_req
=
10
next_isl
=
3000
engine_rps
=
40000.0
/
next_isl
planner
.
num_req_predictor
.
predict_next
.
return_value
=
next_num_req
planner
.
isl_predictor
.
predict_next
.
return_value
=
next_isl
planner
.
osl_predictor
.
predict_next
.
return_value
=
150
planner
.
ttft_regression
.
find_best_engine_prefill_rps
.
return_value
=
(
engine_rps
,
75.0
,
)
planner
.
itl_regression
.
find_best_engine_decode_rps
.
return_value
=
(
10000.0
/
150.0
,
9.5
,
)
# Formula: ceil(num_req / interval / engine_rps)
pred_prefill_demand
=
(
next_num_req
/
planner
.
config
.
throughput_adjustment_interval
)
expected_prefill_replicas
=
math
.
ceil
(
pred_prefill_demand
/
engine_rps
)
planner
.
last_metrics
=
Metrics
(
num_req
=
10
,
isl
=
3000
,
osl
=
150
,
ttft
=
80.0
,
itl
=
10.0
,
request_duration
=
100.0
)
async
def
mock_get_workers_info
(
*
args
,
**
kwargs
):
return
(
1
,
1
,
True
)
planner
.
get_workers_info
=
mock_get_workers_info
asyncio
.
run
(
planner
.
make_adjustments
())
prefill_component
=
"VllmPrefillWorker"
calculated_prefill_replicas
=
_replica_count
(
planner
.
last_target_replicas
,
prefill_component
)
print
(
f
"Expected prefill replicas:
{
expected_prefill_replicas
}
"
)
print
(
f
"Calculated prefill replicas:
{
calculated_prefill_replicas
}
"
)
assert
(
max
(
expected_prefill_replicas
,
planner
.
config
.
min_endpoint
)
==
calculated_prefill_replicas
)
@
pytest
.
mark
.
nightly
@
pytest
.
mark
.
gpu_2
@
pytest
.
mark
.
performance
def
test_decode_replica_calculation_basic
(
self
,
planner
):
"""Test basic decode replica calculation."""
next_num_req
=
10
next_osl
=
150
engine_rps
=
10000.0
/
next_osl
planner
.
num_req_predictor
.
predict_next
.
return_value
=
next_num_req
planner
.
isl_predictor
.
predict_next
.
return_value
=
3000
planner
.
osl_predictor
.
predict_next
.
return_value
=
next_osl
planner
.
ttft_regression
.
find_best_engine_prefill_rps
.
return_value
=
(
40000.0
/
3000.0
,
75.0
,
)
planner
.
itl_regression
.
find_best_engine_decode_rps
.
return_value
=
(
engine_rps
,
9.5
,
)
# Formula: ceil(num_req / interval / engine_rps)
expected_decode_replicas
=
math
.
ceil
(
next_num_req
/
planner
.
config
.
throughput_adjustment_interval
/
engine_rps
)
planner
.
last_metrics
=
Metrics
(
num_req
=
10
,
isl
=
3000
,
osl
=
150
,
ttft
=
80.0
,
itl
=
10.0
,
request_duration
=
100.0
)
async
def
mock_get_workers_info
(
*
args
,
**
kwargs
):
return
(
1
,
1
,
True
)
planner
.
get_workers_info
=
mock_get_workers_info
asyncio
.
run
(
planner
.
make_adjustments
())
decode_component
=
"VllmDecodeWorker"
calculated_decode_replicas
=
_replica_count
(
planner
.
last_target_replicas
,
decode_component
)
print
(
f
"Expected decode replicas:
{
expected_decode_replicas
}
"
)
print
(
f
"Calculated decode replicas:
{
calculated_decode_replicas
}
"
)
assert
(
max
(
expected_decode_replicas
,
planner
.
config
.
min_endpoint
)
==
calculated_decode_replicas
)
@
pytest
.
mark
.
parametrize
(
"num_req,decode_rps,expected_p,expected_d"
,
[
(
10
,
10000.0
/
150.0
,
1
,
1
),
# low_load_10_req_per_second
(
500
,
1000.0
/
150.0
,
1
,
2
,
),
# high_load_500_req_per_second (lower decode rps)
],
)
@
pytest
.
mark
.
nightly
@
pytest
.
mark
.
gpu_2
@
pytest
.
mark
.
performance
def
test_scaling_scenario_low_to_high_load
(
self
,
planner
,
num_req
,
decode_rps
,
expected_p
,
expected_d
):
"""Test scaling from low to high load scenarios."""
planner
.
num_req_predictor
.
predict_next
.
return_value
=
num_req
planner
.
isl_predictor
.
predict_next
.
return_value
=
3000
planner
.
osl_predictor
.
predict_next
.
return_value
=
150
planner
.
ttft_regression
.
find_best_engine_prefill_rps
.
return_value
=
(
40000.0
/
3000.0
,
75.0
,
)
planner
.
itl_regression
.
find_best_engine_decode_rps
.
return_value
=
(
decode_rps
,
9.5
,
)
planner
.
last_metrics
=
Metrics
(
num_req
=
num_req
,
isl
=
3000
,
osl
=
150
,
ttft
=
80.0
,
itl
=
10.0
,
request_duration
=
100.0
,
)
async
def
mock_get_workers_info
(
*
args
,
**
kwargs
):
return
(
1
,
1
,
True
)
planner
.
get_workers_info
=
mock_get_workers_info
planner
.
connector
.
reset_mock
()
asyncio
.
run
(
planner
.
make_adjustments
())
prefill_replicas
=
_replica_count
(
planner
.
last_target_replicas
,
"VllmPrefillWorker"
)
decode_replicas
=
_replica_count
(
planner
.
last_target_replicas
,
"VllmDecodeWorker"
)
print
(
f
"Load
{
num_req
}
req/s: P=
{
prefill_replicas
}
, D=
{
decode_replicas
}
"
)
assert
(
prefill_replicas
==
expected_p
),
f
"Prefill replicas mismatch: expected
{
expected_p
}
, got
{
prefill_replicas
}
"
assert
(
decode_replicas
==
expected_d
),
f
"Decode replicas mismatch: expected
{
expected_d
}
, got
{
decode_replicas
}
"
@
pytest
.
mark
.
nightly
@
pytest
.
mark
.
gpu_2
@
pytest
.
mark
.
performance
def
test_gpu_budget_constraint
(
self
,
planner
):
"""Test that GPU budget constraints are properly applied."""
planner
.
config
.
max_gpu_budget
=
3
planner
.
num_req_predictor
.
predict_next
.
return_value
=
50
planner
.
isl_predictor
.
predict_next
.
return_value
=
3000
planner
.
osl_predictor
.
predict_next
.
return_value
=
150
planner
.
ttft_regression
.
find_best_engine_prefill_rps
.
return_value
=
(
40000.0
/
3000.0
,
75.0
,
)
planner
.
itl_regression
.
find_best_engine_decode_rps
.
return_value
=
(
10000.0
/
150.0
,
9.5
,
)
planner
.
last_metrics
=
Metrics
(
num_req
=
50
,
isl
=
3000
,
osl
=
150
,
ttft
=
80.0
,
itl
=
10.0
,
request_duration
=
100.0
)
async
def
mock_get_workers_info
(
*
args
,
**
kwargs
):
return
(
1
,
1
,
True
)
planner
.
get_workers_info
=
mock_get_workers_info
asyncio
.
run
(
planner
.
make_adjustments
())
prefill_replicas
=
_replica_count
(
planner
.
last_target_replicas
,
"VllmPrefillWorker"
)
decode_replicas
=
_replica_count
(
planner
.
last_target_replicas
,
"VllmDecodeWorker"
)
total_gpus
=
(
prefill_replicas
*
planner
.
config
.
prefill_engine_num_gpu
+
decode_replicas
*
planner
.
config
.
decode_engine_num_gpu
)
print
(
f
"GPU budget test: P=
{
prefill_replicas
}
, D=
{
decode_replicas
}
, Total GPUs=
{
total_gpus
}
"
)
assert
(
total_gpus
<=
planner
.
config
.
max_gpu_budget
),
"Total GPU usage exceeds budget"
@
pytest
.
mark
.
nightly
@
pytest
.
mark
.
gpu_2
@
pytest
.
mark
.
performance
def
test_min_endpoint_constraint
(
self
,
planner
):
"""Test that minimum endpoint constraints are respected."""
planner
.
config
.
min_endpoint
=
2
planner
.
num_req_predictor
.
predict_next
.
return_value
=
1
planner
.
isl_predictor
.
predict_next
.
return_value
=
100
planner
.
osl_predictor
.
predict_next
.
return_value
=
10
planner
.
ttft_regression
.
find_best_engine_prefill_rps
.
return_value
=
(
40000.0
/
100.0
,
75.0
,
)
planner
.
itl_regression
.
find_best_engine_decode_rps
.
return_value
=
(
10000.0
/
10.0
,
9.5
,
)
planner
.
last_metrics
=
Metrics
(
num_req
=
1
,
isl
=
100
,
osl
=
10
,
ttft
=
80.0
,
itl
=
10.0
,
request_duration
=
100.0
)
async
def
mock_get_workers_info
(
*
args
,
**
kwargs
):
return
(
1
,
1
,
True
)
planner
.
get_workers_info
=
mock_get_workers_info
asyncio
.
run
(
planner
.
make_adjustments
())
prefill_replicas
=
_replica_count
(
planner
.
last_target_replicas
,
"VllmPrefillWorker"
)
decode_replicas
=
_replica_count
(
planner
.
last_target_replicas
,
"VllmDecodeWorker"
)
print
(
f
"Min endpoint test: P=
{
prefill_replicas
}
, D=
{
decode_replicas
}
"
)
assert
(
prefill_replicas
>=
planner
.
config
.
min_endpoint
),
"Prefill replicas below minimum"
assert
(
decode_replicas
>=
planner
.
config
.
min_endpoint
),
"Decode replicas below minimum"
@
pytest
.
mark
.
nightly
@
pytest
.
mark
.
gpu_2
@
pytest
.
mark
.
performance
def
test_multi_gpu_engines
(
self
,
planner
):
"""Test replica calculation with multi-GPU engines."""
planner
.
config
.
prefill_engine_num_gpu
=
2
planner
.
config
.
decode_engine_num_gpu
=
4
planner
.
num_req_predictor
.
predict_next
.
return_value
=
20
planner
.
isl_predictor
.
predict_next
.
return_value
=
3000
planner
.
osl_predictor
.
predict_next
.
return_value
=
150
# Engine-level request rate (already accounts for multi-GPU)
prefill_engine_rps
=
40000.0
/
3000.0
decode_engine_rps
=
5000.0
/
150.0
planner
.
ttft_regression
.
find_best_engine_prefill_rps
.
return_value
=
(
prefill_engine_rps
,
75.0
,
)
planner
.
itl_regression
.
find_best_engine_decode_rps
.
return_value
=
(
decode_engine_rps
,
9.5
,
)
planner
.
last_metrics
=
Metrics
(
num_req
=
20
,
isl
=
3000
,
osl
=
150
,
ttft
=
80.0
,
itl
=
10.0
,
request_duration
=
100.0
)
async
def
mock_get_workers_info
(
*
args
,
**
kwargs
):
return
(
1
,
1
,
True
)
planner
.
get_workers_info
=
mock_get_workers_info
# No engine_num_gpu division — regression returns engine-level rps
expected_prefill_replicas
=
math
.
ceil
(
20
/
planner
.
config
.
throughput_adjustment_interval
/
prefill_engine_rps
)
expected_decode_replicas
=
math
.
ceil
(
20
/
planner
.
config
.
throughput_adjustment_interval
/
decode_engine_rps
)
asyncio
.
run
(
planner
.
make_adjustments
())
prefill_replicas
=
_replica_count
(
planner
.
last_target_replicas
,
"VllmPrefillWorker"
)
decode_replicas
=
_replica_count
(
planner
.
last_target_replicas
,
"VllmDecodeWorker"
)
print
(
f
"Multi-GPU test: P=
{
prefill_replicas
}
(expected ~
{
expected_prefill_replicas
}
), "
f
"D=
{
decode_replicas
}
(expected ~
{
expected_decode_replicas
}
)"
)
assert
prefill_replicas
==
max
(
expected_prefill_replicas
,
planner
.
config
.
min_endpoint
)
assert
decode_replicas
==
max
(
expected_decode_replicas
,
planner
.
config
.
min_endpoint
)
@
pytest
.
mark
.
weekly
@
pytest
.
mark
.
gpu_2
@
pytest
.
mark
.
performance
def
test_complex_gpu_budget_scaling
(
self
,
planner
):
"""Test complex GPU budget scaling with proportional reduction."""
planner
.
config
.
max_gpu_budget
=
5
planner
.
config
.
prefill_engine_num_gpu
=
2
planner
.
config
.
decode_engine_num_gpu
=
2
planner
.
config
.
min_endpoint
=
1
planner
.
num_req_predictor
.
predict_next
.
return_value
=
100
planner
.
isl_predictor
.
predict_next
.
return_value
=
3000
planner
.
osl_predictor
.
predict_next
.
return_value
=
150
planner
.
ttft_regression
.
find_best_engine_prefill_rps
.
return_value
=
(
10000.0
/
3000.0
,
300.0
,
)
planner
.
itl_regression
.
find_best_engine_decode_rps
.
return_value
=
(
1000.0
/
150.0
,
9.5
,
)
planner
.
last_metrics
=
Metrics
(
num_req
=
100
,
isl
=
3000
,
osl
=
150
,
ttft
=
80.0
,
itl
=
10.0
,
request_duration
=
100.0
,
)
async
def
mock_get_workers_info
(
*
args
,
**
kwargs
):
return
(
1
,
1
,
True
)
planner
.
get_workers_info
=
mock_get_workers_info
asyncio
.
run
(
planner
.
make_adjustments
())
prefill_replicas
=
_replica_count
(
planner
.
last_target_replicas
,
"VllmPrefillWorker"
)
decode_replicas
=
_replica_count
(
planner
.
last_target_replicas
,
"VllmDecodeWorker"
)
total_gpus
=
(
prefill_replicas
*
planner
.
config
.
prefill_engine_num_gpu
+
decode_replicas
*
planner
.
config
.
decode_engine_num_gpu
)
print
(
f
"Complex GPU budget test: P=
{
prefill_replicas
}
, D=
{
decode_replicas
}
, "
f
"Total GPUs=
{
total_gpus
}
"
)
assert
(
total_gpus
<=
planner
.
config
.
max_gpu_budget
),
"Total GPU usage should not exceed budget"
assert
(
prefill_replicas
>=
planner
.
config
.
min_endpoint
),
"Should respect min_endpoint for prefill"
assert
(
decode_replicas
>=
planner
.
config
.
min_endpoint
),
"Should respect min_endpoint for decode"
components/src/dynamo/planner/tests/unit/test_sla_planner_scaling.py
deleted
100644 → 0
View file @
39a6a240
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
math
import
os
from
unittest.mock
import
MagicMock
,
Mock
,
patch
import
pytest
from
dynamo.planner.config.planner_config
import
PlannerConfig
from
dynamo.planner.core.budget
import
_initialize_gpu_counts
from
dynamo.planner.core.decode
import
DecodePlanner
from
dynamo.planner.core.prefill
import
PrefillPlanner
from
dynamo.planner.core.state
import
PlannerSharedState
from
dynamo.planner.errors
import
DeploymentValidationError
pytestmark
=
[
pytest
.
mark
.
gpu_0
,
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
planner
,
]
PREFILL_ENGINE_RPS
=
10.0
DECODE_ENGINE_RPS
=
5.0
DECODE_ACTUAL_ITL_MS
=
40.0
@
pytest
.
fixture
(
autouse
=
True
)
def
mock_prometheus_metrics
():
with
patch
(
"dynamo.planner.monitoring.planner_metrics.Gauge"
)
as
mock_gauge
:
mock_gauge
.
return_value
=
Mock
()
yield
def
_build_config
():
return
PlannerConfig
.
model_construct
(
throughput_adjustment_interval
=
60
,
prefill_engine_num_gpu
=
1
,
decode_engine_num_gpu
=
1
,
min_endpoint
=
1
,
max_gpu_budget
=-
1
,
ttft
=
500.0
,
itl
=
50.0
,
backend
=
"vllm"
,
no_operation
=
True
,
metric_pulling_prometheus_endpoint
=
"http://localhost:9090"
,
metric_reporting_prometheus_port
=
0
,
load_predictor
=
"constant"
,
load_predictor_warmup_trace
=
None
,
load_predictor_log1p
=
False
,
profile_results_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"data"
,
"profiling_results"
,
"H200_TP1P_TP1D"
,
),
environment
=
"kubernetes"
,
namespace
=
"test-namespace"
,
mode
=
"disagg"
,
enable_throughput_scaling
=
True
,
enable_load_scaling
=
False
,
)
def
_build_prometheus_client
(
samples
):
client
=
Mock
()
client
.
get_avg_time_to_first_token
.
side_effect
=
[
s
[
"ttft_ms"
]
/
1000
for
s
in
samples
]
client
.
get_avg_inter_token_latency
.
side_effect
=
[
s
[
"itl_ms"
]
/
1000
for
s
in
samples
]
client
.
get_avg_request_count
.
side_effect
=
[
s
[
"num_req"
]
for
s
in
samples
]
client
.
get_avg_request_duration
.
side_effect
=
[
s
[
"request_duration"
]
for
s
in
samples
]
client
.
get_avg_input_sequence_tokens
.
side_effect
=
[
s
[
"isl"
]
for
s
in
samples
]
client
.
get_avg_output_sequence_tokens
.
side_effect
=
[
s
[
"osl"
]
for
s
in
samples
]
return
client
def
_build_planners
(
config
,
prometheus_client
):
shared_state
=
PlannerSharedState
()
prefill_planner
=
PrefillPlanner
(
None
,
config
,
shared_state
=
shared_state
)
decode_planner
=
DecodePlanner
(
None
,
config
,
shared_state
=
shared_state
)
prefill_planner
.
prometheus_traffic_client
=
prometheus_client
decode_planner
.
prometheus_traffic_client
=
prometheus_client
prefill_planner
.
model_name
=
"test-model"
decode_planner
.
model_name
=
"test-model"
prefill_planner
.
ttft_regression
=
MagicMock
()
prefill_planner
.
ttft_regression
.
find_best_engine_prefill_rps
.
return_value
=
(
PREFILL_ENGINE_RPS
,
75.0
,
)
prefill_planner
.
ttft_regression
.
has_sufficient_data
.
return_value
=
True
decode_planner
.
itl_regression
=
MagicMock
()
decode_planner
.
itl_regression
.
find_best_engine_decode_rps
.
return_value
=
(
DECODE_ENGINE_RPS
,
DECODE_ACTUAL_ITL_MS
,
)
decode_planner
.
itl_regression
.
has_sufficient_data
.
return_value
=
True
async
def
mock_get_workers_info
(
require_prefill
=
True
,
require_decode
=
True
):
return
(
1
if
require_prefill
else
0
,
1
if
require_decode
else
0
,
True
,
# is_stable
)
prefill_planner
.
get_workers_info
=
mock_get_workers_info
decode_planner
.
get_workers_info
=
mock_get_workers_info
return
prefill_planner
,
decode_planner
,
shared_state
def
_expected_prefill
(
config
,
prefill_planner
,
sample
):
demand_rps
=
sample
[
"num_req"
]
/
config
.
throughput_adjustment_interval
engine_rps
,
_
=
prefill_planner
.
ttft_regression
.
find_best_engine_prefill_rps
(
ttft_sla
=
config
.
ttft
,
isl
=
sample
[
"isl"
]
)
expected
=
math
.
ceil
(
demand_rps
/
engine_rps
)
return
max
(
expected
,
config
.
min_endpoint
)
def
_expected_decode
(
config
,
decode_planner
,
sample
):
demand_rps
=
sample
[
"num_req"
]
/
config
.
throughput_adjustment_interval
engine_rps
,
_
=
decode_planner
.
itl_regression
.
find_best_engine_decode_rps
(
itl
=
config
.
itl
,
context_length
=
sample
[
"isl"
]
+
sample
[
"osl"
]
/
2
)
expected
=
math
.
ceil
(
demand_rps
/
engine_rps
)
return
max
(
expected
,
config
.
min_endpoint
)
def
_run_interval
(
prefill_planner
,
decode_planner
,
shared_state
):
asyncio
.
run
(
prefill_planner
.
observe_traffic_stats
(
require_prefill
=
True
,
require_decode
=
True
)
)
decode_planner
.
update_predictors_from_metrics
(
shared_state
.
last_metrics
)
next_num_p
=
prefill_planner
.
plan_adjustment
()
next_num_d
=
decode_planner
.
plan_adjustment
()
return
next_num_p
,
next_num_d
def
test_disagg_scale_up
():
config
=
_build_config
()
samples
=
[
{
"num_req"
:
10
,
"isl"
:
3000
,
"osl"
:
150
,
"ttft_ms"
:
400.0
,
"itl_ms"
:
30.0
,
"request_duration"
:
20.0
,
},
{
"num_req"
:
5000
,
"isl"
:
3000
,
"osl"
:
150
,
"ttft_ms"
:
400.0
,
"itl_ms"
:
30.0
,
"request_duration"
:
20.0
,
},
]
client
=
_build_prometheus_client
(
samples
)
prefill_planner
,
decode_planner
,
shared_state
=
_build_planners
(
config
,
client
)
low_p
,
low_d
=
_run_interval
(
prefill_planner
,
decode_planner
,
shared_state
)
high_p
,
high_d
=
_run_interval
(
prefill_planner
,
decode_planner
,
shared_state
)
assert
low_p
==
_expected_prefill
(
config
,
prefill_planner
,
samples
[
0
])
assert
low_d
==
_expected_decode
(
config
,
decode_planner
,
samples
[
0
])
assert
high_p
==
_expected_prefill
(
config
,
prefill_planner
,
samples
[
1
])
assert
high_d
==
_expected_decode
(
config
,
decode_planner
,
samples
[
1
])
assert
high_p
>
low_p
assert
high_d
>
low_d
def
test_disagg_scale_down
():
config
=
_build_config
()
samples
=
[
{
"num_req"
:
5000
,
"isl"
:
3000
,
"osl"
:
150
,
"ttft_ms"
:
400.0
,
"itl_ms"
:
30.0
,
"request_duration"
:
20.0
,
},
{
"num_req"
:
10
,
"isl"
:
3000
,
"osl"
:
150
,
"ttft_ms"
:
400.0
,
"itl_ms"
:
30.0
,
"request_duration"
:
20.0
,
},
]
client
=
_build_prometheus_client
(
samples
)
prefill_planner
,
decode_planner
,
shared_state
=
_build_planners
(
config
,
client
)
high_p
,
high_d
=
_run_interval
(
prefill_planner
,
decode_planner
,
shared_state
)
low_p
,
low_d
=
_run_interval
(
prefill_planner
,
decode_planner
,
shared_state
)
assert
high_p
==
_expected_prefill
(
config
,
prefill_planner
,
samples
[
0
])
assert
high_d
==
_expected_decode
(
config
,
decode_planner
,
samples
[
0
])
assert
low_p
==
_expected_prefill
(
config
,
prefill_planner
,
samples
[
1
])
assert
low_d
==
_expected_decode
(
config
,
decode_planner
,
samples
[
1
])
assert
low_p
<
high_p
assert
low_d
<
high_d
class
TestInitializeGpuCounts
:
@
staticmethod
def
_make_config
(
**
overrides
):
defaults
=
dict
(
prefill_engine_num_gpu
=
None
,
decode_engine_num_gpu
=
None
)
defaults
.
update
(
overrides
)
return
PlannerConfig
.
model_construct
(
**
defaults
)
def
test_kubernetes_mode_reads_from_dgd
(
self
):
"""Test that GPU counts are read from DGD in Kubernetes mode"""
config
=
self
.
_make_config
()
connector
=
Mock
()
connector
.
get_gpu_counts
=
Mock
(
return_value
=
(
2
,
4
))
_initialize_gpu_counts
(
config
,
connector
,
require_prefill
=
True
,
require_decode
=
True
)
assert
config
.
prefill_engine_num_gpu
==
2
assert
config
.
decode_engine_num_gpu
==
4
connector
.
get_gpu_counts
.
assert_called_once_with
(
require_prefill
=
True
,
require_decode
=
True
)
def
test_kubernetes_mode_prefill_only
(
self
):
"""Test GPU count initialization for prefill-only mode"""
config
=
self
.
_make_config
()
connector
=
Mock
()
connector
.
get_gpu_counts
=
Mock
(
return_value
=
(
2
,
0
))
_initialize_gpu_counts
(
config
,
connector
,
require_prefill
=
True
,
require_decode
=
False
)
assert
config
.
prefill_engine_num_gpu
==
2
assert
config
.
decode_engine_num_gpu
==
0
connector
.
get_gpu_counts
.
assert_called_once_with
(
require_prefill
=
True
,
require_decode
=
False
)
def
test_virtual_mode_uses_cli_args
(
self
):
"""Test that GPU counts come from config in virtual mode"""
config
=
self
.
_make_config
(
prefill_engine_num_gpu
=
2
,
decode_engine_num_gpu
=
4
)
connector
=
Mock
(
spec
=
[])
_initialize_gpu_counts
(
config
,
connector
,
require_prefill
=
True
,
require_decode
=
True
)
assert
config
.
prefill_engine_num_gpu
==
2
assert
config
.
decode_engine_num_gpu
==
4
def
test_virtual_mode_missing_prefill_raises_error
(
self
):
"""Test that missing prefill GPU config raises error in virtual mode"""
config
=
self
.
_make_config
(
decode_engine_num_gpu
=
4
)
connector
=
Mock
(
spec
=
[])
with
pytest
.
raises
(
DeploymentValidationError
)
as
exc_info
:
_initialize_gpu_counts
(
config
,
connector
,
require_prefill
=
True
,
require_decode
=
True
)
assert
"prefill_engine_num_gpu"
in
str
(
exc_info
.
value
)
def
test_virtual_mode_missing_decode_raises_error
(
self
):
"""Test that missing decode GPU config raises error in virtual mode"""
config
=
self
.
_make_config
(
prefill_engine_num_gpu
=
2
)
connector
=
Mock
(
spec
=
[])
with
pytest
.
raises
(
DeploymentValidationError
)
as
exc_info
:
_initialize_gpu_counts
(
config
,
connector
,
require_prefill
=
True
,
require_decode
=
True
)
assert
"decode_engine_num_gpu"
in
str
(
exc_info
.
value
)
def
test_virtual_mode_missing_both_raises_error_with_both_messages
(
self
):
"""Test that missing both GPU configs shows both error messages"""
config
=
self
.
_make_config
()
connector
=
Mock
(
spec
=
[])
with
pytest
.
raises
(
DeploymentValidationError
)
as
exc_info
:
_initialize_gpu_counts
(
config
,
connector
,
require_prefill
=
True
,
require_decode
=
True
)
assert
len
(
exc_info
.
value
.
errors
)
==
2
def
test_virtual_mode_decode_only_no_prefill_error
(
self
):
"""Test decode-only mode doesn't require prefill GPU config"""
config
=
self
.
_make_config
(
decode_engine_num_gpu
=
4
)
connector
=
Mock
(
spec
=
[])
_initialize_gpu_counts
(
config
,
connector
,
require_prefill
=
False
,
require_decode
=
True
)
assert
config
.
decode_engine_num_gpu
==
4
def
test_kubernetes_mode_fallback_to_cli_on_dgd_error
(
self
):
"""Test that K8s mode falls back to config when DGD parsing fails"""
config
=
self
.
_make_config
(
prefill_engine_num_gpu
=
2
,
decode_engine_num_gpu
=
4
)
connector
=
Mock
()
connector
.
get_gpu_counts
=
Mock
(
side_effect
=
ValueError
(
"No GPU count specified"
)
)
_initialize_gpu_counts
(
config
,
connector
,
require_prefill
=
True
,
require_decode
=
True
)
assert
config
.
prefill_engine_num_gpu
==
2
assert
config
.
decode_engine_num_gpu
==
4
def
test_kubernetes_mode_fallback_missing_cli_flags_raises_error
(
self
):
"""Test that K8s fallback raises error when config also missing"""
config
=
self
.
_make_config
()
connector
=
Mock
()
connector
.
get_gpu_counts
=
Mock
(
side_effect
=
ValueError
(
"No GPU count specified"
)
)
with
pytest
.
raises
(
DeploymentValidationError
)
as
exc_info
:
_initialize_gpu_counts
(
config
,
connector
,
require_prefill
=
True
,
require_decode
=
True
)
assert
len
(
exc_info
.
value
.
errors
)
==
2
def
test_kubernetes_mode_fallback_partial_cli_flags
(
self
):
"""Test K8s fallback with only one config value provided"""
config
=
self
.
_make_config
(
prefill_engine_num_gpu
=
2
)
connector
=
Mock
()
connector
.
get_gpu_counts
=
Mock
(
side_effect
=
ValueError
(
"No GPU count specified"
)
)
with
pytest
.
raises
(
DeploymentValidationError
)
as
exc_info
:
_initialize_gpu_counts
(
config
,
connector
,
require_prefill
=
True
,
require_decode
=
True
)
assert
"decode_engine_num_gpu"
in
str
(
exc_info
.
value
)
components/src/dynamo/planner/tests/unit/test_state_machine.py
0 → 100644
View file @
f7e0b3fd
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Core-only planner tests: TickInput -> PlannerEffects, no mocks."""
import
pytest
try
:
import
msgspec
# noqa: F401
except
ImportError
:
pytest
.
skip
(
"msgspec required for FPM tests"
,
allow_module_level
=
True
)
from
dynamo.common.forward_pass_metrics
import
(
ForwardPassMetrics
,
QueuedRequestMetrics
,
ScheduledRequestMetrics
,
)
from
dynamo.planner.config.planner_config
import
PlannerConfig
from
dynamo.planner.core.state_machine
import
PlannerStateMachine
from
dynamo.planner.core.types
import
(
EngineCapabilities
,
FpmObservations
,
ScheduledTick
,
TickInput
,
TrafficObservation
,
WorkerCapabilities
,
WorkerCounts
,
)
def
_tick_for
(
tick_input
:
TickInput
)
->
ScheduledTick
:
"""Build a ScheduledTick matching the data present in a TickInput."""
has_fpm
=
tick_input
.
fpm_observations
is
not
None
has_traffic
=
tick_input
.
traffic
is
not
None
return
ScheduledTick
(
at_s
=
tick_input
.
now_s
,
run_load_scaling
=
has_fpm
,
run_throughput_scaling
=
has_traffic
,
need_worker_states
=
True
,
need_worker_fpm
=
has_fpm
,
need_traffic_metrics
=
has_traffic
,
traffic_metrics_duration_s
=
tick_input
.
traffic
.
duration_s
if
has_traffic
else
0.0
,
)
pytestmark
=
[
pytest
.
mark
.
gpu_0
,
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
planner
,
]
def
_make_fpm
(
*
,
sum_prefill_tokens
:
int
=
0
,
num_prefill_requests
:
int
=
0
,
sum_decode_kv_tokens
:
int
=
0
,
num_decode_requests
:
int
=
0
,
queued_prefill_tokens
:
int
=
0
,
queued_decode_kv_tokens
:
int
=
0
,
wall_time
:
float
=
0.01
,
worker_id
:
str
=
"w1"
,
dp_rank
:
int
=
0
,
)
->
ForwardPassMetrics
:
return
ForwardPassMetrics
(
worker_id
=
worker_id
,
dp_rank
=
dp_rank
,
wall_time
=
wall_time
,
scheduled_requests
=
ScheduledRequestMetrics
(
sum_prefill_tokens
=
sum_prefill_tokens
,
num_prefill_requests
=
num_prefill_requests
,
sum_decode_kv_tokens
=
sum_decode_kv_tokens
,
num_decode_requests
=
num_decode_requests
,
),
queued_requests
=
QueuedRequestMetrics
(
sum_prefill_tokens
=
queued_prefill_tokens
,
sum_decode_kv_tokens
=
queued_decode_kv_tokens
,
),
)
def
_make_config
(
**
overrides
)
->
PlannerConfig
:
defaults
=
dict
(
mode
=
"disagg"
,
ttft
=
500.0
,
itl
=
50.0
,
min_endpoint
=
1
,
max_gpu_budget
=-
1
,
throughput_adjustment_interval
=
60
,
load_adjustment_interval
=
5
,
load_scaling_down_sensitivity
=
80
,
max_num_fpm_samples
=
50
,
fpm_sample_bucket_size
=
16
,
load_min_observations
=
5
,
enable_load_scaling
=
True
,
enable_throughput_scaling
=
True
,
load_predictor
=
"constant"
,
no_operation
=
True
,
backend
=
"vllm"
,
metric_pulling_prometheus_endpoint
=
"http://localhost:9090"
,
metric_reporting_prometheus_port
=
0
,
)
defaults
.
update
(
overrides
)
return
PlannerConfig
.
model_construct
(
**
defaults
)
def
_default_caps
()
->
WorkerCapabilities
:
return
WorkerCapabilities
(
prefill
=
EngineCapabilities
(
num_gpu
=
1
,
max_num_batched_tokens
=
2048
),
decode
=
EngineCapabilities
(
num_gpu
=
1
,
max_num_batched_tokens
=
2048
),
)
def
_agg_caps
()
->
WorkerCapabilities
:
return
WorkerCapabilities
(
decode
=
EngineCapabilities
(
num_gpu
=
1
,
max_num_batched_tokens
=
2048
),
)
def
_agg_config
(
**
overrides
)
->
PlannerConfig
:
return
_make_config
(
mode
=
"agg"
,
**
overrides
)
def
_make_core
(
config
=
None
,
caps
=
None
,
**
config_overrides
)
->
PlannerStateMachine
:
cfg
=
config
or
_make_config
(
**
config_overrides
)
return
PlannerStateMachine
(
cfg
,
caps
or
_default_caps
())
def
_make_agg_core
(
config
=
None
,
caps
=
None
,
**
config_overrides
)
->
PlannerStateMachine
:
cfg
=
config
or
_agg_config
(
**
config_overrides
)
return
PlannerStateMachine
(
cfg
,
caps
or
_agg_caps
())
def
_train_prefill_regression
(
core
:
PlannerStateMachine
)
->
None
:
fpms
=
[
_make_fpm
(
sum_prefill_tokens
=
t
,
num_prefill_requests
=
1
,
wall_time
=
0.001
*
t
+
0.002
)
for
t
in
[
500
,
1000
,
1500
,
2000
,
2500
]
]
core
.
load_benchmark_fpms
(
prefill_fpms
=
fpms
)
def
_train_decode_regression
(
core
:
PlannerStateMachine
)
->
None
:
fpms
=
[
_make_fpm
(
sum_decode_kv_tokens
=
kv
,
num_decode_requests
=
n
,
wall_time
=
0.00001
*
kv
+
0.001
,
)
for
n
,
kv
in
[(
5
,
5000
),
(
10
,
10000
),
(
20
,
20000
),
(
30
,
30000
),
(
40
,
40000
)]
]
core
.
load_benchmark_fpms
(
decode_fpms
=
fpms
)
# ── Initial ticks ─────────────────────────────────────────────────────
class
TestInitialTick
:
def
test_both_enabled_returns_earliest
(
self
):
core
=
_make_core
()
tick
=
core
.
initial_tick
(
start_s
=
100.0
)
# Load interval (5s) < throughput interval (60s), so load tick first
assert
tick
.
at_s
==
105.0
assert
tick
.
need_worker_fpm
assert
not
tick
.
need_traffic_metrics
def
test_load_only
(
self
):
core
=
_make_core
(
enable_throughput_scaling
=
False
)
tick
=
core
.
initial_tick
(
start_s
=
0.0
)
assert
tick
.
at_s
==
5.0
assert
tick
.
need_worker_fpm
assert
not
tick
.
need_traffic_metrics
def
test_throughput_only
(
self
):
core
=
_make_core
(
enable_load_scaling
=
False
)
tick
=
core
.
initial_tick
(
start_s
=
0.0
)
# Load tick is still scheduled (feeds regression) at 5s < 60s
assert
tick
.
at_s
==
5.0
assert
tick
.
need_worker_fpm
# ── Load benchmark bootstrapping ──────────────────────────────────────
class
TestBenchmarkBootstrap
:
def
test_prefill_regression_bootstrapped
(
self
):
core
=
_make_core
(
mode
=
"prefill"
)
_train_prefill_regression
(
core
)
assert
core
.
prefill_regression
.
has_sufficient_data
()
def
test_decode_regression_bootstrapped
(
self
):
core
=
_make_core
(
mode
=
"decode"
)
_train_decode_regression
(
core
)
assert
core
.
decode_regression
.
has_sufficient_data
()
# ── FPM observation via on_tick ───────────────────────────────────────
class
TestFpmObservation
:
def
test_fpm_feeds_regression
(
self
):
core
=
_make_core
(
mode
=
"prefill"
)
assert
core
.
prefill_regression
.
num_observations
==
0
fpm
=
_make_fpm
(
sum_prefill_tokens
=
500
,
num_prefill_requests
=
1
,
wall_time
=
0.5
)
tick
=
TickInput
(
now_s
=
5.0
,
fpm_observations
=
FpmObservations
(
prefill
=
{(
"w1"
,
0
):
fpm
}),
worker_counts
=
WorkerCounts
(
ready_num_prefill
=
1
),
)
core
.
on_tick
(
_tick_for
(
tick
),
tick
)
assert
core
.
prefill_regression
.
num_observations
==
1
def
test_next_tick_scheduled_after_fpm
(
self
):
core
=
_make_core
(
mode
=
"prefill"
)
tick
=
TickInput
(
now_s
=
10.0
,
fpm_observations
=
FpmObservations
(
prefill
=
{
(
"w1"
,
0
):
_make_fpm
(
sum_prefill_tokens
=
500
,
num_prefill_requests
=
1
,
wall_time
=
0.5
,
)
}
),
worker_counts
=
WorkerCounts
(
ready_num_prefill
=
1
),
)
effects
=
core
.
on_tick
(
_tick_for
(
tick
),
tick
)
assert
effects
.
next_tick
is
not
None
assert
effects
.
next_tick
.
at_s
==
15.0
assert
effects
.
next_tick
.
need_worker_fpm
# ── Load-based scaling (prefill) ──────────────────────────────────────
class
TestPrefillLoadScaling
:
def
test_scale_up_when_all_above_sla
(
self
):
core
=
_make_core
(
mode
=
"prefill"
,
ttft
=
5.0
)
_train_prefill_regression
(
core
)
fpm
=
_make_fpm
(
worker_id
=
"w1"
,
queued_prefill_tokens
=
10000
,
sum_prefill_tokens
=
500
,
num_prefill_requests
=
1
,
wall_time
=
0.5
,
)
tick
=
TickInput
(
now_s
=
5.0
,
fpm_observations
=
FpmObservations
(
prefill
=
{(
"w1"
,
0
):
fpm
}),
worker_counts
=
WorkerCounts
(
ready_num_prefill
=
1
),
)
effects
=
core
.
on_tick
(
_tick_for
(
tick
),
tick
)
assert
effects
.
scale_to
is
not
None
assert
effects
.
scale_to
.
num_prefill
is
not
None
assert
effects
.
scale_to
.
num_prefill
>
1
def
test_no_scaling_when_insufficient_data
(
self
):
core
=
_make_core
(
mode
=
"prefill"
)
fpm
=
_make_fpm
(
queued_prefill_tokens
=
5000
,
sum_prefill_tokens
=
100
,
wall_time
=
0.1
)
tick
=
TickInput
(
now_s
=
5.0
,
fpm_observations
=
FpmObservations
(
prefill
=
{(
"w1"
,
0
):
fpm
}),
worker_counts
=
WorkerCounts
(
ready_num_prefill
=
1
),
)
effects
=
core
.
on_tick
(
_tick_for
(
tick
),
tick
)
assert
effects
.
scale_to
is
None
def
test_no_scaling_when_load_disabled
(
self
):
core
=
_make_core
(
mode
=
"prefill"
,
enable_load_scaling
=
False
)
_train_prefill_regression
(
core
)
fpm
=
_make_fpm
(
queued_prefill_tokens
=
10000
,
sum_prefill_tokens
=
500
,
num_prefill_requests
=
1
,
wall_time
=
0.5
,
)
tick
=
TickInput
(
now_s
=
5.0
,
fpm_observations
=
FpmObservations
(
prefill
=
{(
"w1"
,
0
):
fpm
}),
worker_counts
=
WorkerCounts
(
ready_num_prefill
=
1
),
)
effects
=
core
.
on_tick
(
_tick_for
(
tick
),
tick
)
assert
effects
.
scale_to
is
None
# ── Load-based scaling (decode) ───────────────────────────────────────
class
TestDecodeLoadScaling
:
def
test_scale_up_when_all_above_sla
(
self
):
core
=
_make_core
(
mode
=
"decode"
,
itl
=
5.0
)
_train_decode_regression
(
core
)
fpm
=
_make_fpm
(
worker_id
=
"w1"
,
sum_decode_kv_tokens
=
30000
,
queued_decode_kv_tokens
=
20000
,
num_decode_requests
=
30
,
wall_time
=
0.3
,
)
tick
=
TickInput
(
now_s
=
5.0
,
fpm_observations
=
FpmObservations
(
decode
=
{(
"w1"
,
0
):
fpm
}),
worker_counts
=
WorkerCounts
(
ready_num_decode
=
1
),
)
effects
=
core
.
on_tick
(
_tick_for
(
tick
),
tick
)
assert
effects
.
scale_to
is
not
None
assert
effects
.
scale_to
.
num_decode
is
not
None
assert
effects
.
scale_to
.
num_decode
>
1
# ── Disagg load scaling ───────────────────────────────────────────────
class
TestDisaggLoadScaling
:
def
test_disagg_scale_up
(
self
):
core
=
_make_core
(
ttft
=
5.0
,
itl
=
5.0
)
_train_prefill_regression
(
core
)
_train_decode_regression
(
core
)
p_fpm
=
_make_fpm
(
worker_id
=
"w1"
,
queued_prefill_tokens
=
10000
,
sum_prefill_tokens
=
500
,
num_prefill_requests
=
1
,
wall_time
=
0.5
,
)
d_fpm
=
_make_fpm
(
worker_id
=
"w1"
,
sum_decode_kv_tokens
=
5000
,
queued_decode_kv_tokens
=
3000
,
num_decode_requests
=
20
,
wall_time
=
0.6
,
)
tick
=
TickInput
(
now_s
=
5.0
,
fpm_observations
=
FpmObservations
(
prefill
=
{(
"w1"
,
0
):
p_fpm
},
decode
=
{(
"w1"
,
0
):
d_fpm
},
),
worker_counts
=
WorkerCounts
(
ready_num_prefill
=
1
,
ready_num_decode
=
1
),
)
effects
=
core
.
on_tick
(
_tick_for
(
tick
),
tick
)
assert
effects
.
scale_to
is
not
None
# ── Throughput scaling ────────────────────────────────────────────────
class
TestThroughputScaling
:
def
test_throughput_only_returns_decision
(
self
):
core
=
_make_core
(
mode
=
"prefill"
,
enable_load_scaling
=
False
,
enable_throughput_scaling
=
True
)
_train_prefill_regression
(
core
)
# Warm predictor with traffic
core
.
_observe_traffic
(
TrafficObservation
(
duration_s
=
60
,
num_req
=
100
,
isl
=
1000
,
osl
=
150
)
)
tick
=
TickInput
(
now_s
=
60.0
,
traffic
=
TrafficObservation
(
duration_s
=
60
,
num_req
=
100
,
isl
=
1000
,
osl
=
150
),
worker_counts
=
WorkerCounts
(
ready_num_prefill
=
1
),
)
effects
=
core
.
on_tick
(
_tick_for
(
tick
),
tick
)
assert
effects
.
scale_to
is
not
None
assert
effects
.
scale_to
.
num_prefill
is
not
None
assert
effects
.
scale_to
.
num_prefill
>=
1
def
test_throughput_sets_lower_bound_when_load_enabled
(
self
):
core
=
_make_core
(
enable_load_scaling
=
True
,
enable_throughput_scaling
=
True
)
_train_prefill_regression
(
core
)
_train_decode_regression
(
core
)
core
.
_observe_traffic
(
TrafficObservation
(
duration_s
=
60
,
num_req
=
100
,
isl
=
1000
,
osl
=
150
)
)
tick
=
TickInput
(
now_s
=
60.0
,
traffic
=
TrafficObservation
(
duration_s
=
60
,
num_req
=
100
,
isl
=
1000
,
osl
=
150
),
worker_counts
=
WorkerCounts
(
ready_num_prefill
=
1
,
ready_num_decode
=
1
),
)
effects
=
core
.
on_tick
(
_tick_for
(
tick
),
tick
)
# When both modes enabled, throughput tick returns None (just sets lower bound)
assert
effects
.
scale_to
is
None
assert
core
.
_throughput_lower_bound_p
>=
1
assert
core
.
_throughput_lower_bound_d
>=
1
def
test_next_tick_scheduled_after_traffic
(
self
):
core
=
_make_core
(
mode
=
"prefill"
)
tick
=
TickInput
(
now_s
=
60.0
,
traffic
=
TrafficObservation
(
duration_s
=
60
,
num_req
=
0
,
isl
=
0
,
osl
=
0
),
)
effects
=
core
.
on_tick
(
_tick_for
(
tick
),
tick
)
assert
effects
.
next_tick
is
not
None
assert
effects
.
next_tick
.
need_traffic_metrics
assert
effects
.
next_tick
.
at_s
==
120.0
# ── FPM reconciliation ───────────────────────────────────────────────
class
TestFpmReconciliation
:
def
test_mismatch_skips_scaling
(
self
):
core
=
_make_core
(
mode
=
"prefill"
,
ttft
=
5.0
)
_train_prefill_regression
(
core
)
tick
=
TickInput
(
now_s
=
5.0
,
fpm_observations
=
FpmObservations
(
prefill
=
{
(
"w1"
,
0
):
_make_fpm
(
queued_prefill_tokens
=
10000
,
sum_prefill_tokens
=
500
,
num_prefill_requests
=
1
,
wall_time
=
0.5
,
),
(
"w2"
,
0
):
_make_fpm
(
worker_id
=
"w2"
,
queued_prefill_tokens
=
8000
,
sum_prefill_tokens
=
500
,
num_prefill_requests
=
1
,
wall_time
=
0.5
,
),
}
),
worker_counts
=
WorkerCounts
(
ready_num_prefill
=
3
),
)
effects
=
core
.
on_tick
(
_tick_for
(
tick
),
tick
)
# FPM reports 2 workers but ready count is 3 -> skip scaling
assert
effects
.
scale_to
is
None
# ── Agg planner core ──────────────────────────────────────────────────
class
TestAggPlannerStateMachine
:
def
_train_agg
(
self
,
core
:
PlannerStateMachine
)
->
None
:
fpms
=
[
_make_fpm
(
sum_prefill_tokens
=
p
,
num_prefill_requests
=
1
,
sum_decode_kv_tokens
=
d
,
num_decode_requests
=
10
,
wall_time
=
0.001
*
p
+
0.0001
*
d
+
0.001
,
)
for
p
,
d
in
[
(
100
,
1000
),
(
200
,
2000
),
(
300
,
3000
),
(
400
,
4000
),
(
500
,
5000
),
]
]
core
.
load_benchmark_fpms
(
agg_fpms
=
fpms
)
def
test_initial_tick
(
self
):
core
=
_make_agg_core
()
tick
=
core
.
initial_tick
(
start_s
=
0.0
)
assert
tick
.
at_s
==
5.0
assert
tick
.
need_worker_fpm
def
test_fpm_feeds_regression
(
self
):
core
=
_make_agg_core
()
assert
core
.
regression
.
num_observations
==
0
fpm
=
_make_fpm
(
sum_prefill_tokens
=
200
,
num_prefill_requests
=
1
,
sum_decode_kv_tokens
=
2000
,
num_decode_requests
=
10
,
wall_time
=
0.3
,
)
tick
=
TickInput
(
now_s
=
5.0
,
fpm_observations
=
FpmObservations
(
decode
=
{(
"w1"
,
0
):
fpm
}),
worker_counts
=
WorkerCounts
(
ready_num_decode
=
1
),
)
core
.
on_tick
(
_tick_for
(
tick
),
tick
)
assert
core
.
regression
.
num_observations
==
1
def
test_throughput_only_returns_decision
(
self
):
core
=
_make_agg_core
(
enable_load_scaling
=
False
,
enable_throughput_scaling
=
True
)
self
.
_train_agg
(
core
)
core
.
_observe_traffic
(
TrafficObservation
(
duration_s
=
60
,
num_req
=
100
,
isl
=
1000
,
osl
=
150
)
)
tick
=
TickInput
(
now_s
=
60.0
,
traffic
=
TrafficObservation
(
duration_s
=
60
,
num_req
=
100
,
isl
=
1000
,
osl
=
150
),
worker_counts
=
WorkerCounts
(
ready_num_decode
=
1
),
)
effects
=
core
.
on_tick
(
_tick_for
(
tick
),
tick
)
assert
effects
.
scale_to
is
not
None
assert
effects
.
scale_to
.
num_decode
is
not
None
assert
effects
.
scale_to
.
num_decode
>=
1
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment