config.rs 18 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use std::env::{self, VarError};
5
6
use std::fmt;
use std::str::FromStr;
7
use std::time::Duration;
8

9
10
11
12
13
use derive_builder::Builder;
use rand::Rng;
use serde::{Deserialize, Serialize};
use validator::{Validate, ValidationError};

14
15
16
use crate::protocols::{
    BlockHashOptions, LocalBlockHash, compute_block_hash_for_seq, compute_seq_hash_for_block,
};
17

18
19
20
21
const fn default_track_prefill_tokens() -> bool {
    true
}

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
pub const DYN_ROUTER_MIN_INITIAL_WORKERS: &str = "DYN_ROUTER_MIN_INITIAL_WORKERS";

pub fn min_initial_workers_from_env() -> anyhow::Result<usize> {
    match env::var(DYN_ROUTER_MIN_INITIAL_WORKERS) {
        Ok(value) => value.parse::<usize>().map_err(|error| {
            anyhow::anyhow!(
                "{DYN_ROUTER_MIN_INITIAL_WORKERS} must be a non-negative integer, got {value:?}: {error}"
            )
        }),
        Err(VarError::NotPresent) => Ok(0),
        Err(VarError::NotUnicode(_)) => {
            anyhow::bail!("{DYN_ROUTER_MIN_INITIAL_WORKERS} must be valid unicode")
        }
    }
}

38
39
40
41
42
43
44
45
const fn default_host_cache_hit_weight() -> f64 {
    0.75
}

const fn default_disk_cache_hit_weight() -> f64 {
    0.25
}

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
/// Type of external shared KV cache to query during routing.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum SharedCacheType {
    /// No shared cache (default).
    #[default]
    None,
    /// HiCache L3 shared cache — queries sglang workers via the request plane.
    Hicache,
}

impl fmt::Display for SharedCacheType {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::None => f.write_str("none"),
            Self::Hicache => f.write_str("hicache"),
        }
    }
}

impl FromStr for SharedCacheType {
    type Err = String;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        match s {
            "none" => Ok(Self::None),
            "hicache" => Ok(Self::Hicache),
            _ => Err(format!(
                "unknown shared_cache_type: {s:?}, expected 'none' or 'hicache'"
            )),
        }
    }
}

80
81
82
83
84
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum RouterQueuePolicy {
    #[default]
    Fcfs,
85
    Lcfs,
86
87
88
89
90
91
92
    Wspt,
}

impl fmt::Display for RouterQueuePolicy {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::Fcfs => f.write_str("fcfs"),
93
            Self::Lcfs => f.write_str("lcfs"),
94
95
96
97
98
            Self::Wspt => f.write_str("wspt"),
        }
    }
}

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum RouterPrefillLoadModel {
    #[default]
    None,
    Aic,
}

impl fmt::Display for RouterPrefillLoadModel {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::None => f.write_str("none"),
            Self::Aic => f.write_str("aic"),
        }
    }
}

impl FromStr for RouterPrefillLoadModel {
    type Err = String;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        match s {
            "none" => Ok(Self::None),
            "aic" => Ok(Self::Aic),
            _ => Err(format!(
                "unknown prefill load model: {s:?}, expected 'none' or 'aic'"
            )),
        }
    }
}

impl RouterPrefillLoadModel {
    pub fn is_enabled(self) -> bool {
        !matches!(self, Self::None)
    }
}

136
137
138
139
140
141
impl FromStr for RouterQueuePolicy {
    type Err = String;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        match s {
            "fcfs" => Ok(Self::Fcfs),
142
            "lcfs" => Ok(Self::Lcfs),
143
144
            "wspt" => Ok(Self::Wspt),
            _ => Err(format!(
145
                "unknown queue policy: {s:?}, expected 'fcfs', 'lcfs', or 'wspt'"
146
147
148
149
150
            )),
        }
    }
}

151
152
153
154
155
156
157
158
159
160
161
162
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize, Validate)]
pub struct RouterConfigOverride {
    #[builder(default)]
    pub overlap_score_weight: Option<f64>,

    #[builder(default)]
    #[validate(range(min = 0.0))]
    pub router_temperature: Option<f64>,

    #[builder(default)]
    pub assume_kv_reuse: Option<bool>,
163
164
165

    #[builder(default)]
    pub track_prefill_tokens: Option<bool>,
166
167
168
169
170

    /// Per-request override of `shared_cache_multiplier`.
    #[builder(default)]
    #[validate(range(min = 0.0, max = 1.0))]
    pub shared_cache_multiplier: Option<f64>,
171
172
173
}

