kv_router.rs 38.9 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
5
6
7
8
use std::{
    collections::{HashMap, HashSet},
    sync::Arc,
    time::Instant,
};
9

10
use anyhow::Result;
11
use dynamo_kv_router::{
12
    PrefillLoadEstimator, SharedKvCache,
13
    config::{KvRouterConfig, RouterConfigOverride, min_initial_workers_from_env},
14
    indexer::KvRouterError,
15
16
    protocols::KV_EVENT_SUBJECT,
    protocols::{
17
18
19
        BlockExtraInfo, BlockHashOptions, DpRank, LocalBlockHash, PrefillLoadHint, RouterEvent,
        RouterRequest, RouterResponse, TokensWithHashes, WorkerId, WorkerWithDpRank,
        compute_block_hash_for_seq,
20
    },
21
    scheduling::TierOverlapBlocks,
22
};
23
use dynamo_runtime::{
24
    component::{Client, Endpoint},
25
    discovery::DiscoveryQuery,
26
    pipeline::{
27
28
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, ResponseStream, SingleIn,
        async_trait,
29
    },
30
    protocols::EndpointId,
31
    protocols::annotated::Annotated,
32
    traits::DistributedRuntimeProvider,
33
};
34
use futures::stream;
35
use tracing::Instrument;
36
use validator::Validate;
37

38
39
40
41
42
43
44
// Re-export from dynamo-kv-router crate
pub use dynamo_kv_router::approx;
pub use dynamo_kv_router::protocols;
pub use dynamo_kv_router::scheduling;
pub use dynamo_kv_router::selector;

pub mod agent_controller;
45
pub mod indexer;
46
pub mod metrics;
47
pub mod prefill_router;
48
pub mod publisher;
49
pub mod push_router;
50
pub mod scheduler;
51
pub mod sequence;
52
pub mod shared_cache;
53
pub mod sticky_sessions;
54

55
pub use agent_controller::AgentController;
56
pub use indexer::{Indexer, ServedIndexerHandle, ServedIndexerMode, ensure_served_indexer_service};
57
pub use prefill_router::PrefillRouter;
58
pub use push_router::{DirectRoutingRouter, KvPushRouter};
59
pub use sticky_sessions::StickySessionRouter;
60

61
use crate::{
62
    discovery::RuntimeConfigWatch,
63
    kv_router::{
64
        scheduler::{DefaultWorkerSelector, KvScheduler, PotentialLoad},
65
        sequence::{SequenceError, SequenceRequest},
66
    },
67
    local_model::runtime_config::ModelRuntimeConfig,
68
69
};

70
71
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
72
73
74
75
76
77
78
79
80
81

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
82

83
84
85
86
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

87
88
89
// for worker-local kvindexer query
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer

90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
#[derive(Debug, Clone, Copy, Default)]
pub(crate) struct WorkerCacheHitEstimate {
    pub effective_overlap_blocks: f64,
    pub cached_tokens: usize,
}

impl WorkerCacheHitEstimate {
    pub fn rounded_overlap_blocks(self) -> u32 {
        self.effective_overlap_blocks.round() as u32
    }
}

#[derive(Debug, Clone, Default)]
struct CacheHitEstimates {
    effective_overlap_blocks: HashMap<WorkerWithDpRank, f64>,
    cached_tokens: HashMap<WorkerWithDpRank, usize>,
}

#[derive(Debug, Clone, Copy)]
pub(crate) struct BestMatchDetails {
    pub worker: WorkerWithDpRank,
    pub cache_hit: WorkerCacheHitEstimate,
}

fn cache_hit_weight_for_tier(
    kv_router_config: &KvRouterConfig,
    storage_tier: dynamo_kv_router::protocols::StorageTier,
) -> f64 {
    match storage_tier {
        dynamo_kv_router::protocols::StorageTier::Device => 1.0,
        dynamo_kv_router::protocols::StorageTier::HostPinned => {
            kv_router_config.host_cache_hit_weight
        }
        dynamo_kv_router::protocols::StorageTier::Disk
        | dynamo_kv_router::protocols::StorageTier::External => {
            kv_router_config.disk_cache_hit_weight
        }
    }
}

fn cached_tokens_from_effective_overlap(block_size: u32, effective_overlap_blocks: f64) -> usize {
    (effective_overlap_blocks * block_size as f64)
        .round()
        .max(0.0) as usize
}

fn cache_hit_estimates_from_tiered_matches(
    kv_router_config: &KvRouterConfig,
    block_size: u32,
    tiered_matches: &indexer::TieredMatchDetails,
) -> CacheHitEstimates {
    let mut effective_overlap_blocks = HashMap::new();

    for (worker, overlap) in &tiered_matches.device.overlap_scores.scores {
        effective_overlap_blocks.insert(*worker, *overlap as f64);
    }

    for (storage_tier, tier_matches) in &tiered_matches.lower_tier {
        let weight = cache_hit_weight_for_tier(kv_router_config, *storage_tier);
        if weight == 0.0 {
            continue;
        }

        for (worker, hits) in &tier_matches.hits {
            if *hits == 0 {
                continue;
            }
            *effective_overlap_blocks.entry(*worker).or_insert(0.0) += *hits as f64 * weight;
        }
    }

    let cached_tokens = effective_overlap_blocks
        .iter()
        .map(|(worker, overlap)| {
            (
                *worker,
                cached_tokens_from_effective_overlap(block_size, *overlap),
            )
        })
        .collect();

    CacheHitEstimates {
        effective_overlap_blocks,
        cached_tokens,
    }
}

fn cache_hit_for_worker(
    cache_hit_estimates: &CacheHitEstimates,
    worker: WorkerWithDpRank,
) -> WorkerCacheHitEstimate {
    WorkerCacheHitEstimate {
        effective_overlap_blocks: cache_hit_estimates
            .effective_overlap_blocks
            .get(&worker)
            .copied()
            .unwrap_or(0.0),
        cached_tokens: cache_hit_estimates
            .cached_tokens
            .get(&worker)
            .copied()
            .unwrap_or(0),
    }
}

fn tier_overlap_blocks_from_tiered_matches(
    tiered_matches: &indexer::TieredMatchDetails,
) -> TierOverlapBlocks {
    let mut tier_overlap_blocks = TierOverlapBlocks::default();

    if let Some(host_matches) = tiered_matches
        .lower_tier
        .get(&dynamo_kv_router::protocols::StorageTier::HostPinned)
    {
        tier_overlap_blocks.host_pinned.extend(
            host_matches
                .hits
                .iter()
                .map(|(worker, hits)| (*worker, *hits)),
        );
    }

    // Disk and External share the same weighting (see `storage_tier_weight`),
    // so accumulate both into the disk bucket.
    for tier in [
        dynamo_kv_router::protocols::StorageTier::Disk,
        dynamo_kv_router::protocols::StorageTier::External,
    ] {
        if let Some(matches) = tiered_matches.lower_tier.get(&tier) {
            for (worker, hits) in &matches.hits {
                *tier_overlap_blocks.disk.entry(*worker).or_default() += *hits;
            }
        }
    }

    tier_overlap_blocks
}

228
229
230
231
232
233
/// Generates a dp_rank-specific endpoint name for the worker KV indexer query service.
/// Each dp_rank has its own LocalKvIndexer and query endpoint to ensure per-dp_rank monotonicity.
pub fn worker_kv_indexer_query_endpoint(dp_rank: DpRank) -> String {
    format!("worker_kv_indexer_query_dp{dp_rank}")
}

234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
fn log_routing_input_hashes(
    request_id: Option<&str>,
    block_size: u32,
    tokens: &[u32],
    local_hashes: &[LocalBlockHash],
) {
    if !tracing::enabled!(tracing::Level::DEBUG) {
        return;
    }

    let local_hash_ids: Vec<u64> = local_hashes.iter().map(|hash| hash.0).collect();

    tracing::debug!(
        request_id = request_id.unwrap_or(""),
        isl_tokens = tokens.len(),
        block_size,
        num_blocks = local_hashes.len(),
        local_hashes = ?local_hash_ids,
        "[ROUTING_INPUT] request local hashes"
    );
}

256
// for router discovery registration
257
pub const KV_ROUTER_ENDPOINT: &str = "router-discovery";
258
259

/// Creates an EndpointId for the KV router in the given namespace.
260
pub fn router_endpoint_id(namespace: String, component: String) -> EndpointId {
261
262
    EndpointId {
        namespace,
263
        component,
264
265
266
267
268
        name: KV_ROUTER_ENDPOINT.to_string(),
    }
}

/// Creates a DiscoveryQuery for the KV router in the given namespace.
269
pub fn router_discovery_query(namespace: String, component: String) -> DiscoveryQuery {
270
271
    DiscoveryQuery::Endpoint {
        namespace,
272
        component,
273
274
275
276
        endpoint: KV_ROUTER_ENDPOINT.to_string(),
    }
}

277
278
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
279
280
281
282
pub struct KvRouter<Sel = DefaultWorkerSelector>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>,
{
283
    indexer: Indexer,
284
    scheduler: KvScheduler<Sel>,
285
    workers_with_configs: RuntimeConfigWatch,
286
    block_size: u32,
287
    kv_router_config: KvRouterConfig,
288
    prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
Yan Ru Pei's avatar
Yan Ru Pei committed
289
    cancellation_token: tokio_util::sync::CancellationToken,
290
    client: Client,
291
    is_eagle: bool,
292
    _served_indexer_handle: Option<ServedIndexerHandle>,
293
294
295
    /// Optional external shared KV cache pool. When present, `find_best_match`
    /// queries it in parallel with the indexer and factors shared hits into scoring.
    shared_cache: Option<Box<dyn SharedKvCache>>,
296
297
}

298
299
300
301
impl<Sel> KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
{
302
    #[allow(clippy::too_many_arguments)]
303
    pub async fn new(
304
305
        endpoint: Endpoint,
        client: Client,
306
        workers_with_configs: RuntimeConfigWatch,
307
        block_size: u32,
308
        selector: Sel,
309
        kv_router_config: Option<KvRouterConfig>,
310
        prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
311
        worker_type: &'static str,
312
        model_name: Option<String>,
313
        is_eagle: bool,
314
        shared_cache: Option<Box<dyn SharedKvCache>>,
315
    ) -> Result<Self> {
316
        let kv_router_config = kv_router_config.unwrap_or_default();
317
        kv_router_config.validate()?;
318
        let component = endpoint.component();
319
        let cancellation_token = component.drt().primary_token();
320
        let min_initial_workers = min_initial_workers_from_env()?;
321

322
323
324
325
326
327
328
        let indexer = Indexer::new(
            component,
            &kv_router_config,
            block_size,
            model_name.as_deref(),
        )
        .await?;
329

330
331
332
333
        if min_initial_workers > 0 && !kv_router_config.skip_initial_worker_wait {
            let mut startup_watch = workers_with_configs.clone();
            let _ = startup_watch
                .wait_for(|m| m.len() >= min_initial_workers)
334
335
                .await
                .map_err(|_| {
336
337
                    anyhow::anyhow!(
                        "runtime config watch closed before {} workers appeared",
338
                        min_initial_workers
339
                    )
340
341
                })?;
        }
342

343
        let scheduler = KvScheduler::start(
344
            component.clone(),
345
            block_size,
346
            workers_with_configs.clone(),
347
            selector,
348
            &kv_router_config,
349
            prefill_load_estimator.clone(),
350
            worker_type,
351
352
        )
        .await?;
353

354
355
        // Start KV event subscription if needed — skip when using a remote indexer.
        if kv_router_config.use_remote_indexer {
356
357
            tracing::info!("Skipping KV event subscription (using remote indexer)");
        } else if kv_router_config.should_subscribe_to_kv_events() {
358
            indexer::start_subscriber(component.clone(), &kv_router_config, indexer.clone())
359
                .await?;
360
        } else {
361
            tracing::info!(
362
363
364
                "Skipping KV event subscription (use_kv_events={}, overlap_score_weight={})",
                kv_router_config.use_kv_events,
                kv_router_config.overlap_score_weight,
365
            );
366
        }
367

368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
        let served_indexer_handle = if kv_router_config.serve_indexer {
            let model_name = model_name.clone().ok_or_else(|| {
                anyhow::anyhow!("model_name is required when serve_indexer is configured")
            })?;
            Some(
                ensure_served_indexer_service(
                    component.clone(),
                    ServedIndexerMode::from_use_kv_events(kv_router_config.use_kv_events),
                    model_name,
                    indexer.clone(),
                )
                .await?,
            )
        } else {
            None
        };

385
        tracing::info!("KV Routing initialized");
386
        Ok(Self {
387
            indexer,
388
            scheduler,
389
            workers_with_configs,
390
            block_size,
391
            kv_router_config,
392
            prefill_load_estimator,
Yan Ru Pei's avatar
Yan Ru Pei committed
393
            cancellation_token,
394
            client,
395
            is_eagle,
396
            _served_indexer_handle: served_indexer_handle,
397
            shared_cache,
398
        })
399
400
    }

401
402
403
404
405
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

406
407
408
409
410
411
412
413
    pub fn indexer(&self) -> &Indexer {
        &self.indexer
    }

    pub fn kv_router_config(&self) -> &KvRouterConfig {
        &self.kv_router_config
    }

414
415
416
417
    pub fn is_eagle(&self) -> bool {
        self.is_eagle
    }

418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
    fn cache_hit_estimates_from_tiered_matches(
        &self,
        tiered_matches: &indexer::TieredMatchDetails,
    ) -> CacheHitEstimates {
        cache_hit_estimates_from_tiered_matches(
            &self.kv_router_config,
            self.block_size,
            tiered_matches,
        )
    }

