"container/vscode:/vscode.git/clone" did not exist on "e041ccfcb14995827cc856b9bb181abf04c64cc7"
Unverified Commit 69823d72 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(profiler): add replay objective knob (#8518)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent bce060d2
......@@ -18,6 +18,8 @@ except ImportError:
from dynamo.profiler.utils import replay_optimize
from dynamo.profiler.utils.replay_optimize import (
DenseAggReplayState,
ReplayConstraints,
ReplayObjective,
SyntheticReplayWorkload,
TraceReplayWorkload,
compare_agg_and_disagg_with_replay,
......@@ -500,6 +502,84 @@ def test_optimizer_supports_round_robin_router_mode(monkeypatch) -> None:
assert set(seen_weights) == {0.0}
def test_disagg_optimizer_supports_latency_objective(monkeypatch) -> None:
def fake_run(**kwargs):
state = kwargs["state"]
if state.prefill_tp == 1 and state.decode_tp == 1:
return {
"output_throughput_tok_s": 1200.0,
"mean_ttft_ms": 140.0,
"p95_ttft_ms": 160.0,
"mean_tpot_ms": 10.0,
"p95_tpot_ms": 12.0,
"mean_e2e_latency_ms": 300.0,
"p95_e2e_latency_ms": 320.0,
}
return {
"output_throughput_tok_s": 1000.0,
"mean_ttft_ms": 100.0,
"p95_ttft_ms": 120.0,
"mean_tpot_ms": 10.0,
"p95_tpot_ms": 12.0,
"mean_e2e_latency_ms": 200.0,
"p95_e2e_latency_ms": 220.0,
}
monkeypatch.setattr(
replay_optimize.aic,
"_enumerate_dense_tp_candidates",
lambda backend, system: ([1, 2], [1, 2]),
)
monkeypatch.setattr(replay_optimize.evaluate, "_run_replay_for_state", fake_run)
result = optimize_dense_disagg_with_replay(
model=_AIC_MODEL,
backend="vllm",
system=_AIC_SYSTEM,
workload=SyntheticReplayWorkload(
isl=64,
osl=32,
request_count=8,
replay_concurrency=4,
),
base_prefill_engine_args=_base_prefill_args(),
base_decode_engine_args=_base_decode_args(),
max_total_gpus=4,
constraints={"mean_e2e_latency_ms": 500.0},
objective="mean_e2e_latency",
overlap_score_weights=[0.0],
max_parallel_evals=1,
)
assert result.best_feasible is not None
assert (
result.best_feasible["prefill_tp"],
result.best_feasible["decode_tp"],
) in {(1, 2), (2, 1), (2, 2)}
assert result.best_feasible["score"] == -200.0
assert result.best_feasible["objective"] == "mean_e2e_latency"
def test_disagg_optimizer_rejects_invalid_objective() -> None:
with pytest.raises(ValueError, match="not a valid ReplayObjective"):
optimize_dense_disagg_with_replay(
model=_AIC_MODEL,
backend="vllm",
system=_AIC_SYSTEM,
workload=SyntheticReplayWorkload(
isl=64,
osl=32,
request_count=8,
replay_concurrency=4,
),
base_prefill_engine_args=_base_prefill_args(),
base_decode_engine_args=_base_decode_args(),
max_total_gpus=4,
objective="bad_objective",
max_parallel_evals=1,
)
def test_disagg_optimizer_supports_router_mode_search(monkeypatch) -> None:
seen_router_modes: list[str] = []
seen_weights: list[float] = []
......@@ -705,7 +785,8 @@ def test_evaluate_state_prefers_normalized_metrics_over_report_payload() -> None
model="meta-llama/Llama-3.1-8B-Instruct",
backend="vllm",
system="h100_sxm",
constraints={"mean_e2e_latency_ms": 1000.0},
objective=ReplayObjective.THROUGHPUT,
constraints=ReplayConstraints(mean_e2e_latency_ms=1000.0),
cache=cache,
)
......@@ -747,7 +828,8 @@ def test_evaluate_agg_state_prefers_normalized_metrics_over_report_payload() ->
model="meta-llama/Llama-3.1-8B-Instruct",
backend="vllm",
system="h100_sxm",
constraints={"mean_e2e_latency_ms": 1000.0},
objective=ReplayObjective.THROUGHPUT,
constraints=ReplayConstraints(mean_e2e_latency_ms=1000.0),
cache=cache,
)
......
......@@ -15,6 +15,8 @@ from .models import (
DenseAggReplayState,
DenseReplayOptimizationResult,
DenseReplayState,
ReplayConstraints,
ReplayObjective,
SyntheticReplayWorkload,
TraceReplayWorkload,
)
......@@ -44,6 +46,8 @@ __all__ = [
"DenseAggReplayState",
"DenseReplayOptimizationResult",
"DenseReplayState",
"ReplayConstraints",
"ReplayObjective",
"SyntheticReplayWorkload",
"TraceReplayWorkload",
"aic",
......
......@@ -11,7 +11,7 @@ from aiconfigurator.sdk.task import TaskConfig, TaskRunner
from dynamo.llm import MockEngineArgs
from .models import SyntheticReplayWorkload, TraceReplayWorkload
from .models import ReplayConstraints, SyntheticReplayWorkload, TraceReplayWorkload
from .scoring import _pick_best_record
from .search import optimize_dense_agg_with_replay, optimize_dense_disagg_with_replay
......@@ -31,11 +31,7 @@ def compare_aic_and_replay_disagg(
constraints: Mapping[str, float] | None = None,
max_parallel_evals: int = 1,
) -> dict[str, Any]:
ttft_constraint = None if constraints is None else constraints.get("mean_ttft_ms")
tpot_constraint = None if constraints is None else constraints.get("mean_tpot_ms")
request_latency_constraint = (
None if constraints is None else constraints.get("mean_e2e_latency_ms")
)
aic_constraints = ReplayConstraints.from_mapping(constraints, max_total_gpus)
aic_task = TaskConfig(
serving_mode="disagg",
model_path=model,
......@@ -44,13 +40,7 @@ def compare_aic_and_replay_disagg(
total_gpus=max_total_gpus,
isl=isl,
osl=osl,
ttft=None if ttft_constraint is None else float(ttft_constraint),
tpot=None if tpot_constraint is None else float(tpot_constraint),
request_latency=(
None
if request_latency_constraint is None
else float(request_latency_constraint)
),
**aic_constraints.aic_task_kwargs(),
)
aic_result = TaskRunner().run(aic_task)
aic_df = aic_result.get("pareto_df", pd.DataFrame())
......
......@@ -24,20 +24,15 @@ from .engine_args import (
_build_candidate_engine_args,
_build_router_config,
)
from .logging import (
ensure_dynamo_logging,
log_agg_state_finish,
log_agg_state_start,
log_dense_state_finish,
log_dense_state_start,
)
from .logging import ensure_dynamo_logging, log_state_finish, log_state_start
from .models import (
DenseAggReplayState,
DenseReplayState,
ReplayConstraints,
ReplayObjective,
SyntheticReplayWorkload,
TraceReplayWorkload,
)
from .scoring import _violation_penalty
def _run_replay_for_state(
......@@ -127,7 +122,8 @@ def _evaluate_state(
model: str,
backend: str,
system: str,
constraints: Mapping[str, float],
objective: ReplayObjective,
constraints: ReplayConstraints,
cache: dict[DenseReplayState, dict[str, Any]],
) -> dict[str, Any]:
ensure_dynamo_logging()
......@@ -135,7 +131,7 @@ def _evaluate_state(
if cached is not None:
return cached
log_dense_state_start(state)
log_state_start(state)
prefill_args = _build_candidate_engine_args(
base_args=base_prefill_engine_args,
......@@ -168,8 +164,8 @@ def _evaluate_state(
total_gpus_used = state.total_gpus_used
throughput = float(report["output_throughput_tok_s"])
score = throughput
penalty = _violation_penalty(report, constraints, total_gpus_used)
score = objective.score(report)
penalty = constraints.violation_penalty(report, total_gpus_used)
feasible = penalty == 0.0
record = {
**report,
......@@ -177,10 +173,11 @@ def _evaluate_state(
"total_gpus_used": total_gpus_used,
"output_throughput_tok_s": throughput,
"score": score,
"objective": objective.value,
"feasible": feasible,
"violation_penalty": penalty,
}
log_dense_state_finish(
log_state_finish(
state=state,
report=report,
constraints=constraints,
......@@ -201,7 +198,8 @@ def _evaluate_agg_state(
model: str,
backend: str,
system: str,
constraints: Mapping[str, float],
objective: ReplayObjective,
constraints: ReplayConstraints,
cache: dict[DenseAggReplayState, dict[str, Any]],
) -> dict[str, Any]:
ensure_dynamo_logging()
......@@ -209,7 +207,7 @@ def _evaluate_agg_state(
if cached is not None:
return cached
log_agg_state_start(state)
log_state_start(state)
engine_args = _build_agg_candidate_engine_args(
base_args=base_engine_args,
......@@ -232,8 +230,8 @@ def _evaluate_agg_state(
total_gpus_used = state.total_gpus_used
throughput = float(report["output_throughput_tok_s"])
score = throughput
penalty = _violation_penalty(report, constraints, total_gpus_used)
score = objective.score(report)
penalty = constraints.violation_penalty(report, total_gpus_used)
feasible = penalty == 0.0
record = {
**report,
......@@ -241,10 +239,11 @@ def _evaluate_agg_state(
"total_gpus_used": total_gpus_used,
"output_throughput_tok_s": throughput,
"score": score,
"objective": objective.value,
"feasible": feasible,
"violation_penalty": penalty,
}
log_agg_state_finish(
log_state_finish(
state=state,
report=report,
constraints=constraints,
......@@ -274,6 +273,7 @@ def _evaluate_state_from_json_payloads(payload: Mapping[str, Any]) -> dict[str,
model=payload["model"],
backend=payload["backend"],
system=payload["system"],
objective=payload["objective"],
constraints=payload["constraints"],
cache={},
)
......@@ -294,6 +294,7 @@ def _evaluate_agg_state_from_json_payloads(
model=payload["model"],
backend=payload["backend"],
system=payload["system"],
objective=payload["objective"],
constraints=payload["constraints"],
cache={},
)
......@@ -309,7 +310,8 @@ def _evaluate_states(
model: str,
backend: str,
system: str,
constraints: Mapping[str, float],
objective: ReplayObjective,
constraints: ReplayConstraints,
cache: dict[DenseReplayState, dict[str, Any]],
max_parallel_evals: int,
executor: Executor | None = None,
......@@ -340,6 +342,7 @@ def _evaluate_states(
model=model,
backend=backend,
system=system,
objective=objective,
constraints=constraints,
cache=cache,
)
......@@ -360,6 +363,7 @@ def _evaluate_states(
"model": model,
"backend": backend,
"system": system,
"objective": objective,
"constraints": constraints,
}
for state in uncached_states
......@@ -388,7 +392,8 @@ def _evaluate_agg_states(
model: str,
backend: str,
system: str,
constraints: Mapping[str, float],
objective: ReplayObjective,
constraints: ReplayConstraints,
cache: dict[DenseAggReplayState, dict[str, Any]],
max_parallel_evals: int,
executor: Executor | None = None,
......@@ -418,6 +423,7 @@ def _evaluate_agg_states(
model=model,
backend=backend,
system=system,
objective=objective,
constraints=constraints,
cache=cache,
)
......@@ -436,6 +442,7 @@ def _evaluate_agg_states(
"model": model,
"backend": backend,
"system": system,
"objective": objective,
"constraints": constraints,
}
for state in uncached_states
......
......@@ -9,7 +9,7 @@ from typing import Any
from dynamo.runtime.logging import configure_dynamo_logging
from .models import DenseAggReplayState, DenseReplayState
from .models import DenseAggReplayState, DenseReplayState, ReplayConstraints
logger = logging.getLogger(__name__)
_LOGGING_CONFIGURED = False
......@@ -23,95 +23,24 @@ def ensure_dynamo_logging() -> None:
_LOGGING_CONFIGURED = True
def format_dense_state(state: DenseReplayState) -> str:
return (
"prefill_tp=%s decode_tp=%s prefill_workers=%s decode_workers=%s "
"router_mode=%s overlap_score_weight=%s total_gpus=%s"
) % (
state.prefill_tp,
state.decode_tp,
state.prefill_workers,
state.decode_workers,
state.router_mode,
state.overlap_score_weight,
state.total_gpus_used,
)
def format_agg_state(state: DenseAggReplayState) -> str:
return ("tp=%s workers=%s router_mode=%s overlap_score_weight=%s total_gpus=%s") % (
state.tp,
state.workers,
state.router_mode,
state.overlap_score_weight,
state.total_gpus_used,
)
def summarize_constraints(
report: Mapping[str, Any],
constraints: Mapping[str, float],
total_gpus_used: int,
) -> str:
if not constraints:
return "constraints=none"
statuses: list[str] = []
for key, bound in constraints.items():
if bound <= 0:
continue
value = total_gpus_used if key == "max_total_gpus" else report.get(key)
if value is None:
statuses.append(f"{key}=missing<={bound:g} unsatisfied")
continue
metric = float(value)
state = "satisfied" if metric <= bound else "unsatisfied"
statuses.append(f"{key}={metric:.3f}<={bound:g} {state}")
return "constraints=" + ", ".join(statuses) if statuses else "constraints=none"
def log_dense_state_start(state: DenseReplayState) -> None:
logger.info("Replay optimize evaluating %s", format_dense_state(state))
def log_dense_state_finish(
*,
state: DenseReplayState,
report: Mapping[str, Any],
constraints: Mapping[str, float],
score: float,
feasible: bool,
violation_penalty: float,
) -> None:
logger.info(
"Replay optimize finished %s score=%.3f feasible=%s violation_penalty=%.6f %s",
format_dense_state(state),
score,
feasible,
violation_penalty,
summarize_constraints(report, constraints, state.total_gpus_used),
)
def log_agg_state_start(state: DenseAggReplayState) -> None:
logger.info("Replay optimize evaluating %s", format_agg_state(state))
def log_state_start(state: DenseReplayState | DenseAggReplayState) -> None:
logger.info("Replay optimize evaluating %s", state.format_summary())
def log_agg_state_finish(
def log_state_finish(
*,
state: DenseAggReplayState,
state: DenseReplayState | DenseAggReplayState,
report: Mapping[str, Any],
constraints: Mapping[str, float],
constraints: ReplayConstraints,
score: float,
feasible: bool,
violation_penalty: float,
) -> None:
logger.info(
"Replay optimize finished %s score=%.3f feasible=%s violation_penalty=%.6f %s",
format_agg_state(state),
state.format_summary(),
score,
feasible,
violation_penalty,
summarize_constraints(report, constraints, state.total_gpus_used),
constraints.summarize(report, state.total_gpus_used),
)
......@@ -3,12 +3,122 @@
from __future__ import annotations
import math
import os
from dataclasses import dataclass
from collections.abc import Iterator, Mapping
from dataclasses import dataclass, fields
from enum import Enum
from typing import Any
import pandas as pd
from .constants import SUPPORTED_CONSTRAINTS
@dataclass(frozen=True)
class ReplayConstraints:
mean_ttft_ms: float | None = None
p95_ttft_ms: float | None = None
mean_tpot_ms: float | None = None
p95_tpot_ms: float | None = None
mean_e2e_latency_ms: float | None = None
p95_e2e_latency_ms: float | None = None
max_total_gpus: int | None = None
@classmethod
def from_mapping(
cls,
mapping: Mapping[str, float] | None,
max_total_gpus: int,
) -> ReplayConstraints:
raw = dict(mapping or {})
unknown = sorted(set(raw) - SUPPORTED_CONSTRAINTS)
if unknown:
raise ValueError(
"unsupported constraints: "
+ ", ".join(unknown)
+ f"; supported constraints are {sorted(SUPPORTED_CONSTRAINTS)}"
)
raw_gpus = raw.get("max_total_gpus")
if raw_gpus is not None and int(raw_gpus) != max_total_gpus:
raise ValueError(
"constraints['max_total_gpus'] must match max_total_gpus when both are provided"
)
def _bound(key: str) -> float | None:
value = raw.get(key)
return None if value is None or value <= 0 else float(value)
return cls(
mean_ttft_ms=_bound("mean_ttft_ms"),
p95_ttft_ms=_bound("p95_ttft_ms"),
mean_tpot_ms=_bound("mean_tpot_ms"),
p95_tpot_ms=_bound("p95_tpot_ms"),
mean_e2e_latency_ms=_bound("mean_e2e_latency_ms"),
p95_e2e_latency_ms=_bound("p95_e2e_latency_ms"),
max_total_gpus=int(max_total_gpus),
)
def _active(
self, report: Mapping[str, Any], total_gpus_used: int
) -> Iterator[tuple[str, float | None, float]]:
for field in fields(self):
if field.name == "max_total_gpus":
continue
bound = getattr(self, field.name)
if bound is None:
continue
value = report.get(field.name)
yield field.name, None if value is None else float(value), bound
if self.max_total_gpus is not None:
yield (
"max_total_gpus",
float(total_gpus_used),
float(self.max_total_gpus),
)
def violation_penalty(
self, report: Mapping[str, Any], total_gpus_used: int
) -> float:
penalty = 0.0
for _, metric, bound in self._active(report, total_gpus_used):
if metric is None:
penalty += math.inf
continue
penalty += max(metric / bound - 1.0, 0.0)
return penalty
def summarize(self, report: Mapping[str, Any], total_gpus_used: int) -> str:
statuses: list[str] = []
for name, metric, bound in self._active(report, total_gpus_used):
if metric is None:
statuses.append(f"{name}=missing<={bound:g} unsatisfied")
continue
state = "satisfied" if metric <= bound else "unsatisfied"
statuses.append(f"{name}={metric:.3f}<={bound:g} {state}")
return "constraints=" + ", ".join(statuses) if statuses else "constraints=none"
def aic_task_kwargs(self) -> dict[str, float | None]:
return {
"ttft": self.mean_ttft_ms,
"tpot": self.mean_tpot_ms,
"request_latency": self.mean_e2e_latency_ms,
}
class ReplayObjective(str, Enum):
THROUGHPUT = "throughput"
MEAN_TTFT = "mean_ttft"
MEAN_E2E_LATENCY = "mean_e2e_latency"
def score(self, report: Mapping[str, Any]) -> float:
if self is ReplayObjective.THROUGHPUT:
return float(report["output_throughput_tok_s"])
if self is ReplayObjective.MEAN_TTFT:
return -float(report["mean_ttft_ms"])
return -float(report["mean_e2e_latency_ms"])
@dataclass(frozen=True)
class SyntheticReplayWorkload:
......@@ -45,6 +155,14 @@ class DenseReplayState:
+ self.decode_tp * self.decode_workers
)
def format_summary(self) -> str:
return (
f"prefill_tp={self.prefill_tp} decode_tp={self.decode_tp} "
f"prefill_workers={self.prefill_workers} decode_workers={self.decode_workers} "
f"router_mode={self.router_mode} overlap_score_weight={self.overlap_score_weight} "
f"total_gpus={self.total_gpus_used}"
)
@dataclass(frozen=True)
class DenseAggReplayState:
......@@ -57,6 +175,13 @@ class DenseAggReplayState:
def total_gpus_used(self) -> int:
return self.tp * self.workers
def format_summary(self) -> str:
return (
f"tp={self.tp} workers={self.workers} "
f"router_mode={self.router_mode} overlap_score_weight={self.overlap_score_weight} "
f"total_gpus={self.total_gpus_used}"
)
@dataclass(frozen=True)
class DenseReplayOptimizationResult:
......
......@@ -7,28 +7,9 @@ import math
from collections.abc import Mapping, Sequence
from typing import Any
import pandas as pd
def _metric_value(report: Mapping[str, Any], key: str, total_gpus_used: int) -> float:
if key == "max_total_gpus":
return float(total_gpus_used)
value = report.get(key)
if value is None:
return math.inf
return float(value)
def _violation_penalty(
report: Mapping[str, Any],
constraints: Mapping[str, float],
total_gpus_used: int,
) -> float:
penalty = 0.0
for key, bound in constraints.items():
if bound <= 0:
continue
metric = _metric_value(report, key, total_gpus_used)
penalty += max(metric / bound - 1.0, 0.0)
return penalty
from .models import DenseReplayOptimizationResult
def _rank_record(record: Mapping[str, Any]) -> tuple[float, float, float]:
......@@ -58,3 +39,48 @@ def _pick_best_record(records: Sequence[dict[str, Any]]) -> dict[str, Any]:
float(record.get("mean_e2e_latency_ms", math.inf)),
),
)
def _finalize_result(
cache: Mapping[Any, dict[str, Any]],
) -> DenseReplayOptimizationResult:
evaluated_df = pd.DataFrame.from_records(list(cache.values()))
feasible_df = (
evaluated_df[evaluated_df["feasible"]]
if not evaluated_df.empty
else evaluated_df
)
if not feasible_df.empty:
feasible_df = feasible_df.sort_values(
by=[
"score",
"output_throughput_tok_s",
"mean_e2e_latency_ms",
"total_gpus_used",
],
ascending=[False, False, True, True],
).reset_index(drop=True)
best_feasible = feasible_df.iloc[0].to_dict() if not feasible_df.empty else None
best_infeasible = None
if not evaluated_df.empty:
infeasible_df = evaluated_df[~evaluated_df["feasible"]]
if not infeasible_df.empty:
best_infeasible = (
infeasible_df.sort_values(
by=[
"violation_penalty",
"output_throughput_tok_s",
"mean_e2e_latency_ms",
],
ascending=[True, False, True],
)
.iloc[0]
.to_dict()
)
return DenseReplayOptimizationResult(
best_feasible=best_feasible,
best_infeasible=best_infeasible,
evaluated_df=evaluated_df.reset_index(drop=True),
feasible_df=feasible_df,
)
......@@ -30,8 +30,6 @@ from collections.abc import Mapping, Sequence
from concurrent.futures import ProcessPoolExecutor
from typing import Literal
import pandas as pd
from dynamo.llm import KvRouterConfig, MockEngineArgs
from . import aic, evaluate
......@@ -40,16 +38,17 @@ from .constants import (
DEFAULT_MAX_PARALLEL_EVALS,
DEFAULT_OVERLAP_SCORE_WEIGHTS,
DEFAULT_SEARCH_ROUNDS,
SUPPORTED_CONSTRAINTS,
)
from .models import (
DenseAggReplayState,
DenseReplayOptimizationResult,
DenseReplayState,
ReplayConstraints,
ReplayObjective,
SyntheticReplayWorkload,
TraceReplayWorkload,
)
from .scoring import _pick_best_record
from .scoring import _finalize_result, _pick_best_record
def _validate_backend(backend: str) -> str:
......@@ -60,31 +59,6 @@ def _validate_backend(backend: str) -> str:
return backend
def _normalize_constraints(
constraints: Mapping[str, float] | None,
max_total_gpus: int,
) -> dict[str, float]:
normalized = dict(constraints or {})
invalid_keys = sorted(set(normalized) - SUPPORTED_CONSTRAINTS)
if invalid_keys:
raise ValueError(
"unsupported constraints: "
+ ", ".join(invalid_keys)
+ f"; supported constraints are {sorted(SUPPORTED_CONSTRAINTS)}"
)
if (
"max_total_gpus" in normalized
and int(normalized["max_total_gpus"]) != max_total_gpus
):
raise ValueError(
"constraints['max_total_gpus'] must match max_total_gpus when both are provided"
)
normalized["max_total_gpus"] = float(max_total_gpus)
return normalized
def _normalize_overlap_score_weights(
overlap_score_weights: Sequence[float] | None,
) -> tuple[float, ...]:
......@@ -303,6 +277,7 @@ def optimize_dense_disagg_with_replay(
base_router_config: KvRouterConfig | None = None,
max_total_gpus: int,
constraints: Mapping[str, float] | None = None,
objective: Literal["throughput", "mean_e2e_latency", "mean_ttft"] = "throughput",
router_mode: Literal["kv_router", "round_robin", "both"] = "kv_router",
overlap_score_weights: Sequence[float] | None = None,
max_parallel_evals: int = DEFAULT_MAX_PARALLEL_EVALS,
......@@ -310,8 +285,10 @@ def optimize_dense_disagg_with_replay(
"""Run a heuristic block search over dense disaggregated offline replay configs.
This routine assumes we want to use as much of `max_total_gpus` as possible,
then ranks visited states by raw output throughput subject to replay
constraints. The descended dimensions are:
then ranks visited states by the selected `objective` subject to replay
constraints. Supported objectives: `"throughput"` (default, maximize
`output_throughput_tok_s`), `"mean_e2e_latency"` and `"mean_ttft"` (minimize
the corresponding report metric). The descended dimensions are:
1. `(prefill_tp, decode_tp)` at equal worker counts that fit the budget.
2. `(prefill_workers, decode_workers)` on the budget edge for the incumbent TP
shape.
......@@ -321,10 +298,11 @@ def optimize_dense_disagg_with_replay(
"""
backend = _validate_backend(backend)
router_mode = _normalize_router_mode(router_mode)
typed_objective = ReplayObjective(objective)
if max_total_gpus < 2:
raise ValueError("max_total_gpus must be at least 2 for disaggregated replay")
normalized_constraints = _normalize_constraints(constraints, max_total_gpus)
typed_constraints = ReplayConstraints.from_mapping(constraints, max_total_gpus)
overlap_weights = _normalize_overlap_score_weights(overlap_score_weights)
if router_mode == "round_robin":
overlap_weights = (0.0,)
......@@ -368,7 +346,8 @@ def optimize_dense_disagg_with_replay(
model=model,
backend=backend,
system=system,
constraints=normalized_constraints,
objective=typed_objective,
constraints=typed_constraints,
cache=cache,
max_parallel_evals=max_parallel_evals,
executor=executor,
......@@ -391,7 +370,8 @@ def optimize_dense_disagg_with_replay(
model=model,
backend=backend,
system=system,
constraints=normalized_constraints,
objective=typed_objective,
constraints=typed_constraints,
cache=cache,
max_parallel_evals=max_parallel_evals,
executor=executor,
......@@ -420,7 +400,8 @@ def optimize_dense_disagg_with_replay(
model=model,
backend=backend,
system=system,
constraints=normalized_constraints,
objective=typed_objective,
constraints=typed_constraints,
cache=cache,
max_parallel_evals=max_parallel_evals,
executor=executor,
......@@ -433,46 +414,7 @@ def optimize_dense_disagg_with_replay(
if executor is not None:
executor.shutdown()
evaluated_df = pd.DataFrame.from_records(list(cache.values()))
feasible_df = (
evaluated_df[evaluated_df["feasible"]]
if not evaluated_df.empty
else evaluated_df
)
if not feasible_df.empty:
feasible_df = feasible_df.sort_values(
by=[
"score",
"output_throughput_tok_s",
"mean_e2e_latency_ms",
"total_gpus_used",
],
ascending=[False, False, True, True],
).reset_index(drop=True)
best_feasible = feasible_df.iloc[0].to_dict() if not feasible_df.empty else None
best_infeasible = None
if not evaluated_df.empty:
infeasible_df = evaluated_df[~evaluated_df["feasible"]]
if not infeasible_df.empty:
best_infeasible = (
infeasible_df.sort_values(
by=[
"violation_penalty",
"output_throughput_tok_s",
"mean_e2e_latency_ms",
],
ascending=[True, False, True],
)
.iloc[0]
.to_dict()
)
return DenseReplayOptimizationResult(
best_feasible=best_feasible,
best_infeasible=best_infeasible,
evaluated_df=evaluated_df.reset_index(drop=True),
feasible_df=feasible_df,
)
return _finalize_result(cache)
def optimize_dense_agg_with_replay(
......@@ -485,6 +427,7 @@ def optimize_dense_agg_with_replay(
base_router_config: KvRouterConfig | None = None,
max_total_gpus: int,
constraints: Mapping[str, float] | None = None,
objective: Literal["throughput", "mean_e2e_latency", "mean_ttft"] = "throughput",
router_mode: Literal["kv_router", "round_robin", "both"] = "kv_router",
overlap_score_weights: Sequence[float] | None = None,
max_parallel_evals: int = DEFAULT_MAX_PARALLEL_EVALS,
......@@ -492,8 +435,10 @@ def optimize_dense_agg_with_replay(
"""Run a heuristic block search over dense aggregated offline replay configs.
This routine assumes we want to use as much of `max_total_gpus` as possible,
then ranks visited states by raw output throughput subject to replay
constraints. The descended dimensions are:
then ranks visited states by the selected `objective` subject to replay
constraints. Supported objectives: `"throughput"` (default, maximize
`output_throughput_tok_s`), `"mean_e2e_latency"` and `"mean_ttft"` (minimize
the corresponding report metric). The descended dimensions are:
1. `tp` at the maximum worker count that fits the budget.
2. `workers` for the incumbent `tp`.
3. `(router_mode, overlap_score_weight)`.
......@@ -502,10 +447,11 @@ def optimize_dense_agg_with_replay(
"""
backend = _validate_backend(backend)
router_mode = _normalize_router_mode(router_mode)
typed_objective = ReplayObjective(objective)
if max_total_gpus < 1:
raise ValueError("max_total_gpus must be at least 1 for aggregated replay")
normalized_constraints = _normalize_constraints(constraints, max_total_gpus)
typed_constraints = ReplayConstraints.from_mapping(constraints, max_total_gpus)
overlap_weights = _normalize_overlap_score_weights(overlap_score_weights)
if router_mode == "round_robin":
overlap_weights = (0.0,)
......@@ -542,7 +488,8 @@ def optimize_dense_agg_with_replay(
model=model,
backend=backend,
system=system,
constraints=normalized_constraints,
objective=typed_objective,
constraints=typed_constraints,
cache=cache,
max_parallel_evals=max_parallel_evals,
executor=executor,
......@@ -563,7 +510,8 @@ def optimize_dense_agg_with_replay(
model=model,
backend=backend,
system=system,
constraints=normalized_constraints,
objective=typed_objective,
constraints=typed_constraints,
cache=cache,
max_parallel_evals=max_parallel_evals,
executor=executor,
......@@ -593,7 +541,8 @@ def optimize_dense_agg_with_replay(
model=model,
backend=backend,
system=system,
constraints=normalized_constraints,
objective=typed_objective,
constraints=typed_constraints,
cache=cache,
max_parallel_evals=max_parallel_evals,
executor=executor,
......@@ -607,43 +556,4 @@ def optimize_dense_agg_with_replay(
if executor is not None:
executor.shutdown()
evaluated_df = pd.DataFrame.from_records(list(cache.values()))
feasible_df = (
evaluated_df[evaluated_df["feasible"]]
if not evaluated_df.empty
else evaluated_df
)
if not feasible_df.empty:
feasible_df = feasible_df.sort_values(
by=[
"score",
"output_throughput_tok_s",
"mean_e2e_latency_ms",
"total_gpus_used",
],
ascending=[False, False, True, True],
).reset_index(drop=True)
best_feasible = feasible_df.iloc[0].to_dict() if not feasible_df.empty else None
best_infeasible = None
if not evaluated_df.empty:
infeasible_df = evaluated_df[~evaluated_df["feasible"]]
if not infeasible_df.empty:
best_infeasible = (
infeasible_df.sort_values(
by=[
"violation_penalty",
"output_throughput_tok_s",
"mean_e2e_latency_ms",
],
ascending=[True, False, True],
)
.iloc[0]
.to_dict()
)
return DenseReplayOptimizationResult(
best_feasible=best_feasible,
best_infeasible=best_infeasible,
evaluated_df=evaluated_df.reset_index(drop=True),
feasible_df=feasible_df,
)
return _finalize_result(cache)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment