"components/vscode:/vscode.git/clone" did not exist on "3f53a78e036721d367f8cbf9b3087de8b8666059"
Unverified Commit 4ea21079 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(replay): add agg/disagg offline replay optimization [DYN-2566] (#7774)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent b55277c9
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
from collections import Counter
from pathlib import Path
from types import SimpleNamespace
from typing import Any
from unittest.mock import patch
import pandas as pd
import pytest
from dynamo.llm import KvRouterConfig, MockEngineArgs
from dynamo.profiler.utils import replay_optimize
from dynamo.profiler.utils.replay_optimize import (
DenseAggReplayState,
SyntheticReplayWorkload,
TraceReplayWorkload,
compare_agg_and_disagg_with_replay,
optimize_dense_agg_with_replay,
optimize_dense_disagg_with_replay,
)
pytestmark = [
pytest.mark.unit,
pytest.mark.gpu_0,
pytest.mark.pre_merge,
pytest.mark.parallel,
]
_AIC_MODEL = "Qwen/Qwen3-32B"
_AIC_SYSTEM = "h200_sxm"
def _base_prefill_args() -> MockEngineArgs:
return MockEngineArgs(
engine_type="vllm",
num_gpu_blocks=128,
block_size=64,
max_num_seqs=16,
max_num_batched_tokens=4096,
enable_prefix_caching=True,
enable_chunked_prefill=False,
worker_type="prefill",
)
def _base_decode_args() -> MockEngineArgs:
return MockEngineArgs(
engine_type="vllm",
num_gpu_blocks=192,
block_size=64,
max_num_seqs=32,
max_num_batched_tokens=4096,
enable_prefix_caching=True,
enable_chunked_prefill=False,
worker_type="decode",
)
def _base_agg_args() -> MockEngineArgs:
return MockEngineArgs(
engine_type="vllm",
num_gpu_blocks=160,
block_size=64,
max_num_seqs=24,
max_num_batched_tokens=4096,
enable_prefix_caching=True,
enable_chunked_prefill=False,
worker_type="aggregated",
)
def _write_trace(tmp_path: Path) -> Path:
trace_path = tmp_path / "optimizer_trace.jsonl"
records = [
{
"timestamp": 1000.0,
"input_length": 32,
"output_length": 8,
"hash_ids": [1, 2, 3, 4],
},
{
"timestamp": 1001.0,
"input_length": 48,
"output_length": 6,
"hash_ids": [1, 2, 3, 5],
},
]
trace_path.write_text(
"\n".join(json.dumps(record) for record in records) + "\n",
encoding="utf-8",
)
return trace_path
def test_enumerate_dense_tp_candidates_filters_to_tp_only(monkeypatch) -> None:
common = SimpleNamespace(BackendName=SimpleNamespace(vllm="vllm"))
task = SimpleNamespace(
build_disagg_parallel_lists=lambda **_: (
{
"num_gpu_per_worker": [1, 2, 4],
"tp_list": [1, 2, 4],
"pp_list": [1],
"dp_list": [1],
"moe_tp_list": [1],
"moe_ep_list": [1],
},
{
"num_gpu_per_worker": [1, 2, 4],
"tp_list": [1, 2, 4],
"pp_list": [1],
"dp_list": [1],
"moe_tp_list": [1],
"moe_ep_list": [1],
},
)
)
utils = SimpleNamespace(
enumerate_parallel_config=lambda **_: [
[1, 1, 1, 1, 1],
[2, 1, 1, 1, 1],
[2, 2, 1, 1, 1],
[4, 1, 2, 1, 1],
[4, 1, 1, 1, 1],
]
)
monkeypatch.setattr(
replay_optimize.aic,
"_load_aiconfigurator_modules",
lambda: (common, task, utils),
)
prefill_tps, decode_tps = replay_optimize._enumerate_dense_tp_candidates(
"vllm", "h200_sxm"
)
assert prefill_tps == [1, 2, 4]
assert decode_tps == [1, 2, 4]
def test_iter_tp_states_with_equal_workers_respects_gpu_budget() -> None:
states = replay_optimize._iter_tp_states_with_equal_workers(
prefill_tps=[1, 2, 4, 8],
decode_tps=[1, 2, 4, 8],
router_mode="round_robin",
overlap_score_weight=1.0,
max_total_gpus=8,
)
states_by_tp = {
(state.prefill_tp, state.decode_tp): (
state.prefill_workers,
state.decode_workers,
)
for state in states
}
assert (8, 8) not in states_by_tp
assert states_by_tp[(1, 1)] == (4, 4)
assert states_by_tp[(2, 1)] == (2, 2)
assert states_by_tp[(4, 4)] == (1, 1)
assert all(state.total_gpus_used <= 8 for state in states)
def test_iter_agg_tp_states_with_max_workers_respects_gpu_budget() -> None:
states = replay_optimize._iter_agg_tp_states_with_max_workers(
tps=[1, 2, 4, 8],
router_mode="round_robin",
overlap_score_weight=0.0,
max_total_gpus=8,
)
states_by_tp = {state.tp: state.workers for state in states}
assert states_by_tp == {1: 8, 2: 4, 4: 2, 8: 1}
assert all(state.total_gpus_used <= 8 for state in states)
assert set(state.router_mode for state in states) == {"round_robin"}
def test_mock_engine_args_dump_json_round_trips_explicit_none_fields() -> None:
base_args = MockEngineArgs(
engine_type="vllm",
num_gpu_blocks=128,
block_size=64,
max_num_seqs=None,
max_num_batched_tokens=None,
enable_prefix_caching=True,
worker_type="decode",
)
restored = MockEngineArgs.from_json(base_args.dump_json())
assert restored.worker_type == "decode"
assert restored.max_num_seqs is None
assert restored.max_num_batched_tokens is None
def test_iter_agg_worker_states_collapses_round_robin_overlap() -> None:
states = replay_optimize._iter_agg_worker_states(
tp=2,
router_mode="round_robin",
overlap_score_weight=0.0,
max_total_gpus=8,
)
assert [(state.tp, state.workers) for state in states] == [
(2, 1),
(2, 2),
(2, 3),
(2, 4),
]
assert set(state.router_mode for state in states) == {"round_robin"}
assert set(state.overlap_score_weight for state in states) == {0.0}
def test_optimizer_finds_coordinate_optimum_and_reuses_cache(monkeypatch) -> None:
call_counter: Counter = Counter()
target_state = replay_optimize.DenseReplayState(2, 4, 2, 1, 2.0)
def fake_run(**kwargs):
state = kwargs["state"]
call_counter[state] += 1
desired_score = (
1000.0
- 100.0 * abs(state.prefill_tp - target_state.prefill_tp)
- 100.0 * abs(state.decode_tp - target_state.decode_tp)
- 50.0 * abs(state.prefill_workers - target_state.prefill_workers)
- 50.0 * abs(state.decode_workers - target_state.decode_workers)
- 10.0 * abs(state.overlap_score_weight - target_state.overlap_score_weight)
)
return {
"output_throughput_tok_s": desired_score,
"mean_ttft_ms": 100.0,
"p95_ttft_ms": 120.0,
"mean_tpot_ms": 10.0,
"p95_tpot_ms": 12.0,
"mean_e2e_latency_ms": 200.0,
"p95_e2e_latency_ms": 220.0,
}
monkeypatch.setattr(
replay_optimize.aic,
"_enumerate_dense_tp_candidates",
lambda backend, system: ([1, 2, 4], [1, 2, 4]),
)
monkeypatch.setattr(replay_optimize.evaluate, "_run_replay_for_state", fake_run)
result = optimize_dense_disagg_with_replay(
model=_AIC_MODEL,
backend="vllm",
system=_AIC_SYSTEM,
workload=SyntheticReplayWorkload(
isl=64,
osl=32,
request_count=8,
replay_concurrency=4,
),
base_prefill_engine_args=_base_prefill_args(),
base_decode_engine_args=_base_decode_args(),
max_total_gpus=8,
constraints={"mean_e2e_latency_ms": 500.0},
overlap_score_weights=[0.0, 1.0, 2.0],
max_parallel_evals=1,
)
assert result.best_feasible is not None
assert result.best_feasible["prefill_tp"] == 2
assert result.best_feasible["decode_tp"] == 4
assert result.best_feasible["prefill_workers"] == 2
assert result.best_feasible["decode_workers"] == 1
assert result.best_feasible["overlap_score_weight"] == 2.0
assert sum(call_counter.values()) == len(call_counter)
assert len(call_counter) == len(result.evaluated_df)
def test_agg_optimizer_finds_coordinate_optimum_and_reuses_cache(monkeypatch) -> None:
call_counter: Counter = Counter()
target_state = DenseAggReplayState(2, 3, "kv_router", 2.0)
def fake_run(**kwargs):
state = kwargs["state"]
call_counter[state] += 1
desired_score = (
1000.0
- 100.0 * abs(state.tp - target_state.tp)
- 50.0 * abs(state.workers - target_state.workers)
- 100.0 * (state.router_mode != target_state.router_mode)
- 10.0 * abs(state.overlap_score_weight - target_state.overlap_score_weight)
)
return {
"output_throughput_tok_s": desired_score,
"mean_ttft_ms": 100.0,
"p95_ttft_ms": 120.0,
"mean_tpot_ms": 10.0,
"p95_tpot_ms": 12.0,
"mean_e2e_latency_ms": 200.0,
"p95_e2e_latency_ms": 220.0,
}
monkeypatch.setattr(
replay_optimize.aic,
"_enumerate_dense_tp_candidates",
lambda backend, system: ([1, 2, 4], [1, 2, 4]),
)
monkeypatch.setattr(replay_optimize.evaluate, "_run_agg_replay_for_state", fake_run)
result = optimize_dense_agg_with_replay(
model=_AIC_MODEL,
backend="vllm",
system=_AIC_SYSTEM,
workload=SyntheticReplayWorkload(
isl=64,
osl=32,
request_count=8,
replay_concurrency=4,
),
base_engine_args=_base_agg_args(),
max_total_gpus=8,
constraints={"mean_e2e_latency_ms": 500.0},
router_mode="both",
overlap_score_weights=[0.0, 1.0, 2.0],
max_parallel_evals=1,
)
assert result.best_feasible is not None
assert result.best_feasible["tp"] == 2
assert result.best_feasible["workers"] == 3
assert result.best_feasible["router_mode"] == "kv_router"
assert result.best_feasible["overlap_score_weight"] == 2.0
assert sum(call_counter.values()) == len(call_counter)
assert len(call_counter) == len(result.evaluated_df)
def test_optimizer_uses_violation_penalty_when_no_state_is_feasible(
monkeypatch,
) -> None:
target_state = replay_optimize.DenseReplayState(1, 2, 2, 2, 1.0)
def fake_run(**kwargs):
state = kwargs["state"]
latency = (
60.0
+ 10.0 * abs(state.prefill_tp - target_state.prefill_tp)
+ 10.0 * abs(state.decode_tp - target_state.decode_tp)
+ 5.0 * abs(state.prefill_workers - target_state.prefill_workers)
+ 5.0 * abs(state.decode_workers - target_state.decode_workers)
+ abs(state.overlap_score_weight - target_state.overlap_score_weight)
)
return {
"output_throughput_tok_s": 1000.0,
"mean_ttft_ms": latency,
"p95_ttft_ms": latency,
"mean_tpot_ms": 10.0,
"p95_tpot_ms": 10.0,
"mean_e2e_latency_ms": latency,
"p95_e2e_latency_ms": latency,
}
monkeypatch.setattr(
replay_optimize.aic,
"_enumerate_dense_tp_candidates",
lambda backend, system: ([1, 2], [1, 2]),
)
monkeypatch.setattr(replay_optimize.evaluate, "_run_replay_for_state", fake_run)
result = optimize_dense_disagg_with_replay(
model=_AIC_MODEL,
backend="vllm",
system=_AIC_SYSTEM,
workload=SyntheticReplayWorkload(
isl=64,
osl=32,
request_count=8,
replay_concurrency=4,
),
base_prefill_engine_args=_base_prefill_args(),
base_decode_engine_args=_base_decode_args(),
max_total_gpus=6,
constraints={"mean_e2e_latency_ms": 50.0},
overlap_score_weights=[0.0, 1.0],
max_parallel_evals=1,
)
assert result.best_feasible is None
assert result.best_infeasible is not None
assert result.best_infeasible["prefill_tp"] == 1
assert result.best_infeasible["decode_tp"] == 2
assert result.best_infeasible["prefill_workers"] == 2
assert result.best_infeasible["decode_workers"] == 2
assert result.best_infeasible["overlap_score_weight"] == 1.0
def test_agg_optimizer_uses_violation_penalty_when_no_state_is_feasible(
monkeypatch,
) -> None:
target_state = DenseAggReplayState(2, 3, "kv_router", 1.0)
def fake_run(**kwargs):
state = kwargs["state"]
latency = (
60.0
+ 10.0 * abs(state.tp - target_state.tp)
+ 5.0 * abs(state.workers - target_state.workers)
+ 3.0 * (state.router_mode != target_state.router_mode)
+ abs(state.overlap_score_weight - target_state.overlap_score_weight)
)
return {
"output_throughput_tok_s": 1000.0,
"mean_ttft_ms": latency,
"p95_ttft_ms": latency,
"mean_tpot_ms": 10.0,
"p95_tpot_ms": 10.0,
"mean_e2e_latency_ms": latency,
"p95_e2e_latency_ms": latency,
}
monkeypatch.setattr(
replay_optimize.aic,
"_enumerate_dense_tp_candidates",
lambda backend, system: ([1, 2], [1, 2]),
)
monkeypatch.setattr(replay_optimize.evaluate, "_run_agg_replay_for_state", fake_run)
result = optimize_dense_agg_with_replay(
model=_AIC_MODEL,
backend="vllm",
system=_AIC_SYSTEM,
workload=SyntheticReplayWorkload(
isl=64,
osl=32,
request_count=8,
replay_concurrency=4,
),
base_engine_args=_base_agg_args(),
max_total_gpus=8,
constraints={"mean_e2e_latency_ms": 50.0},
router_mode="both",
overlap_score_weights=[0.0, 1.0],
max_parallel_evals=1,
)
assert result.best_feasible is None
assert result.best_infeasible is not None
assert result.best_infeasible["tp"] == 2
assert result.best_infeasible["workers"] == 3
assert result.best_infeasible["router_mode"] == "kv_router"
assert result.best_infeasible["overlap_score_weight"] == 1.0
def test_optimizer_supports_round_robin_router_mode(monkeypatch) -> None:
seen_router_modes: list[str] = []
seen_weights: list[float] = []
def fake_run(**kwargs):
seen_router_modes.append(kwargs["state"].router_mode)
seen_weights.append(kwargs["state"].overlap_score_weight)
return {
"output_throughput_tok_s": 1000.0,
"mean_ttft_ms": 100.0,
"p95_ttft_ms": 120.0,
"mean_tpot_ms": 10.0,
"p95_tpot_ms": 12.0,
"mean_e2e_latency_ms": 200.0,
"p95_e2e_latency_ms": 220.0,
}
monkeypatch.setattr(
replay_optimize.aic,
"_enumerate_dense_tp_candidates",
lambda backend, system: ([1, 2], [1, 2]),
)
monkeypatch.setattr(replay_optimize.evaluate, "_run_replay_for_state", fake_run)
result = optimize_dense_disagg_with_replay(
model=_AIC_MODEL,
backend="vllm",
system=_AIC_SYSTEM,
workload=SyntheticReplayWorkload(
isl=64,
osl=32,
request_count=8,
replay_concurrency=4,
),
base_prefill_engine_args=_base_prefill_args(),
base_decode_engine_args=_base_decode_args(),
max_total_gpus=4,
constraints={"mean_e2e_latency_ms": 500.0},
router_mode="round_robin",
overlap_score_weights=[0.0, 1.0, 2.0],
max_parallel_evals=1,
)
assert result.best_feasible is not None
assert set(seen_router_modes) == {"round_robin"}
assert set(seen_weights) == {0.0}
def test_disagg_optimizer_supports_router_mode_search(monkeypatch) -> None:
seen_router_modes: list[str] = []
seen_weights: list[float] = []
def fake_run(**kwargs):
state = kwargs["state"]
seen_router_modes.append(state.router_mode)
seen_weights.append(state.overlap_score_weight)
return {
"output_throughput_tok_s": 1000.0 * state.total_gpus_used,
"mean_ttft_ms": 100.0,
"p95_ttft_ms": 120.0,
"mean_tpot_ms": 10.0,
"p95_tpot_ms": 12.0,
"mean_e2e_latency_ms": 200.0,
"p95_e2e_latency_ms": 220.0,
}
monkeypatch.setattr(
replay_optimize.aic,
"_enumerate_dense_tp_candidates",
lambda backend, system: ([1, 2], [1, 2]),
)
monkeypatch.setattr(replay_optimize.evaluate, "_run_replay_for_state", fake_run)
result = optimize_dense_disagg_with_replay(
model=_AIC_MODEL,
backend="vllm",
system=_AIC_SYSTEM,
workload=SyntheticReplayWorkload(
isl=64,
osl=32,
request_count=8,
replay_concurrency=4,
),
base_prefill_engine_args=_base_prefill_args(),
base_decode_engine_args=_base_decode_args(),
max_total_gpus=4,
constraints={"mean_e2e_latency_ms": 500.0},
router_mode="both",
overlap_score_weights=[0.0, 1.0, 2.0],
max_parallel_evals=1,
)
assert result.best_feasible is not None
assert "round_robin" in seen_router_modes
assert "kv_router" in seen_router_modes
assert 0.0 in seen_weights
assert 1.0 in seen_weights
assert 2.0 in seen_weights
def test_agg_optimizer_supports_router_mode_search(monkeypatch) -> None:
seen_router_modes: list[str] = []
seen_weights: list[float] = []
def fake_run(**kwargs):
state = kwargs["state"]
seen_router_modes.append(state.router_mode)
seen_weights.append(state.overlap_score_weight)
return {
"output_throughput_tok_s": 1000.0 * state.workers,
"mean_ttft_ms": 100.0,
"p95_ttft_ms": 120.0,
"mean_tpot_ms": 10.0,
"p95_tpot_ms": 12.0,
"mean_e2e_latency_ms": 200.0,
"p95_e2e_latency_ms": 220.0,
}
monkeypatch.setattr(
replay_optimize.aic,
"_enumerate_dense_tp_candidates",
lambda backend, system: ([1, 2], [1, 2]),
)
monkeypatch.setattr(replay_optimize.evaluate, "_run_agg_replay_for_state", fake_run)
result = optimize_dense_agg_with_replay(
model=_AIC_MODEL,
backend="vllm",
system=_AIC_SYSTEM,
workload=SyntheticReplayWorkload(
isl=64,
osl=32,
request_count=8,
replay_concurrency=4,
),
base_engine_args=_base_agg_args(),
max_total_gpus=4,
constraints={"mean_e2e_latency_ms": 500.0},
router_mode="both",
overlap_score_weights=[0.0, 1.0, 2.0],
max_parallel_evals=1,
)
assert result.best_feasible is not None
assert "round_robin" in seen_router_modes
assert "kv_router" in seen_router_modes
assert 0.0 in seen_weights
assert 1.0 in seen_weights
assert 2.0 in seen_weights
def test_compare_agg_and_disagg_with_replay_picks_expected_mode(monkeypatch) -> None:
agg_result = replay_optimize.DenseReplayOptimizationResult(
best_feasible={
"tp": 2,
"workers": 3,
"router_mode": "kv_router",
"overlap_score_weight": 1.0,
"total_gpus_used": 6,
"output_throughput_tok_s": 3000.0,
"score": 500.0,
"feasible": True,
"violation_penalty": 0.0,
"mean_e2e_latency_ms": 100.0,
},
best_infeasible=None,
evaluated_df=pd.DataFrame(),
feasible_df=pd.DataFrame(),
)
disagg_result = replay_optimize.DenseReplayOptimizationResult(
best_feasible={
"prefill_tp": 1,
"decode_tp": 1,
"prefill_workers": 2,
"decode_workers": 2,
"overlap_score_weight": 0.0,
"total_gpus_used": 4,
"output_throughput_tok_s": 1200.0,
"score": 300.0,
"feasible": True,
"violation_penalty": 0.0,
"mean_e2e_latency_ms": 150.0,
},
best_infeasible=None,
evaluated_df=pd.DataFrame(),
feasible_df=pd.DataFrame(),
)
monkeypatch.setattr(
replay_optimize.bench, "optimize_dense_agg_with_replay", lambda **_: agg_result
)
monkeypatch.setattr(
replay_optimize.bench,
"optimize_dense_disagg_with_replay",
lambda **_: disagg_result,
)
comparison = compare_agg_and_disagg_with_replay(
model=_AIC_MODEL,
backend="vllm",
system=_AIC_SYSTEM,
workload=SyntheticReplayWorkload(
isl=64,
osl=32,
request_count=8,
replay_concurrency=4,
),
base_engine_args=_base_agg_args(),
base_prefill_engine_args=_base_prefill_args(),
base_decode_engine_args=_base_decode_args(),
max_total_gpus=8,
constraints={"mean_e2e_latency_ms": 500.0},
)
assert comparison["chosen_mode"] == "agg"
assert comparison["chosen_best"] == agg_result.best_feasible
def test_evaluate_state_prefers_normalized_metrics_over_report_payload() -> None:
state = replay_optimize.DenseReplayState(
prefill_tp=1,
decode_tp=1,
prefill_workers=1,
decode_workers=1,
overlap_score_weight=0.0,
router_mode="round_robin",
)
cache: dict[replay_optimize.DenseReplayState, dict[str, Any]] = {}
with patch(
"dynamo.profiler.utils.replay_optimize.evaluate._run_replay_for_state",
return_value={
"output_throughput_tok_s": "11.0",
"score": -1.0,
"feasible": False,
"violation_penalty": 7.0,
"mean_e2e_latency_ms": 100.0,
},
):
record = replay_optimize.evaluate._evaluate_state(
state=state,
workload=SyntheticReplayWorkload(
isl=128,
osl=32,
request_count=16,
replay_concurrency=4,
),
base_prefill_engine_args=_base_prefill_args(),
base_decode_engine_args=_base_decode_args(),
base_router_config=None,
model="meta-llama/Llama-3.1-8B-Instruct",
backend="vllm",
system="h100_sxm",
constraints={"mean_e2e_latency_ms": 1000.0},
cache=cache,
)
assert record["output_throughput_tok_s"] == 11.0
assert record["score"] == 11.0
assert record["feasible"] is True
assert record["violation_penalty"] == 0.0
def test_evaluate_agg_state_prefers_normalized_metrics_over_report_payload() -> None:
state = DenseAggReplayState(
tp=2,
workers=2,
router_mode="round_robin",
overlap_score_weight=0.0,
)
cache: dict[DenseAggReplayState, dict[str, Any]] = {}
with patch(
"dynamo.profiler.utils.replay_optimize.evaluate._run_agg_replay_for_state",
return_value={
"output_throughput_tok_s": "24.0",
"score": -1.0,
"feasible": False,
"violation_penalty": 9.0,
"mean_e2e_latency_ms": 200.0,
},
):
record = replay_optimize.evaluate._evaluate_agg_state(
state=state,
workload=SyntheticReplayWorkload(
isl=128,
osl=32,
request_count=16,
replay_concurrency=4,
),
base_engine_args=_base_agg_args(),
base_router_config=None,
model="meta-llama/Llama-3.1-8B-Instruct",
backend="vllm",
system="h100_sxm",
constraints={"mean_e2e_latency_ms": 1000.0},
cache=cache,
)
assert record["output_throughput_tok_s"] == 24.0
assert record["score"] == 24.0
assert record["feasible"] is True
assert record["violation_penalty"] == 0.0
def test_kv_router_config_rejects_negative_overlap_weight() -> None:
config = KvRouterConfig(overlap_score_weight=1.0)
with pytest.raises(ValueError, match="overlap_score_weight must be non-negative"):
config.overlap_score_weight = -1.0
with pytest.raises(ValueError, match="overlap_score_weight must be non-negative"):
config.with_overrides(overlap_score_weight=-1.0)
@pytest.mark.timeout(30)
def test_agg_optimizer_synthetic_replay_smoke(monkeypatch) -> None:
pytest.importorskip("aiconfigurator")
monkeypatch.setattr(
replay_optimize.aic,
"_enumerate_dense_tp_candidates",
lambda backend, system: ([1, 2], [1, 2]),
)
result = optimize_dense_agg_with_replay(
model=_AIC_MODEL,
backend="vllm",
system=_AIC_SYSTEM,
workload=SyntheticReplayWorkload(
isl=128,
osl=32,
request_count=8,
replay_concurrency=4,
),
base_engine_args=_base_agg_args(),
max_total_gpus=4,
constraints={
"mean_ttft_ms": 100000.0,
"mean_tpot_ms": 100000.0,
"mean_e2e_latency_ms": 100000.0,
},
router_mode="both",
overlap_score_weights=[0.0, 1.0],
max_parallel_evals=1,
)
assert not result.evaluated_df.empty
assert result.best_feasible is not None
@pytest.mark.timeout(30)
def test_agg_optimizer_timed_trace_smoke(tmp_path, monkeypatch) -> None:
pytest.importorskip("aiconfigurator")
monkeypatch.setattr(
replay_optimize.aic,
"_enumerate_dense_tp_candidates",
lambda backend, system: ([1, 2], [1, 2]),
)
result = optimize_dense_agg_with_replay(
model=_AIC_MODEL,
backend="vllm",
system=_AIC_SYSTEM,
workload=TraceReplayWorkload(
trace_file=_write_trace(tmp_path),
arrival_speedup_ratio=100.0,
),
base_engine_args=_base_agg_args(),
max_total_gpus=4,
constraints={
"mean_ttft_ms": 100000.0,
"mean_tpot_ms": 100000.0,
"mean_e2e_latency_ms": 100000.0,
},
router_mode="both",
overlap_score_weights=[0.0, 1.0],
max_parallel_evals=1,
)
assert not result.evaluated_df.empty
assert result.best_feasible is not None
@pytest.mark.timeout(30)
def test_optimizer_synthetic_replay_smoke(tmp_path, monkeypatch) -> None:
pytest.importorskip("aiconfigurator")
monkeypatch.setattr(
replay_optimize.aic,
"_enumerate_dense_tp_candidates",
lambda backend, system: ([1, 2], [1, 2]),
)
result = optimize_dense_disagg_with_replay(
model=_AIC_MODEL,
backend="vllm",
system=_AIC_SYSTEM,
workload=SyntheticReplayWorkload(
isl=128,
osl=32,
request_count=8,
replay_concurrency=4,
),
base_prefill_engine_args=_base_prefill_args(),
base_decode_engine_args=_base_decode_args(),
max_total_gpus=4,
constraints={
"mean_ttft_ms": 100000.0,
"mean_tpot_ms": 100000.0,
"mean_e2e_latency_ms": 100000.0,
},
overlap_score_weights=[0.0, 1.0],
max_parallel_evals=1,
)
assert not result.evaluated_df.empty
assert result.best_feasible is not None
@pytest.mark.timeout(30)
def test_optimizer_timed_trace_smoke(tmp_path, monkeypatch) -> None:
pytest.importorskip("aiconfigurator")
monkeypatch.setattr(
replay_optimize.aic,
"_enumerate_dense_tp_candidates",
lambda backend, system: ([1, 2], [1, 2]),
)
result = optimize_dense_disagg_with_replay(
model=_AIC_MODEL,
backend="vllm",
system=_AIC_SYSTEM,
workload=TraceReplayWorkload(
trace_file=_write_trace(tmp_path),
arrival_speedup_ratio=100.0,
),
base_prefill_engine_args=_base_prefill_args(),
base_decode_engine_args=_base_decode_args(),
max_total_gpus=4,
constraints={
"mean_ttft_ms": 100000.0,
"mean_tpot_ms": 100000.0,
"mean_e2e_latency_ms": 100000.0,
},
overlap_score_weights=[0.0, 1.0],
max_parallel_evals=1,
)
assert not result.evaluated_df.empty
assert result.best_feasible is not None
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from . import aic, bench, engine_args, evaluate, scoring, search
from .aic import _enumerate_dense_tp_candidates, _load_aiconfigurator_modules
from .bench import compare_agg_and_disagg_with_replay, compare_aic_and_replay_disagg
from .engine_args import (
_build_agg_candidate_engine_args,
_build_candidate_engine_args,
_build_router_config,
)
from .models import (
DenseAggReplayState,
DenseReplayOptimizationResult,
DenseReplayState,
SyntheticReplayWorkload,
TraceReplayWorkload,
)
from .scoring import _pick_best_record
from .search import (
_iter_agg_tp_states_with_max_workers,
_iter_agg_worker_states,
_iter_budget_edge_worker_states,
_iter_tp_states_with_equal_workers,
optimize_dense_agg_with_replay,
optimize_dense_disagg_with_replay,
)
__all__ = [
"_build_agg_candidate_engine_args",
"_build_candidate_engine_args",
"_build_router_config",
"_enumerate_dense_tp_candidates",
"_iter_agg_tp_states_with_max_workers",
"_iter_agg_worker_states",
"_iter_budget_edge_worker_states",
"_iter_tp_states_with_equal_workers",
"_load_aiconfigurator_modules",
"_pick_best_record",
"compare_agg_and_disagg_with_replay",
"compare_aic_and_replay_disagg",
"DenseAggReplayState",
"DenseReplayOptimizationResult",
"DenseReplayState",
"SyntheticReplayWorkload",
"TraceReplayWorkload",
"aic",
"bench",
"engine_args",
"evaluate",
"optimize_dense_agg_with_replay",
"optimize_dense_disagg_with_replay",
"scoring",
"search",
]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import importlib
from typing import Any
def _load_aiconfigurator_modules() -> tuple[Any, Any, Any]:
try:
common = importlib.import_module("aiconfigurator.sdk.common")
task = importlib.import_module("aiconfigurator.sdk.task")
utils = importlib.import_module("aiconfigurator.sdk.utils")
except ModuleNotFoundError as exc:
raise RuntimeError(
"aiconfigurator is required to enumerate dense TP candidates for replay optimization"
) from exc
return common, task, utils
def _enumerate_dense_tp_candidates(
backend: str, system: str
) -> tuple[list[int], list[int]]:
common, task, utils = _load_aiconfigurator_modules()
backend_enum = getattr(common.BackendName, backend)
prefill_cfg, decode_cfg = task.build_disagg_parallel_lists(
backend_name=backend,
prefill_system=system,
decode_system=system,
is_moe=False,
should_enable_pp=False,
)
prefill_parallel = utils.enumerate_parallel_config(
num_gpu_list=prefill_cfg["num_gpu_per_worker"],
tp_list=prefill_cfg["tp_list"],
pp_list=prefill_cfg["pp_list"],
dp_list=prefill_cfg["dp_list"],
moe_tp_list=prefill_cfg["moe_tp_list"],
moe_ep_list=prefill_cfg["moe_ep_list"],
is_moe=False,
backend=backend_enum,
)
decode_parallel = utils.enumerate_parallel_config(
num_gpu_list=decode_cfg["num_gpu_per_worker"],
tp_list=decode_cfg["tp_list"],
pp_list=decode_cfg["pp_list"],
dp_list=decode_cfg["dp_list"],
moe_tp_list=decode_cfg["moe_tp_list"],
moe_ep_list=decode_cfg["moe_ep_list"],
is_moe=False,
backend=backend_enum,
)
def extract_tp(parallel_configs: list[list[int]]) -> list[int]:
return sorted(
{
tp
for tp, pp, dp, moe_tp, moe_ep in parallel_configs
if pp == 1 and dp == 1 and moe_tp == 1 and moe_ep == 1
}
)
return extract_tp(prefill_parallel), extract_tp(decode_parallel)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
import pandas as pd
from aiconfigurator.sdk.task import TaskConfig, TaskRunner
from dynamo.llm import MockEngineArgs
from .models import SyntheticReplayWorkload, TraceReplayWorkload
from .scoring import _pick_best_record
from .search import optimize_dense_agg_with_replay, optimize_dense_disagg_with_replay
def compare_aic_and_replay_disagg(
*,
model: str,
backend: str,
system: str,
isl: int,
osl: int,
max_total_gpus: int,
replay_request_count: int,
replay_concurrency: int,
base_prefill_engine_args: MockEngineArgs,
base_decode_engine_args: MockEngineArgs,
constraints: Mapping[str, float] | None = None,
max_parallel_evals: int = 1,
) -> dict[str, Any]:
ttft_constraint = None if constraints is None else constraints.get("mean_ttft_ms")
tpot_constraint = None if constraints is None else constraints.get("mean_tpot_ms")
request_latency_constraint = (
None if constraints is None else constraints.get("mean_e2e_latency_ms")
)
aic_task = TaskConfig(
serving_mode="disagg",
model_path=model,
system_name=system,
backend_name=backend,
total_gpus=max_total_gpus,
isl=isl,
osl=osl,
ttft=None if ttft_constraint is None else float(ttft_constraint),
tpot=None if tpot_constraint is None else float(tpot_constraint),
request_latency=(
None
if request_latency_constraint is None
else float(request_latency_constraint)
),
)
aic_result = TaskRunner().run(aic_task)
aic_df = aic_result.get("pareto_df", pd.DataFrame())
replay_result = optimize_dense_disagg_with_replay(
model=model,
backend=backend,
system=system,
workload=SyntheticReplayWorkload(
isl=isl,
osl=osl,
request_count=replay_request_count,
replay_concurrency=replay_concurrency,
),
base_prefill_engine_args=base_prefill_engine_args,
base_decode_engine_args=base_decode_engine_args,
max_total_gpus=max_total_gpus,
constraints=constraints,
router_mode="round_robin",
max_parallel_evals=max_parallel_evals,
)
aic_best = None
if not aic_df.empty:
row = aic_df.iloc[0]
aic_best = {
"prefill_tp": int(row.get("(p)tp", 0)),
"decode_tp": int(row.get("(d)tp", 0)),
"prefill_workers": int(row.get("(p)workers", 0)),
"decode_workers": int(row.get("(d)workers", 0)),
"total_gpus_used": int(row.get("num_total_gpus", 0)),
"ttft": float(row.get("ttft", 0.0)),
"tpot": float(row.get("tpot", 0.0)),
"request_latency": float(row.get("request_latency", 0.0)),
"tokens_per_s": float(row.get("tokens/s", 0.0)),
"tokens_per_s_per_gpu": float(row.get("tokens/s/gpu", 0.0)),
}
replay_best = None
if replay_result.best_feasible is not None:
replay_best_record = replay_result.best_feasible
replay_best = {
"prefill_tp": int(replay_best_record["prefill_tp"]),
"decode_tp": int(replay_best_record["decode_tp"]),
"prefill_workers": int(replay_best_record["prefill_workers"]),
"decode_workers": int(replay_best_record["decode_workers"]),
"total_gpus_used": int(replay_best_record["total_gpus_used"]),
"mean_ttft_ms": float(replay_best_record.get("mean_ttft_ms", 0.0)),
"mean_tpot_ms": float(replay_best_record.get("mean_tpot_ms", 0.0)),
"mean_e2e_latency_ms": float(
replay_best_record.get("mean_e2e_latency_ms", 0.0)
),
"output_throughput_tok_s": float(
replay_best_record.get("output_throughput_tok_s", 0.0)
),
"score": float(replay_best_record.get("score", 0.0)),
}
return {
"aic_pareto_df": aic_df,
"aic_best": aic_best,
"replay_result": replay_result,
"replay_best": replay_best,
}
def compare_agg_and_disagg_with_replay(
*,
model: str,
backend: str,
system: str,
workload: SyntheticReplayWorkload | TraceReplayWorkload,
base_engine_args: MockEngineArgs,
base_prefill_engine_args: MockEngineArgs,
base_decode_engine_args: MockEngineArgs,
max_total_gpus: int,
constraints: Mapping[str, float] | None = None,
router_mode: str = "kv_router",
overlap_score_weights: tuple[float, ...] | list[float] | None = None,
max_parallel_evals: int = 1,
) -> dict[str, Any]:
agg_result = optimize_dense_agg_with_replay(
model=model,
backend=backend,
system=system,
workload=workload,
base_engine_args=base_engine_args,
max_total_gpus=max_total_gpus,
constraints=constraints,
router_mode=router_mode,
overlap_score_weights=overlap_score_weights,
max_parallel_evals=max_parallel_evals,
)
disagg_result = optimize_dense_disagg_with_replay(
model=model,
backend=backend,
system=system,
workload=workload,
base_prefill_engine_args=base_prefill_engine_args,
base_decode_engine_args=base_decode_engine_args,
max_total_gpus=max_total_gpus,
constraints=constraints,
router_mode=router_mode,
overlap_score_weights=overlap_score_weights,
max_parallel_evals=max_parallel_evals,
)
agg_best = agg_result.best_feasible
disagg_best = disagg_result.best_feasible
if agg_best is None and disagg_best is None:
candidates = [
result.best_infeasible
for result in (agg_result, disagg_result)
if result.best_infeasible is not None
]
chosen_best = None if not candidates else _pick_best_record(candidates)
elif agg_best is None:
chosen_best = disagg_best
elif disagg_best is None:
chosen_best = agg_best
else:
chosen_best = _pick_best_record([agg_best, disagg_best])
chosen_mode = None
if chosen_best is not None:
chosen_mode = (
"agg" if "tp" in chosen_best and "workers" in chosen_best else "disagg"
)
return {
"agg_result": agg_result,
"disagg_result": disagg_result,
"chosen_mode": chosen_mode,
"chosen_best": chosen_best,
}
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import os
AIC_BACKEND_VERSIONS = {
"vllm": "0.12.0",
"sglang": "0.5.6.post2",
}
DEFAULT_OVERLAP_SCORE_WEIGHTS = (0.0, 0.25, 0.5, 1.0, 2.0, 4.0)
DEFAULT_MAX_PARALLEL_EVALS = min(4, os.cpu_count() or 1)
DEFAULT_SEARCH_ROUNDS = 3
SUPPORTED_CONSTRAINTS = frozenset(
{
"mean_ttft_ms",
"p95_ttft_ms",
"mean_tpot_ms",
"p95_tpot_ms",
"mean_e2e_latency_ms",
"p95_e2e_latency_ms",
"max_total_gpus",
}
)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from typing import Literal
from dynamo.llm import KvRouterConfig, MockEngineArgs
from .constants import AIC_BACKEND_VERSIONS
def _build_candidate_engine_args(
*,
base_args: MockEngineArgs,
tp_size: int,
worker_type: Literal["prefill", "decode", "aggregated"],
backend: str,
system: str,
model: str,
) -> MockEngineArgs:
args = base_args.copy()
args.worker_type = worker_type
args.enable_prefix_caching = worker_type != "decode"
# Keep the base KV block capacity fixed across TP for now.
#
# TP does not have a simple, backend-agnostic relationship with
# effective KV capacity. In particular, MLA-style attention and other
# specialized cache layouts break the usual KV-head-sharding intuition.
# A future version should derive a TP-aware capacity estimate from the
# AIC SDK instead of applying a generic heuristic here.
args.num_gpu_blocks = base_args.num_gpu_blocks
args.aic_backend = backend
args.aic_system = system
args.aic_backend_version = AIC_BACKEND_VERSIONS[backend]
args.aic_tp_size = tp_size
args.aic_model_path = model
return args
def _build_agg_candidate_engine_args(
*,
base_args: MockEngineArgs,
tp_size: int,
backend: str,
system: str,
model: str,
) -> MockEngineArgs:
return _build_candidate_engine_args(
base_args=base_args,
tp_size=tp_size,
worker_type="aggregated",
backend=backend,
system=system,
model=model,
)
def _build_router_config(
base_router_config: KvRouterConfig | None,
overlap_score_weight: float,
) -> KvRouterConfig:
if base_router_config is None:
return KvRouterConfig(overlap_score_weight=overlap_score_weight)
router_config = base_router_config.copy()
router_config.overlap_score_weight = overlap_score_weight
return router_config
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Replay evaluation helpers for the budget-focused dense search heuristic.
The search in `search.py` assumes we prefer to consume the available GPU budget
and therefore ranks visited states by raw output throughput, subject to replay
constraints, rather than by throughput normalized per GPU.
"""
from __future__ import annotations
from collections.abc import Mapping, Sequence
from concurrent.futures import Executor
from dataclasses import asdict
from pathlib import Path
from typing import Any
from dynamo.llm import KvRouterConfig, MockEngineArgs
from dynamo.replay import run_synthetic_trace_replay, run_trace_replay
from .engine_args import (
_build_agg_candidate_engine_args,
_build_candidate_engine_args,
_build_router_config,
)
from .models import (
DenseAggReplayState,
DenseReplayState,
SyntheticReplayWorkload,
TraceReplayWorkload,
)
from .scoring import _violation_penalty
def _run_replay_for_state(
*,
state: DenseReplayState,
workload: SyntheticReplayWorkload | TraceReplayWorkload,
prefill_engine_args: MockEngineArgs,
decode_engine_args: MockEngineArgs,
router_config: KvRouterConfig | None,
) -> dict[str, Any]:
if isinstance(workload, SyntheticReplayWorkload):
return run_synthetic_trace_replay(
workload.isl,
workload.osl,
workload.request_count,
prefill_engine_args=prefill_engine_args,
decode_engine_args=decode_engine_args,
router_config=router_config,
num_prefill_workers=state.prefill_workers,
num_decode_workers=state.decode_workers,
replay_concurrency=workload.replay_concurrency,
replay_mode="offline",
router_mode=state.router_mode,
arrival_interval_ms=workload.arrival_interval_ms,
turns_per_session=workload.turns_per_session,
shared_prefix_ratio=workload.shared_prefix_ratio,
num_prefix_groups=workload.num_prefix_groups,
inter_turn_delay_ms=workload.inter_turn_delay_ms,
)
return run_trace_replay(
Path(workload.trace_file),
prefill_engine_args=prefill_engine_args,
decode_engine_args=decode_engine_args,
router_config=router_config,
num_prefill_workers=state.prefill_workers,
num_decode_workers=state.decode_workers,
replay_mode="offline",
router_mode=state.router_mode,
arrival_speedup_ratio=workload.arrival_speedup_ratio,
)
def _run_agg_replay_for_state(
*,
state: DenseAggReplayState,
workload: SyntheticReplayWorkload | TraceReplayWorkload,
engine_args: MockEngineArgs,
router_config: KvRouterConfig | None,
) -> dict[str, Any]:
if isinstance(workload, SyntheticReplayWorkload):
return run_synthetic_trace_replay(
workload.isl,
workload.osl,
workload.request_count,
extra_engine_args=engine_args,
router_config=router_config,
num_workers=state.workers,
replay_concurrency=workload.replay_concurrency,
replay_mode="offline",
router_mode=state.router_mode,
arrival_interval_ms=workload.arrival_interval_ms,
turns_per_session=workload.turns_per_session,
shared_prefix_ratio=workload.shared_prefix_ratio,
num_prefix_groups=workload.num_prefix_groups,
inter_turn_delay_ms=workload.inter_turn_delay_ms,
)
return run_trace_replay(
Path(workload.trace_file),
extra_engine_args=engine_args,
router_config=router_config,
num_workers=state.workers,
replay_mode="offline",
router_mode=state.router_mode,
arrival_speedup_ratio=workload.arrival_speedup_ratio,
)
def _evaluate_state(
*,
state: DenseReplayState,
workload: SyntheticReplayWorkload | TraceReplayWorkload,
base_prefill_engine_args: MockEngineArgs,
base_decode_engine_args: MockEngineArgs,
base_router_config: KvRouterConfig | None,
model: str,
backend: str,
system: str,
constraints: Mapping[str, float],
cache: dict[DenseReplayState, dict[str, Any]],
) -> dict[str, Any]:
cached = cache.get(state)
if cached is not None:
return cached
prefill_args = _build_candidate_engine_args(
base_args=base_prefill_engine_args,
tp_size=state.prefill_tp,
worker_type="prefill",
backend=backend,
system=system,
model=model,
)
decode_args = _build_candidate_engine_args(
base_args=base_decode_engine_args,
tp_size=state.decode_tp,
worker_type="decode",
backend=backend,
system=system,
model=model,
)
router_config = None
if state.router_mode == "kv_router":
router_config = _build_router_config(
base_router_config, state.overlap_score_weight
)
report = _run_replay_for_state(
state=state,
workload=workload,
prefill_engine_args=prefill_args,
decode_engine_args=decode_args,
router_config=router_config,
)
total_gpus_used = state.total_gpus_used
throughput = float(report["output_throughput_tok_s"])
score = throughput
penalty = _violation_penalty(report, constraints, total_gpus_used)
feasible = penalty == 0.0
record = {
**report,
**asdict(state),
"total_gpus_used": total_gpus_used,
"output_throughput_tok_s": throughput,
"score": score,
"feasible": feasible,
"violation_penalty": penalty,
}
cache[state] = record
return record
def _evaluate_agg_state(
*,
state: DenseAggReplayState,
workload: SyntheticReplayWorkload | TraceReplayWorkload,
base_engine_args: MockEngineArgs,
base_router_config: KvRouterConfig | None,
model: str,
backend: str,
system: str,
constraints: Mapping[str, float],
cache: dict[DenseAggReplayState, dict[str, Any]],
) -> dict[str, Any]:
cached = cache.get(state)
if cached is not None:
return cached
engine_args = _build_agg_candidate_engine_args(
base_args=base_engine_args,
tp_size=state.tp,
backend=backend,
system=system,
model=model,
)
router_config = None
if state.router_mode == "kv_router":
router_config = _build_router_config(
base_router_config, state.overlap_score_weight
)
report = _run_agg_replay_for_state(
state=state,
workload=workload,
engine_args=engine_args,
router_config=router_config,
)
total_gpus_used = state.total_gpus_used
throughput = float(report["output_throughput_tok_s"])
score = throughput
penalty = _violation_penalty(report, constraints, total_gpus_used)
feasible = penalty == 0.0
record = {
**report,
**asdict(state),
"total_gpus_used": total_gpus_used,
"output_throughput_tok_s": throughput,
"score": score,
"feasible": feasible,
"violation_penalty": penalty,
}
cache[state] = record
return record
def _evaluate_state_from_json_payloads(payload: Mapping[str, Any]) -> dict[str, Any]:
return _evaluate_state(
state=payload["state"],
workload=payload["workload"],
base_prefill_engine_args=MockEngineArgs.from_json(
payload["base_prefill_engine_args_json"]
),
base_decode_engine_args=MockEngineArgs.from_json(
payload["base_decode_engine_args_json"]
),
base_router_config=(
KvRouterConfig.from_json(payload["base_router_config_json"])
if payload["base_router_config_json"] is not None
else None
),
model=payload["model"],
backend=payload["backend"],
system=payload["system"],
constraints=payload["constraints"],
cache={},
)
def _evaluate_agg_state_from_json_payloads(
payload: Mapping[str, Any]
) -> dict[str, Any]:
return _evaluate_agg_state(
state=payload["state"],
workload=payload["workload"],
base_engine_args=MockEngineArgs.from_json(payload["base_engine_args_json"]),
base_router_config=(
KvRouterConfig.from_json(payload["base_router_config_json"])
if payload["base_router_config_json"] is not None
else None
),
model=payload["model"],
backend=payload["backend"],
system=payload["system"],
constraints=payload["constraints"],
cache={},
)
def _evaluate_states(
*,
states: Sequence[DenseReplayState],
workload: SyntheticReplayWorkload | TraceReplayWorkload,
base_prefill_engine_args: MockEngineArgs,
base_decode_engine_args: MockEngineArgs,
base_router_config: KvRouterConfig | None,
model: str,
backend: str,
system: str,
constraints: Mapping[str, float],
cache: dict[DenseReplayState, dict[str, Any]],
max_parallel_evals: int,
executor: Executor | None = None,
) -> list[dict[str, Any]]:
records: list[dict[str, Any] | None] = [None] * len(states)
uncached_indices: list[int] = []
uncached_states: list[DenseReplayState] = []
for index, state in enumerate(states):
cached = cache.get(state)
if cached is not None:
records[index] = cached
continue
uncached_indices.append(index)
uncached_states.append(state)
if not uncached_states:
return [record for record in records if record is not None]
if max_parallel_evals <= 1 or len(uncached_states) == 1 or executor is None:
for index, state in zip(uncached_indices, uncached_states, strict=True):
records[index] = _evaluate_state(
state=state,
workload=workload,
base_prefill_engine_args=base_prefill_engine_args,
base_decode_engine_args=base_decode_engine_args,
base_router_config=base_router_config,
model=model,
backend=backend,
system=system,
constraints=constraints,
cache=cache,
)
return [record for record in records if record is not None]
base_prefill_engine_args_json = base_prefill_engine_args.dump_json()
base_decode_engine_args_json = base_decode_engine_args.dump_json()
base_router_config_json = (
None if base_router_config is None else base_router_config.dump_json()
)
payloads = [
{
"state": state,
"workload": workload,
"base_prefill_engine_args_json": base_prefill_engine_args_json,
"base_decode_engine_args_json": base_decode_engine_args_json,
"base_router_config_json": base_router_config_json,
"model": model,
"backend": backend,
"system": system,
"constraints": constraints,
}
for state in uncached_states
]
future_records = list(executor.map(_evaluate_state_from_json_payloads, payloads))
for index, state, record in zip(
uncached_indices,
uncached_states,
future_records,
strict=True,
):
cache[state] = record
records[index] = record
return [record for record in records if record is not None]
def _evaluate_agg_states(
*,
states: Sequence[DenseAggReplayState],
workload: SyntheticReplayWorkload | TraceReplayWorkload,
base_engine_args: MockEngineArgs,
base_router_config: KvRouterConfig | None,
model: str,
backend: str,
system: str,
constraints: Mapping[str, float],
cache: dict[DenseAggReplayState, dict[str, Any]],
max_parallel_evals: int,
executor: Executor | None = None,
) -> list[dict[str, Any]]:
records: list[dict[str, Any] | None] = [None] * len(states)
uncached_indices: list[int] = []
uncached_states: list[DenseAggReplayState] = []
for index, state in enumerate(states):
cached = cache.get(state)
if cached is not None:
records[index] = cached
continue
uncached_indices.append(index)
uncached_states.append(state)
if not uncached_states:
return [record for record in records if record is not None]
if max_parallel_evals <= 1 or len(uncached_states) == 1 or executor is None:
for index, state in zip(uncached_indices, uncached_states, strict=True):
records[index] = _evaluate_agg_state(
state=state,
workload=workload,
base_engine_args=base_engine_args,
base_router_config=base_router_config,
model=model,
backend=backend,
system=system,
constraints=constraints,
cache=cache,
)
return [record for record in records if record is not None]
base_engine_args_json = base_engine_args.dump_json()
base_router_config_json = (
None if base_router_config is None else base_router_config.dump_json()
)
payloads = [
{
"state": state,
"workload": workload,
"base_engine_args_json": base_engine_args_json,
"base_router_config_json": base_router_config_json,
"model": model,
"backend": backend,
"system": system,
"constraints": constraints,
}
for state in uncached_states
]
future_records = list(
executor.map(_evaluate_agg_state_from_json_payloads, payloads)
)
for index, state, record in zip(
uncached_indices,
uncached_states,
future_records,
strict=True,
):
cache[state] = record
records[index] = record
return [record for record in records if record is not None]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Any
import pandas as pd
@dataclass(frozen=True)
class SyntheticReplayWorkload:
isl: int
osl: int
request_count: int
replay_concurrency: int
arrival_interval_ms: float = 0.0
turns_per_session: int = 1
shared_prefix_ratio: float = 0.0
num_prefix_groups: int = 0
inter_turn_delay_ms: float = 0.0
@dataclass(frozen=True)
class TraceReplayWorkload:
trace_file: str | os.PathLike[str]
arrival_speedup_ratio: float = 1.0
@dataclass(frozen=True)
class DenseReplayState:
prefill_tp: int
decode_tp: int
prefill_workers: int
decode_workers: int
overlap_score_weight: float
router_mode: str = "kv_router"
@property
def total_gpus_used(self) -> int:
return (
self.prefill_tp * self.prefill_workers
+ self.decode_tp * self.decode_workers
)
@dataclass(frozen=True)
class DenseAggReplayState:
tp: int
workers: int
router_mode: str
overlap_score_weight: float
@property
def total_gpus_used(self) -> int:
return self.tp * self.workers
@dataclass(frozen=True)
class DenseReplayOptimizationResult:
best_feasible: dict[str, Any] | None
best_infeasible: dict[str, Any] | None
evaluated_df: pd.DataFrame
feasible_df: pd.DataFrame
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import math
from collections.abc import Mapping, Sequence
from typing import Any
def _metric_value(report: Mapping[str, Any], key: str, total_gpus_used: int) -> float:
if key == "max_total_gpus":
return float(total_gpus_used)
value = report.get(key)
if value is None:
return math.inf
return float(value)
def _violation_penalty(
report: Mapping[str, Any],
constraints: Mapping[str, float],
total_gpus_used: int,
) -> float:
penalty = 0.0
for key, bound in constraints.items():
if bound <= 0:
continue
metric = _metric_value(report, key, total_gpus_used)
penalty += max(metric / bound - 1.0, 0.0)
return penalty
def _rank_record(record: Mapping[str, Any]) -> tuple[float, float, float]:
return (
float(record["score"]),
float(record["output_throughput_tok_s"]),
-float(record.get("mean_e2e_latency_ms", math.inf)),
)
def _pick_best_record(records: Sequence[dict[str, Any]]) -> dict[str, Any]:
feasible_records = [record for record in records if record["feasible"]]
if feasible_records:
return max(
feasible_records,
key=lambda record: (
*_rank_record(record),
-float(record["total_gpus_used"]),
),
)
return min(
records,
key=lambda record: (
float(record["violation_penalty"]),
-float(record["output_throughput_tok_s"]),
float(record.get("mean_e2e_latency_ms", math.inf)),
),
)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Heuristic replay search over dense aggregated and disaggregated configs.
This module intentionally assumes the optimizer should try to consume as much of
`max_total_gpus` as possible once a TP family is under consideration.
Accordingly, the search prunes to near-budget-edge states instead of treating
throughput-per-GPU as the primary objective.
The descent dimensions are:
- Disaggregated replay:
1. TP shape: `(prefill_tp, decode_tp)` probed at equal worker counts that fit
the budget.
2. Worker split: `(prefill_workers, decode_workers)` probed only among states
that maximize GPU usage for the current TP shape.
3. Router settings: `(router_mode, overlap_score_weight)`.
- Aggregated replay:
1. TP size: `tp` probed at the maximum worker count that fits the budget.
2. Worker count: `workers` for the incumbent `tp`.
3. Router settings: `(router_mode, overlap_score_weight)`.
This is a budget-focused heuristic, not an exact optimizer over all feasible
replay states.
"""
from __future__ import annotations
from collections.abc import Mapping, Sequence
from concurrent.futures import ProcessPoolExecutor
from typing import Literal
import pandas as pd
from dynamo.llm import KvRouterConfig, MockEngineArgs
from . import aic, evaluate
from .constants import (
AIC_BACKEND_VERSIONS,
DEFAULT_MAX_PARALLEL_EVALS,
DEFAULT_OVERLAP_SCORE_WEIGHTS,
DEFAULT_SEARCH_ROUNDS,
SUPPORTED_CONSTRAINTS,
)
from .models import (
DenseAggReplayState,
DenseReplayOptimizationResult,
DenseReplayState,
SyntheticReplayWorkload,
TraceReplayWorkload,
)
from .scoring import _pick_best_record
def _validate_backend(backend: str) -> str:
if backend not in AIC_BACKEND_VERSIONS:
raise ValueError(
f"backend must be one of {sorted(AIC_BACKEND_VERSIONS)}, got {backend!r}"
)
return backend
def _normalize_constraints(
constraints: Mapping[str, float] | None,
max_total_gpus: int,
) -> dict[str, float]:
normalized = dict(constraints or {})
invalid_keys = sorted(set(normalized) - SUPPORTED_CONSTRAINTS)
if invalid_keys:
raise ValueError(
"unsupported constraints: "
+ ", ".join(invalid_keys)
+ f"; supported constraints are {sorted(SUPPORTED_CONSTRAINTS)}"
)
if (
"max_total_gpus" in normalized
and int(normalized["max_total_gpus"]) != max_total_gpus
):
raise ValueError(
"constraints['max_total_gpus'] must match max_total_gpus when both are provided"
)
normalized["max_total_gpus"] = float(max_total_gpus)
return normalized
def _normalize_overlap_score_weights(
overlap_score_weights: Sequence[float] | None,
) -> tuple[float, ...]:
if overlap_score_weights is None:
return DEFAULT_OVERLAP_SCORE_WEIGHTS
weights = tuple(float(weight) for weight in overlap_score_weights)
if not weights:
raise ValueError("overlap_score_weights must not be empty")
return weights
def _normalize_router_mode(
router_mode: str,
) -> Literal["kv_router", "round_robin", "both"]:
if router_mode not in {"kv_router", "round_robin", "both"}:
raise ValueError(
"router_mode must be one of ['kv_router', 'round_robin', 'both'], "
f"got {router_mode!r}"
)
return router_mode
def _router_states(
*,
router_mode: Literal["kv_router", "round_robin", "both"],
overlap_score_weights: Sequence[float],
) -> list[tuple[str, float]]:
if router_mode == "round_robin":
return [("round_robin", 0.0)]
if router_mode == "kv_router":
return [("kv_router", float(weight)) for weight in overlap_score_weights]
return [("round_robin", 0.0)] + [
("kv_router", float(weight)) for weight in overlap_score_weights
]
def _supports_agg_router_mode(*, workers: int, router_mode: str) -> bool:
return router_mode == "round_robin" or workers > 1
def _iter_budget_edge_worker_states(
*,
prefill_tp: int,
decode_tp: int,
router_mode: Literal["kv_router", "round_robin"],
overlap_score_weight: float,
max_total_gpus: int,
) -> list[DenseReplayState]:
states: list[DenseReplayState] = []
max_gpus_used = 0
for prefill_workers in range(1, max_total_gpus // prefill_tp + 1):
for decode_workers in range(1, max_total_gpus // decode_tp + 1):
total_gpus_used = prefill_tp * prefill_workers + decode_tp * decode_workers
if total_gpus_used > max_total_gpus:
continue
state = DenseReplayState(
prefill_tp=prefill_tp,
decode_tp=decode_tp,
prefill_workers=prefill_workers,
decode_workers=decode_workers,
overlap_score_weight=overlap_score_weight,
router_mode=router_mode,
)
if total_gpus_used > max_gpus_used:
max_gpus_used = total_gpus_used
states = [state]
continue
if total_gpus_used == max_gpus_used:
states.append(state)
return states
def _iter_agg_worker_states(
*,
tp: int,
router_mode: Literal["kv_router", "round_robin"],
overlap_score_weight: float,
max_total_gpus: int,
) -> list[DenseAggReplayState]:
return [
DenseAggReplayState(
tp=tp,
workers=workers,
router_mode=router_mode,
overlap_score_weight=overlap_score_weight,
)
for workers in range(1, max_total_gpus // tp + 1)
if _supports_agg_router_mode(workers=workers, router_mode=router_mode)
]
def _iter_tp_states_with_equal_workers(
*,
prefill_tps: Sequence[int],
decode_tps: Sequence[int],
router_mode: Literal["kv_router", "round_robin"],
overlap_score_weight: float,
max_total_gpus: int,
) -> list[DenseReplayState]:
states: list[DenseReplayState] = []
for prefill_tp in prefill_tps:
for decode_tp in decode_tps:
max_equal_workers = max_total_gpus // (prefill_tp + decode_tp)
if max_equal_workers < 1:
continue
states.append(
DenseReplayState(
prefill_tp=prefill_tp,
decode_tp=decode_tp,
prefill_workers=max_equal_workers,
decode_workers=max_equal_workers,
overlap_score_weight=overlap_score_weight,
router_mode=router_mode,
)
)
return states
def _iter_agg_tp_states_with_max_workers(
*,
tps: Sequence[int],
router_mode: Literal["kv_router", "round_robin"],
overlap_score_weight: float,
max_total_gpus: int,
) -> list[DenseAggReplayState]:
states: list[DenseAggReplayState] = []
for tp in tps:
workers = max_total_gpus // tp
if workers < 1:
continue
states.append(
DenseAggReplayState(
tp=tp,
workers=workers,
router_mode=router_mode,
overlap_score_weight=overlap_score_weight,
)
)
return states
def _select_initial_state(
*,
prefill_tps: Sequence[int],
decode_tps: Sequence[int],
overlap_score_weight: float,
max_total_gpus: int,
) -> DenseReplayState:
initial_states = _iter_tp_states_with_equal_workers(
prefill_tps=prefill_tps,
decode_tps=decode_tps,
router_mode="round_robin",
overlap_score_weight=overlap_score_weight,
max_total_gpus=max_total_gpus,
)
if initial_states:
return initial_states[0]
raise ValueError(
"no TP candidates fit within "
f"max_total_gpus={max_total_gpus} with equal prefill and decode workers"
)
def _select_initial_agg_state(
*,
tps: Sequence[int],
max_total_gpus: int,
) -> DenseAggReplayState:
states = _iter_agg_tp_states_with_max_workers(
tps=tps,
router_mode="round_robin",
overlap_score_weight=0.0,
max_total_gpus=max_total_gpus,
)
if states:
return states[0]
raise ValueError(
"no TP candidates fit within "
f"max_total_gpus={max_total_gpus} for aggregated replay"
)
def _record_to_state(record: Mapping[str, float | int]) -> DenseReplayState:
return DenseReplayState(
prefill_tp=int(record["prefill_tp"]),
decode_tp=int(record["decode_tp"]),
prefill_workers=int(record["prefill_workers"]),
decode_workers=int(record["decode_workers"]),
overlap_score_weight=float(record["overlap_score_weight"]),
router_mode=str(record.get("router_mode", "kv_router")),
)
def _record_to_agg_state(
record: Mapping[str, float | int | str]
) -> DenseAggReplayState:
return DenseAggReplayState(
tp=int(record["tp"]),
workers=int(record["workers"]),
router_mode=str(record["router_mode"]),
overlap_score_weight=float(record["overlap_score_weight"]),
)
def optimize_dense_disagg_with_replay(
*,
model: str,
backend: Literal["vllm", "sglang"],
system: str,
workload: SyntheticReplayWorkload | TraceReplayWorkload,
base_prefill_engine_args: MockEngineArgs,
base_decode_engine_args: MockEngineArgs,
base_router_config: KvRouterConfig | None = None,
max_total_gpus: int,
constraints: Mapping[str, float] | None = None,
router_mode: Literal["kv_router", "round_robin", "both"] = "kv_router",
overlap_score_weights: Sequence[float] | None = None,
max_parallel_evals: int = DEFAULT_MAX_PARALLEL_EVALS,
) -> DenseReplayOptimizationResult:
"""Run a heuristic block search over dense disaggregated offline replay configs.
This routine assumes we want to use as much of `max_total_gpus` as possible,
then ranks visited states by raw output throughput subject to replay
constraints. The descended dimensions are:
1. `(prefill_tp, decode_tp)` at equal worker counts that fit the budget.
2. `(prefill_workers, decode_workers)` on the budget edge for the incumbent TP
shape.
3. `(router_mode, overlap_score_weight)`.
Returned "best" records are best among visited states, not a global optimum.
"""
backend = _validate_backend(backend)
router_mode = _normalize_router_mode(router_mode)
if max_total_gpus < 2:
raise ValueError("max_total_gpus must be at least 2 for disaggregated replay")
normalized_constraints = _normalize_constraints(constraints, max_total_gpus)
overlap_weights = _normalize_overlap_score_weights(overlap_score_weights)
if router_mode == "round_robin":
overlap_weights = (0.0,)
max_parallel_evals = max(1, int(max_parallel_evals))
prefill_tps, decode_tps = aic._enumerate_dense_tp_candidates(backend, system)
if not prefill_tps or not decode_tps:
raise ValueError(
f"no dense TP candidates found for backend={backend!r}, system={system!r}"
)
cache: dict[DenseReplayState, dict[str, float | int | bool | str]] = {}
incumbent = _select_initial_state(
prefill_tps=prefill_tps,
decode_tps=decode_tps,
overlap_score_weight=overlap_weights[0],
max_total_gpus=max_total_gpus,
)
executor = (
ProcessPoolExecutor(max_workers=max_parallel_evals)
if max_parallel_evals > 1
else None
)
try:
for _ in range(DEFAULT_SEARCH_ROUNDS):
round_start = incumbent
tp_states = _iter_tp_states_with_equal_workers(
prefill_tps=prefill_tps,
decode_tps=decode_tps,
router_mode=incumbent.router_mode,
overlap_score_weight=incumbent.overlap_score_weight,
max_total_gpus=max_total_gpus,
)
tp_records = evaluate._evaluate_states(
states=tp_states,
workload=workload,
base_prefill_engine_args=base_prefill_engine_args,
base_decode_engine_args=base_decode_engine_args,
base_router_config=base_router_config,
model=model,
backend=backend,
system=system,
constraints=normalized_constraints,
cache=cache,
max_parallel_evals=max_parallel_evals,
executor=executor,
)
incumbent = _record_to_state(_pick_best_record(tp_records))
worker_states = _iter_budget_edge_worker_states(
prefill_tp=incumbent.prefill_tp,
decode_tp=incumbent.decode_tp,
router_mode=incumbent.router_mode,
overlap_score_weight=incumbent.overlap_score_weight,
max_total_gpus=max_total_gpus,
)
worker_records = evaluate._evaluate_states(
states=worker_states,
workload=workload,
base_prefill_engine_args=base_prefill_engine_args,
base_decode_engine_args=base_decode_engine_args,
base_router_config=base_router_config,
model=model,
backend=backend,
system=system,
constraints=normalized_constraints,
cache=cache,
max_parallel_evals=max_parallel_evals,
executor=executor,
)
incumbent = _record_to_state(_pick_best_record(worker_records))
router_records = evaluate._evaluate_states(
states=[
DenseReplayState(
prefill_tp=incumbent.prefill_tp,
decode_tp=incumbent.decode_tp,
prefill_workers=incumbent.prefill_workers,
decode_workers=incumbent.decode_workers,
overlap_score_weight=weight,
router_mode=mode,
)
for mode, weight in _router_states(
router_mode=router_mode,
overlap_score_weights=overlap_weights,
)
],
workload=workload,
base_prefill_engine_args=base_prefill_engine_args,
base_decode_engine_args=base_decode_engine_args,
base_router_config=base_router_config,
model=model,
backend=backend,
system=system,
constraints=normalized_constraints,
cache=cache,
max_parallel_evals=max_parallel_evals,
executor=executor,
)
incumbent = _record_to_state(_pick_best_record(router_records))
if incumbent == round_start:
break
finally:
if executor is not None:
executor.shutdown()
evaluated_df = pd.DataFrame.from_records(list(cache.values()))
feasible_df = (
evaluated_df[evaluated_df["feasible"]]
if not evaluated_df.empty
else evaluated_df
)
if not feasible_df.empty:
feasible_df = feasible_df.sort_values(
by=[
"score",
"output_throughput_tok_s",
"mean_e2e_latency_ms",
"total_gpus_used",
],
ascending=[False, False, True, True],
).reset_index(drop=True)
best_feasible = feasible_df.iloc[0].to_dict() if not feasible_df.empty else None
best_infeasible = None
if not evaluated_df.empty:
infeasible_df = evaluated_df[~evaluated_df["feasible"]]
if not infeasible_df.empty:
best_infeasible = (
infeasible_df.sort_values(
by=[
"violation_penalty",
"output_throughput_tok_s",
"mean_e2e_latency_ms",
],
ascending=[True, False, True],
)
.iloc[0]
.to_dict()
)
return DenseReplayOptimizationResult(
best_feasible=best_feasible,
best_infeasible=best_infeasible,
evaluated_df=evaluated_df.reset_index(drop=True),
feasible_df=feasible_df,
)
def optimize_dense_agg_with_replay(
*,
model: str,
backend: Literal["vllm", "sglang"],
system: str,
workload: SyntheticReplayWorkload | TraceReplayWorkload,
base_engine_args: MockEngineArgs,
base_router_config: KvRouterConfig | None = None,
max_total_gpus: int,
constraints: Mapping[str, float] | None = None,
router_mode: Literal["kv_router", "round_robin", "both"] = "kv_router",
overlap_score_weights: Sequence[float] | None = None,
max_parallel_evals: int = DEFAULT_MAX_PARALLEL_EVALS,
) -> DenseReplayOptimizationResult:
"""Run a heuristic block search over dense aggregated offline replay configs.
This routine assumes we want to use as much of `max_total_gpus` as possible,
then ranks visited states by raw output throughput subject to replay
constraints. The descended dimensions are:
1. `tp` at the maximum worker count that fits the budget.
2. `workers` for the incumbent `tp`.
3. `(router_mode, overlap_score_weight)`.
Returned "best" records are best among visited states, not a global optimum.
"""
backend = _validate_backend(backend)
router_mode = _normalize_router_mode(router_mode)
if max_total_gpus < 1:
raise ValueError("max_total_gpus must be at least 1 for aggregated replay")
normalized_constraints = _normalize_constraints(constraints, max_total_gpus)
overlap_weights = _normalize_overlap_score_weights(overlap_score_weights)
if router_mode == "round_robin":
overlap_weights = (0.0,)
max_parallel_evals = max(1, int(max_parallel_evals))
tps, _ = aic._enumerate_dense_tp_candidates(backend, system)
if not tps:
raise ValueError(
f"no dense TP candidates found for backend={backend!r}, system={system!r}"
)
cache: dict[DenseAggReplayState, dict[str, float | int | bool | str]] = {}
incumbent = _select_initial_agg_state(tps=tps, max_total_gpus=max_total_gpus)
executor = (
ProcessPoolExecutor(max_workers=max_parallel_evals)
if max_parallel_evals > 1
else None
)
try:
for _ in range(DEFAULT_SEARCH_ROUNDS):
round_start = incumbent
tp_states = _iter_agg_tp_states_with_max_workers(
tps=tps,
router_mode=incumbent.router_mode,
overlap_score_weight=incumbent.overlap_score_weight,
max_total_gpus=max_total_gpus,
)
tp_records = evaluate._evaluate_agg_states(
states=tp_states,
workload=workload,
base_engine_args=base_engine_args,
base_router_config=base_router_config,
model=model,
backend=backend,
system=system,
constraints=normalized_constraints,
cache=cache,
max_parallel_evals=max_parallel_evals,
executor=executor,
)
incumbent = _record_to_agg_state(_pick_best_record(tp_records))
worker_states = _iter_agg_worker_states(
tp=incumbent.tp,
router_mode=incumbent.router_mode,
overlap_score_weight=incumbent.overlap_score_weight,
max_total_gpus=max_total_gpus,
)
worker_records = evaluate._evaluate_agg_states(
states=worker_states,
workload=workload,
base_engine_args=base_engine_args,
base_router_config=base_router_config,
model=model,
backend=backend,
system=system,
constraints=normalized_constraints,
cache=cache,
max_parallel_evals=max_parallel_evals,
executor=executor,
)
incumbent = _record_to_agg_state(_pick_best_record(worker_records))
router_records = evaluate._evaluate_agg_states(
states=[
DenseAggReplayState(
tp=incumbent.tp,
workers=incumbent.workers,
router_mode=mode,
overlap_score_weight=weight,
)
for mode, weight in _router_states(
router_mode=router_mode,
overlap_score_weights=overlap_weights,
)
if _supports_agg_router_mode(
workers=incumbent.workers,
router_mode=mode,
)
],
workload=workload,
base_engine_args=base_engine_args,
base_router_config=base_router_config,
model=model,
backend=backend,
system=system,
constraints=normalized_constraints,
cache=cache,
max_parallel_evals=max_parallel_evals,
executor=executor,
)
if router_records:
incumbent = _record_to_agg_state(_pick_best_record(router_records))
if incumbent == round_start:
break
finally:
if executor is not None:
executor.shutdown()
evaluated_df = pd.DataFrame.from_records(list(cache.values()))
feasible_df = (
evaluated_df[evaluated_df["feasible"]]
if not evaluated_df.empty
else evaluated_df
)
if not feasible_df.empty:
feasible_df = feasible_df.sort_values(
by=[
"score",
"output_throughput_tok_s",
"mean_e2e_latency_ms",
"total_gpus_used",
],
ascending=[False, False, True, True],
).reset_index(drop=True)
best_feasible = feasible_df.iloc[0].to_dict() if not feasible_df.empty else None
best_infeasible = None
if not evaluated_df.empty:
infeasible_df = evaluated_df[~evaluated_df["feasible"]]
if not infeasible_df.empty:
best_infeasible = (
infeasible_df.sort_values(
by=[
"violation_penalty",
"output_throughput_tok_s",
"mean_e2e_latency_ms",
],
ascending=[True, False, True],
)
.iloc[0]
.to_dict()
)
return DenseReplayOptimizationResult(
best_feasible=best_feasible,
best_infeasible=best_infeasible,
evaluated_df=evaluated_df.reset_index(drop=True),
feasible_df=feasible_df,
)
...@@ -7,7 +7,7 @@ use std::path::PathBuf; ...@@ -7,7 +7,7 @@ use std::path::PathBuf;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use pyo3::{exceptions::PyException, prelude::*}; use pyo3::{exceptions::PyException, exceptions::PyValueError, prelude::*};
use pyo3_async_runtimes::TaskLocals; use pyo3_async_runtimes::TaskLocals;
use dynamo_kv_router::config::KvRouterConfig as RsKvRouterConfig; use dynamo_kv_router::config::KvRouterConfig as RsKvRouterConfig;
...@@ -113,6 +113,45 @@ impl KvRouterConfig { ...@@ -113,6 +113,45 @@ impl KvRouterConfig {
.map(|inner| KvRouterConfig { inner }) .map(|inner| KvRouterConfig { inner })
.map_err(|e| PyException::new_err(format!("Failed to parse KvRouterConfig JSON: {e}"))) .map_err(|e| PyException::new_err(format!("Failed to parse KvRouterConfig JSON: {e}")))
} }
fn dump_json(&self) -> PyResult<String> {
serde_json::to_string(&self.inner)
.map_err(|e| PyException::new_err(format!("Failed to serialize KvRouterConfig: {e}")))
}
fn copy(&self) -> Self {
self.clone()
}
#[getter]
fn overlap_score_weight(&self) -> f64 {
self.inner.overlap_score_weight
}
#[setter]
fn set_overlap_score_weight(&mut self, value: f64) -> PyResult<()> {
if value < 0.0 {
return Err(PyValueError::new_err(
"overlap_score_weight must be non-negative",
));
}
self.inner.overlap_score_weight = value;
Ok(())
}
#[pyo3(signature = (overlap_score_weight=None))]
fn with_overrides(&self, overlap_score_weight: Option<f64>) -> PyResult<Self> {
let mut inner = self.inner.clone();
if let Some(weight) = overlap_score_weight {
if weight < 0.0 {
return Err(PyValueError::new_err(
"overlap_score_weight must be non-negative",
));
}
inner.overlap_score_weight = weight;
}
Ok(Self { inner })
}
} }
#[pyclass] #[pyclass]
......
...@@ -282,12 +282,15 @@ impl MockEngineArgs { ...@@ -282,12 +282,15 @@ impl MockEngineArgs {
"preemption_mode": preemption_mode, "preemption_mode": preemption_mode,
"router_queue_policy": router_queue_policy, "router_queue_policy": router_queue_policy,
"sglang": self.inner.sglang, "sglang": self.inner.sglang,
"has_perf_model": true,
}); });
serde_json::to_string_pretty(&payload) serde_json::to_string_pretty(&payload)
.map_err(|e| PyException::new_err(format!("Failed to serialize MockEngineArgs: {e}"))) .map_err(|e| PyException::new_err(format!("Failed to serialize MockEngineArgs: {e}")))
} }
fn copy(&self) -> Self {
self.clone()
}
#[getter] #[getter]
fn block_size(&self) -> usize { fn block_size(&self) -> usize {
self.inner.block_size self.inner.block_size
...@@ -308,6 +311,16 @@ impl MockEngineArgs { ...@@ -308,6 +311,16 @@ impl MockEngineArgs {
self.inner.max_num_batched_tokens self.inner.max_num_batched_tokens
} }
#[getter]
fn enable_prefix_caching(&self) -> bool {
self.inner.enable_prefix_caching
}
#[setter]
fn set_enable_prefix_caching(&mut self, value: bool) {
self.inner.enable_prefix_caching = value;
}
#[getter] #[getter]
fn enable_local_indexer(&self) -> bool { fn enable_local_indexer(&self) -> bool {
self.inner.enable_local_indexer self.inner.enable_local_indexer
...@@ -323,6 +336,76 @@ impl MockEngineArgs { ...@@ -323,6 +336,76 @@ impl MockEngineArgs {
self.inner.bootstrap_port self.inner.bootstrap_port
} }
#[getter]
fn aic_backend(&self) -> Option<String> {
self.inner.aic_backend.clone()
}
#[setter]
fn set_aic_backend(&mut self, value: Option<String>) {
self.inner.aic_backend = value;
}
#[getter]
fn aic_system(&self) -> Option<String> {
self.inner.aic_system.clone()
}
#[setter]
fn set_aic_system(&mut self, value: Option<String>) {
self.inner.aic_system = value;
}
#[getter]
fn aic_backend_version(&self) -> Option<String> {
self.inner.aic_backend_version.clone()
}
#[setter]
fn set_aic_backend_version(&mut self, value: Option<String>) {
self.inner.aic_backend_version = value;
}
#[getter]
fn aic_tp_size(&self) -> Option<usize> {
self.inner.aic_tp_size
}
#[setter]
fn set_aic_tp_size(&mut self, value: Option<usize>) {
self.inner.aic_tp_size = value;
}
#[getter]
fn aic_model_path(&self) -> Option<String> {
self.inner.aic_model_path.clone()
}
#[setter]
fn set_aic_model_path(&mut self, value: Option<String>) {
self.inner.aic_model_path = value;
}
#[getter]
fn worker_type(&self) -> &'static str {
match self.inner.worker_type {
RsWorkerType::Aggregated => "aggregated",
RsWorkerType::Prefill => "prefill",
RsWorkerType::Decode => "decode",
}
}
#[setter]
fn set_worker_type(&mut self, value: &str) -> PyResult<()> {
self.inner.worker_type = parse_worker_type(value)?;
Ok(())
}
#[setter]
fn set_num_gpu_blocks(&mut self, value: usize) {
self.inner.num_gpu_blocks = value;
}
fn is_prefill(&self) -> bool { fn is_prefill(&self) -> bool {
self.inner.is_prefill() self.inner.is_prefill()
} }
...@@ -331,13 +414,22 @@ impl MockEngineArgs { ...@@ -331,13 +414,22 @@ impl MockEngineArgs {
self.inner.is_decode() self.inner.is_decode()
} }
#[pyo3(signature = (bootstrap_port=None, zmq_kv_events_port=None, zmq_replay_port=None, kv_bytes_per_token=None))] #[allow(clippy::too_many_arguments)]
#[pyo3(signature = (bootstrap_port=None, zmq_kv_events_port=None, zmq_replay_port=None, kv_bytes_per_token=None, num_gpu_blocks=None, aic_backend=None, aic_system=None, aic_backend_version=None, aic_tp_size=None, aic_model_path=None, enable_prefix_caching=None, worker_type=None))]
fn with_overrides( fn with_overrides(
&self, &self,
bootstrap_port: Option<u16>, bootstrap_port: Option<u16>,
zmq_kv_events_port: Option<u16>, zmq_kv_events_port: Option<u16>,
zmq_replay_port: Option<u16>, zmq_replay_port: Option<u16>,
kv_bytes_per_token: Option<usize>, kv_bytes_per_token: Option<usize>,
num_gpu_blocks: Option<usize>,
aic_backend: Option<String>,
aic_system: Option<String>,
aic_backend_version: Option<String>,
aic_tp_size: Option<usize>,
aic_model_path: Option<String>,
enable_prefix_caching: Option<bool>,
worker_type: Option<String>,
) -> PyResult<Self> { ) -> PyResult<Self> {
let mut inner = self.inner.clone(); let mut inner = self.inner.clone();
if let Some(port) = bootstrap_port { if let Some(port) = bootstrap_port {
...@@ -352,6 +444,30 @@ impl MockEngineArgs { ...@@ -352,6 +444,30 @@ impl MockEngineArgs {
if let Some(bytes_per_token) = kv_bytes_per_token { if let Some(bytes_per_token) = kv_bytes_per_token {
inner.kv_bytes_per_token = Some(bytes_per_token); inner.kv_bytes_per_token = Some(bytes_per_token);
} }
if let Some(blocks) = num_gpu_blocks {
inner.num_gpu_blocks = blocks;
}
if let Some(backend) = aic_backend {
inner.aic_backend = Some(backend);
}
if let Some(system) = aic_system {
inner.aic_system = Some(system);
}
if let Some(version) = aic_backend_version {
inner.aic_backend_version = Some(version);
}
if let Some(tp_size) = aic_tp_size {
inner.aic_tp_size = Some(tp_size);
}
if let Some(model_path) = aic_model_path {
inner.aic_model_path = Some(model_path);
}
if let Some(enable_prefix_caching) = enable_prefix_caching {
inner.enable_prefix_caching = enable_prefix_caching;
}
if let Some(worker_type) = worker_type {
inner.worker_type = parse_worker_type(&worker_type)?;
}
inner.normalized().map(|inner| Self { inner }).map_err(|e| { inner.normalized().map(|inner| Self { inner }).map_err(|e| {
PyException::new_err(format!("Failed to normalize MockEngineArgs overrides: {e}")) PyException::new_err(format!("Failed to normalize MockEngineArgs overrides: {e}"))
}) })
......
...@@ -1221,6 +1221,21 @@ class KvRouterConfig: ...@@ -1221,6 +1221,21 @@ class KvRouterConfig:
def from_json(config_json: str) -> "KvRouterConfig": def from_json(config_json: str) -> "KvRouterConfig":
... ...
def dump_json(self) -> str: ...
def copy(self) -> "KvRouterConfig": ...
@property
def overlap_score_weight(self) -> float: ...
@overlap_score_weight.setter
def overlap_score_weight(self, value: float) -> None: ...
def with_overrides(
self,
overlap_score_weight: Optional[float] = None,
) -> "KvRouterConfig": ...
class ReasoningConfig: class ReasoningConfig:
def __init__( def __init__(
self, self,
...@@ -1280,6 +1295,8 @@ class MockEngineArgs: ...@@ -1280,6 +1295,8 @@ class MockEngineArgs:
def from_json(config_json: str) -> "MockEngineArgs": def from_json(config_json: str) -> "MockEngineArgs":
... ...
def copy(self) -> "MockEngineArgs": ...
def dump_json(self) -> str: ... def dump_json(self) -> str: ...
@property @property
...@@ -1288,12 +1305,21 @@ class MockEngineArgs: ...@@ -1288,12 +1305,21 @@ class MockEngineArgs:
@property @property
def num_gpu_blocks(self) -> int: ... def num_gpu_blocks(self) -> int: ...
@num_gpu_blocks.setter
def num_gpu_blocks(self, value: int) -> None: ...
@property @property
def max_num_seqs(self) -> Optional[int]: ... def max_num_seqs(self) -> Optional[int]: ...
@property @property
def max_num_batched_tokens(self) -> Optional[int]: ... def max_num_batched_tokens(self) -> Optional[int]: ...
@property
def enable_prefix_caching(self) -> bool: ...
@enable_prefix_caching.setter
def enable_prefix_caching(self, value: bool) -> None: ...
@property @property
def enable_local_indexer(self) -> bool: ... def enable_local_indexer(self) -> bool: ...
...@@ -1303,6 +1329,42 @@ class MockEngineArgs: ...@@ -1303,6 +1329,42 @@ class MockEngineArgs:
@property @property
def bootstrap_port(self) -> Optional[int]: ... def bootstrap_port(self) -> Optional[int]: ...
@property
def aic_backend(self) -> Optional[str]: ...
@aic_backend.setter
def aic_backend(self, value: Optional[str]) -> None: ...
@property
def aic_system(self) -> Optional[str]: ...
@aic_system.setter
def aic_system(self, value: Optional[str]) -> None: ...
@property
def aic_backend_version(self) -> Optional[str]: ...
@aic_backend_version.setter
def aic_backend_version(self, value: Optional[str]) -> None: ...
@property
def aic_tp_size(self) -> Optional[int]: ...
@aic_tp_size.setter
def aic_tp_size(self, value: Optional[int]) -> None: ...
@property
def aic_model_path(self) -> Optional[str]: ...
@aic_model_path.setter
def aic_model_path(self, value: Optional[str]) -> None: ...
@property
def worker_type(self) -> str: ...
@worker_type.setter
def worker_type(self, value: str) -> None: ...
def is_prefill(self) -> bool: ... def is_prefill(self) -> bool: ...
def is_decode(self) -> bool: ... def is_decode(self) -> bool: ...
...@@ -1313,6 +1375,14 @@ class MockEngineArgs: ...@@ -1313,6 +1375,14 @@ class MockEngineArgs:
zmq_kv_events_port: Optional[int] = None, zmq_kv_events_port: Optional[int] = None,
zmq_replay_port: Optional[int] = None, zmq_replay_port: Optional[int] = None,
kv_bytes_per_token: Optional[int] = None, kv_bytes_per_token: Optional[int] = None,
num_gpu_blocks: Optional[int] = None,
aic_backend: Optional[str] = None,
aic_system: Optional[str] = None,
aic_backend_version: Optional[str] = None,
aic_tp_size: Optional[int] = None,
aic_model_path: Optional[str] = None,
enable_prefix_caching: Optional[bool] = None,
worker_type: Optional[str] = None,
) -> "MockEngineArgs": ... ) -> "MockEngineArgs": ...
async def register_model( async def register_model(
......
...@@ -484,6 +484,7 @@ impl MockEngineArgs { ...@@ -484,6 +484,7 @@ impl MockEngineArgs {
"decode_speedup_ratio", "decode_speedup_ratio",
"dp_size", "dp_size",
"startup_time", "startup_time",
"worker_type",
"is_prefill", "is_prefill",
"is_decode", "is_decode",
"planner_profile_data", "planner_profile_data",
...@@ -502,6 +503,7 @@ impl MockEngineArgs { ...@@ -502,6 +503,7 @@ impl MockEngineArgs {
"preemption_mode", "preemption_mode",
"router_queue_policy", "router_queue_policy",
"sglang", "sglang",
"has_perf_model",
] ]
.iter() .iter()
.cloned() .cloned()
...@@ -551,16 +553,20 @@ impl MockEngineArgs { ...@@ -551,16 +553,20 @@ impl MockEngineArgs {
builder = builder.block_size(num as usize); builder = builder.block_size(num as usize);
} }
if let Some(value) = extra_args.get("max_num_seqs") if let Some(value) = extra_args.get("max_num_seqs") {
&& let Some(num) = value.as_u64() if value.is_null() {
{ builder = builder.max_num_seqs(None);
builder = builder.max_num_seqs(Some(num as usize)); } else if let Some(num) = value.as_u64() {
builder = builder.max_num_seqs(Some(num as usize));
}
} }
if let Some(value) = extra_args.get("max_num_batched_tokens") if let Some(value) = extra_args.get("max_num_batched_tokens") {
&& let Some(num) = value.as_u64() if value.is_null() {
{ builder = builder.max_num_batched_tokens(None);
builder = builder.max_num_batched_tokens(Some(num as usize)); } else if let Some(num) = value.as_u64() {
builder = builder.max_num_batched_tokens(Some(num as usize));
}
} }
if let Some(value) = extra_args.get("enable_prefix_caching") if let Some(value) = extra_args.get("enable_prefix_caching")
...@@ -623,7 +629,9 @@ impl MockEngineArgs { ...@@ -623,7 +629,9 @@ impl MockEngineArgs {
builder = builder.kv_transfer_bandwidth(Some(num)); builder = builder.kv_transfer_bandwidth(Some(num));
} }
if let Some(value) = extra_args.get("reasoning") { if let Some(value) = extra_args.get("reasoning")
&& !value.is_null()
{
let cfg: ReasoningConfig = serde_json::from_value(value.clone()) let cfg: ReasoningConfig = serde_json::from_value(value.clone())
.map_err(|e| anyhow::anyhow!("Failed to parse reasoning config: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to parse reasoning config: {}", e))?;
builder = builder.reasoning(Some(cfg)); builder = builder.reasoning(Some(cfg));
...@@ -664,31 +672,51 @@ impl MockEngineArgs { ...@@ -664,31 +672,51 @@ impl MockEngineArgs {
builder = builder.router_queue_policy(Some(policy)); builder = builder.router_queue_policy(Some(policy));
} }
if let Some(value) = extra_args.get("sglang") { if let Some(value) = extra_args.get("sglang")
&& !value.is_null()
{
let cfg: SglangArgs = serde_json::from_value(value.clone()) let cfg: SglangArgs = serde_json::from_value(value.clone())
.map_err(|e| anyhow::anyhow!("Failed to parse sglang config: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to parse sglang config: {}", e))?;
builder = builder.sglang(Some(cfg)); builder = builder.sglang(Some(cfg));
} }
// Parse worker type from is_prefill and is_decode flags let worker_type = if let Some(value) = extra_args.get("worker_type") {
let is_prefill = extra_args match value.as_str() {
.get("is_prefill") Some("aggregated") => WorkerType::Aggregated,
.and_then(|v| v.as_bool()) Some("prefill") => WorkerType::Prefill,
.unwrap_or(false); Some("decode") => WorkerType::Decode,
let is_decode = extra_args Some(other) => {
.get("is_decode") return Err(anyhow::anyhow!(
.and_then(|v| v.as_bool()) "Invalid worker_type '{}'. Must be 'aggregated', 'prefill', or 'decode'.",
.unwrap_or(false); other
));
// Determine worker type based on flags }
let worker_type = match (is_prefill, is_decode) { None => {
(false, false) => WorkerType::Aggregated, return Err(anyhow::anyhow!(
(true, false) => WorkerType::Prefill, "Invalid worker_type: expected string value."
(false, true) => WorkerType::Decode, ));
(true, true) => panic!( }
"Invalid worker configuration: is_prefill and is_decode cannot both be true. \ }
Worker must be either Aggregated (both false), Prefill (is_prefill=true), or Decode (is_decode=true)." } else {
), let is_prefill = extra_args
.get("is_prefill")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let is_decode = extra_args
.get("is_decode")
.and_then(|v| v.as_bool())
.unwrap_or(false);
match (is_prefill, is_decode) {
(false, false) => WorkerType::Aggregated,
(true, false) => WorkerType::Prefill,
(false, true) => WorkerType::Decode,
(true, true) => {
return Err(anyhow::anyhow!(
"Invalid worker configuration: is_prefill and is_decode cannot both be true."
));
}
}
}; };
builder = builder.worker_type(worker_type); builder = builder.worker_type(worker_type);
...@@ -756,6 +784,58 @@ mod tests { ...@@ -756,6 +784,58 @@ mod tests {
use super::*; use super::*;
use serde_json::json; use serde_json::json;
#[test]
fn test_mock_engine_args_json_round_trip_preserves_worker_type_and_nulls() {
let args = MockEngineArgs::builder()
.worker_type(WorkerType::Decode)
.max_num_seqs(None)
.max_num_batched_tokens(None)
.reasoning(None)
.sglang(None)
.build()
.unwrap()
.normalized()
.unwrap();
let payload = serde_json::json!({
"engine_type": "vllm",
"num_gpu_blocks": args.num_gpu_blocks,
"block_size": args.block_size,
"max_num_seqs": args.max_num_seqs,
"max_num_batched_tokens": args.max_num_batched_tokens,
"enable_prefix_caching": args.enable_prefix_caching,
"enable_chunked_prefill": args.enable_chunked_prefill,
"speedup_ratio": args.speedup_ratio,
"decode_speedup_ratio": args.decode_speedup_ratio,
"dp_size": args.dp_size,
"startup_time": args.startup_time,
"worker_type": "decode",
"planner_profile_data": args.planner_profile_data,
"aic_backend": args.aic_backend,
"aic_system": args.aic_system,
"aic_backend_version": args.aic_backend_version,
"aic_tp_size": args.aic_tp_size,
"aic_model_path": args.aic_model_path,
"enable_local_indexer": args.enable_local_indexer,
"bootstrap_port": args.bootstrap_port,
"kv_bytes_per_token": args.kv_bytes_per_token,
"kv_transfer_bandwidth": args.kv_transfer_bandwidth,
"reasoning": args.reasoning,
"zmq_kv_events_port": args.zmq_kv_events_port,
"zmq_replay_port": args.zmq_replay_port,
"preemption_mode": "lifo",
"router_queue_policy": args.router_queue_policy.map(|policy| policy.to_string()),
"sglang": args.sglang,
"has_perf_model": true,
});
let restored = MockEngineArgs::from_json_str(&payload.to_string()).unwrap();
assert_eq!(restored.worker_type, WorkerType::Decode);
assert_eq!(restored.max_num_seqs, None);
assert_eq!(restored.max_num_batched_tokens, None);
}
#[test] #[test]
fn test_unique_block_default_uniqueness() { fn test_unique_block_default_uniqueness() {
// Create 10 default UniqueBlock instances // Create 10 default UniqueBlock instances
......
...@@ -10,7 +10,7 @@ from types import SimpleNamespace ...@@ -10,7 +10,7 @@ from types import SimpleNamespace
import numpy as np import numpy as np
import pytest import pytest
from dynamo.llm import EngineType, EntrypointArgs from dynamo.llm import EngineType, EntrypointArgs, MockEngineArgs
MODULE_PATH = ( MODULE_PATH = (
Path(__file__).resolve().parents[2] / "components/src/dynamo/mocker/config.py" Path(__file__).resolve().parents[2] / "components/src/dynamo/mocker/config.py"
...@@ -230,5 +230,26 @@ def test_build_mocker_engine_args_preserves_cli_mapped_fields(tmp_path): ...@@ -230,5 +230,26 @@ def test_build_mocker_engine_args_preserves_cli_mapped_fields(tmp_path):
"clip_max_new_tokens": 1024, "clip_max_new_tokens": 1024,
"schedule_conservativeness": 0.8, "schedule_conservativeness": 0.8,
}, },
}
assert "has_perf_model" not in payload
def test_mock_engine_args_from_json_ignores_legacy_has_perf_model_field():
payload = {
"engine_type": "vllm",
"num_gpu_blocks": 2048,
"block_size": 128,
"max_num_seqs": None,
"max_num_batched_tokens": None,
"worker_type": "decode",
"has_perf_model": True, "has_perf_model": True,
} }
engine_args = MockEngineArgs.from_json(json.dumps(payload))
assert engine_args.num_gpu_blocks == 2048
assert engine_args.block_size == 128
assert engine_args.max_num_seqs is None
assert engine_args.max_num_batched_tokens is None
assert engine_args.worker_type == "decode"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment