kv_router.rs 24.6 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

Yan Ru Pei's avatar
Yan Ru Pei committed
4
use std::sync::Arc;
5
use std::time::{Duration, Instant};
6

7
use anyhow::Result;
8
9
10
11
12
13
14
use dynamo_kv_router::{
    ConcurrentRadixTree, ThreadPoolIndexer,
    approx::PruneConfig,
    config::{KvRouterConfig, RouterConfigOverride},
    indexer::{GetWorkersRequest, KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvRouterError},
    protocols::KV_EVENT_SUBJECT,
    protocols::{
15
16
17
        BlockExtraInfo, BlockHashOptions, DpRank, LocalBlockHash, OverlapScores, RouterEvent,
        RouterRequest, RouterResponse, TokensWithHashes, WorkerId, WorkerWithDpRank,
        compute_block_hash_for_seq,
18
19
    },
};
20
use dynamo_runtime::{
21
    component::{Client, Endpoint},
22
    discovery::DiscoveryQuery,
23
    pipeline::{
24
25
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, ResponseStream, SingleIn,
        async_trait,
26
    },
27
    protocols::EndpointId,
28
    protocols::annotated::Annotated,
29
    traits::DistributedRuntimeProvider,
30
};
31
use futures::stream;
Yan Ru Pei's avatar
Yan Ru Pei committed
32
use tokio::sync::oneshot;
33
use tracing::Instrument;
34
use validator::Validate;
35

36
pub mod cache_control;
37
mod jetstream;
38
pub mod metrics;
39
pub mod prefill_router;
40
pub mod publisher;
41
pub mod push_router;
42
pub mod remote_indexer;
43
pub mod scheduler;
44
pub mod sequence;
45
pub mod subscriber;
46
pub mod worker_query;
47

48
pub use cache_control::{CacheControlClient, spawn_pin_prefix};
49
pub use prefill_router::PrefillRouter;
50
pub use push_router::{DirectRoutingRouter, KvPushRouter};
51

52
use crate::{
53
    discovery::RuntimeConfigWatch,
54
    kv_router::{
55
        remote_indexer::RemoteIndexer,
56
        scheduler::{DefaultWorkerSelector, KvScheduler, PotentialLoad},
57
        sequence::{SequenceError, SequenceRequest},
58
    },
59
    local_model::runtime_config::ModelRuntimeConfig,
60
61
};

62
63
use std::collections::HashSet;

64
65
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
66
67
68
69
70
71
72
73
74
75

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
76

77
78
79
80
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

81
82
83
// for worker-local kvindexer query
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer

84
85
86
87
88
89
/// Generates a dp_rank-specific endpoint name for the worker KV indexer query service.
/// Each dp_rank has its own LocalKvIndexer and query endpoint to ensure per-dp_rank monotonicity.
pub fn worker_kv_indexer_query_endpoint(dp_rank: DpRank) -> String {
    format!("worker_kv_indexer_query_dp{dp_rank}")
}

90
// for router discovery registration
91
pub const KV_ROUTER_ENDPOINT: &str = "router-discovery";
92
93

/// Creates an EndpointId for the KV router in the given namespace.
94
pub fn router_endpoint_id(namespace: String, component: String) -> EndpointId {
95
96
    EndpointId {
        namespace,
97
        component,
98
99
100
101
102
        name: KV_ROUTER_ENDPOINT.to_string(),
    }
}

/// Creates a DiscoveryQuery for the KV router in the given namespace.
103
pub fn router_discovery_query(namespace: String, component: String) -> DiscoveryQuery {
104
105
    DiscoveryQuery::Endpoint {
        namespace,
106
        component,
107
108
109
110
        endpoint: KV_ROUTER_ENDPOINT.to_string(),
    }
}

