kv_router.rs 30.4 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use std::sync::Arc;
5
use std::time::Instant;
6

7
use anyhow::Result;
8
use dynamo_kv_router::{
9
    PrefillLoadEstimator, SharedKvCache,
10
    config::{KvRouterConfig, RouterConfigOverride, min_initial_workers_from_env},
11
    indexer::KvRouterError,
12
13
    protocols::KV_EVENT_SUBJECT,
    protocols::{
14
15
16
        BlockExtraInfo, BlockHashOptions, DpRank, LocalBlockHash, PrefillLoadHint, RouterEvent,
        RouterRequest, RouterResponse, TokensWithHashes, WorkerId, WorkerWithDpRank,
        compute_block_hash_for_seq,
17
18
    },
};
19
use dynamo_runtime::{
20
    component::{Client, Endpoint},
21
    discovery::DiscoveryQuery,
22
    pipeline::{
23
24
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, ResponseStream, SingleIn,
        async_trait,
25
    },
26
    protocols::EndpointId,
27
    protocols::annotated::Annotated,
28
    traits::DistributedRuntimeProvider,
29
};
30
use futures::stream;
31
use tracing::Instrument;
32
use validator::Validate;
33

34
35
36
37
38
39
40
// Re-export from dynamo-kv-router crate
pub use dynamo_kv_router::approx;
pub use dynamo_kv_router::protocols;
pub use dynamo_kv_router::scheduling;
pub use dynamo_kv_router::selector;

pub mod agent_controller;
41
pub mod indexer;
42
pub mod metrics;
43
pub mod prefill_router;
44
pub mod publisher;
45
pub mod push_router;
46
pub mod scheduler;
47
pub mod sequence;
48
pub mod shared_cache;
49
pub mod sticky_sessions;
50

51
pub use agent_controller::AgentController;
52
pub use indexer::{Indexer, ServedIndexerHandle, ServedIndexerMode, ensure_served_indexer_service};
53
pub use prefill_router::PrefillRouter;
54
pub use push_router::{DirectRoutingRouter, KvPushRouter};
55
pub use sticky_sessions::StickySessionRouter;
56

57
use crate::{
58
    discovery::RuntimeConfigWatch,
59
    kv_router::{
60
        scheduler::{DefaultWorkerSelector, KvScheduler, PotentialLoad},
61
        sequence::{SequenceError, SequenceRequest},
62
    },
63
    local_model::runtime_config::ModelRuntimeConfig,
64
65
};

66
67
use std::collections::HashSet;

68
69
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
70
71
72
73
74
75
76
77
78
79

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
80

81
82
83
84
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

85
86
87
// for worker-local kvindexer query
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer

88
89
90
91
92
93
/// Generates a dp_rank-specific endpoint name for the worker KV indexer query service.
/// Each dp_rank has its own LocalKvIndexer and query endpoint to ensure per-dp_rank monotonicity.
pub fn worker_kv_indexer_query_endpoint(dp_rank: DpRank) -> String {
    format!("worker_kv_indexer_query_dp{dp_rank}")
}

94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
fn log_routing_input_hashes(
    request_id: Option<&str>,
    block_size: u32,
    tokens: &[u32],
    local_hashes: &[LocalBlockHash],
) {
    if !tracing::enabled!(tracing::Level::DEBUG) {
        return;
    }

    let local_hash_ids: Vec<u64> = local_hashes.iter().map(|hash| hash.0).collect();

    tracing::debug!(
        request_id = request_id.unwrap_or(""),
        isl_tokens = tokens.len(),
        block_size,
        num_blocks = local_hashes.len(),
        local_hashes = ?local_hash_ids,
        "[ROUTING_INPUT] request local hashes"
    );
}

116
// for router discovery registration
117
pub const KV_ROUTER_ENDPOINT: &str = "router-discovery";
118
119

/// Creates an EndpointId for the KV router in the given namespace.
120
pub fn router_endpoint_id(namespace: String, component: String) -> EndpointId {
121
122
    EndpointId {
        namespace,
123
        component,
124
125
126
127
128
        name: KV_ROUTER_ENDPOINT.to_string(),
    }
}

