protocols.rs 9.38 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

16
use derive_builder::Builder;
17
use serde::{Deserialize, Serialize};
18
19
use std::collections::{HashMap, HashSet};
use std::path::Path;
20
21
use uuid::Uuid;

22
23
24
25
use crate::kv_router::protocols::{
    ExternalSequenceBlockHash, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
    KvCacheStoredBlockData, LocalBlockHash,
};
26
use crate::tokens::blocks::UniqueBlock;
27
use crate::tokens::{BlockHash, SequenceHash, Token};
28

29
30
31
pub type NumBlocks = usize;

/// Represents different block movement operations in the cache
32
/// For Use and Promote variants, parent hash is the second field
33
34
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MoveBlock {
35
    Use(Vec<UniqueBlock>),
36
37
    Destroy(Vec<UniqueBlock>),
    Deref(Vec<UniqueBlock>),
38
    Promote(Uuid, SequenceHash, Option<u64>),
39
40
41
42
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MoveBlockResponse {
43
44
    Store(Vec<SequenceHash>, Option<u64>),
    Remove(Vec<SequenceHash>),
45
46
47
48
49
50
51
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DirectRequest {
    pub tokens: Vec<Token>,
    pub max_output_tokens: usize,
    pub uuid: Option<Uuid>,
52
    pub dp_rank: Option<u32>,
53
54
55
56
57
}

/// Represents the cost of prefilling content in the cache
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrefillCost {
58
    pub new_blocks: usize,
59
    pub new_tokens: usize,
60
61
62
63
64
65
66
}

impl PrefillCost {
    pub fn predict_prefill_compute(&self, new_tokens: Option<usize>) -> f64 {
        let tokens = new_tokens.unwrap_or(self.new_tokens);
        1.25e-6 * (tokens as f64).powi(2) + 7.41e-2 * (tokens as f64) + 2.62e1
    }
67
68
}

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
/// Signal for output token generation with completion status
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OutputSignal {
    pub uuid: Uuid,
    pub completed: bool,
}

/// Configuration arguments for MockVllmEngine
#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
#[builder(pattern = "owned", build_fn(public))]
pub struct MockEngineArgs {
    #[builder(default = "16384")]
    pub num_gpu_blocks: usize,

    #[builder(default = "64")]
    pub block_size: usize,

    // This was 1024 in the past but reverted back to 256
    #[builder(default = Some(256))]
    pub max_num_seqs: Option<usize>,

    // default for open api server, for llm class it's 16384
    #[builder(default = Some(8192))]
    pub max_num_batched_tokens: Option<usize>,

    #[builder(default = true)]
    pub enable_prefix_caching: bool,

97
98
99
    #[builder(default = true)]
    pub enable_chunked_prefill: bool,

100
101
102
103
104
105
106
107
108
109
    #[builder(default = "0.01")]
    pub watermark: f64,

    #[builder(default = "1.0")]
    pub speedup_ratio: f64,

    #[builder(default = "1")]
    pub dp_size: u32,
}

110
111
112
113
114
115
116
117
impl Default for MockEngineArgs {
    fn default() -> MockEngineArgs {
        MockEngineArgsBuilder::default()
            .build()
            .expect("Failed to build default MockEngineArgs")
    }
}

118
119
120
121
impl MockEngineArgs {
    pub fn builder() -> MockEngineArgsBuilder {
        MockEngineArgsBuilder::default()
    }
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

    /// Create MockEngineArgs from a JSON file containing extra engine arguments
    pub fn from_json_file(path: &Path) -> anyhow::Result<Self> {
        let mut builder = Self::builder();

        // Load and parse the JSON file
        let file_content = std::fs::read_to_string(path)?;
        let extra_args: HashMap<String, serde_json::Value> = serde_json::from_str(&file_content)?;

        // Define valid field names
        let valid_fields: HashSet<&str> = [
            "num_gpu_blocks",
            "block_size",
            "max_num_seqs",
            "max_num_batched_tokens",
            "enable_prefix_caching",
138
            "enable_chunked_prefill",
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
            "watermark",
            "speedup_ratio",
            "dp_size",
        ]
        .iter()
        .cloned()
        .collect();

        // Check for invalid arguments
        let invalid_args: Vec<String> = extra_args
            .keys()
            .filter(|key| !valid_fields.contains(key.as_str()))
            .cloned()
            .collect();

        if !invalid_args.is_empty() {
            return Err(anyhow::anyhow!(
                "Invalid arguments found in JSON file: {}. Valid arguments are: {:?}",
                invalid_args.join(", "),
                valid_fields
            ));
        }

        // Apply each extra argument to the builder
163
164
165
166
        if let Some(value) = extra_args.get("num_gpu_blocks")
            && let Some(num) = value.as_u64()
        {
            builder = builder.num_gpu_blocks(num as usize);
167
168
        }

169
170
171
172
        if let Some(value) = extra_args.get("block_size")
            && let Some(num) = value.as_u64()
        {
            builder = builder.block_size(num as usize);
173
174
        }

175
176
177
178
        if let Some(value) = extra_args.get("max_num_seqs")
            && let Some(num) = value.as_u64()
        {
            builder = builder.max_num_seqs(Some(num as usize));
179
180
        }

181
182
183
184
        if let Some(value) = extra_args.get("max_num_batched_tokens")
            && let Some(num) = value.as_u64()
        {
            builder = builder.max_num_batched_tokens(Some(num as usize));
185
186
        }

187
188
189
190
        if let Some(value) = extra_args.get("enable_prefix_caching")
            && let Some(enabled) = value.as_bool()
        {
            builder = builder.enable_prefix_caching(enabled);
191
192
        }

193
194
195
196
        if let Some(value) = extra_args.get("enable_chunked_prefill")
            && let Some(enabled) = value.as_bool()
        {
            builder = builder.enable_chunked_prefill(enabled);
197
198
        }

199
200
201
202
        if let Some(value) = extra_args.get("watermark")
            && let Some(num) = value.as_f64()
        {
            builder = builder.watermark(num);
203
204
        }

205
206
207
208
        if let Some(value) = extra_args.get("speedup_ratio")
            && let Some(num) = value.as_f64()
        {
            builder = builder.speedup_ratio(num);
209
210
        }

211
212
213
214
        if let Some(value) = extra_args.get("dp_size")
            && let Some(num) = value.as_u64()
        {
            builder = builder.dp_size(num as u32);
215
216
217
218
219
220
221
        }

        // Build the MockEngineArgs with either defaults or overridden values
        builder
            .build()
            .map_err(|e| anyhow::anyhow!("Failed to build MockEngineArgs: {}", e))
    }
222
223
}

224
225
226
227
228
229
230
231
232
233
234
235
236
237
/// Converts a MoveBlockResponse from the mocker backend into a KvCacheEventData.
///
/// This function assumes that the stored sequence hashes in the response always
/// correspond to the tail part of the local hashes array. This is the expected
/// behavior of KV block storage, where blocks are stored sequentially and the
/// response contains the most recent blocks that were stored.
///
/// # Panics
/// Panics if the number of blocks in the Store response exceeds the length
/// of local_hashes.
pub fn block_response_to_kv_event(
    response: MoveBlockResponse,
    local_hashes: &[BlockHash],
) -> KvCacheEventData {
238
239
    match response {
        MoveBlockResponse::Store(full_blocks, parent_hash) => {
240
241
242
243
244
245
            let num_blocks = full_blocks.len();
            let local_hashes_slice = &local_hashes[local_hashes
                .len()
                .checked_sub(num_blocks)
                .expect("local hashes fewer than block response signal")..];

246
247
248
249
            KvCacheEventData::Stored(KvCacheStoreData {
                parent_hash: parent_hash.map(ExternalSequenceBlockHash),
                blocks: full_blocks
                    .into_iter()
250
251
252
253
                    .zip(local_hashes_slice.iter())
                    .map(|(global_hash, local_hash)| KvCacheStoredBlockData {
                        block_hash: ExternalSequenceBlockHash(global_hash),
                        tokens_hash: LocalBlockHash(*local_hash),
254
255
256
257
258
259
260
261
262
263
264
265
266
                    })
                    .collect(),
            })
        }
        MoveBlockResponse::Remove(full_blocks) => KvCacheEventData::Removed(KvCacheRemoveData {
            block_hashes: full_blocks
                .into_iter()
                .map(ExternalSequenceBlockHash)
                .collect(),
        }),
    }
}

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_unique_block_default_uniqueness() {
        // Create 10 default UniqueBlock instances
        let blocks: Vec<UniqueBlock> = (0..10).map(|_| UniqueBlock::default()).collect();

        // Extract UUIDs from each block
        let mut uuids = Vec::new();
        for block in blocks {
            match block {
                UniqueBlock::PartialBlock(uuid) => uuids.push(uuid),
                _ => panic!("Expected UuidIdentifier variant"),
            }
        }

        // Check that all UUIDs are unique by comparing each with every other
        for i in 0..uuids.len() {
            for j in i + 1..uuids.len() {
                assert_ne!(
                    uuids[i], uuids[j],
                    "UUID at index {} and {} are identical",
                    i, j
                );
            }
        }
    }
}