Yan Ru Pei's avatar
Yan Ru Pei committed
111
#[derive(Clone)]
112
pub enum Indexer {
Yan Ru Pei's avatar
Yan Ru Pei committed
113
    /// Single-threaded radix tree with channel-based event processing.
114
    /// Supports TTL-based expiration and size-based pruning.
115
    /// Has the ability to persist and snapshot states.
116
    KvIndexer(KvIndexer),
117

Yan Ru Pei's avatar
Yan Ru Pei committed
118
119
120
121
122
    /// Concurrent radix tree with a thread pool for event processing.
    /// Uses sticky worker routing for per-worker event serialization.
    /// Does not support TTL/pruning.
    Concurrent(Arc<ThreadPoolIndexer<ConcurrentRadixTree>>),

123
124
125
126
    /// Forwards queries to a standalone KV indexer service via the request plane.
    /// The standalone indexer manages its own radix tree and event subscription.
    Remote(Arc<RemoteIndexer>),

127
128
129
    /// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
    /// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
    None,
130
131
132
}

impl Indexer {
133
    pub async fn new(
134
135
136
        component: &dynamo_runtime::component::Component,
        kv_router_config: &KvRouterConfig,
        block_size: u32,
137
138
        model_name: Option<String>,
    ) -> Result<Self> {
139
        if kv_router_config.overlap_score_weight == 0.0 {
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
            return Ok(Indexer::None);
        }

        // Remote indexer: forward queries to a standalone KV indexer service.
        if let Some(ref indexer_component_name) = kv_router_config.remote_indexer_component {
            let model_name = model_name.ok_or_else(|| {
                anyhow::anyhow!(
                    "model_name is required when remote_indexer_component is configured"
                )
            })?;
            tracing::info!(
                remote_indexer_component = %indexer_component_name,
                model_name,
                "Using remote KV indexer"
            );
            let remote = RemoteIndexer::new(component, indexer_component_name, model_name).await?;
            return Ok(Indexer::Remote(Arc::new(remote)));
Yan Ru Pei's avatar
Yan Ru Pei committed
157
158
        }

159
160
161
162
        // Approximate mode (--no-kv-events): always use single-threaded KvIndexer
        // with TTL/pruning regardless of event_threads, since updates come from
        // routing decisions only, not live KV events from workers.
        if !kv_router_config.use_kv_events {
163
            let kv_indexer_metrics = KvIndexerMetrics::from_component(component);
164
165
166
167
168
169
            let cancellation_token = component.drt().primary_token();
            let prune_config = Some(PruneConfig {
                ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
                max_tree_size: kv_router_config.router_max_tree_size,
                prune_target_ratio: kv_router_config.router_prune_target_ratio,
            });
170
            return Ok(Indexer::KvIndexer(KvIndexer::new_with_frequency(
171
172
173
174
175
                cancellation_token,
                None,
                block_size,
                kv_indexer_metrics,
                prune_config,
176
            )));
177
178
        }

Yan Ru Pei's avatar
Yan Ru Pei committed
179
        if kv_router_config.router_event_threads > 1 {
180
            return Ok(Indexer::Concurrent(Arc::new(ThreadPoolIndexer::new(
Yan Ru Pei's avatar
Yan Ru Pei committed
181
182
                ConcurrentRadixTree::new(),
                kv_router_config.router_event_threads as usize,
183
                block_size,
184
            ))));
185
        }
Yan Ru Pei's avatar
Yan Ru Pei committed
186

187
        let kv_indexer_metrics = KvIndexerMetrics::from_component(component);
188
        let cancellation_token = component.drt().primary_token();
Yan Ru Pei's avatar
Yan Ru Pei committed
189

190
        Ok(Indexer::KvIndexer(KvIndexer::new_with_frequency(
Yan Ru Pei's avatar
Yan Ru Pei committed
191
192
193
194
            cancellation_token,
            None, // expiration_duration for frequency tracking
            block_size,
            kv_indexer_metrics,
195
            None,
196
        )))
197
198
199
    }

    pub(crate) async fn find_matches(
200
201
202
203
204
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
Yan Ru Pei's avatar
Yan Ru Pei committed
205
            Indexer::Concurrent(tpi) => tpi.find_matches(sequence).await,
206
207
208
209
            Indexer::Remote(remote) => remote.find_matches(sequence).await.map_err(|e| {
                tracing::warn!(error = %e, "Remote indexer query failed");
                KvRouterError::IndexerOffline
            }),
210
            Indexer::None => Ok(OverlapScores::new()),
211
212
        }
    }