/// Creates a DiscoveryQuery for the KV router in the given namespace.
129
pub fn router_discovery_query(namespace: String, component: String) -> DiscoveryQuery {
130
131
    DiscoveryQuery::Endpoint {
        namespace,
132
        component,
133
134
135
136
        endpoint: KV_ROUTER_ENDPOINT.to_string(),
    }
}

137
138
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
139
140
141
142
pub struct KvRouter<Sel = DefaultWorkerSelector>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>,
{
143
    indexer: Indexer,
144
    scheduler: KvScheduler<Sel>,
145
    workers_with_configs: RuntimeConfigWatch,
146
    block_size: u32,
147
    kv_router_config: KvRouterConfig,
148
    prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
Yan Ru Pei's avatar
Yan Ru Pei committed
149
    cancellation_token: tokio_util::sync::CancellationToken,
150
    client: Client,
151
    is_eagle: bool,
152
    _served_indexer_handle: Option<ServedIndexerHandle>,
153
154
155
    /// Optional external shared KV cache pool. When present, `find_best_match`
    /// queries it in parallel with the indexer and factors shared hits into scoring.
    shared_cache: Option<Box<dyn SharedKvCache>>,
156
157
}

158
159
160
161
impl<Sel> KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
{
162
    #[allow(clippy::too_many_arguments)]
163
    pub async fn new(
164
165
        endpoint: Endpoint,
        client: Client,
166
        workers_with_configs: RuntimeConfigWatch,
167
        block_size: u32,
168
        selector: Sel,
169
        kv_router_config: Option<KvRouterConfig>,
170
        prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
171
        worker_type: &'static str,
172
        model_name: Option<String>,
173
        is_eagle: bool,
174
        shared_cache: Option<Box<dyn SharedKvCache>>,
175
    ) -> Result<Self> {
176
        let kv_router_config = kv_router_config.unwrap_or_default();
177
        kv_router_config.validate()?;
178
        let component = endpoint.component();
179
        let cancellation_token = component.drt().primary_token();
180
        let min_initial_workers = min_initial_workers_from_env()?;
181

182
183
184
185
186
187
188
        let indexer = Indexer::new(
            component,
            &kv_router_config,
            block_size,
            model_name.as_deref(),
        )
        .await?;
189

190
191
192
193
        if min_initial_workers > 0 && !kv_router_config.skip_initial_worker_wait {
            let mut startup_watch = workers_with_configs.clone();
            let _ = startup_watch
                .wait_for(|m| m.len() >= min_initial_workers)
194
195
                .await
                .map_err(|_| {
196
197
                    anyhow::anyhow!(
                        "runtime config watch closed before {} workers appeared",
198
                        min_initial_workers
199
                    )
200
201
                })?;
        }
202

203
        let scheduler = KvScheduler::start(
204
            component.clone(),
205
            block_size,
206
            workers_with_configs.clone(),
207
            selector,
208
            &kv_router_config,
209
            prefill_load_estimator.clone(),
210
            worker_type,
211
212
        )
        .await?;
213

214
215
        // Start KV event subscription if needed — skip when using a remote indexer.
        if kv_router_config.use_remote_indexer {
216
217
            tracing::info!("Skipping KV event subscription (using remote indexer)");
        } else if kv_router_config.should_subscribe_to_kv_events() {
218
            indexer::start_subscriber(component.clone(), &kv_router_config, indexer.clone())
219
                .await?;
220
        } else {
221
            tracing::info!(
222
223
224
                "Skipping KV event subscription (use_kv_events={}, overlap_score_weight={})",
                kv_router_config.use_kv_events,
                kv_router_config.overlap_score_weight,
225
            );
226
        }
227

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
        let served_indexer_handle = if kv_router_config.serve_indexer {
            let model_name = model_name.clone().ok_or_else(|| {
                anyhow::anyhow!("model_name is required when serve_indexer is configured")
            })?;
            Some(
                ensure_served_indexer_service(
                    component.clone(),
                    ServedIndexerMode::from_use_kv_events(kv_router_config.use_kv_events),
                    model_name,
                    indexer.clone(),
                )
                .await?,
            )
        } else {
            None
        };

245
        tracing::info!("KV Routing initialized");
246
        Ok(Self {
247
            indexer,
248
            scheduler,
249
            workers_with_configs,
250
            block_size,
251
            kv_router_config,
252
            prefill_load_estimator,
Yan Ru Pei's avatar
Yan Ru Pei committed
253
            cancellation_token,
254
            client,
255
            is_eagle,
256
            _served_indexer_handle: served_indexer_handle,
257
            shared_cache,
258
        })
259
260
    }

261
262
263
264
265
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

266
267
268
269
270
271
272
273
    pub fn indexer(&self) -> &Indexer {
        &self.indexer
    }