    fn cache_hit_for_worker(
        &self,
        cache_hit_estimates: &CacheHitEstimates,
        worker: WorkerWithDpRank,
    ) -> WorkerCacheHitEstimate {
        cache_hit_for_worker(cache_hit_estimates, worker)
    }

437
438
    pub async fn record_routing_decision(
        &self,
439
        mut tokens_with_hashes: TokensWithHashes,
440
441
442
443
444
445
446
        worker: WorkerWithDpRank,
    ) -> Result<(), KvRouterError> {
        self.indexer
            .process_routing_decision_for_request(&mut tokens_with_hashes, worker)
            .await
    }

447
448
    /// Give these tokens, find the worker with the best weighted cache hit.
    /// Returns the full match details for the selected worker.
449
    ///
450
451
452
    /// When `pinned_worker` is Some, scheduling and queueing are constrained to
    /// that exact worker/rank.
    ///
453
    /// When `allowed_worker_ids` is Some, only workers in that set are considered for selection.
454
    #[allow(clippy::too_many_arguments)]
455
    pub(crate) async fn find_best_match_details(
456
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
457
        context_id: Option<&str>,
458
        tokens: &[u32],
459
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
460
        router_config_override: Option<&RouterConfigOverride>,
461
        update_states: bool,
462
        lora_name: Option<String>,
463
        priority_jump: f64,
464
        expected_output_tokens: Option<u32>,
465
        pinned_worker: Option<WorkerWithDpRank>,
466
        allowed_worker_ids: Option<HashSet<WorkerId>>,
467
    ) -> anyhow::Result<BestMatchDetails> {
468
469
        let start = Instant::now();

Yan Ru Pei's avatar
Yan Ru Pei committed
470
        if update_states && context_id.is_none() {
471
            anyhow::bail!("context_id must be provided when update_states is true");
Yan Ru Pei's avatar
Yan Ru Pei committed
472
473
        }

474
        let isl_tokens = tokens.len();
475
476
477
478
479
480
481
482
        let hash_options = BlockHashOptions {
            block_mm_infos,
            lora_name: lora_name.as_deref(),
            is_eagle: Some(self.is_eagle),
        };

        let block_hashes = tracing::info_span!("kv_router.compute_block_hashes")
            .in_scope(|| compute_block_hash_for_seq(tokens, self.block_size, hash_options));
483
        log_routing_input_hashes(context_id, self.block_size, tokens, &block_hashes);
484
485
486
487
        let hash_elapsed = start.elapsed();
        // Compute seq_hashes only if scheduler needs it for active blocks tracking
        let maybe_seq_hashes = tracing::info_span!("kv_router.compute_seq_hashes").in_scope(|| {
            self.kv_router_config.compute_seq_hashes_for_tracking(
488
489
                tokens,
                self.block_size,
490
491
492
                router_config_override,
                hash_options,
                Some(&block_hashes),
493
494
            )
        });
495
        let seq_hash_elapsed = start.elapsed();
496

497
        // Query indexer (tiered) and shared cache in parallel when shared cache is configured.
498
        // Time each independently so metrics can separate indexer vs shared cache latency.
499
        let (tiered_matches, shared_cache_hits, indexer_duration, shared_cache_duration) =
500
501
502
            if let Some(ref shared_cache) = self.shared_cache {
                let indexer_fut = self
                    .indexer
503
                    .find_matches_by_tier(block_hashes)
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
                    .instrument(tracing::info_span!("kv_router.find_matches"));
                let shared_fut = shared_cache
                    .check_blocks(tokens, self.block_size)
                    .instrument(tracing::info_span!("kv_router.shared_cache_check"));

                let indexer_timed = async {
                    let t = Instant::now();
                    let r = indexer_fut.await;
                    (r, t.elapsed())
                };
                let shared_timed = async {
                    let t = Instant::now();
                    let r = shared_fut.await;
                    (r, t.elapsed())
                };

                let ((indexer_result, idx_dur), (shared_result, sc_dur)) =
                    tokio::join!(indexer_timed, shared_timed);
522
                let tiered = indexer_result?;
523
524
525
526
527
528
529
530
531
532
533
                // Shared cache failure is non-fatal: log warning and fall back to empty hits.
                let hits = match shared_result {
                    Ok(hits) => Some(hits),
                    Err(e) => {
                        tracing::warn!(error = %e, "Shared cache query failed, ignoring");
                        if let Some(m) = metrics::RoutingOverheadMetrics::get() {
                            m.inc_shared_cache_errors();
                        }
                        None
                    }
                };
534
                (tiered, hits, idx_dur, Some(sc_dur))
535
536
            } else {
                let t = Instant::now();
537
                let tiered = self
538
                    .indexer
539
                    .find_matches_by_tier(block_hashes)
540
541
                    .instrument(tracing::info_span!("kv_router.find_matches"))
                    .await?;
542
                (tiered, None, t.elapsed(), None)
543
            };
544
545
546
547
548
549
550
551
552
553

        let tier_overlap_blocks = tier_overlap_blocks_from_tiered_matches(&tiered_matches);
        let cache_hit_estimates = self.cache_hit_estimates_from_tiered_matches(&tiered_matches);
        let tree_sizes: HashMap<_, _> = tiered_matches
            .device
            .overlap_scores
            .tree_sizes
            .iter()
            .map(|(k, v)| (*k, *v))
            .collect();
554
        let find_matches_elapsed = start.elapsed();
555

556
557
558
559
560
561
        // Capture shared cache info for metrics before moving into schedule().
        // Clone the hits so we can compute `hits_beyond(overlap_blocks)` after
        // scheduling returns, since `overlap_blocks` isn't known until then.
        let num_blocks = isl_tokens / self.block_size as usize;
        let sc_hits_for_metrics = shared_cache_hits.clone();

562
        let response = self
563
            .scheduler
564
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
565
                context_id.map(|s| s.to_string()),
566
                isl_tokens,
567
                maybe_seq_hashes,
568
569
570
571
                tier_overlap_blocks,
                cache_hit_estimates.effective_overlap_blocks,
                cache_hit_estimates.cached_tokens,
                tree_sizes,
572
                router_config_override,
573
                update_states,
574
                lora_name,
575
                priority_jump,
576
                expected_output_tokens,
577
                pinned_worker,
578
                allowed_worker_ids,
579
                shared_cache_hits,
580
            )
581
            .instrument(tracing::info_span!("kv_router.schedule"))
582
            .await?;
583
584
        let total_elapsed = start.elapsed();

585
586
587
588
        if let Some(m) = metrics::RoutingOverheadMetrics::get() {
            m.observe(
                hash_elapsed,
                seq_hash_elapsed,
589
590
                indexer_duration,
                shared_cache_duration,
591
                find_matches_elapsed,
592
593
594
                total_elapsed,
            );
        }
595

596
597
598
599
600
601
602
603
        // Observe per-request shared cache metrics.
        if let Some(hits) = sc_hits_for_metrics
            && let Some(m) = metrics::RouterRequestMetrics::get()
        {
            if num_blocks > 0 {
                m.shared_cache_hit_rate
                    .observe(hits.total_hits as f64 / num_blocks as f64);
            }
604
            let beyond = hits.hits_beyond(response.effective_overlap_blocks.round() as u32);
605
606
607
            m.shared_cache_beyond_blocks.observe(beyond as f64);
        }

608
        #[cfg(feature = "bench")]
609
610
611
        tracing::info!(
            isl_tokens,
            hash_us = hash_elapsed.as_micros() as u64,
612
613
614
            seq_hash_us = (seq_hash_elapsed - hash_elapsed).as_micros() as u64,
            find_matches_us = (find_matches_elapsed - seq_hash_elapsed).as_micros() as u64,
            schedule_us = (total_elapsed - find_matches_elapsed).as_micros() as u64,
615
616
617
            total_us = total_elapsed.as_micros() as u64,
            "find_best_match completed"
        );
618

619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
        Ok(BestMatchDetails {
            worker: response.best_worker,
            cache_hit: WorkerCacheHitEstimate {
                effective_overlap_blocks: response.effective_overlap_blocks,
                cached_tokens: response.cached_tokens,
            },
        })
    }

    /// Give these tokens, find the worker with the best match in its KV cache.
    /// Returns the best worker (with dp_rank) and approximate effective overlap in blocks.
    #[allow(clippy::too_many_arguments)]
    pub async fn find_best_match(
        &self,
        context_id: Option<&str>,
        tokens: &[u32],
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
        router_config_override: Option<&RouterConfigOverride>,
        update_states: bool,
        lora_name: Option<String>,
        priority_jump: f64,
        expected_output_tokens: Option<u32>,
        allowed_worker_ids: Option<HashSet<WorkerId>>,
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
        let result = self
            .find_best_match_details(
                context_id,
                tokens,
                block_mm_infos,
                router_config_override,
                update_states,
                lora_name,
                priority_jump,
                expected_output_tokens,
                None,
                allowed_worker_ids,
            )
            .await?;
        Ok((result.worker, result.cache_hit.rounded_overlap_blocks()))
658
659
    }