213

214
    pub(crate) async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
215
216
        match self {
            Indexer::KvIndexer(indexer) => indexer.dump_events().await,
Yan Ru Pei's avatar
Yan Ru Pei committed
217
            Indexer::Concurrent(tpi) => tpi.dump_events().await,
218
            Indexer::Remote(_) => Ok(Vec::new()),
219
220
221
222
223
            Indexer::None => {
                panic!(
                    "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
                );
            }
224
225
        }
    }
226

227
    pub(crate) async fn process_routing_decision_for_request(
228
        &self,
229
        tokens_with_hashes: &mut TokensWithHashes,
230
231
232
233
234
        worker: WorkerWithDpRank,
    ) -> Result<(), KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => {
                indexer
235
                    .process_routing_decision_for_request(tokens_with_hashes, worker)
236
237
                    .await
            }
Yan Ru Pei's avatar
Yan Ru Pei committed
238
239
240
241
            Indexer::Concurrent(tpi) => {
                tpi.process_routing_decision_for_request(tokens_with_hashes, worker)
                    .await
            }
242
            Indexer::Remote(_) => Ok(()),
243
244
245
            Indexer::None => Ok(()),
        }
    }
Yan Ru Pei's avatar
Yan Ru Pei committed
246
247
248
249
250
251
252
253
254

    pub(crate) async fn apply_event(&self, event: RouterEvent) {
        match self {
            Indexer::KvIndexer(indexer) => {
                if let Err(e) = indexer.event_sender().send(event).await {
                    tracing::warn!("Failed to send event to indexer: {e}");
                }
            }
            Indexer::Concurrent(tpi) => tpi.apply_event(event).await,
255
            Indexer::Remote(_) => {} // standalone indexer gets events directly
Yan Ru Pei's avatar
Yan Ru Pei committed
256
257
258
259
260
261
262
263
264
265
266
267
268
269
            Indexer::None => {}
        }
    }

    pub(crate) async fn remove_worker(&self, worker_id: WorkerId) {
        match self {
            Indexer::KvIndexer(indexer) => {
                if let Err(e) = indexer.remove_worker_sender().send(worker_id).await {
                    tracing::warn!("Failed to send worker removal for {worker_id}: {e}");
                }
            }
            Indexer::Concurrent(tpi) => {
                KvIndexerInterface::remove_worker(tpi.as_ref(), worker_id).await;
            }
270
            Indexer::Remote(_) => {} // standalone indexer manages its own workers
Yan Ru Pei's avatar
Yan Ru Pei committed
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
            Indexer::None => {}
        }
    }

    pub(crate) async fn get_workers(&self) -> Vec<WorkerId> {
        match self {
            Indexer::KvIndexer(indexer) => {
                let (resp_tx, resp_rx) = oneshot::channel();
                let req = GetWorkersRequest { resp: resp_tx };
                if let Err(e) = indexer.get_workers_sender().send(req).await {
                    tracing::warn!("Failed to send get_workers request: {e}");
                    return Vec::new();
                }
                resp_rx.await.unwrap_or_default()
            }
            Indexer::Concurrent(tpi) => tpi.backend().get_workers(),
287
            Indexer::Remote(_) => Vec::new(),
Yan Ru Pei's avatar
Yan Ru Pei committed
288
289
290
            Indexer::None => Vec::new(),
        }
    }
291
292
}

293
294
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
295
296
297
298
pub struct KvRouter<Sel = DefaultWorkerSelector>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>,
{
299
    indexer: Indexer,
300
    scheduler: KvScheduler<Sel>,
301
    block_size: u32,
302
    kv_router_config: KvRouterConfig,
Yan Ru Pei's avatar
Yan Ru Pei committed
303
    cancellation_token: tokio_util::sync::CancellationToken,
304
    client: Client,
305
    is_eagle: bool,
306
307
}

