protocols.rs 8.84 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use derive_builder::Builder;
5
use serde::{Deserialize, Serialize};
6
7
use std::collections::{HashMap, HashSet};
use std::path::Path;
8
9
use uuid::Uuid;

10
use crate::tokens::blocks::UniqueBlock;
11
use crate::tokens::{BlockHash, SequenceHash, Token};
12

13
14
15
pub type NumBlocks = usize;

/// Represents different block movement operations in the cache
Yan Ru Pei's avatar
Yan Ru Pei committed
16
/// For Use and Promote variants, block hashes are included for KV event publishing
17
18
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MoveBlock {
Yan Ru Pei's avatar
Yan Ru Pei committed
19
    Use(Vec<UniqueBlock>, Vec<BlockHash>),
20
21
    Destroy(Vec<UniqueBlock>),
    Deref(Vec<UniqueBlock>),
Yan Ru Pei's avatar
Yan Ru Pei committed
22
    Promote(Uuid, SequenceHash, Option<u64>, BlockHash),
23
24
25
26
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MoveBlockResponse {
27
28
    Store(Vec<SequenceHash>, Option<u64>),
    Remove(Vec<SequenceHash>),
29
30
31
32
33
34
35
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DirectRequest {
    pub tokens: Vec<Token>,
    pub max_output_tokens: usize,
    pub uuid: Option<Uuid>,
Yan Ru Pei's avatar
Yan Ru Pei committed
36
    pub dp_rank: u32,
37
38
39
40
41
}

/// Represents the cost of prefilling content in the cache
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrefillCost {
42
    pub new_blocks: usize,
43
    pub new_tokens: usize,
44
45
46
47
48
}

impl PrefillCost {
    pub fn predict_prefill_compute(&self, new_tokens: Option<usize>) -> f64 {
        let tokens = new_tokens.unwrap_or(self.new_tokens);
Yan Ru Pei's avatar
Yan Ru Pei committed
49
        4.209989e-07 * (tokens as f64).powi(2) + 1.518344e-02 * (tokens as f64) + 1.650142e+01
50
    }
51
52
}

53
54
55
56
57
58
59
/// Signal for output token generation with completion status
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OutputSignal {
    pub uuid: Uuid,
    pub completed: bool,
}

60
61
62
63
64
65
66
67
68
69
70
71
/// Worker type for disaggregated serving configurations
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum WorkerType {
    /// Standard aggregated worker handling both prefill and decode
    #[default]
    Aggregated,
    /// Dedicated prefill worker in disaggregated mode
    Prefill,
    /// Dedicated decode worker in disaggregated mode
    Decode,
}

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
/// Configuration arguments for MockVllmEngine
#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
#[builder(pattern = "owned", build_fn(public))]
pub struct MockEngineArgs {
    #[builder(default = "16384")]
    pub num_gpu_blocks: usize,

    #[builder(default = "64")]
    pub block_size: usize,

    // This was 1024 in the past but reverted back to 256
    #[builder(default = Some(256))]
    pub max_num_seqs: Option<usize>,

    // default for open api server, for llm class it's 16384
    #[builder(default = Some(8192))]
    pub max_num_batched_tokens: Option<usize>,

    #[builder(default = true)]
    pub enable_prefix_caching: bool,

93
94
95
    #[builder(default = true)]
    pub enable_chunked_prefill: bool,

96
97
98
99
100
101
102
103
    #[builder(default = "0.01")]
    pub watermark: f64,

    #[builder(default = "1.0")]
    pub speedup_ratio: f64,

    #[builder(default = "1")]
    pub dp_size: u32,
104
105
106
107

    /// Optional startup time in seconds to simulate engine initialization delay
    #[builder(default = "None")]
    pub startup_time: Option<f64>,
108
109
110
111

    /// Worker type for disaggregated serving (Aggregated, Prefill, or Decode)
    #[builder(default = "WorkerType::Aggregated")]
    pub worker_type: WorkerType,
112
113
}

114
115
116
117
118
119
120
121
impl Default for MockEngineArgs {
    fn default() -> MockEngineArgs {
        MockEngineArgsBuilder::default()
            .build()
            .expect("Failed to build default MockEngineArgs")
    }
}

122
123
124
125
impl MockEngineArgs {
    pub fn builder() -> MockEngineArgsBuilder {
        MockEngineArgsBuilder::default()
    }
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141

    /// Create MockEngineArgs from a JSON file containing extra engine arguments
    pub fn from_json_file(path: &Path) -> anyhow::Result<Self> {
        let mut builder = Self::builder();

        // Load and parse the JSON file
        let file_content = std::fs::read_to_string(path)?;
        let extra_args: HashMap<String, serde_json::Value> = serde_json::from_str(&file_content)?;

        // Define valid field names
        let valid_fields: HashSet<&str> = [
            "num_gpu_blocks",
            "block_size",
            "max_num_seqs",
            "max_num_batched_tokens",
            "enable_prefix_caching",
142
            "enable_chunked_prefill",
143
144
145
            "watermark",
            "speedup_ratio",
            "dp_size",
146
            "startup_time",
147
148
            "is_prefill",
            "is_decode",
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        ]
        .iter()
        .cloned()
        .collect();

        // Check for invalid arguments
        let invalid_args: Vec<String> = extra_args
            .keys()
            .filter(|key| !valid_fields.contains(key.as_str()))
            .cloned()
            .collect();

        if !invalid_args.is_empty() {
            return Err(anyhow::anyhow!(
                "Invalid arguments found in JSON file: {}. Valid arguments are: {:?}",
                invalid_args.join(", "),
                valid_fields
            ));
        }

        // Apply each extra argument to the builder
170
171
172
173
        if let Some(value) = extra_args.get("num_gpu_blocks")
            && let Some(num) = value.as_u64()
        {
            builder = builder.num_gpu_blocks(num as usize);
174
175
        }

176
177
178
179
        if let Some(value) = extra_args.get("block_size")
            && let Some(num) = value.as_u64()
        {
            builder = builder.block_size(num as usize);
180
181
        }

182
183
184
185
        if let Some(value) = extra_args.get("max_num_seqs")
            && let Some(num) = value.as_u64()
        {
            builder = builder.max_num_seqs(Some(num as usize));
186
187
        }

188
189
190
191
        if let Some(value) = extra_args.get("max_num_batched_tokens")
            && let Some(num) = value.as_u64()
        {
            builder = builder.max_num_batched_tokens(Some(num as usize));
192
193
        }

194
195
196
197
        if let Some(value) = extra_args.get("enable_prefix_caching")
            && let Some(enabled) = value.as_bool()
        {
            builder = builder.enable_prefix_caching(enabled);
198
199
        }

200
201
202
203
        if let Some(value) = extra_args.get("enable_chunked_prefill")
            && let Some(enabled) = value.as_bool()
        {
            builder = builder.enable_chunked_prefill(enabled);
204
205
        }

206
207
208
209
        if let Some(value) = extra_args.get("watermark")
            && let Some(num) = value.as_f64()
        {
            builder = builder.watermark(num);
210
211
        }

212
213
214
215
        if let Some(value) = extra_args.get("speedup_ratio")
            && let Some(num) = value.as_f64()
        {
            builder = builder.speedup_ratio(num);
216
217
        }

218
219
220
221
        if let Some(value) = extra_args.get("dp_size")
            && let Some(num) = value.as_u64()
        {
            builder = builder.dp_size(num as u32);
222
223
        }

224
225
226
227
228
229
        if let Some(value) = extra_args.get("startup_time")
            && let Some(num) = value.as_f64()
        {
            builder = builder.startup_time(Some(num));
        }

230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
        // Parse worker type from is_prefill and is_decode flags
        let is_prefill = extra_args
            .get("is_prefill")
            .and_then(|v| v.as_bool())
            .unwrap_or(false);
        let is_decode = extra_args
            .get("is_decode")
            .and_then(|v| v.as_bool())
            .unwrap_or(false);

        // Determine worker type based on flags
        let worker_type = match (is_prefill, is_decode) {
            (false, false) => WorkerType::Aggregated,
            (true, false) => WorkerType::Prefill,
            (false, true) => WorkerType::Decode,
            (true, true) => panic!(
                "Invalid worker configuration: is_prefill and is_decode cannot both be true. \
                 Worker must be either Aggregated (both false), Prefill (is_prefill=true), or Decode (is_decode=true)."
            ),
        };
        builder = builder.worker_type(worker_type);

252
253
254
255
256
        // Build the MockEngineArgs with either defaults or overridden values
        builder
            .build()
            .map_err(|e| anyhow::anyhow!("Failed to build MockEngineArgs: {}", e))
    }
257
258
}

259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_unique_block_default_uniqueness() {
        // Create 10 default UniqueBlock instances
        let blocks: Vec<UniqueBlock> = (0..10).map(|_| UniqueBlock::default()).collect();

        // Extract UUIDs from each block
        let mut uuids = Vec::new();
        for block in blocks {
            match block {
                UniqueBlock::PartialBlock(uuid) => uuids.push(uuid),
                _ => panic!("Expected UuidIdentifier variant"),
            }
        }

        // Check that all UUIDs are unique by comparing each with every other
        for i in 0..uuids.len() {
            for j in i + 1..uuids.len() {
                assert_ne!(
                    uuids[i], uuids[j],
                    "UUID at index {} and {} are identical",
                    i, j
                );
            }
        }
    }
}