660
661
662
663
664
    /// Register externally-provided workers in the slot tracker.
    pub fn register_workers(&self, worker_ids: &HashSet<WorkerId>) {
        self.scheduler.register_workers(worker_ids);
    }

665
    #[allow(clippy::too_many_arguments)]
666
667
668
669
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
670
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
671
        cached_tokens: usize,
672
        expected_output_tokens: Option<u32>,
Yan Ru Pei's avatar
Yan Ru Pei committed
673
        worker: WorkerWithDpRank,
674
        lora_name: Option<String>,
675
        router_config_override: Option<&RouterConfigOverride>,
676
677
    ) {
        let isl_tokens = tokens.len();
678
679
680
681
682
        let hash_options = BlockHashOptions {
            block_mm_infos,
            lora_name: lora_name.as_deref(),
            is_eagle: Some(self.is_eagle),
        };
683

684
685
686
687
        let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
            tokens,
            self.block_size,
            router_config_override,
688
689
            hash_options,
            None,
690
        );
691
692
693
        let track_prefill_tokens = self
            .kv_router_config
            .track_prefill_tokens(router_config_override);
694
        let prefill_load_hint =
695
            self.prefill_load_hint_for(isl_tokens, cached_tokens, track_prefill_tokens);
696

697
698
        if let Err(e) = self
            .scheduler
