planner_handle.rs 7.77 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

//! Public handle for driving an offline replay with planner-in-the-loop.
//!
//! Supports both aggregated and disaggregated topologies via [`RuntimeKind`].
//! The Python planner adapter calls [`PlannerReplayHandle::advance_to`] to
//! step the simulation, collects metrics, and calls [`PlannerReplayHandle::apply_scaling`]
//! to resize worker pools.

use std::path::Path;
use std::time::Instant;

use anyhow::Result;
use dynamo_kv_router::config::KvRouterConfig;

use super::offline::agg::AggRuntime;
18
use super::offline::components::{ReplayMode, TrafficStats};
19
20
21
22
23
24
25
26
27
28
29
use super::offline::disagg::DisaggRuntime;
use super::{
    OfflineDisaggReplayConfig, ReplayPrefillLoadEstimator, ReplayRouterMode, TraceSimulationReport,
};
use crate::common::protocols::{ForwardPassSnapshot, MockEngineArgs};
use crate::loadgen::Trace;

/// Snapshot of metrics collected between planner ticks.
///
/// For aggregated mode, prefill fields are 0 and all data is in decode fields
/// (matching how the planner treats agg as a single decode-stage engine).
30
31
32
33
34
///
/// Traffic metrics are NOT included here — they accumulate across ticks and
/// must be drained explicitly via [`PlannerReplayHandle::drain_traffic`] on
/// throughput-scaling ticks only. Draining on every tick would discard data
/// between the more frequent load-scaling ticks.
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
pub struct PlannerTickData {
    /// Current simulated time in milliseconds.
    pub now_ms: f64,
    /// Whether the replay has finished (no more work).
    pub is_done: bool,
    /// Prefill FPM snapshots since last tick: (worker_id, snapshot).
    pub prefill_fpm_snapshots: Vec<(usize, ForwardPassSnapshot)>,
    /// Decode (or agg) FPM snapshots since last tick: (worker_id, snapshot).
    pub decode_fpm_snapshots: Vec<(usize, ForwardPassSnapshot)>,
    /// Active prefill workers (0 for agg mode).
    pub active_prefill_count: usize,
    /// Active decode workers (or total active for agg mode).
    pub active_decode_count: usize,
    /// Total prefill workers including pending removal (0 for agg mode).
    pub total_prefill_count: usize,
    /// Total decode workers including pending removal (or total for agg mode).
    pub total_decode_count: usize,
}

#[allow(clippy::large_enum_variant)]
enum RuntimeKind {
    Agg(AggRuntime),
    Disagg(DisaggRuntime),
}

pub struct PlannerReplayHandle {
    runtime: RuntimeKind,
    started_at: Instant,
}

impl PlannerReplayHandle {
    /// Create a handle for an aggregated trace-file replay.
    #[allow(clippy::too_many_arguments)]
    pub fn from_trace_file(
        args: MockEngineArgs,
        router_config: Option<KvRouterConfig>,
        prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
        trace_path: &Path,
        trace_block_size: usize,
        num_workers: usize,
        arrival_speedup_ratio: f64,
        router_mode: ReplayRouterMode,
    ) -> Result<Self> {
        let args = args.normalized()?;
        let trace = Trace::from_mooncake(trace_path, trace_block_size)?
            .normalize_session_starts()?
            .speed_up_timing(arrival_speedup_ratio)?;
        let runtime = AggRuntime::new_workload(
            &args,
            router_config,
            prefill_load_estimator,
            trace.into_trace_driver_with_block_size(args.block_size)?,
            num_workers,
            ReplayMode::Trace,
            router_mode,
        )?;
        Ok(Self {
            runtime: RuntimeKind::Agg(runtime),
            started_at: Instant::now(),
        })
    }

