config.rs 13.1 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use std::env::{self, VarError};
5
6
7
use std::fmt;
use std::str::FromStr;

8
9
10
11
12
use derive_builder::Builder;
use rand::Rng;
use serde::{Deserialize, Serialize};
use validator::{Validate, ValidationError};

13
14
15
use crate::protocols::{
    BlockHashOptions, LocalBlockHash, compute_block_hash_for_seq, compute_seq_hash_for_block,
};
16

17
18
19
20
const fn default_track_prefill_tokens() -> bool {
    true
}

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
pub const DYN_ROUTER_MIN_INITIAL_WORKERS: &str = "DYN_ROUTER_MIN_INITIAL_WORKERS";

pub fn min_initial_workers_from_env() -> anyhow::Result<usize> {
    match env::var(DYN_ROUTER_MIN_INITIAL_WORKERS) {
        Ok(value) => value.parse::<usize>().map_err(|error| {
            anyhow::anyhow!(
                "{DYN_ROUTER_MIN_INITIAL_WORKERS} must be a non-negative integer, got {value:?}: {error}"
            )
        }),
        Err(VarError::NotPresent) => Ok(0),
        Err(VarError::NotUnicode(_)) => {
            anyhow::bail!("{DYN_ROUTER_MIN_INITIAL_WORKERS} must be valid unicode")
        }
    }
}

37
38
39
40
41
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum RouterQueuePolicy {
    #[default]
    Fcfs,
42
    Lcfs,
43
44
45
46
47
48
49
    Wspt,
}

impl fmt::Display for RouterQueuePolicy {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::Fcfs => f.write_str("fcfs"),
50
            Self::Lcfs => f.write_str("lcfs"),
51
52
53
54
55
56
57
58
59
60
61
            Self::Wspt => f.write_str("wspt"),
        }
    }
}

impl FromStr for RouterQueuePolicy {
    type Err = String;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        match s {
            "fcfs" => Ok(Self::Fcfs),
62
            "lcfs" => Ok(Self::Lcfs),
63
64
            "wspt" => Ok(Self::Wspt),
            _ => Err(format!(
65
                "unknown queue policy: {s:?}, expected 'fcfs', 'lcfs', or 'wspt'"
66
67
68
69
70
            )),
        }
    }
}

71
72
73
74
75
76
77
78
79
80
81
82
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize, Validate)]
pub struct RouterConfigOverride {
    #[builder(default)]
    pub overlap_score_weight: Option<f64>,

    #[builder(default)]
    #[validate(range(min = 0.0))]
    pub router_temperature: Option<f64>,

    #[builder(default)]
    pub assume_kv_reuse: Option<bool>,
83
84
85

    #[builder(default)]
    pub track_prefill_tokens: Option<bool>,
86
87
88
}

/// KV Router configuration parameters
89
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
90
#[serde(default)]
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#[validate(schema(function = "validate_kv_router_config"))]
pub struct KvRouterConfig {
    #[validate(range(min = 0.0))]
    pub overlap_score_weight: f64,

    #[validate(range(min = 0.0))]
    pub router_temperature: f64,

    pub use_kv_events: bool,

    /// **Deprecated:** Enable durable KV events using NATS JetStream instead of the default event plane.
    /// This option will be removed in a future release. The event-plane subscriber
    /// (local_indexer mode) is now the recommended path.
    pub durable_kv_events: bool,

    pub router_replica_sync: bool,

    /// Whether to track active blocks in the router (default: true)
    pub router_track_active_blocks: bool,

    /// Whether to track output blocks during generation (default: false)
    /// When enabled, the router adds placeholder blocks as tokens are generated
    /// and applies fractional decay based on progress toward agent_hints.osl.
    pub router_track_output_blocks: bool,

    /// Whether to assume KV cache reuse when tracking active blocks (default: true).
    /// When true, computes actual block hashes for sequence tracking.
    /// When false, generates random hashes (assuming no KV cache reuse).
    pub router_assume_kv_reuse: bool,

121
122
123
124
125
126
    /// Whether to include prompt-side prefill tokens in active load accounting (default: true).
    /// When false, prompt tokens are excluded from active prefill token tracking, queue pressure,
    /// and potential prefill-token load calculations.
    #[serde(default = "default_track_prefill_tokens")]
    pub router_track_prefill_tokens: bool,

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    /// Threshold for triggering snapshots. If None, no snapshots will be performed.
    #[validate(range(min = 1))]
    pub router_snapshot_threshold: Option<u32>,

    /// Whether to reset the router state on startup (default: false)
    pub router_reset_states: bool,

