kv_router.rs 24.8 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

Yan Ru Pei's avatar
Yan Ru Pei committed
4
use std::sync::Arc;
5
use std::time::{Duration, Instant};
6

7
use anyhow::Result;
8
9
10
11
12
13
14
use dynamo_kv_router::{
    ConcurrentRadixTree, ThreadPoolIndexer,
    approx::PruneConfig,
    config::{KvRouterConfig, RouterConfigOverride},
    indexer::{GetWorkersRequest, KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvRouterError},
    protocols::KV_EVENT_SUBJECT,
    protocols::{
15
16
17
        BlockExtraInfo, BlockHashOptions, DpRank, LocalBlockHash, OverlapScores, RouterEvent,
        RouterRequest, RouterResponse, TokensWithHashes, WorkerId, WorkerWithDpRank,
        compute_block_hash_for_seq,
18
19
    },
};
20
use dynamo_runtime::{
21
    component::{Client, Endpoint},
22
    discovery::DiscoveryQuery,
23
    pipeline::{
24
25
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, ResponseStream, SingleIn,
        async_trait,
26
    },
27
    protocols::EndpointId,
28
    protocols::annotated::Annotated,
29
    traits::DistributedRuntimeProvider,
30
};
31
use futures::stream;
Yan Ru Pei's avatar
Yan Ru Pei committed
32
use tokio::sync::oneshot;
33
use tracing::Instrument;
34
use validator::Validate;
35

36
pub mod cache_control;
37
mod jetstream;
38
pub mod metrics;
39
pub mod prefill_router;
40
pub mod publisher;
41
pub mod push_router;
42
pub mod remote_indexer;
43
pub mod scheduler;
44
pub mod sequence;
45
pub mod subscriber;
46
pub mod worker_query;
47

48
pub use cache_control::{CacheControlClient, spawn_pin_prefix};
49
pub use prefill_router::PrefillRouter;
50
pub use push_router::{DirectRoutingRouter, KvPushRouter};
51

52
use crate::{
53
    discovery::RuntimeConfigWatch,
54
    kv_router::{
55
        remote_indexer::RemoteIndexer,
56
        scheduler::{DefaultWorkerSelector, KvScheduler, PotentialLoad},
57
        sequence::{SequenceError, SequenceRequest},
58
    },
59
    local_model::runtime_config::ModelRuntimeConfig,
60
61
};

62
63
use std::collections::HashSet;

64
65
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
66
67
68
69
70
71
72
73
74
75

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
76

77
78
79
80
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

81
82
83
// for worker-local kvindexer query
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer

84
85
86
87
88
89
/// Generates a dp_rank-specific endpoint name for the worker KV indexer query service.
/// Each dp_rank has its own LocalKvIndexer and query endpoint to ensure per-dp_rank monotonicity.
pub fn worker_kv_indexer_query_endpoint(dp_rank: DpRank) -> String {
    format!("worker_kv_indexer_query_dp{dp_rank}")
}

90
// for router discovery registration
91
pub const KV_ROUTER_ENDPOINT: &str = "router-discovery";
92
93

/// Creates an EndpointId for the KV router in the given namespace.
94
pub fn router_endpoint_id(namespace: String, component: String) -> EndpointId {
95
96
    EndpointId {
        namespace,
97
        component,
98
99
100
101
102
        name: KV_ROUTER_ENDPOINT.to_string(),
    }
}

/// Creates a DiscoveryQuery for the KV router in the given namespace.
103
pub fn router_discovery_query(namespace: String, component: String) -> DiscoveryQuery {
104
105
    DiscoveryQuery::Endpoint {
        namespace,
106
        component,
107
108
109
110
        endpoint: KV_ROUTER_ENDPOINT.to_string(),
    }
}