    pub fn kv_router_config(&self) -> &KvRouterConfig {
        &self.kv_router_config
    }

274
275
276
277
    pub fn is_eagle(&self) -> bool {
        self.is_eagle
    }

278
279
    pub async fn record_routing_decision(
        &self,
280
        mut tokens_with_hashes: TokensWithHashes,
281
282
283
284
285
286
287
        worker: WorkerWithDpRank,
    ) -> Result<(), KvRouterError> {
        self.indexer
            .process_routing_decision_for_request(&mut tokens_with_hashes, worker)
            .await
    }

288
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
289
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
290
291
    /// Now also takes optional context_id for request tracking.
    ///
292
293
294
    /// When `pinned_worker` is Some, scheduling and queueing are constrained to
    /// that exact worker/rank.
    ///
295
    /// When `allowed_worker_ids` is Some, only workers in that set are considered for selection.
296
    #[allow(clippy::too_many_arguments)]
Yan Ru Pei's avatar
Yan Ru Pei committed
297
    pub async fn find_best_match(
298
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
299
        context_id: Option<&str>,
300
        tokens: &[u32],
301
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
302
        router_config_override: Option<&RouterConfigOverride>,
303
        update_states: bool,
304
        lora_name: Option<String>,
305
        priority_jump: f64,
306
        expected_output_tokens: Option<u32>,
307
        pinned_worker: Option<WorkerWithDpRank>,
308
        allowed_worker_ids: Option<HashSet<WorkerId>>,
Yan Ru Pei's avatar
Yan Ru Pei committed
309
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
310
311
        let start = Instant::now();

Yan Ru Pei's avatar
Yan Ru Pei committed
312
        if update_states && context_id.is_none() {
313
            anyhow::bail!("context_id must be provided when update_states is true");
Yan Ru Pei's avatar
Yan Ru Pei committed
314
315
        }

316
        let isl_tokens = tokens.len();
317
318
319
320
321
322
323
324
        let hash_options = BlockHashOptions {
            block_mm_infos,
            lora_name: lora_name.as_deref(),
            is_eagle: Some(self.is_eagle),
        };

        let block_hashes = tracing::info_span!("kv_router.compute_block_hashes")
            .in_scope(|| compute_block_hash_for_seq(tokens, self.block_size, hash_options));
325
        log_routing_input_hashes(context_id, self.block_size, tokens, &block_hashes);
326
327
328
329
        let hash_elapsed = start.elapsed();
        // Compute seq_hashes only if scheduler needs it for active blocks tracking
        let maybe_seq_hashes = tracing::info_span!("kv_router.compute_seq_hashes").in_scope(|| {
            self.kv_router_config.compute_seq_hashes_for_tracking(
330
331
                tokens,
                self.block_size,
332
333
334
                router_config_override,
                hash_options,
                Some(&block_hashes),
335
336
            )
        });
337
        let seq_hash_elapsed = start.elapsed();
338

339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
        // Query indexer and shared cache in parallel when shared cache is configured.
        // Time each independently so metrics can separate indexer vs shared cache latency.
        let (overlap_scores, shared_cache_hits, indexer_duration, shared_cache_duration) =
            if let Some(ref shared_cache) = self.shared_cache {
                let indexer_fut = self
                    .indexer
                    .find_matches(block_hashes.clone())
                    .instrument(tracing::info_span!("kv_router.find_matches"));
                let shared_fut = shared_cache
                    .check_blocks(tokens, self.block_size)
                    .instrument(tracing::info_span!("kv_router.shared_cache_check"));

                let indexer_timed = async {
                    let t = Instant::now();
                    let r = indexer_fut.await;
                    (r, t.elapsed())
                };
                let shared_timed = async {
                    let t = Instant::now();
                    let r = shared_fut.await;
                    (r, t.elapsed())
                };

                let ((indexer_result, idx_dur), (shared_result, sc_dur)) =
                    tokio::join!(indexer_timed, shared_timed);
                let overlaps = indexer_result?;
                // Shared cache failure is non-fatal: log warning and fall back to empty hits.
                let hits = match shared_result {
                    Ok(hits) => Some(hits),
                    Err(e) => {
                        tracing::warn!(error = %e, "Shared cache query failed, ignoring");
                        if let Some(m) = metrics::RoutingOverheadMetrics::get() {
                            m.inc_shared_cache_errors();
                        }
                        None
                    }
                };
                (overlaps, hits, idx_dur, Some(sc_dur))
            } else {
                let t = Instant::now();
                let overlaps = self
                    .indexer
                    .find_matches(block_hashes)
                    .instrument(tracing::info_span!("kv_router.find_matches"))
                    .await?;
                (overlaps, None, t.elapsed(), None)
            };
386
        let find_matches_elapsed = start.elapsed();
387

388
389
390
391
392
393
        // Capture shared cache info for metrics before moving into schedule().
        // Clone the hits so we can compute `hits_beyond(overlap_blocks)` after
        // scheduling returns, since `overlap_blocks` isn't known until then.
        let num_blocks = isl_tokens / self.block_size as usize;
        let sc_hits_for_metrics = shared_cache_hits.clone();

394
        let response = self
395
            .scheduler
396
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
397
                context_id.map(|s| s.to_string()),
398
                isl_tokens,
399
                maybe_seq_hashes,
400
                overlap_scores,
401
                router_config_override,
402
                update_states,
403
                lora_name,
404
                priority_jump,
405
                expected_output_tokens,
406
                pinned_worker,
407
                allowed_worker_ids,
408
                shared_cache_hits,
409
            )
410
            .instrument(tracing::info_span!("kv_router.schedule"))
411
            .await?;
412
413
        let total_elapsed = start.elapsed();

414
415
416
417
        if let Some(m) = metrics::RoutingOverheadMetrics::get() {
            m.observe(
                hash_elapsed,
                seq_hash_elapsed,
418
419
                indexer_duration,
                shared_cache_duration,
420
                find_matches_elapsed,
421
422
423
                total_elapsed,
            );
        }
424

425
426
427
428
429
430
431
432
433
434
435
436
        // Observe per-request shared cache metrics.
        if let Some(hits) = sc_hits_for_metrics
            && let Some(m) = metrics::RouterRequestMetrics::get()
        {
            if num_blocks > 0 {
                m.shared_cache_hit_rate
                    .observe(hits.total_hits as f64 / num_blocks as f64);
            }
            let beyond = hits.hits_beyond(response.overlap_blocks);
            m.shared_cache_beyond_blocks.observe(beyond as f64);
        }

437
        #[cfg(feature = "bench")]
438
439
440
        tracing::info!(
            isl_tokens,
            hash_us = hash_elapsed.as_micros() as u64,
441
442
443
            seq_hash_us = (seq_hash_elapsed - hash_elapsed).as_micros() as u64,
            find_matches_us = (find_matches_elapsed - seq_hash_elapsed).as_micros() as u64,
            schedule_us = (total_elapsed - find_matches_elapsed).as_micros() as u64,
444
445
446
            total_us = total_elapsed.as_micros() as u64,
            "find_best_match completed"
        );
447

448
        Ok((response.best_worker, response.overlap_blocks))
449
450
    }