    /// Create a handle for a disaggregated trace-file replay.
    pub fn from_trace_file_disagg(
        config: OfflineDisaggReplayConfig,
        router_config: Option<KvRouterConfig>,
        prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
        trace_path: &Path,
        trace_block_size: usize,
        arrival_speedup_ratio: f64,
        router_mode: ReplayRouterMode,
    ) -> Result<Self> {
        let config = config.normalized()?;
        let trace = Trace::from_mooncake(trace_path, trace_block_size)?
            .normalize_session_starts()?
            .speed_up_timing(arrival_speedup_ratio)?;
        let runtime = DisaggRuntime::new_workload(
            &config,
            router_config,
            prefill_load_estimator,
            trace.into_trace_driver_with_block_size(config.decode_args.block_size)?,
            ReplayMode::Trace,
            router_mode,
        )?;
        Ok(Self {
            runtime: RuntimeKind::Disagg(runtime),
            started_at: Instant::now(),
        })
    }

    /// Advance the simulation up to `until_ms`, collect metrics, return tick data.
126
127
128
    ///
    /// Traffic metrics are NOT drained here — call [`drain_traffic`] explicitly
    /// on throughput-scaling ticks so the accumulator covers the full interval.
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    pub fn advance_to(&mut self, until_ms: f64) -> Result<PlannerTickData> {
        match &mut self.runtime {
            RuntimeKind::Agg(rt) => {
                let is_done = rt.advance_to(until_ms)?;
                let fpm = rt.drain_fpm();
                Ok(PlannerTickData {
                    now_ms: rt.now_ms(),
                    is_done,
                    prefill_fpm_snapshots: Vec::new(),
                    decode_fpm_snapshots: fpm,
                    active_prefill_count: 0,
                    active_decode_count: rt.active_worker_count(),
                    total_prefill_count: 0,
                    total_decode_count: rt.total_worker_count(),
                })
            }
            RuntimeKind::Disagg(rt) => {
                let is_done = rt.advance_to(until_ms)?;
                let prefill_fpm = rt.drain_prefill_fpm();
                let decode_fpm = rt.drain_decode_fpm();
                Ok(PlannerTickData {
                    now_ms: rt.now_ms(),
                    is_done,
                    prefill_fpm_snapshots: prefill_fpm,
                    decode_fpm_snapshots: decode_fpm,
                    active_prefill_count: rt.active_prefill_count(),
                    active_decode_count: rt.active_decode_count(),
                    total_prefill_count: rt.total_prefill_count(),
                    total_decode_count: rt.total_decode_count(),
                })
            }
        }
    }

163
164
    /// Drain accumulated traffic metrics since the last drain.
    ///
165
166
167
168
169
170
171
    /// Call this only on throughput-scaling ticks so the window covers the
    /// full `throughput_adjustment_interval`, not just the gap between load
    /// ticks. The returned [`TrafficStats::avg_kv_hit_rate`] is the
    /// arithmetic mean of per-request ``overlap / isl`` ratios across
    /// admissions in the window — matching the real router's per-request
    /// Prometheus histogram, where each request contributes one sample
    /// regardless of ISL size.
172
    pub fn drain_traffic(&mut self) -> TrafficStats {
173
174
175
176
177
178
        match &mut self.runtime {
            RuntimeKind::Agg(rt) => rt.drain_traffic(),
            RuntimeKind::Disagg(rt) => rt.drain_traffic(),
        }
    }

179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    /// Apply a scaling decision with separate prefill and decode targets.
    /// For agg mode, `target_prefill` is ignored.
    pub fn apply_scaling(&mut self, target_prefill: usize, target_decode: usize) -> Result<()> {
        match &mut self.runtime {
            RuntimeKind::Agg(rt) => rt.apply_scaling(target_decode),
            RuntimeKind::Disagg(rt) => rt.apply_scaling(target_prefill, target_decode),
        }
    }

    /// Finalize the replay and return the report.
    pub fn finalize(self) -> TraceSimulationReport {
        let report = match self.runtime {
            RuntimeKind::Agg(rt) => rt.finalize_report(),
            RuntimeKind::Disagg(rt) => rt.finalize_report(),
        };
        report.with_wall_time_ms(self.started_at.elapsed().as_secs_f64() * 1000.0)
    }
}