Yan Ru Pei's avatar
Yan Ru Pei committed
111
#[derive(Clone)]
112
pub enum Indexer {
Yan Ru Pei's avatar
Yan Ru Pei committed
113
    /// Single-threaded radix tree with channel-based event processing.
114
    /// Supports TTL-based expiration and size-based pruning.
115
    /// Has the ability to persist and snapshot states.
116
    KvIndexer(KvIndexer),
117

Yan Ru Pei's avatar
Yan Ru Pei committed
118
119
120
121
122
    /// Concurrent radix tree with a thread pool for event processing.
    /// Uses sticky worker routing for per-worker event serialization.
    /// Does not support TTL/pruning.
    Concurrent(Arc<ThreadPoolIndexer<ConcurrentRadixTree>>),

123
124
125
126
    /// Forwards queries to a standalone KV indexer service via the request plane.
    /// The standalone indexer manages its own radix tree and event subscription.
    Remote(Arc<RemoteIndexer>),

127
128
129
    /// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
    /// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
    None,
130
131
132
}

impl Indexer {
133
    pub async fn new(
134
135
136
        component: &dynamo_runtime::component::Component,
        kv_router_config: &KvRouterConfig,
        block_size: u32,
137
138
        model_name: Option<String>,
    ) -> Result<Self> {
139
        if kv_router_config.overlap_score_weight == 0.0 {
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
            return Ok(Indexer::None);
        }

        // Remote indexer: forward queries to a standalone KV indexer service.
        if let Some(ref indexer_component_name) = kv_router_config.remote_indexer_component {
            let model_name = model_name.ok_or_else(|| {
                anyhow::anyhow!(
                    "model_name is required when remote_indexer_component is configured"
                )
            })?;
            tracing::info!(
                remote_indexer_component = %indexer_component_name,
                model_name,
                "Using remote KV indexer"
            );
            let remote = RemoteIndexer::new(component, indexer_component_name, model_name).await?;
            return Ok(Indexer::Remote(Arc::new(remote)));
Yan Ru Pei's avatar
Yan Ru Pei committed
157
158
        }

159
160
161
162
        // Approximate mode (--no-kv-events): always use single-threaded KvIndexer
        // with TTL/pruning regardless of event_threads, since updates come from
        // routing decisions only, not live KV events from workers.
        if !kv_router_config.use_kv_events {
163
            let kv_indexer_metrics = KvIndexerMetrics::from_component(component);
164
165
166
167
168
169
            let cancellation_token = component.drt().primary_token();
            let prune_config = Some(PruneConfig {
                ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
                max_tree_size: kv_router_config.router_max_tree_size,
                prune_target_ratio: kv_router_config.router_prune_target_ratio,
            });
170
            return Ok(Indexer::KvIndexer(KvIndexer::new_with_frequency(
171
172
173
174
175
                cancellation_token,
                None,
                block_size,
                kv_indexer_metrics,
                prune_config,
176
            )));
177
178
        }

Yan Ru Pei's avatar
Yan Ru Pei committed
179
        if kv_router_config.router_event_threads > 1 {
180
            return Ok(Indexer::Concurrent(Arc::new(ThreadPoolIndexer::new(
Yan Ru Pei's avatar
Yan Ru Pei committed
181
182
                ConcurrentRadixTree::new(),
                kv_router_config.router_event_threads as usize,
183
                block_size,
184
            ))));
185
        }
Yan Ru Pei's avatar
Yan Ru Pei committed
186

187
        let kv_indexer_metrics = KvIndexerMetrics::from_component(component);
188
        let cancellation_token = component.drt().primary_token();
Yan Ru Pei's avatar
Yan Ru Pei committed
189

190
        Ok(Indexer::KvIndexer(KvIndexer::new_with_frequency(
Yan Ru Pei's avatar
Yan Ru Pei committed
191
192
193
194
            cancellation_token,
            None, // expiration_duration for frequency tracking
            block_size,
            kv_indexer_metrics,
195
            None,
196
        )))
197
198
199
    }

    pub(crate) async fn find_matches(
200
201
202
203
204
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
Yan Ru Pei's avatar
Yan Ru Pei committed
205
            Indexer::Concurrent(tpi) => tpi.find_matches(sequence).await,
206
207
208
209
            Indexer::Remote(remote) => remote.find_matches(sequence).await.map_err(|e| {
                tracing::warn!(error = %e, "Remote indexer query failed");
                KvRouterError::IndexerOffline
            }),
210
            Indexer::None => Ok(OverlapScores::new()),
211
212
        }
    }