451
452
453
454
455
    /// Register externally-provided workers in the slot tracker.
    pub fn register_workers(&self, worker_ids: &HashSet<WorkerId>) {
        self.scheduler.register_workers(worker_ids);
    }

456
    #[allow(clippy::too_many_arguments)]
457
458
459
460
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
461
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
462
        overlap_blocks: u32,
463
        expected_output_tokens: Option<u32>,
Yan Ru Pei's avatar
Yan Ru Pei committed
464
        worker: WorkerWithDpRank,
465
        lora_name: Option<String>,
466
        router_config_override: Option<&RouterConfigOverride>,
467
468
    ) {
        let isl_tokens = tokens.len();
469
470
471
472
473
        let hash_options = BlockHashOptions {
            block_mm_infos,
            lora_name: lora_name.as_deref(),
            is_eagle: Some(self.is_eagle),
        };
474

475
476
477
478
        let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
            tokens,
            self.block_size,
            router_config_override,
479
480
            hash_options,
            None,
481
        );
482
483
484
        let track_prefill_tokens = self
            .kv_router_config
            .track_prefill_tokens(router_config_override);
485
486
        let prefill_load_hint =
            self.prefill_load_hint_for(isl_tokens, overlap_blocks, track_prefill_tokens);
487

488
489
        if let Err(e) = self
            .scheduler