699
700
701
            .add_request(SequenceRequest {
                request_id: request_id.clone(),
                token_sequence: maybe_seq_hashes,
702
                track_prefill_tokens,
703
                expected_output_tokens,
704
                prefill_load_hint,
Yan Ru Pei's avatar
Yan Ru Pei committed
705
                worker,
706
                lora_name,
707
            })
708
709
710
711
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
712
713
    }

714
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
715
        self.scheduler.mark_prefill_completed(request_id).await
716
717
    }

718
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
719
        self.scheduler.free(request_id).await
720
    }
721

722
723
724
725
726
    /// Number of requests currently parked in the scheduler queue.
    pub fn pending_count(&self) -> usize {
        self.scheduler.pending_count()
    }

727
728
729
    fn prefill_load_hint_for(
        &self,
        isl_tokens: usize,
730
        cached_tokens: usize,
731
732
733
734
735
736
        track_prefill_tokens: bool,
    ) -> Option<PrefillLoadHint> {
        if !track_prefill_tokens {
            return None;
        }

737
        let prefix = cached_tokens.min(isl_tokens);
738
739
740
741
742
        let effective_isl = isl_tokens.saturating_sub(prefix);
        if effective_isl == 0 {
            return None;
        }

743
744
745
746
747
748
749
750
751
752
753
754
755
        let expected_prefill_duration = match &self.prefill_load_estimator {
            Some(estimator) => match estimator.predict_prefill_duration(1, effective_isl, prefix) {
                Ok(expected_prefill_duration) => Some(expected_prefill_duration),
                Err(error) => {
                    tracing::warn!(
                        effective_isl,
                        prefix,
                        "failed to predict prefill duration for direct add_request path: {error}"
                    );
                    None
                }
            },
            None => None,
756
757
        };

758
759
760
761
        Some(PrefillLoadHint {
            initial_effective_prefill_tokens: effective_isl,
            expected_prefill_duration,
        })
762
763
    }

764
765
766
767
768
769
    /// Get the worker type for this router ("prefill" or "decode").
    /// Used for Prometheus metric labeling.
    pub fn worker_type(&self) -> &'static str {
        self.scheduler.worker_type()
    }

770
771
772
773
774
775
776
    /// Return the worker's unique global DP rank when it owns exactly one rank.
    pub fn unique_dp_rank_for_worker(&self, worker_id: WorkerId) -> Option<u32> {
        let configs = self.workers_with_configs.borrow();
        let config = configs.get(&worker_id)?;
        (config.data_parallel_size == 1).then_some(config.data_parallel_start_rank)
    }

777
    pub fn add_output_block(
778
779
780
781
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
782
        self.scheduler.add_output_block(request_id, decay_fraction)
783
784
    }

785
    pub fn block_size(&self) -> u32 {
786
787
        self.block_size
    }
788

789
    /// Compute the overlap blocks for a given token sequence and worker.
790
    /// This queries the indexer to find the effective weighted cache hit.
791
792
793
    pub async fn get_overlap_blocks(
        &self,
        tokens: &[u32],
794
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
795
        worker: WorkerWithDpRank,
796
        lora_name: Option<&str>,
797
    ) -> Result<u32, KvRouterError> {
798
799
800
801
802
803
804
805
806
807
808
809
810
        Ok(self
            .get_cache_hit_estimate(tokens, block_mm_infos, worker, lora_name)
            .await?
            .rounded_overlap_blocks())
    }

    pub(crate) async fn get_cache_hit_estimate(
        &self,
        tokens: &[u32],
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
        worker: WorkerWithDpRank,
        lora_name: Option<&str>,
    ) -> Result<WorkerCacheHitEstimate, KvRouterError> {
811
812
813
814
815
816
817
818
819
        let block_hashes = compute_block_hash_for_seq(
            tokens,
            self.block_size,
            BlockHashOptions {
                block_mm_infos,
                lora_name,
                is_eagle: Some(self.is_eagle),
            },
        );
820
821
822
        let tiered_matches = self.indexer.find_matches_by_tier(block_hashes).await?;
        let cache_hit_estimates = self.cache_hit_estimates_from_tiered_matches(&tiered_matches);
        Ok(self.cache_hit_for_worker(&cache_hit_estimates, worker))
823
824
    }