/// KV Router configuration parameters
174
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
175
#[serde(default)]
176
177
178
179
180
#[validate(schema(function = "validate_kv_router_config"))]
pub struct KvRouterConfig {
    #[validate(range(min = 0.0))]
    pub overlap_score_weight: f64,

181
182
183
184
185
186
187
188
    #[serde(default = "default_host_cache_hit_weight")]
    #[validate(range(min = 0.0, max = 1.0))]
    pub host_cache_hit_weight: f64,

    #[serde(default = "default_disk_cache_hit_weight")]
    #[validate(range(min = 0.0, max = 1.0))]
    pub disk_cache_hit_weight: f64,

189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    #[validate(range(min = 0.0))]
    pub router_temperature: f64,

    pub use_kv_events: bool,

    /// **Deprecated:** Enable durable KV events using NATS JetStream instead of the default event plane.
    /// This option will be removed in a future release. The event-plane subscriber
    /// (local_indexer mode) is now the recommended path.
    pub durable_kv_events: bool,

    pub router_replica_sync: bool,

    /// Whether to track active blocks in the router (default: true)
    pub router_track_active_blocks: bool,

    /// Whether to track output blocks during generation (default: false)
    /// When enabled, the router adds placeholder blocks as tokens are generated
    /// and applies fractional decay based on progress toward agent_hints.osl.
    pub router_track_output_blocks: bool,

    /// Whether to assume KV cache reuse when tracking active blocks (default: true).
    /// When true, computes actual block hashes for sequence tracking.
    /// When false, generates random hashes (assuming no KV cache reuse).
    pub router_assume_kv_reuse: bool,

214
215
216
217
218
219
    /// Whether to include prompt-side prefill tokens in active load accounting (default: true).
    /// When false, prompt tokens are excluded from active prefill token tracking, queue pressure,
    /// and potential prefill-token load calculations.
    #[serde(default = "default_track_prefill_tokens")]
    pub router_track_prefill_tokens: bool,

220
221
222
    /// Optional model for estimating effective prompt-side prefill load over time.
    pub router_prefill_load_model: RouterPrefillLoadModel,

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    /// Threshold for triggering snapshots. If None, no snapshots will be performed.
    #[validate(range(min = 1))]
    pub router_snapshot_threshold: Option<u32>,

    /// Whether to reset the router state on startup (default: false)
    pub router_reset_states: bool,

    /// TTL for blocks in seconds (only used when use_kv_events is false, default: 120.0)
    #[validate(range(min = 0.0))]
    pub router_ttl_secs: f64,

    /// Maximum tree size before pruning (only used when use_kv_events is false, default: 2^20 = 1048576)
    #[validate(range(min = 1))]
    pub router_max_tree_size: usize,

    /// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
    #[validate(range(min = 0.0, max = 1.0))]
    pub router_prune_target_ratio: f64,

    /// Queue threshold fraction for prefill token capacity.
    /// When set, requests are queued if all workers exceed this fraction of max_num_batched_tokens.
244
    /// If None, queueing is disabled and all requests go directly to ready.
245
    /// Default: 4.0. Must be >= 0. Use 0.0 for maximum queueing sensitivity.
246
247
248
249
250
251
252
253
254
    #[validate(range(min = 0.0))]
    pub router_queue_threshold: Option<f64>,

    /// Number of event processing threads for the KV indexer.
    /// When > 1, uses ConcurrentRadixTree with a thread pool instead of the
    /// single-threaded RadixTree. Default: 4.
    #[validate(range(min = 1))]
    pub router_event_threads: u32,

255
    pub skip_initial_worker_wait: bool,
256

257
258
259
260
    /// Scheduling policy for the router queue.
    /// "fcfs" (default): first-come first-served with priority bumps — optimizes tail TTFT.
    /// "wspt": weighted shortest processing time (Smith's rule) — optimizes average TTFT.
    pub router_queue_policy: RouterQueuePolicy,
261

262
263
    /// Whether to query a remote KV indexer served from the worker component
    /// instead of maintaining a local radix tree for overlap scoring.
264
    #[serde(default)]
265
266
267
268
269
    pub use_remote_indexer: bool,

    /// Whether this router should serve its local indexer from the worker component.
    #[serde(default)]
    pub serve_indexer: bool,
270
271
272
273
274
275
276
277
278
279
280
281

    /// Multiplier for shared cache hits when scoring workers (0.0 to 1.0).
    /// Blocks available in the shared cache are less valuable than device-local blocks
    /// because they need to be fetched. A value of 0.5 means each shared cache hit
    /// counts as half a device-local hit. Default: 0.0 (shared cache scoring disabled);
    /// the CLI sets this to 0.5 when shared cache is enabled.
    #[validate(range(min = 0.0, max = 1.0))]
    pub shared_cache_multiplier: f64,