213

214
    pub(crate) async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
215
216
        match self {
            Indexer::KvIndexer(indexer) => indexer.dump_events().await,
Yan Ru Pei's avatar
Yan Ru Pei committed
217
            Indexer::Concurrent(tpi) => tpi.dump_events().await,
218
            Indexer::Remote(_) => Ok(Vec::new()),
219
220
221
222
223
            Indexer::None => {
                panic!(
                    "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
                );
            }
224
225
        }
    }
226

227
    pub(crate) async fn process_routing_decision_for_request(
228
        &self,
229
        tokens_with_hashes: &mut TokensWithHashes,
230
231
232
233
234
        worker: WorkerWithDpRank,
    ) -> Result<(), KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => {
                indexer
235
                    .process_routing_decision_for_request(tokens_with_hashes, worker)
236
237
                    .await
            }
Yan Ru Pei's avatar
Yan Ru Pei committed
238
239
240
241
            Indexer::Concurrent(tpi) => {
                tpi.process_routing_decision_for_request(tokens_with_hashes, worker)
                    .await
            }
242
            Indexer::Remote(_) => Ok(()),
243
244
245
            Indexer::None => Ok(()),
        }
    }
Yan Ru Pei's avatar
Yan Ru Pei committed
246
247
248
249
250
251
252
253
254

    pub(crate) async fn apply_event(&self, event: RouterEvent) {
        match self {
            Indexer::KvIndexer(indexer) => {
                if let Err(e) = indexer.event_sender().send(event).await {
                    tracing::warn!("Failed to send event to indexer: {e}");
                }
            }
            Indexer::Concurrent(tpi) => tpi.apply_event(event).await,
255
            Indexer::Remote(_) => {} // standalone indexer gets events directly
Yan Ru Pei's avatar
Yan Ru Pei committed
256
257
258
259
260
261
262
263
264
265
266
267
268
269
            Indexer::None => {}
        }
    }

    pub(crate) async fn remove_worker(&self, worker_id: WorkerId) {
        match self {
            Indexer::KvIndexer(indexer) => {
                if let Err(e) = indexer.remove_worker_sender().send(worker_id).await {
                    tracing::warn!("Failed to send worker removal for {worker_id}: {e}");
                }
            }
            Indexer::Concurrent(tpi) => {
                KvIndexerInterface::remove_worker(tpi.as_ref(), worker_id).await;
            }
270
            Indexer::Remote(_) => {} // standalone indexer manages its own workers
Yan Ru Pei's avatar
Yan Ru Pei committed
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
            Indexer::None => {}
        }
    }

    pub(crate) async fn get_workers(&self) -> Vec<WorkerId> {
        match self {
            Indexer::KvIndexer(indexer) => {
                let (resp_tx, resp_rx) = oneshot::channel();
                let req = GetWorkersRequest { resp: resp_tx };
                if let Err(e) = indexer.get_workers_sender().send(req).await {
                    tracing::warn!("Failed to send get_workers request: {e}");
                    return Vec::new();
                }
                resp_rx.await.unwrap_or_default()
            }
            Indexer::Concurrent(tpi) => tpi.backend().get_workers(),
287
            Indexer::Remote(_) => Vec::new(),
Yan Ru Pei's avatar
Yan Ru Pei committed
288
289
290
            Indexer::None => Vec::new(),
        }
    }
291
292
}

293
294
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
295
296
297
298
pub struct KvRouter<Sel = DefaultWorkerSelector>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>,
{
299
    indexer: Indexer,
300
    scheduler: KvScheduler<Sel>,
301
    block_size: u32,
302
    kv_router_config: KvRouterConfig,
Yan Ru Pei's avatar
Yan Ru Pei committed
303
    cancellation_token: tokio_util::sync::CancellationToken,
304
    client: Client,
305
    is_eagle: bool,
306
307
}

