protocols.rs 10.5 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use derive_builder::Builder;
5
use serde::{Deserialize, Serialize};
6
use std::collections::{HashMap, HashSet};
7
8
use std::path::{Path, PathBuf};
use std::sync::Arc;
9
10
use uuid::Uuid;

11
use crate::mocker::perf_model::PerfModel;
12
use crate::tokens::blocks::UniqueBlock;
13
use crate::tokens::{BlockHash, SequenceHash, Token};
14

15
16
17
pub type NumBlocks = usize;

/// Represents different block movement operations in the cache
Yan Ru Pei's avatar
Yan Ru Pei committed
18
/// For Use and Promote variants, block hashes are included for KV event publishing
19
20
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MoveBlock {
Yan Ru Pei's avatar
Yan Ru Pei committed
21
    Use(Vec<UniqueBlock>, Vec<BlockHash>),
22
23
    Destroy(Vec<UniqueBlock>),
    Deref(Vec<UniqueBlock>),
Yan Ru Pei's avatar
Yan Ru Pei committed
24
    Promote(Uuid, SequenceHash, Option<u64>, BlockHash),
25
26
27
28
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MoveBlockResponse {
29
30
    Store(Vec<SequenceHash>, Option<u64>),
    Remove(Vec<SequenceHash>),
31
32
33
34
35
36
37
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DirectRequest {
    pub tokens: Vec<Token>,
    pub max_output_tokens: usize,
    pub uuid: Option<Uuid>,
Yan Ru Pei's avatar
Yan Ru Pei committed
38
    pub dp_rank: u32,
39
40
41
42
43
}

/// Represents the cost of prefilling content in the cache
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrefillCost {
44
    pub new_blocks: usize,
45
    pub new_tokens: usize,
46
47
48
}

impl PrefillCost {
49
50
51
52
53
    pub fn predict_prefill_compute(
        &self,
        new_tokens: Option<usize>,
        perf_model: &PerfModel,
    ) -> f64 {
54
        let tokens = new_tokens.unwrap_or(self.new_tokens);
55
        perf_model.predict_prefill_time(tokens)
56
    }
57
58
}

59
60
61
62
63
64
65
/// Signal for output token generation with completion status
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OutputSignal {
    pub uuid: Uuid,
    pub completed: bool,
}

66
67
68
69
70
71
72
73
74
75
76
77
/// Worker type for disaggregated serving configurations
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum WorkerType {
    /// Standard aggregated worker handling both prefill and decode
    #[default]
    Aggregated,
    /// Dedicated prefill worker in disaggregated mode
    Prefill,
    /// Dedicated decode worker in disaggregated mode
    Decode,
}

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
/// Configuration arguments for MockVllmEngine
#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
#[builder(pattern = "owned", build_fn(public))]
pub struct MockEngineArgs {
    #[builder(default = "16384")]
    pub num_gpu_blocks: usize,

    #[builder(default = "64")]
    pub block_size: usize,

    // This was 1024 in the past but reverted back to 256
    #[builder(default = Some(256))]
    pub max_num_seqs: Option<usize>,

    // default for open api server, for llm class it's 16384
    #[builder(default = Some(8192))]
    pub max_num_batched_tokens: Option<usize>,

    #[builder(default = true)]
    pub enable_prefix_caching: bool,

99
100
101
    #[builder(default = true)]
    pub enable_chunked_prefill: bool,

102
103
104
105
106
107
108
109
    #[builder(default = "0.01")]
    pub watermark: f64,

    #[builder(default = "1.0")]
    pub speedup_ratio: f64,

    #[builder(default = "1")]
    pub dp_size: u32,
110
111
112
113

    /// Optional startup time in seconds to simulate engine initialization delay
    #[builder(default = "None")]
    pub startup_time: Option<f64>,
114
115
116
117

    /// Worker type for disaggregated serving (Aggregated, Prefill, or Decode)
    #[builder(default = "WorkerType::Aggregated")]
    pub worker_type: WorkerType,
118
119
120
121
122

    /// Performance model for timing predictions (not serialized, loaded from planner_profile_data)
    #[serde(skip)]
    #[builder(default = "Arc::new(PerfModel::default())")]
    pub perf_model: Arc<PerfModel>,
123
124
125
126

    /// Enable worker-local KV indexer for tracking this worker's own KV cache state
    #[builder(default = "false")]
    pub enable_local_indexer: bool,
127
128
}

129
130
131
132
133
134
135
136
impl Default for MockEngineArgs {
    fn default() -> MockEngineArgs {
        MockEngineArgsBuilder::default()
            .build()
            .expect("Failed to build default MockEngineArgs")
    }
}

137
138
139
140
impl MockEngineArgs {
    pub fn builder() -> MockEngineArgsBuilder {
        MockEngineArgsBuilder::default()
    }
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

    /// Create MockEngineArgs from a JSON file containing extra engine arguments
    pub fn from_json_file(path: &Path) -> anyhow::Result<Self> {
        let mut builder = Self::builder();

        // Load and parse the JSON file
        let file_content = std::fs::read_to_string(path)?;
        let extra_args: HashMap<String, serde_json::Value> = serde_json::from_str(&file_content)?;

        // Define valid field names
        let valid_fields: HashSet<&str> = [
            "num_gpu_blocks",
            "block_size",
            "max_num_seqs",
            "max_num_batched_tokens",
            "enable_prefix_caching",
157
            "enable_chunked_prefill",
158
159
160
            "watermark",
            "speedup_ratio",
            "dp_size",
161
            "startup_time",
162
163
            "is_prefill",
            "is_decode",
164
            "planner_profile_data",
165
            "enable_local_indexer",
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        ]
        .iter()
        .cloned()
        .collect();

        // Check for invalid arguments
        let invalid_args: Vec<String> = extra_args
            .keys()
            .filter(|key| !valid_fields.contains(key.as_str()))
            .cloned()
            .collect();

        if !invalid_args.is_empty() {
            return Err(anyhow::anyhow!(
                "Invalid arguments found in JSON file: {}. Valid arguments are: {:?}",
                invalid_args.join(", "),
                valid_fields
            ));
        }

        // Apply each extra argument to the builder
187
188
189
190
        if let Some(value) = extra_args.get("num_gpu_blocks")
            && let Some(num) = value.as_u64()
        {
            builder = builder.num_gpu_blocks(num as usize);
191
192
        }

193
194
195
196
        if let Some(value) = extra_args.get("block_size")
            && let Some(num) = value.as_u64()
        {
            builder = builder.block_size(num as usize);
197
198
        }

199
200
201
202
        if let Some(value) = extra_args.get("max_num_seqs")
            && let Some(num) = value.as_u64()
        {
            builder = builder.max_num_seqs(Some(num as usize));
203
204
        }

205
206
207
208
        if let Some(value) = extra_args.get("max_num_batched_tokens")
            && let Some(num) = value.as_u64()
        {
            builder = builder.max_num_batched_tokens(Some(num as usize));
209
210
        }

211
212
213
214
        if let Some(value) = extra_args.get("enable_prefix_caching")
            && let Some(enabled) = value.as_bool()
        {
            builder = builder.enable_prefix_caching(enabled);
215
216
        }

217
218
219
220
        if let Some(value) = extra_args.get("enable_chunked_prefill")
            && let Some(enabled) = value.as_bool()
        {
            builder = builder.enable_chunked_prefill(enabled);
221
222
        }

223
224
225
226
        if let Some(value) = extra_args.get("watermark")
            && let Some(num) = value.as_f64()
        {
            builder = builder.watermark(num);
227
228
        }

229
230
231
232
        if let Some(value) = extra_args.get("speedup_ratio")
            && let Some(num) = value.as_f64()
        {
            builder = builder.speedup_ratio(num);
233
234
        }

235
236
237
238
        if let Some(value) = extra_args.get("dp_size")
            && let Some(num) = value.as_u64()
        {
            builder = builder.dp_size(num as u32);
239
240
        }

241
242
243
244
245
246
        if let Some(value) = extra_args.get("startup_time")
            && let Some(num) = value.as_f64()
        {
            builder = builder.startup_time(Some(num));
        }

247
248
249
250
251
252
        if let Some(value) = extra_args.get("enable_local_indexer")
            && let Some(enabled) = value.as_bool()
        {
            builder = builder.enable_local_indexer(enabled);
        }

253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
        // Parse worker type from is_prefill and is_decode flags
        let is_prefill = extra_args
            .get("is_prefill")
            .and_then(|v| v.as_bool())
            .unwrap_or(false);
        let is_decode = extra_args
            .get("is_decode")
            .and_then(|v| v.as_bool())
            .unwrap_or(false);

        // Determine worker type based on flags
        let worker_type = match (is_prefill, is_decode) {
            (false, false) => WorkerType::Aggregated,
            (true, false) => WorkerType::Prefill,
            (false, true) => WorkerType::Decode,
            (true, true) => panic!(
                "Invalid worker configuration: is_prefill and is_decode cannot both be true. \
                 Worker must be either Aggregated (both false), Prefill (is_prefill=true), or Decode (is_decode=true)."
            ),
        };
        builder = builder.worker_type(worker_type);

275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
        // Load performance model from NPZ file if provided
        let perf_model = if let Some(path_str) = extra_args.get("planner_profile_data")
            && let Some(path_str) = path_str.as_str()
        {
            let npz_path = PathBuf::from(path_str);
            match PerfModel::from_npz(&npz_path) {
                Ok(model) => {
                    tracing::info!("Successfully loaded performance model from: {:?}", npz_path);
                    Arc::new(model)
                }
                Err(e) => {
                    tracing::error!(
                        "Failed to load performance model from {:?}: {}. Falling back to polynomial model.",
                        npz_path,
                        e
                    );
                    Arc::new(PerfModel::default())
                }
            }
        } else {
            Arc::new(PerfModel::default())
        };
        builder = builder.perf_model(perf_model);

299
300
301
302
303
        // Build the MockEngineArgs with either defaults or overridden values
        builder
            .build()
            .map_err(|e| anyhow::anyhow!("Failed to build MockEngineArgs: {}", e))
    }
304
305
}

306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_unique_block_default_uniqueness() {
        // Create 10 default UniqueBlock instances
        let blocks: Vec<UniqueBlock> = (0..10).map(|_| UniqueBlock::default()).collect();

        // Extract UUIDs from each block
        let mut uuids = Vec::new();
        for block in blocks {
            match block {
                UniqueBlock::PartialBlock(uuid) => uuids.push(uuid),
                _ => panic!("Expected UuidIdentifier variant"),
            }
        }

        // Check that all UUIDs are unique by comparing each with every other
        for i in 0..uuids.len() {
            for j in i + 1..uuids.len() {
                assert_ne!(
                    uuids[i], uuids[j],
                    "UUID at index {} and {} are identical",
                    i, j
                );
            }
        }
    }
}