490
491
492
            .add_request(SequenceRequest {
                request_id: request_id.clone(),
                token_sequence: maybe_seq_hashes,
493
                track_prefill_tokens,
494
                expected_output_tokens,
495
                prefill_load_hint,
Yan Ru Pei's avatar
Yan Ru Pei committed
496
                worker,
497
                lora_name,
498
            })
499
500
501
502
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
503
504
    }

505
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
506
        self.scheduler.mark_prefill_completed(request_id).await
507
508
    }

509
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
510
        self.scheduler.free(request_id).await
511
    }
512

513
514
515
516
517
    /// Number of requests currently parked in the scheduler queue.
    pub fn pending_count(&self) -> usize {
        self.scheduler.pending_count()
    }

518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
    fn prefill_load_hint_for(
        &self,
        isl_tokens: usize,
        overlap_blocks: u32,
        track_prefill_tokens: bool,
    ) -> Option<PrefillLoadHint> {
        if !track_prefill_tokens {
            return None;
        }

        let prefix = (overlap_blocks as usize) * (self.block_size as usize);
        let effective_isl = isl_tokens.saturating_sub(prefix);
        if effective_isl == 0 {
            return None;
        }

534
535
536
537
538
539
540
541
542
543
544
545
546
        let expected_prefill_duration = match &self.prefill_load_estimator {
            Some(estimator) => match estimator.predict_prefill_duration(1, effective_isl, prefix) {
                Ok(expected_prefill_duration) => Some(expected_prefill_duration),
                Err(error) => {
                    tracing::warn!(
                        effective_isl,
                        prefix,
                        "failed to predict prefill duration for direct add_request path: {error}"
                    );
                    None
                }
            },
            None => None,
547
548
        };

549
550
551
552
        Some(PrefillLoadHint {
            initial_effective_prefill_tokens: effective_isl,
            expected_prefill_duration,
        })
553
554
    }

555
556
557
558
559
560
    /// Get the worker type for this router ("prefill" or "decode").
    /// Used for Prometheus metric labeling.
    pub fn worker_type(&self) -> &'static str {
        self.scheduler.worker_type()
    }

561
562
563
564
565
566
567
    /// Return the worker's unique global DP rank when it owns exactly one rank.
    pub fn unique_dp_rank_for_worker(&self, worker_id: WorkerId) -> Option<u32> {
        let configs = self.workers_with_configs.borrow();
        let config = configs.get(&worker_id)?;
        (config.data_parallel_size == 1).then_some(config.data_parallel_start_rank)
    }

568
    pub fn add_output_block(
569
570
571
572
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
573
        self.scheduler.add_output_block(request_id, decay_fraction)
574
575
    }

576
    pub fn block_size(&self) -> u32 {
577
578
        self.block_size
    }
579