308
309
310
311
impl<Sel> KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
{
312
    #[allow(clippy::too_many_arguments)]
313
    pub async fn new(
314
315
        endpoint: Endpoint,
        client: Client,
316
        mut workers_with_configs: RuntimeConfigWatch,
317
        block_size: u32,
318
        selector: Sel,
319
        kv_router_config: Option<KvRouterConfig>,
320
        worker_type: &'static str,
321
        model_name: Option<String>,
322
        is_eagle: bool,
323
    ) -> Result<Self> {
324
        let kv_router_config = kv_router_config.unwrap_or_default();
325
        kv_router_config.validate()?;
326
        let component = endpoint.component();
327
        let cancellation_token = component.drt().primary_token();
328

329
        let indexer = Indexer::new(component, &kv_router_config, block_size, model_name).await?;
330

331
332
        if !kv_router_config.skip_initial_worker_wait {
            let _ = workers_with_configs
333
                .wait_for(|m| m.len() >= kv_router_config.min_initial_workers)
334
335
                .await
                .map_err(|_| {
336
337
338
339
                    anyhow::anyhow!(
                        "runtime config watch closed before {} workers appeared",
                        kv_router_config.min_initial_workers
                    )
340
341
                })?;
        }
342

343
        let scheduler = KvScheduler::start(
344
            component.clone(),
345
            block_size,
346
            workers_with_configs.clone(),
347
            selector,
348
            &kv_router_config,
349
            worker_type,
350
351
        )
        .await?;
352

353
354
355
356
357
        // Start KV event subscription if needed — skip when using a remote indexer
        // (the standalone indexer handles its own event subscription).
        if kv_router_config.remote_indexer_component.is_some() {
            tracing::info!("Skipping KV event subscription (using remote indexer)");
        } else if kv_router_config.should_subscribe_to_kv_events() {
358
359
            subscriber::start_subscriber(component.clone(), &kv_router_config, indexer.clone())
                .await?;
360
        } else {
361
            tracing::info!(
362
363
364
                "Skipping KV event subscription (use_kv_events={}, overlap_score_weight={})",
                kv_router_config.use_kv_events,
                kv_router_config.overlap_score_weight,
365
            );
366
        }
367

368
        tracing::info!("KV Routing initialized");
369
        Ok(Self {
370
            indexer,
371
            scheduler,
372
            block_size,
373
            kv_router_config,
Yan Ru Pei's avatar
Yan Ru Pei committed
374
            cancellation_token,
375
            client,
376
            is_eagle,
377
        })
378
379
    }

380
381
382
383
384
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

385
386
387
388
389
390
391
392
    pub fn indexer(&self) -> &Indexer {
        &self.indexer
    }