    /// Type of external shared KV cache to query during routing.
    /// "none" (default): disabled. "hicache": query sglang workers for L3 cache state.
    pub shared_cache_type: SharedCacheType,
282
283
284
285
286
287
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
            overlap_score_weight: 1.0,
288
289
            host_cache_hit_weight: default_host_cache_hit_weight(),
            disk_cache_hit_weight: default_disk_cache_hit_weight(),
290
291
292
293
294
295
296
            router_temperature: 0.0,
            use_kv_events: true,
            durable_kv_events: false, // default to NATS Core (local indexer mode)
            router_replica_sync: false,
            router_track_active_blocks: true,
            router_track_output_blocks: false,
            router_assume_kv_reuse: true,
297
            router_track_prefill_tokens: default_track_prefill_tokens(),
298
            router_prefill_load_model: RouterPrefillLoadModel::default(),
299
300
301
302
303
            router_snapshot_threshold: Some(1000000),
            router_reset_states: false,
            router_ttl_secs: 120.0,
            router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default()
            router_prune_target_ratio: 0.8,
304
            router_queue_threshold: Some(4.0),
305
            router_event_threads: 4,
306
            skip_initial_worker_wait: false,
307
            router_queue_policy: RouterQueuePolicy::default(),
308
309
            use_remote_indexer: false,
            serve_indexer: false,
310
311
            shared_cache_multiplier: 0.0,
            shared_cache_type: SharedCacheType::default(),
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        }
    }
}

fn validate_kv_router_config(config: &KvRouterConfig) -> Result<(), ValidationError> {
    if config.durable_kv_events {
        tracing::warn!(
            "--durable-kv-events is deprecated and will be removed in a future release. \
             The event-plane subscriber (local_indexer mode) is now the recommended path."
        );
    }
    if config.durable_kv_events && !config.use_kv_events {
        return Err(ValidationError::new(
            "durable_kv_events requires use_kv_events=true",
        ));
    }
    if config.router_track_output_blocks && !config.router_track_active_blocks {
        return Err(ValidationError::new(
            "router_track_output_blocks requires router_track_active_blocks=true",
        ));
    }
333
334
335
336
337
338
339
340
341
342
343
344
    if config.router_prefill_load_model.is_enabled() && !config.router_track_prefill_tokens {
        return Err(ValidationError::new(
            "router_prefill_load_model requires router_track_prefill_tokens=true",
        ));
    }
    if config.router_prefill_load_model.is_enabled()
        && !matches!(config.router_queue_policy, RouterQueuePolicy::Fcfs)
    {
        return Err(ValidationError::new(
            "router_prefill_load_model currently requires router_queue_policy='fcfs'",
        ));
    }
345
346
347
348
349
350
351
352
353
354
    if config.use_remote_indexer && config.serve_indexer {
        return Err(ValidationError::new(
            "use_remote_indexer and serve_indexer are mutually exclusive",
        ));
    }
    if config.serve_indexer && config.overlap_score_weight == 0.0 {
        return Err(ValidationError::new(
            "serve_indexer requires overlap_score_weight > 0",
        ));
    }
355
356
357
358
    Ok(())
}

impl KvRouterConfig {
359
360
361
362
363
364
365
366
367
368
369
    pub fn router_queue_recheck_interval(&self) -> Duration {
        const DEFAULT_RECHECK_INTERVAL: Duration = Duration::from_secs(60);
        const PREFILL_LOAD_RECHECK_INTERVAL: Duration = Duration::from_millis(100);

        if self.router_prefill_load_model.is_enabled() && self.router_queue_threshold.is_some() {
            return PREFILL_LOAD_RECHECK_INTERVAL;
        }

        DEFAULT_RECHECK_INTERVAL
    }

370
371
372
373
374
375
376
377
378
379
380
381
    pub fn assume_kv_reuse(&self, config_override: Option<&RouterConfigOverride>) -> bool {
        config_override
            .and_then(|cfg| cfg.assume_kv_reuse)
            .unwrap_or(self.router_assume_kv_reuse)
    }