    /// TTL for blocks in seconds (only used when use_kv_events is false, default: 120.0)
    #[validate(range(min = 0.0))]
    pub router_ttl_secs: f64,

    /// Maximum tree size before pruning (only used when use_kv_events is false, default: 2^20 = 1048576)
    #[validate(range(min = 1))]
    pub router_max_tree_size: usize,

    /// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
    #[validate(range(min = 0.0, max = 1.0))]
    pub router_prune_target_ratio: f64,

    /// Queue threshold fraction for prefill token capacity.
    /// When set, requests are queued if all workers exceed this fraction of max_num_batched_tokens.
148
149
    /// If None, queueing is disabled and all requests go directly to ready.
    /// Default: 2.0. Must be > 0.
150
151
152
153
154
155
156
157
158
    #[validate(range(min = 0.0))]
    pub router_queue_threshold: Option<f64>,

    /// Number of event processing threads for the KV indexer.
    /// When > 1, uses ConcurrentRadixTree with a thread pool instead of the
    /// single-threaded RadixTree. Default: 4.
    #[validate(range(min = 1))]
    pub router_event_threads: u32,

159
    pub skip_initial_worker_wait: bool,
160

161
162
163
164
    /// Scheduling policy for the router queue.
    /// "fcfs" (default): first-come first-served with priority bumps — optimizes tail TTFT.
    /// "wspt": weighted shortest processing time (Smith's rule) — optimizes average TTFT.
    pub router_queue_policy: RouterQueuePolicy,
165
166
167
168
169
170
171

    /// Component name of a standalone KV indexer to use for overlap scoring.
    /// When set, the router creates a `Remote` indexer that queries the standalone
    /// indexer via the request plane instead of maintaining a local radix tree.
    /// The standalone indexer handles its own event subscription and discovery.
    #[serde(default)]
    pub remote_indexer_component: Option<String>,
172
173
174
175
176
177
178
179
180
181
182
183
184
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
            overlap_score_weight: 1.0,
            router_temperature: 0.0,
            use_kv_events: true,
            durable_kv_events: false, // default to NATS Core (local indexer mode)
            router_replica_sync: false,
            router_track_active_blocks: true,
            router_track_output_blocks: false,
            router_assume_kv_reuse: true,
185
            router_track_prefill_tokens: default_track_prefill_tokens(),
186
187
188
189
190
            router_snapshot_threshold: Some(1000000),
            router_reset_states: false,
            router_ttl_secs: 120.0,
            router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default()
            router_prune_target_ratio: 0.8,
191
            router_queue_threshold: Some(4.0),
192
            router_event_threads: 4,
193
            skip_initial_worker_wait: false,
194
            router_queue_policy: RouterQueuePolicy::default(),
195
            remote_indexer_component: None,
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        }
    }
}

fn validate_kv_router_config(config: &KvRouterConfig) -> Result<(), ValidationError> {
    if config.durable_kv_events {
        tracing::warn!(
            "--durable-kv-events is deprecated and will be removed in a future release. \
             The event-plane subscriber (local_indexer mode) is now the recommended path."
        );
    }
    if config.durable_kv_events && !config.use_kv_events {
        return Err(ValidationError::new(
            "durable_kv_events requires use_kv_events=true",
        ));
    }
    if config.router_track_output_blocks && !config.router_track_active_blocks {
        return Err(ValidationError::new(
            "router_track_output_blocks requires router_track_active_blocks=true",
        ));
    }
    Ok(())
}

impl KvRouterConfig {
221
222
223
224
225
226
227
228
229
230
231
232
    pub fn assume_kv_reuse(&self, config_override: Option<&RouterConfigOverride>) -> bool {
        config_override
            .and_then(|cfg| cfg.assume_kv_reuse)
            .unwrap_or(self.router_assume_kv_reuse)
    }