    pub fn kv_router_config(&self) -> &KvRouterConfig {
        &self.kv_router_config
    }

393
394
395
396
    pub fn is_eagle(&self) -> bool {
        self.is_eagle
    }

397
398
    pub async fn record_routing_decision(
        &self,
399
        mut tokens_with_hashes: TokensWithHashes,
400
401
402
403
404
405
406
        worker: WorkerWithDpRank,
    ) -> Result<(), KvRouterError> {
        self.indexer
            .process_routing_decision_for_request(&mut tokens_with_hashes, worker)
            .await
    }

407
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
408
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
409
410
411
    /// Now also takes optional context_id for request tracking.
    ///
    /// When `allowed_worker_ids` is Some, only workers in that set are considered for selection.
412
    #[allow(clippy::too_many_arguments)]
Yan Ru Pei's avatar
Yan Ru Pei committed
413
    pub async fn find_best_match(
414
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
415
        context_id: Option<&str>,
416
        tokens: &[u32],
417
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
418
        router_config_override: Option<&RouterConfigOverride>,
419
        update_states: bool,
420
        lora_name: Option<String>,
421
        priority_jump: f64,
422
        expected_output_tokens: Option<u32>,
423
        allowed_worker_ids: Option<HashSet<WorkerId>>,
Yan Ru Pei's avatar
Yan Ru Pei committed
424
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
425
426
        let start = Instant::now();

Yan Ru Pei's avatar
Yan Ru Pei committed
427
        if update_states && context_id.is_none() {
428
            anyhow::bail!("context_id must be provided when update_states is true");
Yan Ru Pei's avatar
Yan Ru Pei committed
429
430
        }

431
        let isl_tokens = tokens.len();
432

433
434
435
436
        let block_hashes = tracing::info_span!("kv_router.compute_block_hashes").in_scope(|| {
            compute_block_hash_for_seq(
                tokens,
                self.block_size,
437
438
439
440
441
                BlockHashOptions {
                    block_mm_infos,
                    lora_name: lora_name.as_deref(),
                    is_eagle: Some(self.is_eagle),
                },
442
443
            )
        });
444
        let hash_elapsed = start.elapsed();
445

446
        let overlap_scores = self
447
448
449
450
            .indexer
            .find_matches(block_hashes)
            .instrument(tracing::info_span!("kv_router.find_matches"))
            .await?;
451
        let find_matches_elapsed = start.elapsed();
452

453
        // Compute seq_hashes only if scheduler needs it for active blocks tracking
454
455
456
457
458
        let maybe_seq_hashes = tracing::info_span!("kv_router.compute_seq_hashes").in_scope(|| {
            self.kv_router_config.compute_seq_hashes_for_tracking(
                tokens,
                self.block_size,
                router_config_override,
459
460
461
462
463
                BlockHashOptions {
                    block_mm_infos,
                    lora_name: lora_name.as_deref(),
                    is_eagle: Some(self.is_eagle),
                },
464
465
            )
        });
466
        let seq_hash_elapsed = start.elapsed();
467

468
        let response = self
469
            .scheduler
470
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
471
                context_id.map(|s| s.to_string()),
472
                isl_tokens,
473
                maybe_seq_hashes,
474
                overlap_scores,
475
                router_config_override,
476
                update_states,
477
                lora_name,
478
                priority_jump,
479
                expected_output_tokens,
480
                allowed_worker_ids,
481
            )
482
            .instrument(tracing::info_span!("kv_router.schedule"))
483
            .await?;
484
485
        let total_elapsed = start.elapsed();

486
487
488
489
490
491
492
493
        if let Some(m) = metrics::RoutingOverheadMetrics::get() {
            m.observe(
                hash_elapsed,
                find_matches_elapsed,
                seq_hash_elapsed,
                total_elapsed,
            );
        }
494

495
        #[cfg(feature = "bench")]
496
497
498
499
500
501
502
503
504
        tracing::info!(
            isl_tokens,
            hash_us = hash_elapsed.as_micros() as u64,
            find_matches_us = (find_matches_elapsed - hash_elapsed).as_micros() as u64,
            seq_hash_us = (seq_hash_elapsed - find_matches_elapsed).as_micros() as u64,
            schedule_us = (total_elapsed - seq_hash_elapsed).as_micros() as u64,
            total_us = total_elapsed.as_micros() as u64,
            "find_best_match completed"
        );
505

506
        Ok((response.best_worker, response.overlap_blocks))
507
508
    }

509
510
511
512
513
    /// Register externally-provided workers in the slot tracker.
    pub fn register_workers(&self, worker_ids: &HashSet<WorkerId>) {
        self.scheduler.register_workers(worker_ids);
    }

514
    #[allow(clippy::too_many_arguments)]
515
516
517
518
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
519
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
520
        overlap_blocks: u32,
521
        expected_output_tokens: Option<u32>,
Yan Ru Pei's avatar
Yan Ru Pei committed
522
        worker: WorkerWithDpRank,
523
        lora_name: Option<String>,
524
        router_config_override: Option<&RouterConfigOverride>,
525
526
    ) {
        let isl_tokens = tokens.len();
527

528
529
530
531
        let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
            tokens,
            self.block_size,
            router_config_override,
532
533
534
535
536
            BlockHashOptions {
                block_mm_infos,
                lora_name: lora_name.as_deref(),
                is_eagle: Some(self.is_eagle),
            },
537
        );
538

539
540
        if let Err(e) = self
            .scheduler
541
542
543
544
545
            .add_request(SequenceRequest {
                request_id: request_id.clone(),
                token_sequence: maybe_seq_hashes,
                isl: isl_tokens,
                overlap: overlap_blocks,
546
                expected_output_tokens,
Yan Ru Pei's avatar
Yan Ru Pei committed
547
                worker,
548
                lora_name,
549
            })