308
309
310
311
impl<Sel> KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
{
312
    #[allow(clippy::too_many_arguments)]
313
    pub async fn new(
314
315
        endpoint: Endpoint,
        client: Client,
316
        mut workers_with_configs: RuntimeConfigWatch,
317
        block_size: u32,
318
        selector: Sel,
319
        kv_router_config: Option<KvRouterConfig>,
320
        worker_type: &'static str,
321
        model_name: Option<String>,
322
        is_eagle: bool,
323
    ) -> Result<Self> {
324
        let kv_router_config = kv_router_config.unwrap_or_default();
325
        kv_router_config.validate()?;
326
        let component = endpoint.component();
327
        let cancellation_token = component.drt().primary_token();
328

329
        let indexer = Indexer::new(component, &kv_router_config, block_size, model_name).await?;
330

331
332
        if !kv_router_config.skip_initial_worker_wait {
            let _ = workers_with_configs
333
                .wait_for(|m| m.len() >= kv_router_config.min_initial_workers)
334
335
                .await
                .map_err(|_| {
336
337
338
339
                    anyhow::anyhow!(
                        "runtime config watch closed before {} workers appeared",
                        kv_router_config.min_initial_workers
                    )
340
341
                })?;
        }
342

343
        let scheduler = KvScheduler::start(
344
            component.clone(),
345
            block_size,
346
            workers_with_configs.clone(),
347
            selector,
348
            &kv_router_config,
349
            worker_type,
350
351
        )
        .await?;
352

353
354
355
356
357
        // Start KV event subscription if needed — skip when using a remote indexer
        // (the standalone indexer handles its own event subscription).
        if kv_router_config.remote_indexer_component.is_some() {
            tracing::info!("Skipping KV event subscription (using remote indexer)");
        } else if kv_router_config.should_subscribe_to_kv_events() {
358
359
            subscriber::start_subscriber(component.clone(), &kv_router_config, indexer.clone())
                .await?;
360
        } else {
361
            tracing::info!(
362
363
364
                "Skipping KV event subscription (use_kv_events={}, overlap_score_weight={})",
                kv_router_config.use_kv_events,
                kv_router_config.overlap_score_weight,
365
            );
366
        }
367

368
        tracing::info!("KV Routing initialized");
369
        Ok(Self {
370
            indexer,
371
            scheduler,
372
            block_size,
373
            kv_router_config,
Yan Ru Pei's avatar
Yan Ru Pei committed
374
            cancellation_token,
375
            client,
376
            is_eagle,
377
        })
378
379
    }

380
381
382
383
384
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

385
386
387
388
389
390
391
392
    pub fn indexer(&self) -> &Indexer {
        &self.indexer
    }