    pub fn track_prefill_tokens(&self, config_override: Option<&RouterConfigOverride>) -> bool {
        config_override
            .and_then(|cfg| cfg.track_prefill_tokens)
            .unwrap_or(self.router_track_prefill_tokens)
    }

233
234
235
236
237
238
239
240
241
242
243
    /// Compute sequence hashes for active block tracking based on configuration.
    ///
    /// Returns:
    /// - `None` if `router_track_active_blocks` is false
    /// - Random hashes if `router_track_active_blocks` is true but `router_assume_kv_reuse` is false
    /// - Actual sequence hashes if both are true
    pub fn compute_seq_hashes_for_tracking(
        &self,
        tokens: &[u32],
        block_size: u32,
        config_override: Option<&RouterConfigOverride>,
244
        hash_options: BlockHashOptions<'_>,
245
        precomputed_block_hashes: Option<&[LocalBlockHash]>,
246
247
248
249
250
251
252
253
254
255
    ) -> Option<Vec<u64>> {
        if !self.router_track_active_blocks {
            return None;
        }

        let num_blocks = tokens.len() / block_size as usize;
        if num_blocks == 0 {
            return Some(Vec::new());
        }

256
        let assume_kv_reuse = self.assume_kv_reuse(config_override);
257
258

        if assume_kv_reuse {
259
260
261
262
263
264
265
266
            let block_hashes = match precomputed_block_hashes {
                Some(block_hashes) => block_hashes,
                None => {
                    let computed = compute_block_hash_for_seq(tokens, block_size, hash_options);
                    return Some(compute_seq_hash_for_block(&computed));
                }
            };
            Some(compute_seq_hash_for_block(block_hashes))
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
        } else {
            let mut rng = rand::rng();
            Some((0..num_blocks).map(|_| rng.random::<u64>()).collect())
        }
    }

    /// Check if KV event subscription should be started.
    ///
    /// Returns false if:
    /// - KV events are disabled (`use_kv_events=false`)
    /// - Overlap scoring is disabled (`overlap_score_weight=0`)
    ///
    /// When false, the router skips starting the KV event subscription entirely,
    /// avoiding the need to query workers for their local indexer state.
    pub fn should_subscribe_to_kv_events(&self) -> bool {
        self.use_kv_events && self.overlap_score_weight > 0.0
    }
}
285
286
287
288

#[cfg(test)]
mod tests {
    use super::*;
289
    use crate::protocols::{BlockExtraInfo, BlockMmObjectInfo};
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307

    #[test]
    fn router_queue_policy_display_and_parse_support_lcfs() {
        assert_eq!(RouterQueuePolicy::Lcfs.to_string(), "lcfs");
        assert_eq!(
            "lcfs".parse::<RouterQueuePolicy>().unwrap(),
            RouterQueuePolicy::Lcfs
        );
    }

    #[test]
    fn router_queue_policy_serde_round_trip_supports_lcfs() {
        let serialized = serde_json::to_string(&RouterQueuePolicy::Lcfs).unwrap();
        assert_eq!(serialized, "\"lcfs\"");
        let deserialized: RouterQueuePolicy = serde_json::from_str(&serialized).unwrap();
        assert_eq!(deserialized, RouterQueuePolicy::Lcfs);
    }

308
309
310
311
312
    #[test]
    fn kv_router_config_defaults_to_tracking_prefill_tokens() {
        assert!(KvRouterConfig::default().router_track_prefill_tokens);
    }

313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
    #[test]
    fn compute_seq_hashes_for_tracking_uses_mm_hashes() {
        let cfg = KvRouterConfig::default();
        let tokens = vec![1, 2, 3, 4];
        let mm_infos = vec![
            Some(BlockExtraInfo {
                mm_objects: vec![BlockMmObjectInfo {
                    mm_hash: 42,
                    offsets: vec![],
                }],
            }),
            None,
        ];

        let without_mm = cfg
328
            .compute_seq_hashes_for_tracking(&tokens, 2, None, BlockHashOptions::default(), None)
329
330
331
332
333
334
335
336
337
338
            .unwrap();
        let with_mm = cfg
            .compute_seq_hashes_for_tracking(
                &tokens,
                2,
                None,
                BlockHashOptions {
                    block_mm_infos: Some(&mm_infos),
                    ..Default::default()
                },
339
                None,
340
341
342
343
344
            )
            .unwrap();

        assert_ne!(without_mm, with_mm);
    }
345

346
347
348
349
350
351
352
353
354
355
356
    #[test]
    fn router_config_override_serde_round_trip_preserves_track_prefill_tokens() {
        let serialized = serde_json::to_string(&RouterConfigOverride {
            track_prefill_tokens: Some(false),
            ..Default::default()
        })
        .unwrap();
        let deserialized: RouterConfigOverride = serde_json::from_str(&serialized).unwrap();
        assert_eq!(deserialized.track_prefill_tokens, Some(false));
    }

357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
    #[test]
    fn compute_seq_hashes_for_tracking_uses_precomputed_block_hashes() {
        let config = KvRouterConfig::default();
        let tokens: Vec<u32> = (0..8).collect();
        let precomputed = vec![LocalBlockHash(11), LocalBlockHash(29)];

        let seq_hashes = config.compute_seq_hashes_for_tracking(
            &tokens,
            4,
            None,
            BlockHashOptions::default(),
            Some(&precomputed),
        );

        assert_eq!(seq_hashes, Some(compute_seq_hash_for_block(&precomputed)));
    }
373
}