protocols.rs 11.2 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use derive_builder::Builder;
5
use serde::{Deserialize, Serialize};
6
use std::collections::{HashMap, HashSet};
7
8
use std::path::{Path, PathBuf};
use std::sync::Arc;
9
10
use uuid::Uuid;

11
12
use crate::perf_model::PerfModel;
use dynamo_kv_router::protocols::KvCacheEvent;
13
14
use dynamo_tokens::blocks::UniqueBlock;
use dynamo_tokens::{BlockHash, SequenceHash, Token};
15

16
17
18
19
20
21
/// Trait for publishing KV cache events.
/// This abstracts the runtime dependency so mocker components can remain generic.
pub trait KvCacheEventSink: Send + Sync {
    fn publish(&self, event: KvCacheEvent) -> anyhow::Result<()>;
}

22
23
24
pub type NumBlocks = usize;

/// Represents different block movement operations in the cache
Yan Ru Pei's avatar
Yan Ru Pei committed
25
/// For Use and Promote variants, block hashes are included for KV event publishing
26
27
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MoveBlock {
Yan Ru Pei's avatar
Yan Ru Pei committed
28
    Use(Vec<UniqueBlock>, Vec<BlockHash>),
29
30
    Destroy(Vec<UniqueBlock>),
    Deref(Vec<UniqueBlock>),
Yan Ru Pei's avatar
Yan Ru Pei committed
31
    Promote(Uuid, SequenceHash, Option<u64>, BlockHash),
32
33
34
35
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MoveBlockResponse {
36
37
    Store(Vec<SequenceHash>, Option<u64>),
    Remove(Vec<SequenceHash>),
38
39
40
41
42
43
44
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DirectRequest {
    pub tokens: Vec<Token>,
    pub max_output_tokens: usize,
    pub uuid: Option<Uuid>,
Yan Ru Pei's avatar
Yan Ru Pei committed
45
    pub dp_rank: u32,
46
47
48
49
50
}

/// Represents the cost of prefilling content in the cache
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrefillCost {
51
    pub new_blocks: usize,
52
    pub new_tokens: usize,
53
54
55
}

impl PrefillCost {
56
57
58
59
60
    pub fn predict_prefill_compute(
        &self,
        new_tokens: Option<usize>,
        perf_model: &PerfModel,
    ) -> f64 {
61
        let tokens = new_tokens.unwrap_or(self.new_tokens);
62
        perf_model.predict_prefill_time(tokens)
63
    }
64
65
}

66
67
68
69
70
71
72
/// Signal for output token generation with completion status
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OutputSignal {
    pub uuid: Uuid,
    pub completed: bool,
}

73
74
75
76
77
78
79
80
81
82
83
84
/// Worker type for disaggregated serving configurations
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum WorkerType {
    /// Standard aggregated worker handling both prefill and decode
    #[default]
    Aggregated,
    /// Dedicated prefill worker in disaggregated mode
    Prefill,
    /// Dedicated decode worker in disaggregated mode
    Decode,
}

85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
/// Configuration arguments for MockVllmEngine
#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
#[builder(pattern = "owned", build_fn(public))]
pub struct MockEngineArgs {
    #[builder(default = "16384")]
    pub num_gpu_blocks: usize,

    #[builder(default = "64")]
    pub block_size: usize,

    // This was 1024 in the past but reverted back to 256
    #[builder(default = Some(256))]
    pub max_num_seqs: Option<usize>,

    // default for open api server, for llm class it's 16384
    #[builder(default = Some(8192))]
    pub max_num_batched_tokens: Option<usize>,

    #[builder(default = true)]
    pub enable_prefix_caching: bool,

106
107
108
    #[builder(default = true)]
    pub enable_chunked_prefill: bool,

109
110
111
112
113
114
115
116
    #[builder(default = "0.01")]
    pub watermark: f64,

    #[builder(default = "1.0")]
    pub speedup_ratio: f64,

    #[builder(default = "1")]
    pub dp_size: u32,
117
118
119
120

    /// Optional startup time in seconds to simulate engine initialization delay
    #[builder(default = "None")]
    pub startup_time: Option<f64>,
121
122
123
124

    /// Worker type for disaggregated serving (Aggregated, Prefill, or Decode)
    #[builder(default = "WorkerType::Aggregated")]
    pub worker_type: WorkerType,
125
126
127
128
129

    /// Performance model for timing predictions (not serialized, loaded from planner_profile_data)
    #[serde(skip)]
    #[builder(default = "Arc::new(PerfModel::default())")]
    pub perf_model: Arc<PerfModel>,
130
131
132
133

    /// Enable worker-local KV indexer for tracking this worker's own KV cache state
    #[builder(default = "false")]
    pub enable_local_indexer: bool,
134
135
136
137
138
139

    /// Bootstrap port for disaggregated serving rendezvous.
    /// Prefill workers listen on this port; decode workers connect to it.
    /// If None, bootstrap rendezvous is disabled.
    #[builder(default = "None")]
    pub bootstrap_port: Option<u16>,
140
141
}

142
143
144
145
146
147
148
149
impl Default for MockEngineArgs {
    fn default() -> MockEngineArgs {
        MockEngineArgsBuilder::default()
            .build()
            .expect("Failed to build default MockEngineArgs")
    }
}

150
151
152
153
impl MockEngineArgs {
    pub fn builder() -> MockEngineArgsBuilder {
        MockEngineArgsBuilder::default()
    }
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

    /// Create MockEngineArgs from a JSON file containing extra engine arguments
    pub fn from_json_file(path: &Path) -> anyhow::Result<Self> {
        let mut builder = Self::builder();

        // Load and parse the JSON file
        let file_content = std::fs::read_to_string(path)?;
        let extra_args: HashMap<String, serde_json::Value> = serde_json::from_str(&file_content)?;

        // Define valid field names
        let valid_fields: HashSet<&str> = [
            "num_gpu_blocks",
            "block_size",
            "max_num_seqs",
            "max_num_batched_tokens",
            "enable_prefix_caching",
170
            "enable_chunked_prefill",
171
172
173
            "watermark",
            "speedup_ratio",
            "dp_size",
174
            "startup_time",
175
176
            "is_prefill",
            "is_decode",
177
            "planner_profile_data",
178
            "enable_local_indexer",
179
            "bootstrap_port",
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        ]
        .iter()
        .cloned()
        .collect();

        // Check for invalid arguments
        let invalid_args: Vec<String> = extra_args
            .keys()
            .filter(|key| !valid_fields.contains(key.as_str()))
            .cloned()
            .collect();

        if !invalid_args.is_empty() {
            return Err(anyhow::anyhow!(
                "Invalid arguments found in JSON file: {}. Valid arguments are: {:?}",
                invalid_args.join(", "),
                valid_fields
            ));
        }

        // Apply each extra argument to the builder
201
202
203
204
        if let Some(value) = extra_args.get("num_gpu_blocks")
            && let Some(num) = value.as_u64()
        {
            builder = builder.num_gpu_blocks(num as usize);
205
206
        }

207
208
209
210
        if let Some(value) = extra_args.get("block_size")
            && let Some(num) = value.as_u64()
        {
            builder = builder.block_size(num as usize);
211
212
        }

213
214
215
216
        if let Some(value) = extra_args.get("max_num_seqs")
            && let Some(num) = value.as_u64()
        {
            builder = builder.max_num_seqs(Some(num as usize));
217
218
        }

219
220
221
222
        if let Some(value) = extra_args.get("max_num_batched_tokens")
            && let Some(num) = value.as_u64()
        {
            builder = builder.max_num_batched_tokens(Some(num as usize));
223
224
        }

225
226
227
228
        if let Some(value) = extra_args.get("enable_prefix_caching")
            && let Some(enabled) = value.as_bool()
        {
            builder = builder.enable_prefix_caching(enabled);
229
230
        }

231
232
233
234
        if let Some(value) = extra_args.get("enable_chunked_prefill")
            && let Some(enabled) = value.as_bool()
        {
            builder = builder.enable_chunked_prefill(enabled);
235
236
        }

237
238
239
240
        if let Some(value) = extra_args.get("watermark")
            && let Some(num) = value.as_f64()
        {
            builder = builder.watermark(num);
241
242
        }

243
244
245
246
        if let Some(value) = extra_args.get("speedup_ratio")
            && let Some(num) = value.as_f64()
        {
            builder = builder.speedup_ratio(num);
247
248
        }

249
250
251
252
        if let Some(value) = extra_args.get("dp_size")
            && let Some(num) = value.as_u64()
        {
            builder = builder.dp_size(num as u32);
253
254
        }

255
256
257
258
259
260
        if let Some(value) = extra_args.get("startup_time")
            && let Some(num) = value.as_f64()
        {
            builder = builder.startup_time(Some(num));
        }

261
262
263
264
265
266
        if let Some(value) = extra_args.get("enable_local_indexer")
            && let Some(enabled) = value.as_bool()
        {
            builder = builder.enable_local_indexer(enabled);
        }

267
268
269
270
271
272
        if let Some(value) = extra_args.get("bootstrap_port")
            && let Some(port) = value.as_u64()
        {
            builder = builder.bootstrap_port(Some(port as u16));
        }

273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        // Parse worker type from is_prefill and is_decode flags
        let is_prefill = extra_args
            .get("is_prefill")
            .and_then(|v| v.as_bool())
            .unwrap_or(false);
        let is_decode = extra_args
            .get("is_decode")
            .and_then(|v| v.as_bool())
            .unwrap_or(false);

        // Determine worker type based on flags
        let worker_type = match (is_prefill, is_decode) {
            (false, false) => WorkerType::Aggregated,
            (true, false) => WorkerType::Prefill,
            (false, true) => WorkerType::Decode,
            (true, true) => panic!(
                "Invalid worker configuration: is_prefill and is_decode cannot both be true. \
                 Worker must be either Aggregated (both false), Prefill (is_prefill=true), or Decode (is_decode=true)."
            ),
        };
        builder = builder.worker_type(worker_type);

295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
        // Load performance model from NPZ file if provided
        let perf_model = if let Some(path_str) = extra_args.get("planner_profile_data")
            && let Some(path_str) = path_str.as_str()
        {
            let npz_path = PathBuf::from(path_str);
            match PerfModel::from_npz(&npz_path) {
                Ok(model) => {
                    tracing::info!("Successfully loaded performance model from: {:?}", npz_path);
                    Arc::new(model)
                }
                Err(e) => {
                    tracing::error!(
                        "Failed to load performance model from {:?}: {}. Falling back to polynomial model.",
                        npz_path,
                        e
                    );
                    Arc::new(PerfModel::default())
                }
            }
        } else {
            Arc::new(PerfModel::default())
        };
        builder = builder.perf_model(perf_model);

319
320
321
322
323
        // Build the MockEngineArgs with either defaults or overridden values
        builder
            .build()
            .map_err(|e| anyhow::anyhow!("Failed to build MockEngineArgs: {}", e))
    }
324
325
}

326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_unique_block_default_uniqueness() {
        // Create 10 default UniqueBlock instances
        let blocks: Vec<UniqueBlock> = (0..10).map(|_| UniqueBlock::default()).collect();

        // Extract UUIDs from each block
        let mut uuids = Vec::new();
        for block in blocks {
            match block {
                UniqueBlock::PartialBlock(uuid) => uuids.push(uuid),
                _ => panic!("Expected UuidIdentifier variant"),
            }
        }

        // Check that all UUIDs are unique by comparing each with every other
        for i in 0..uuids.len() {
            for j in i + 1..uuids.len() {
                assert_ne!(
                    uuids[i], uuids[j],
                    "UUID at index {} and {} are identical",
                    i, j
                );
            }
        }
    }
}