550
551
552
553
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
554
555
    }

556
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
557
        self.scheduler.mark_prefill_completed(request_id).await
558
559
    }

560
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
561
        self.scheduler.free(request_id).await
562
    }
563

564
565
566
567
568
    /// Number of requests currently parked in the scheduler queue.
    pub fn pending_count(&self) -> usize {
        self.scheduler.pending_count()
    }

569
570
571
572
573
574
    /// Get the worker type for this router ("prefill" or "decode").
    /// Used for Prometheus metric labeling.
    pub fn worker_type(&self) -> &'static str {
        self.scheduler.worker_type()
    }

575
    pub fn add_output_block(
576
577
578
579
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
580
        self.scheduler.add_output_block(request_id, decay_fraction)
581
582
    }

583
    pub fn block_size(&self) -> u32 {
584
585
        self.block_size
    }
586

587
588
589
590
591
    /// Compute the overlap blocks for a given token sequence and worker.
    /// This queries the indexer to find how many blocks are already cached.
    pub async fn get_overlap_blocks(
        &self,
        tokens: &[u32],
592
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
593
        worker: WorkerWithDpRank,
594
        lora_name: Option<&str>,
595
    ) -> Result<u32, KvRouterError> {
596
597
598
599
600
601
602
603
604
        let block_hashes = compute_block_hash_for_seq(
            tokens,
            self.block_size,
            BlockHashOptions {
                block_mm_infos,
                lora_name,
                is_eagle: Some(self.is_eagle),
            },
        );
605
606
607
608
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;
        Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0))
    }

609
    /// Get potential prefill and decode loads for all workers
610
611
612
613
    pub async fn get_potential_loads(
        &self,
        tokens: &[u32],
        router_config_override: Option<&RouterConfigOverride>,
614
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
615
        lora_name: Option<&str>,
616
    ) -> Result<Vec<PotentialLoad>> {
617
        let isl_tokens = tokens.len();
618
619
620
621
622
623
624
625
626
        let block_hashes = compute_block_hash_for_seq(
            tokens,
            self.block_size,
            BlockHashOptions {
                block_mm_infos,
                lora_name,
                is_eagle: Some(self.is_eagle),
            },
        );
627
        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
628

629
630
631
632
        let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
            tokens,
            self.block_size,
            router_config_override,
633
634
635
636
637
            BlockHashOptions {
                block_mm_infos,
                lora_name,
                is_eagle: Some(self.is_eagle),
            },
638
        );
639

640
641
        Ok(self
            .scheduler
642
            .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores))
643
644
    }

645
646
647
648
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
649
650
}

Michael Feil's avatar
Michael Feil committed
651
652
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
653
#[async_trait]
654
655
656
657
658
impl<Sel> AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error>
    for KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
{
659
660
661
662
663
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
664
665
666
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
667
668
669
670
            RouterRequest::New {
                tokens,
                block_mm_infos,
            } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
671
                let (best_worker, overlap_blocks) = self
672
673
674
675
676
677
678
679
                    .find_best_match(
                        Some(&context_id),
                        &tokens,
                        block_mm_infos.as_deref(),
                        None,
                        true,
                        None,
                        0.0,
680
                        None,
681
                        None,
682
                    )
Michael Feil's avatar
Michael Feil committed
683
684
685
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
686
687
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
688
689
690
                    overlap_blocks,
                }
            }
691
692
693
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
694
695
696
697
698
699
700
701
702
            RouterRequest::MarkFree { request_id } => {
                let request_id = match request_id.as_deref() {
                    Some(request_id) if !request_id.trim().is_empty() => request_id,
                    _ => &context_id,
                };
                RouterResponse::FreeMarked {
                    success: self.free(request_id).await.is_ok(),
                }
            }
Michael Feil's avatar
Michael Feil committed
703
        };
704
705
706
707
708
709

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
710

711
712
713
714
impl<Sel> Drop for KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>,
{
Yan Ru Pei's avatar
Yan Ru Pei committed
715
716
717
718
719
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}