580
581
582
583
584
    /// Compute the overlap blocks for a given token sequence and worker.
    /// This queries the indexer to find how many blocks are already cached.
    pub async fn get_overlap_blocks(
        &self,
        tokens: &[u32],
585
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
586
        worker: WorkerWithDpRank,
587
        lora_name: Option<&str>,
588
    ) -> Result<u32, KvRouterError> {
589
590
591
592
593
594
595
596
597
        let block_hashes = compute_block_hash_for_seq(
            tokens,
            self.block_size,
            BlockHashOptions {
                block_mm_infos,
                lora_name,
                is_eagle: Some(self.is_eagle),
            },
        );
598
        log_routing_input_hashes(None, self.block_size, tokens, &block_hashes);
599
600
601
602
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;
        Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0))
    }

603
    /// Get potential prefill and decode loads for all workers
604
605
606
607
    pub async fn get_potential_loads(
        &self,
        tokens: &[u32],
        router_config_override: Option<&RouterConfigOverride>,
608
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
609
        lora_name: Option<&str>,
610
    ) -> Result<Vec<PotentialLoad>> {
611
        let isl_tokens = tokens.len();
612
613
614
615
616
617
        let hash_options = BlockHashOptions {
            block_mm_infos,
            lora_name,
            is_eagle: Some(self.is_eagle),
        };
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, hash_options);
618

619
620
621
622
        let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
            tokens,
            self.block_size,
            router_config_override,
623
624
            hash_options,
            Some(&block_hashes),
625
        );
626
627
628
        let track_prefill_tokens = self
            .kv_router_config
            .track_prefill_tokens(router_config_override);
629
630
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;

631
632
633
634
635
636
        Ok(self.scheduler.get_potential_loads(
            maybe_seq_hashes,
            isl_tokens,
            overlap_scores,
            track_prefill_tokens,
        ))
637
638
    }

639
640
641
642
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
643
644
}

Michael Feil's avatar
Michael Feil committed
645
646
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
647
#[async_trait]
648
649
650
651
652
impl<Sel> AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error>
    for KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
{
653
654
655
656
657
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
658
659
660
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
661
662
663
664
            RouterRequest::New {
                tokens,
                block_mm_infos,
            } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
665
                let (best_worker, overlap_blocks) = self
666
667
668
669
670
671
672
673
                    .find_best_match(
                        Some(&context_id),
                        &tokens,
                        block_mm_infos.as_deref(),
                        None,
                        true,
                        None,
                        0.0,
674
                        None,
675
                        None,
676
                        None,
677
                    )
Michael Feil's avatar
Michael Feil committed
678
679
680
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
681
682
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
683
684
685
                    overlap_blocks,
                }
            }
686
687
688
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
689
690
691
692
693
694
695
696
697
            RouterRequest::MarkFree { request_id } => {
                let request_id = match request_id.as_deref() {
                    Some(request_id) if !request_id.trim().is_empty() => request_id,
                    _ => &context_id,
                };
                RouterResponse::FreeMarked {
                    success: self.free(request_id).await.is_ok(),
                }
            }
Michael Feil's avatar
Michael Feil committed
698
        };
699
700
701
702
703
704

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
705

706
707
708
709
impl<Sel> Drop for KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>,
{
Yan Ru Pei's avatar
Yan Ru Pei committed
710
711
712
713
714
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900

#[cfg(test)]
mod tests {
    use super::*;
    use std::collections::HashMap;

    use async_trait::async_trait;
    use dynamo_runtime::{DistributedRuntime, Runtime, distributed::DistributedConfig};
    use tokio::sync::watch;

    use crate::kv_router::scheduler::KvSchedulerError;
    use crate::local_model::runtime_config::ModelRuntimeConfig;

    struct FakeSharedCache {
        hits: Option<dynamo_kv_router::protocols::SharedCacheHits>,
        should_error: bool,
    }

    #[async_trait]
    impl SharedKvCache for FakeSharedCache {
        async fn check_blocks(
            &self,
            _tokens: &[u32],
            _block_size: u32,
        ) -> Result<dynamo_kv_router::protocols::SharedCacheHits, KvRouterError> {
            if self.should_error {
                Err(KvRouterError::IndexerOffline)
            } else {
                Ok(self.hits.clone().unwrap_or_default())
            }
        }
    }