    pub fn kv_router_config(&self) -> &KvRouterConfig {
        &self.kv_router_config
    }

393
394
395
396
    pub fn is_eagle(&self) -> bool {
        self.is_eagle
    }

397
398
    pub async fn record_routing_decision(
        &self,
399
        mut tokens_with_hashes: TokensWithHashes,
400
401
402
403
404
405
406
        worker: WorkerWithDpRank,
    ) -> Result<(), KvRouterError> {
        self.indexer
            .process_routing_decision_for_request(&mut tokens_with_hashes, worker)
            .await
    }

407
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
408
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
409
410
411
    /// Now also takes optional context_id for request tracking.
    ///
    /// When `allowed_worker_ids` is Some, only workers in that set are considered for selection.
412
    #[allow(clippy::too_many_arguments)]
Yan Ru Pei's avatar
Yan Ru Pei committed
413
    pub async fn find_best_match(
414
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
415
        context_id: Option<&str>,
416
        tokens: &[u32],
417
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
418
        router_config_override: Option<&RouterConfigOverride>,
419
        update_states: bool,
420
        lora_name: Option<String>,
421
        priority_jump: f64,
422
        expected_output_tokens: Option<u32>,
423
        allowed_worker_ids: Option<HashSet<WorkerId>>,
Yan Ru Pei's avatar
Yan Ru Pei committed
424
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
425
426
        let start = Instant::now();

Yan Ru Pei's avatar
Yan Ru Pei committed
427
        if update_states && context_id.is_none() {
428
            anyhow::bail!("context_id must be provided when update_states is true");
Yan Ru Pei's avatar
Yan Ru Pei committed
429
430
        }

431
        let isl_tokens = tokens.len();
432
433
434
435
436
437
438
439
440
441
442
443
        let hash_options = BlockHashOptions {
            block_mm_infos,
            lora_name: lora_name.as_deref(),
            is_eagle: Some(self.is_eagle),
        };

        let block_hashes = tracing::info_span!("kv_router.compute_block_hashes")
            .in_scope(|| compute_block_hash_for_seq(tokens, self.block_size, hash_options));
        let hash_elapsed = start.elapsed();
        // Compute seq_hashes only if scheduler needs it for active blocks tracking
        let maybe_seq_hashes = tracing::info_span!("kv_router.compute_seq_hashes").in_scope(|| {
            self.kv_router_config.compute_seq_hashes_for_tracking(
444
445
                tokens,
                self.block_size,
446
447
448
                router_config_override,
                hash_options,
                Some(&block_hashes),
449
450
            )
        });
451
        let seq_hash_elapsed = start.elapsed();
452

453
        let overlap_scores = self
454
455
456
457
            .indexer
            .find_matches(block_hashes)
            .instrument(tracing::info_span!("kv_router.find_matches"))
            .await?;
458
        let find_matches_elapsed = start.elapsed();
459

460
        let response = self
461
            .scheduler
462
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
463
                context_id.map(|s| s.to_string()),
464
                isl_tokens,
465
                maybe_seq_hashes,
466
                overlap_scores,
467
                router_config_override,
468
                update_states,
469
                lora_name,
470
                priority_jump,
471
                expected_output_tokens,
472
                allowed_worker_ids,
473
            )
474
            .instrument(tracing::info_span!("kv_router.schedule"))
475
            .await?;
476
477
        let total_elapsed = start.elapsed();

478
479
480
481
        if let Some(m) = metrics::RoutingOverheadMetrics::get() {
            m.observe(
                hash_elapsed,
                seq_hash_elapsed,
482
                find_matches_elapsed,
483
484
485
                total_elapsed,
            );
        }
486

487
        #[cfg(feature = "bench")]
488
489
490
        tracing::info!(
            isl_tokens,
            hash_us = hash_elapsed.as_micros() as u64,
491
492
493
            seq_hash_us = (seq_hash_elapsed - hash_elapsed).as_micros() as u64,
            find_matches_us = (find_matches_elapsed - seq_hash_elapsed).as_micros() as u64,
            schedule_us = (total_elapsed - find_matches_elapsed).as_micros() as u64,
494
495
496
            total_us = total_elapsed.as_micros() as u64,
            "find_best_match completed"
        );
497

498
        Ok((response.best_worker, response.overlap_blocks))
499
500
    }

501
502
503
504
505
    /// Register externally-provided workers in the slot tracker.
    pub fn register_workers(&self, worker_ids: &HashSet<WorkerId>) {
        self.scheduler.register_workers(worker_ids);
    }

506
    #[allow(clippy::too_many_arguments)]
507
508
509
510
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
511
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
512
        overlap_blocks: u32,
513
        expected_output_tokens: Option<u32>,
Yan Ru Pei's avatar
Yan Ru Pei committed
514
        worker: WorkerWithDpRank,
515
        lora_name: Option<String>,
516
        router_config_override: Option<&RouterConfigOverride>,
517
518
    ) {
        let isl_tokens = tokens.len();
519
520
521
522
523
        let hash_options = BlockHashOptions {
            block_mm_infos,
            lora_name: lora_name.as_deref(),
            is_eagle: Some(self.is_eagle),
        };
524

525
526
527
528
        let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
            tokens,
            self.block_size,
            router_config_override,
529
530
            hash_options,
            None,
531
        );
532

533
534
        if let Err(e) = self
            .scheduler