    pub fn track_prefill_tokens(&self, config_override: Option<&RouterConfigOverride>) -> bool {
        config_override
            .and_then(|cfg| cfg.track_prefill_tokens)
            .unwrap_or(self.router_track_prefill_tokens)
    }

382
383
384
385
386
387
388
389
390
391
392
    /// Compute sequence hashes for active block tracking based on configuration.
    ///
    /// Returns:
    /// - `None` if `router_track_active_blocks` is false
    /// - Random hashes if `router_track_active_blocks` is true but `router_assume_kv_reuse` is false
    /// - Actual sequence hashes if both are true
    pub fn compute_seq_hashes_for_tracking(
        &self,
        tokens: &[u32],
        block_size: u32,
        config_override: Option<&RouterConfigOverride>,
393
        hash_options: BlockHashOptions<'_>,
394
        precomputed_block_hashes: Option<&[LocalBlockHash]>,
395
396
397
398
399
400
401
402
403
404
    ) -> Option<Vec<u64>> {
        if !self.router_track_active_blocks {
            return None;
        }

        let num_blocks = tokens.len() / block_size as usize;
        if num_blocks == 0 {
            return Some(Vec::new());
        }

405
        let assume_kv_reuse = self.assume_kv_reuse(config_override);
406
407

        if assume_kv_reuse {
408
409
410
411
412
413
414
415
            let block_hashes = match precomputed_block_hashes {
                Some(block_hashes) => block_hashes,
                None => {
                    let computed = compute_block_hash_for_seq(tokens, block_size, hash_options);
                    return Some(compute_seq_hash_for_block(&computed));
                }
            };
            Some(compute_seq_hash_for_block(block_hashes))
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
        } else {
            let mut rng = rand::rng();
            Some((0..num_blocks).map(|_| rng.random::<u64>()).collect())
        }
    }

    /// Check if KV event subscription should be started.
    ///
    /// Returns false if:
    /// - KV events are disabled (`use_kv_events=false`)
    /// - Overlap scoring is disabled (`overlap_score_weight=0`)
    ///
    /// When false, the router skips starting the KV event subscription entirely,
    /// avoiding the need to query workers for their local indexer state.
    pub fn should_subscribe_to_kv_events(&self) -> bool {
        self.use_kv_events && self.overlap_score_weight > 0.0
    }
}
434
435
436
437

#[cfg(test)]
mod tests {
    use super::*;
438
    use crate::protocols::{BlockExtraInfo, BlockMmObjectInfo};
439

440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
    #[test]
    fn compute_seq_hashes_for_tracking_uses_mm_hashes() {
        let cfg = KvRouterConfig::default();
        let tokens = vec![1, 2, 3, 4];
        let mm_infos = vec![
            Some(BlockExtraInfo {
                mm_objects: vec![BlockMmObjectInfo {
                    mm_hash: 42,
                    offsets: vec![],
                }],
            }),
            None,
        ];

        let without_mm = cfg
455
            .compute_seq_hashes_for_tracking(&tokens, 2, None, BlockHashOptions::default(), None)
456
457
458
459
460
461
462
463
464
465
            .unwrap();
        let with_mm = cfg
            .compute_seq_hashes_for_tracking(
                &tokens,
                2,
                None,
                BlockHashOptions {
                    block_mm_infos: Some(&mm_infos),
                    ..Default::default()
                },
466
                None,
467
468
469
470
471
            )
            .unwrap();

        assert_ne!(without_mm, with_mm);
    }
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488

    #[test]
    fn compute_seq_hashes_for_tracking_uses_precomputed_block_hashes() {
        let config = KvRouterConfig::default();
        let tokens: Vec<u32> = (0..8).collect();
        let precomputed = vec![LocalBlockHash(11), LocalBlockHash(29)];

        let seq_hashes = config.compute_seq_hashes_for_tracking(
            &tokens,
            4,
            None,
            BlockHashOptions::default(),
            Some(&precomputed),
        );

        assert_eq!(seq_hashes, Some(compute_seq_hash_for_block(&precomputed)));
    }
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529

    #[test]
    fn test_kv_router_config_rejects_out_of_range_shared_cache_multiplier() {
        let too_small = KvRouterConfig {
            shared_cache_multiplier: -0.1,
            ..Default::default()
        };
        let too_large = KvRouterConfig {
            shared_cache_multiplier: 1.1,
            ..Default::default()
        };

        assert!(too_small.validate().is_err());
        assert!(too_large.validate().is_err());
    }

    #[test]
    fn test_router_config_override_rejects_out_of_range_shared_cache_multiplier() {
        let too_small = RouterConfigOverride {
            overlap_score_weight: None,
            router_temperature: None,
            assume_kv_reuse: None,
            track_prefill_tokens: None,
            shared_cache_multiplier: Some(-0.1),
        };
        let too_large = RouterConfigOverride {
            overlap_score_weight: None,
            router_temperature: None,
            assume_kv_reuse: None,
            track_prefill_tokens: None,
            shared_cache_multiplier: Some(1.1),
        };

        assert!(too_small.validate().is_err());
        assert!(too_large.validate().is_err());
    }

    #[test]
    fn test_kv_router_config_default_shared_cache_multiplier_is_disabled() {
        assert_eq!(KvRouterConfig::default().shared_cache_multiplier, 0.0);
    }
530
}