825
    /// Get potential prefill and decode loads for all workers
826
827
828
829
    pub async fn get_potential_loads(
        &self,
        tokens: &[u32],
        router_config_override: Option<&RouterConfigOverride>,
830
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
831
        lora_name: Option<&str>,
832
    ) -> Result<Vec<PotentialLoad>> {
833
        let isl_tokens = tokens.len();
834
835
836
837
838
839
        let hash_options = BlockHashOptions {
            block_mm_infos,
            lora_name,
            is_eagle: Some(self.is_eagle),
        };
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, hash_options);
840

841
842
843
844
        let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
            tokens,
            self.block_size,
            router_config_override,
845
846
            hash_options,
            Some(&block_hashes),
847
        );
848
849
850
        let track_prefill_tokens = self
            .kv_router_config
            .track_prefill_tokens(router_config_override);
851
852
        let tiered_matches = self.indexer.find_matches_by_tier(block_hashes).await?;
        let cache_hit_estimates = self.cache_hit_estimates_from_tiered_matches(&tiered_matches);
853

854
855
856
        Ok(self.scheduler.get_potential_loads(
            maybe_seq_hashes,
            isl_tokens,
857
            cache_hit_estimates.cached_tokens,
858
859
            track_prefill_tokens,
        ))
860
861
    }

862
863
864
865
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
866
867
}

Michael Feil's avatar
Michael Feil committed
868
869
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
870
#[async_trait]
871
872
873
874
875
impl<Sel> AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error>
    for KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
{
876
877
878
879
880
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
881
882
883
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
884
885
886
887
            RouterRequest::New {
                tokens,
                block_mm_infos,
            } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
888
                let (best_worker, overlap_blocks) = self
889
890
891
892
893
894
895
896
                    .find_best_match(
                        Some(&context_id),
                        &tokens,
                        block_mm_infos.as_deref(),
                        None,
                        true,
                        None,
                        0.0,
897
                        None,
898
                        None,
899
                    )
Michael Feil's avatar
Michael Feil committed
900
901
902
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
903
904
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
905
906
907
                    overlap_blocks,
                }
            }
908
909
910
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
911
912
913
914
915
916
917
918
919
            RouterRequest::MarkFree { request_id } => {
                let request_id = match request_id.as_deref() {
                    Some(request_id) if !request_id.trim().is_empty() => request_id,
                    _ => &context_id,
                };
                RouterResponse::FreeMarked {
                    success: self.free(request_id).await.is_ok(),
                }
            }
Michael Feil's avatar
Michael Feil committed
920
        };
921
922
923
924
925
926

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
927

928
929
930
931
impl<Sel> Drop for KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>,
{
Yan Ru Pei's avatar
Yan Ru Pei committed
932
933
934
935
936
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}
937
938
939
940
941
942
943

#[cfg(test)]
mod tests {
    use super::*;
    use std::collections::HashMap;

    use async_trait::async_trait;
944
945
946
947
    use dynamo_kv_router::{
        indexer::{LowerTierMatchDetails, MatchDetails},
        protocols::{OverlapScores, StorageTier},
    };
948
949
950
951
952
953
    use dynamo_runtime::{DistributedRuntime, Runtime, distributed::DistributedConfig};
    use tokio::sync::watch;

    use crate::kv_router::scheduler::KvSchedulerError;
    use crate::local_model::runtime_config::ModelRuntimeConfig;

954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
    #[test]
    fn weighted_cache_hit_estimates_include_lower_tiers() {
        let worker_1 = WorkerWithDpRank::new(1, 0);
        let worker_2 = WorkerWithDpRank::new(2, 0);
        let mut device_overlap_scores = OverlapScores::new();
        device_overlap_scores.scores.insert(worker_1, 2);
        let mut host_match_details = LowerTierMatchDetails::default();
        host_match_details.hits.insert(worker_1, 1);
        host_match_details.hits.insert(worker_2, 1);
        let mut disk_match_details = LowerTierMatchDetails::default();
        disk_match_details.hits.insert(worker_1, 2);

        let tiered_matches = indexer::TieredMatchDetails {
            device: MatchDetails {
                overlap_scores: device_overlap_scores,
                ..Default::default()
            },
            lower_tier: HashMap::from([
                (StorageTier::HostPinned, host_match_details),
                (StorageTier::Disk, disk_match_details),
            ]),
        };

        let estimates = cache_hit_estimates_from_tiered_matches(
            &KvRouterConfig::default(),
            16,
            &tiered_matches,
        );

        assert_eq!(
            estimates.effective_overlap_blocks.get(&worker_1),
            Some(&3.25)
        );
        assert_eq!(estimates.cached_tokens.get(&worker_1), Some(&52));
        assert_eq!(
            estimates.effective_overlap_blocks.get(&worker_2),
            Some(&0.75)
        );
        assert_eq!(estimates.cached_tokens.get(&worker_2), Some(&12));
    }

995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
    struct FakeSharedCache {
        hits: Option<dynamo_kv_router::protocols::SharedCacheHits>,
        should_error: bool,
    }