535
536
537
538
539
            .add_request(SequenceRequest {
                request_id: request_id.clone(),
                token_sequence: maybe_seq_hashes,
                isl: isl_tokens,
                overlap: overlap_blocks,
540
                expected_output_tokens,
Yan Ru Pei's avatar
Yan Ru Pei committed
541
                worker,
542
                lora_name,
543
            })
544
545
546
547
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
548
549
    }

550
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
551
        self.scheduler.mark_prefill_completed(request_id).await
552
553
    }

554
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
555
        self.scheduler.free(request_id).await
556
    }
557

558
559
560
561
562
    /// Number of requests currently parked in the scheduler queue.
    pub fn pending_count(&self) -> usize {
        self.scheduler.pending_count()
    }

563
564
565
566
567
568
    /// Get the worker type for this router ("prefill" or "decode").
    /// Used for Prometheus metric labeling.
    pub fn worker_type(&self) -> &'static str {
        self.scheduler.worker_type()
    }

569
    pub fn add_output_block(
570
571
572
573
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
574
        self.scheduler.add_output_block(request_id, decay_fraction)
575
576
    }

577
    pub fn block_size(&self) -> u32 {
578
579
        self.block_size
    }
580

581
582
583
584
585
    /// Compute the overlap blocks for a given token sequence and worker.
    /// This queries the indexer to find how many blocks are already cached.
    pub async fn get_overlap_blocks(
        &self,
        tokens: &[u32],
586
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
587
        worker: WorkerWithDpRank,
588
        lora_name: Option<&str>,
589
    ) -> Result<u32, KvRouterError> {
590
591
592
593
594
595
596
597
598
        let block_hashes = compute_block_hash_for_seq(
            tokens,
            self.block_size,
            BlockHashOptions {
                block_mm_infos,
                lora_name,
                is_eagle: Some(self.is_eagle),
            },
        );
599
600
601
602
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;
        Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0))
    }

603
    /// Get potential prefill and decode loads for all workers
604
605
606
607
    pub async fn get_potential_loads(
        &self,
        tokens: &[u32],
        router_config_override: Option<&RouterConfigOverride>,
608
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
609
        lora_name: Option<&str>,
610
    ) -> Result<Vec<PotentialLoad>> {
611
        let isl_tokens = tokens.len();
612
613
614
615
616
617
        let hash_options = BlockHashOptions {
            block_mm_infos,
            lora_name,
            is_eagle: Some(self.is_eagle),
        };
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, hash_options);
618

619
620
621
622
        let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
            tokens,
            self.block_size,
            router_config_override,
623
624
            hash_options,
            Some(&block_hashes),
625
        );
626

627
628
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;

629
630
        Ok(self
            .scheduler
631
            .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores))
632
633
    }

634
635
636
637
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
638
639
}

Michael Feil's avatar
Michael Feil committed
640
641
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
642
#[async_trait]
643
644
645
646
647
impl<Sel> AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error>
    for KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
{
648
649
650
651
652
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
653
654
655
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
656
657
658
659
            RouterRequest::New {
                tokens,
                block_mm_infos,
            } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
660
                let (best_worker, overlap_blocks) = self
661
662
663
664
665
666
667
668
                    .find_best_match(
                        Some(&context_id),
                        &tokens,
                        block_mm_infos.as_deref(),
                        None,
                        true,
                        None,
                        0.0,
669
                        None,
670
                        None,
671
                    )
Michael Feil's avatar
Michael Feil committed
672
673
674
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
675
676
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
677
678
679
                    overlap_blocks,
                }
            }
680
681
682
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
683
684
685
686
687
688
689
690
691
            RouterRequest::MarkFree { request_id } => {
                let request_id = match request_id.as_deref() {
                    Some(request_id) if !request_id.trim().is_empty() => request_id,
                    _ => &context_id,
                };
                RouterResponse::FreeMarked {
                    success: self.free(request_id).await.is_ok(),
                }
            }
Michael Feil's avatar
Michael Feil committed
692
        };
693
694
695
696
697
698

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
699

700
701
702
703
impl<Sel> Drop for KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>,
{
Yan Ru Pei's avatar
Yan Ru Pei committed
704
705
706
707
708
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}