entrypoints.rs 7.51 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::collections::VecDeque;

use anyhow::{Result, anyhow, bail};
use dynamo_kv_router::config::KvRouterConfig;

use crate::common::protocols::{DirectRequest, MockEngineArgs};
use crate::loadgen::{Trace, WorkloadDriver};
11
12
13
use crate::replay::{
    ReplayPrefillLoadEstimator, ReplayRouterMode, TraceSimulationReport, normalize_trace_requests,
};
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

use super::live_runtime::LiveRuntime;
use super::state::{LiveReplayMode, LiveRuntimeStats};

fn total_turns(trace: &Trace) -> usize {
    trace
        .sessions
        .iter()
        .map(|session| session.turns.len())
        .sum()
}

fn run_live_runtime(
    args: MockEngineArgs,
    router_config: Option<KvRouterConfig>,
29
    prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
30
31
32
33
34
35
36
37
38
39
40
    pending: VecDeque<DirectRequest>,
    num_workers: usize,
    mode: LiveReplayMode,
    router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
    let runtime = tokio::runtime::Builder::new_multi_thread()
        .enable_all()
        .build()
        .map_err(|e| anyhow!("failed to create online replay runtime: {e}"))?;

    runtime.block_on(async move {
41
42
43
44
45
46
47
48
49
50
51
        LiveRuntime::new(
            args,
            router_config,
            prefill_load_estimator,
            pending,
            num_workers,
            mode,
            router_mode,
        )?
        .run()
        .await
52
53
54
    })
}

55
#[allow(clippy::too_many_arguments)]
56
57
58
fn run_live_workload_runtime(
    args: MockEngineArgs,
    router_config: Option<KvRouterConfig>,
59
    prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    driver: WorkloadDriver,
    total_turns: usize,
    num_workers: usize,
    mode: LiveReplayMode,
    router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
    let runtime = tokio::runtime::Builder::new_multi_thread()
        .enable_all()
        .build()
        .map_err(|e| anyhow!("failed to create online replay runtime: {e}"))?;

    runtime.block_on(async move {
        LiveRuntime::new(
            args,
            router_config,
75
            prefill_load_estimator,
76
77
78
79
80
81
82
83
84
85
86
87
88
            VecDeque::new(),
            num_workers,
            mode,
            router_mode,
        )?
        .run_workload(driver, total_turns)
        .await
    })
}

pub(crate) fn simulate_trace_requests(
    args: MockEngineArgs,
    router_config: Option<KvRouterConfig>,
89
    prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
90
91
92
93
94
95
96
97
98
99
    requests: Vec<DirectRequest>,
    num_workers: usize,
    arrival_speedup_ratio: f64,
    router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
    let args = args.normalized()?;
    let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
    let (report, _) = run_live_runtime(
        args,
        router_config,
100
        prefill_load_estimator,
101
102
103
104
105
106
107
108
109
110
111
        pending,
        num_workers,
        LiveReplayMode::Trace,
        router_mode,
    )?;
    Ok(report)
}

pub(crate) fn simulate_concurrency_requests(
    args: MockEngineArgs,
    router_config: Option<KvRouterConfig>,
112
    prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    requests: Vec<DirectRequest>,
    max_in_flight: usize,
    num_workers: usize,
    router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
    let args = args.normalized()?;
    if requests.is_empty() {
        bail!("online concurrency replay requires at least one request");
    }

    let pending = VecDeque::from(requests);
    let (report, _) = run_live_runtime(
        args,
        router_config,
127
        prefill_load_estimator,
128
129
130
131
132
133
134
135
136
137
138
        pending,
        num_workers,
        LiveReplayMode::Concurrency { max_in_flight },
        router_mode,
    )?;
    Ok(report)
}

pub(crate) fn simulate_trace_workload(
    args: MockEngineArgs,
    router_config: Option<KvRouterConfig>,
139
    prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
140
141
142
143
144
    trace: Trace,
    num_workers: usize,
    router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
    let args = args.normalized()?;
145
    let engine_block_size = args.block_size;
146
147
148
149
    let total_turns = total_turns(&trace);
    let (report, _) = run_live_workload_runtime(
        args,
        router_config,
150
151
        prefill_load_estimator,
        trace.into_trace_driver_with_block_size(engine_block_size)?,
152
153
154
155
156
157
158
159
160
161
162
        total_turns,
        num_workers,
        LiveReplayMode::Trace,
        router_mode,
    )?;
    Ok(report)
}

pub(crate) fn simulate_concurrency_workload(
    args: MockEngineArgs,
    router_config: Option<KvRouterConfig>,
163
    prefill_load_estimator: Option<ReplayPrefillLoadEstimator>,
164
165
166
167
168
169
    trace: Trace,
    max_in_flight: usize,
    num_workers: usize,
    router_mode: ReplayRouterMode,
) -> Result<TraceSimulationReport> {
    let args = args.normalized()?;
170
    let engine_block_size = args.block_size;
171
172
173
174
    let total_turns = total_turns(&trace);
    let (report, _) = run_live_workload_runtime(
        args,
        router_config,
175
176
        prefill_load_estimator,
        trace.into_concurrency_driver_with_block_size(engine_block_size)?,
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        total_turns,
        num_workers,
        LiveReplayMode::Concurrency { max_in_flight },
        router_mode,
    )?;
    Ok(report)
}

#[cfg(test)]
pub(super) fn simulate_trace_requests_with_stats(
    args: MockEngineArgs,
    requests: Vec<DirectRequest>,
    num_workers: usize,
    arrival_speedup_ratio: f64,
    router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
    let args = args.normalized()?;
    let pending = normalize_trace_requests(requests, arrival_speedup_ratio)?;
    run_live_runtime(
        args,
        None,
198
        None,
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        pending,
        num_workers,
        LiveReplayMode::Trace,
        router_mode,
    )
}

#[cfg(test)]
pub(super) fn simulate_concurrency_requests_with_stats(
    args: MockEngineArgs,
    requests: Vec<DirectRequest>,
    max_in_flight: usize,
    num_workers: usize,
    router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
    let args = args.normalized()?;
    let pending = VecDeque::from(requests);
    run_live_runtime(
        args,
        None,
219
        None,
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
        pending,
        num_workers,
        LiveReplayMode::Concurrency { max_in_flight },
        router_mode,
    )
}

#[cfg(test)]
pub(super) fn simulate_trace_workload_with_stats(
    args: MockEngineArgs,
    trace: Trace,
    num_workers: usize,
    router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
    let args = args.normalized()?;
235
    let engine_block_size = args.block_size;
236
237
238
239
    let total_turns = total_turns(&trace);
    run_live_workload_runtime(
        args,
        None,
240
241
        None,
        trace.into_trace_driver_with_block_size(engine_block_size)?,
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        total_turns,
        num_workers,
        LiveReplayMode::Trace,
        router_mode,
    )
}

#[cfg(test)]
pub(super) fn simulate_concurrency_workload_with_stats(
    args: MockEngineArgs,
    trace: Trace,
    max_in_flight: usize,
    num_workers: usize,
    router_mode: ReplayRouterMode,
) -> Result<(TraceSimulationReport, LiveRuntimeStats)> {
    let args = args.normalized()?;
258
    let engine_block_size = args.block_size;
259
260
261
262
    let total_turns = total_turns(&trace);
    run_live_workload_runtime(
        args,
        None,
263
264
        None,
        trace.into_concurrency_driver_with_block_size(engine_block_size)?,
265
266
267
268
269
270
        total_turns,
        num_workers,
        LiveReplayMode::Concurrency { max_in_flight },
        router_mode,
    )
}