    struct InspectingSelector {
        expected_hits: Option<u32>,
        selected_worker: WorkerWithDpRank,
    }

    impl dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> for InspectingSelector {
        fn select_worker(
            &self,
            _workers: &HashMap<WorkerId, ModelRuntimeConfig>,
            request: &dynamo_kv_router::scheduling::SchedulingRequest,
            block_size: u32,
        ) -> Result<dynamo_kv_router::protocols::WorkerSelectionResult, KvSchedulerError> {
            let observed_hits = request
                .shared_cache_hits
                .as_ref()
                .map(|hits| hits.total_hits);
            assert_eq!(observed_hits, self.expected_hits);

            Ok(dynamo_kv_router::protocols::WorkerSelectionResult {
                worker: self.selected_worker,
                required_blocks: request.isl_tokens.div_ceil(block_size as usize) as u64,
                overlap_blocks: 0,
            })
        }
    }

    async fn make_test_component(name: &str) -> dynamo_runtime::component::Component {
        let runtime = Runtime::from_current().unwrap();
        let drt = DistributedRuntime::new(runtime, DistributedConfig::process_local())
            .await
            .unwrap();
        let namespace = drt.namespace(format!("test-ns-{name}")).unwrap();
        namespace
            .component(format!("test-component-{name}"))
            .unwrap()
    }

    async fn make_test_router(
        selector: impl dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>
        + Send
        + Sync
        + 'static,
        shared_cache: Option<Box<dyn SharedKvCache>>,
    ) -> KvRouter<
        impl dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
    > {
        let component = make_test_component("shared-cache-router").await;
        let endpoint = component.endpoint("backend");
        let client = endpoint.client().await.unwrap();

        let mut workers = HashMap::new();
        workers.insert(0, ModelRuntimeConfig::default());
        workers.insert(1, ModelRuntimeConfig::default());
        let (_tx, rx) = watch::channel(workers);

        let config = KvRouterConfig {
            overlap_score_weight: 0.0,
            router_temperature: 0.0,
            use_kv_events: false,
            router_track_active_blocks: false,
            shared_cache_multiplier: 0.5,
            skip_initial_worker_wait: true,
            ..Default::default()
        };

        KvRouter::new(
            endpoint,
            client,
            rx,
            2,
            selector,
            Some(config),
            None,
            "decode",
            None,
            false,
            shared_cache,
        )
        .await
        .unwrap()
    }

    #[tokio::test]
    async fn test_find_best_match_passes_shared_cache_hits_to_scheduler() {
        let router = make_test_router(
            InspectingSelector {
                expected_hits: Some(2),
                selected_worker: WorkerWithDpRank::from_worker_id(1),
            },
            Some(Box::new(FakeSharedCache {
                #[allow(clippy::single_range_in_vec_init)]
                hits: Some(dynamo_kv_router::protocols::SharedCacheHits::from_ranges(
                    vec![0..2],
                )),
                should_error: false,
            })),
        )
        .await;

        let (worker, overlap) = router
            .find_best_match(
                None,
                &[11, 12, 21, 22],
                None,
                None,
                false,
                None,
                0.0,
                None,
                None,
                None,
            )
            .await
            .unwrap();

        assert_eq!(worker, WorkerWithDpRank::from_worker_id(1));
        assert_eq!(overlap, 0);
    }

    #[tokio::test]
    async fn test_find_best_match_ignores_shared_cache_errors() {
        let router = make_test_router(
            InspectingSelector {
                expected_hits: None,
                selected_worker: WorkerWithDpRank::from_worker_id(0),
            },
            Some(Box::new(FakeSharedCache {
                hits: None,
                should_error: true,
            })),
        )
        .await;

        let (worker, overlap) = router
            .find_best_match(
                None,
                &[11, 12, 21, 22],
                None,
                None,
                false,
                None,
                0.0,
                None,
                None,
                None,
            )
            .await
            .unwrap();

        assert_eq!(worker, WorkerWithDpRank::from_worker_id(0));
        assert_eq!(overlap, 0);
    }
}