    #[async_trait]
    impl SharedKvCache for FakeSharedCache {
        async fn check_blocks(
            &self,
            _tokens: &[u32],
            _block_size: u32,
        ) -> Result<dynamo_kv_router::protocols::SharedCacheHits, KvRouterError> {
            if self.should_error {
                Err(KvRouterError::IndexerOffline)
            } else {
                Ok(self.hits.clone().unwrap_or_default())
            }
        }
    }

    struct InspectingSelector {
        expected_hits: Option<u32>,
        selected_worker: WorkerWithDpRank,
    }

    impl dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> for InspectingSelector {
        fn select_worker(
            &self,
            _workers: &HashMap<WorkerId, ModelRuntimeConfig>,
            request: &dynamo_kv_router::scheduling::SchedulingRequest,
            block_size: u32,
        ) -> Result<dynamo_kv_router::protocols::WorkerSelectionResult, KvSchedulerError> {
            let observed_hits = request
                .shared_cache_hits
                .as_ref()
                .map(|hits| hits.total_hits);
            assert_eq!(observed_hits, self.expected_hits);

            Ok(dynamo_kv_router::protocols::WorkerSelectionResult {
                worker: self.selected_worker,
                required_blocks: request.isl_tokens.div_ceil(block_size as usize) as u64,
1036
1037
                effective_overlap_blocks: 0.0,
                cached_tokens: 0,
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
            })
        }
    }

    async fn make_test_component(name: &str) -> dynamo_runtime::component::Component {
        let runtime = Runtime::from_current().unwrap();
        let drt = DistributedRuntime::new(runtime, DistributedConfig::process_local())
            .await
            .unwrap();
        let namespace = drt.namespace(format!("test-ns-{name}")).unwrap();
        namespace
            .component(format!("test-component-{name}"))
            .unwrap()
    }

    async fn make_test_router(
        selector: impl dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>
        + Send
        + Sync
        + 'static,
        shared_cache: Option<Box<dyn SharedKvCache>>,
    ) -> KvRouter<
        impl dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
    > {
        let component = make_test_component("shared-cache-router").await;
        let endpoint = component.endpoint("backend");
        let client = endpoint.client().await.unwrap();

        let mut workers = HashMap::new();
        workers.insert(0, ModelRuntimeConfig::default());
        workers.insert(1, ModelRuntimeConfig::default());
        let (_tx, rx) = watch::channel(workers);

        let config = KvRouterConfig {
            overlap_score_weight: 0.0,
            router_temperature: 0.0,
            use_kv_events: false,
            router_track_active_blocks: false,
            shared_cache_multiplier: 0.5,
            skip_initial_worker_wait: true,
            ..Default::default()
        };

        KvRouter::new(
            endpoint,
            client,
            rx,
            2,
            selector,
            Some(config),
            None,
            "decode",
            None,
            false,
            shared_cache,
        )
        .await
        .unwrap()
    }

    #[tokio::test]
    async fn test_find_best_match_passes_shared_cache_hits_to_scheduler() {
        let router = make_test_router(
            InspectingSelector {
                expected_hits: Some(2),
                selected_worker: WorkerWithDpRank::from_worker_id(1),
            },
            Some(Box::new(FakeSharedCache {
                #[allow(clippy::single_range_in_vec_init)]
                hits: Some(dynamo_kv_router::protocols::SharedCacheHits::from_ranges(
                    vec![0..2],
                )),
                should_error: false,
            })),
        )
        .await;

        let (worker, overlap) = router
            .find_best_match(
                None,
                &[11, 12, 21, 22],
                None,
                None,
                false,
                None,
                0.0,
                None,
                None,
            )
            .await
            .unwrap();

        assert_eq!(worker, WorkerWithDpRank::from_worker_id(1));
        assert_eq!(overlap, 0);
    }

    #[tokio::test]
    async fn test_find_best_match_ignores_shared_cache_errors() {
        let router = make_test_router(
            InspectingSelector {
                expected_hits: None,
                selected_worker: WorkerWithDpRank::from_worker_id(0),
            },
            Some(Box::new(FakeSharedCache {
                hits: None,
                should_error: true,
            })),
        )
        .await;

        let (worker, overlap) = router
            .find_best_match(
                None,
                &[11, 12, 21, 22],
                None,
                None,
                false,
                None,
                0.0,
                None,
                None,
            )
            .await
            .unwrap();

        assert_eq!(worker, WorkerWithDpRank::from_worker_id(0));
        assert_eq!(overlap, 0);
    }
}