kv_router.rs 33.5 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
use std::sync::Arc;
6
use std::time::Duration;
7

8
use anyhow::Result;
9
use derive_builder::Builder;
10
use dynamo_runtime::{
11
    component::{Client, Endpoint},
12
    discovery::{DiscoveryQuery, watch_and_extract_field},
13
    pipeline::{
14
15
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
        SingleIn, async_trait,
16
    },
17
    protocols::EndpointId,
18
    protocols::annotated::Annotated,
19
    traits::DistributedRuntimeProvider,
20
21
};
use futures::stream::{self, StreamExt};
22
use serde::{Deserialize, Serialize};
23
use serde_json::json;
24

25
26
use crate::protocols::openai::nvext::WorkerIdInfo;

27
pub mod approx;
28
pub mod indexer;
29
pub mod prefill_router;
30
31
pub mod protocols;
pub mod publisher;
32
pub mod recorder;
33
pub mod scheduler;
34
pub mod sequence;
35
pub mod subscriber;
36
pub mod worker_query;
37

38
use indexer::WorkerKvQueryResponse;
39
pub use prefill_router::PrefillRouter;
40
use worker_query::WorkerQueryClient;
41

42
43
use crate::{
    kv_router::{
44
        approx::PruneConfig,
45
        indexer::{
46
47
            KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
            compute_block_hash_for_seq, compute_seq_hash_for_block,
48
        },
Yan Ru Pei's avatar
Yan Ru Pei committed
49
        protocols::{
50
51
            LocalBlockHash, RouterRequest, RouterResponse, WorkerId, WorkerSelectionResult,
            WorkerWithDpRank,
Yan Ru Pei's avatar
Yan Ru Pei committed
52
        },
53
        scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
54
        sequence::SequenceError,
55
        subscriber::{start_kv_router_background, start_kv_router_background_nats_core},
56
    },
57
    local_model::runtime_config::ModelRuntimeConfig,
58
    model_card::ModelDeploymentCard,
59
    preprocessor::PreprocessedRequest,
60
    protocols::common::llm_backend::LLMEngineOutput,
61
    tokens::SequenceHash,
62
63
};

64
65
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
66
67
68
69
70

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
71
pub const KV_EVENT_SUBJECT: &str = "kv-events";
72
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
73
74
75
76
77
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
78

79
80
81
82
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

83
84
85
86
// for worker-local kvindexer query
pub const WORKER_KV_INDEXER_QUERY_SUBJECT: &str = "worker_kv_indexer_query";
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer

87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
// for router discovery registration
pub const KV_ROUTER_COMPONENT: &str = "kv-router";
pub const KV_ROUTER_ENDPOINT: &str = "generate";

/// Creates an EndpointId for the KV router in the given namespace.
pub fn router_endpoint_id(namespace: String) -> EndpointId {
    EndpointId {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        name: KV_ROUTER_ENDPOINT.to_string(),
    }
}

/// Creates a DiscoveryQuery for the KV router in the given namespace.
pub fn router_discovery_query(namespace: String) -> DiscoveryQuery {
    DiscoveryQuery::Endpoint {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        endpoint: KV_ROUTER_ENDPOINT.to_string(),
    }
}

109
110
111
112
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
113
        workers: &HashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>,
114
        request: &SchedulingRequest,
115
        block_size: u32,
116
117
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
118

119
120
121
122
123
124
125
126
127
128
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize)]
pub struct RouterConfigOverride {
    #[builder(default)]
    pub overlap_score_weight: Option<f64>,

    #[builder(default)]
    pub router_temperature: Option<f64>,
}

129
/// KV Router configuration parameters
130
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
131
132
133
pub struct KvRouterConfig {
    pub overlap_score_weight: f64,

134
    pub router_temperature: f64,
135

136
137
    pub use_kv_events: bool,

138
139
    pub router_replica_sync: bool,

140
141
142
    /// Whether to track active blocks in the router (default: true)
    pub router_track_active_blocks: bool,

143
144
145
    /// Threshold for triggering snapshots. If None, no snapshots will be performed.
    pub router_snapshot_threshold: Option<u32>,

146
    /// Whether to reset the router state on startup (default: false)
147
    pub router_reset_states: bool,
148
149
150
151
152
153
154
155
156

    /// TTL for blocks in seconds (only used when use_kv_events is false, default: 120.0)
    pub router_ttl_secs: f64,

    /// Maximum tree size before pruning (only used when use_kv_events is false, default: 1024)
    pub router_max_tree_size: usize,

    /// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
    pub router_prune_target_ratio: f64,
157
158
159
160
161
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
162
            overlap_score_weight: 1.0,
163
            router_temperature: 0.0,
164
            use_kv_events: true,
165
            router_replica_sync: false,
166
            router_track_active_blocks: true,
167
            router_snapshot_threshold: Some(1000000),
168
            router_reset_states: false,
169
170
171
            router_ttl_secs: 120.0,
            router_max_tree_size: 1024,
            router_prune_target_ratio: 0.8,
172
173
174
175
176
177
178
        }
    }
}

impl KvRouterConfig {
    /// Create a new KvRouterConfig with optional weight values.
    /// If a weight is None, the default value will be used.
179
    #[allow(clippy::too_many_arguments)]
180
181
    pub fn new(
        overlap_score_weight: Option<f64>,
182
        temperature: Option<f64>,
183
        use_kv_events: Option<bool>,
184
        replica_sync: Option<bool>,
185
        track_active_blocks: Option<bool>,
186
187
        router_snapshot_threshold: Option<Option<u32>>,
        router_reset_states: Option<bool>,
188
189
190
        router_ttl_secs: Option<f64>,
        router_max_tree_size: Option<usize>,
        router_prune_target_ratio: Option<f64>,
191
192
193
194
    ) -> Self {
        let default = Self::default();
        Self {
            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
195
            router_temperature: temperature.unwrap_or(default.router_temperature),
196
            use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
197
            router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
198
199
            router_track_active_blocks: track_active_blocks
                .unwrap_or(default.router_track_active_blocks),
200
201
202
            router_snapshot_threshold: router_snapshot_threshold
                .unwrap_or(default.router_snapshot_threshold),
            router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
203
204
205
206
            router_ttl_secs: router_ttl_secs.unwrap_or(default.router_ttl_secs),
            router_max_tree_size: router_max_tree_size.unwrap_or(default.router_max_tree_size),
            router_prune_target_ratio: router_prune_target_ratio
                .unwrap_or(default.router_prune_target_ratio),
207
208
209
210
        }
    }
}

211
pub enum Indexer {
212
213
    /// Updates itself based on KV events emitted by backend workers or routing decisions.
    /// Supports TTL-based expiration and size-based pruning.
214
    /// Has the ability to persist and snapshot states.
215
    KvIndexer(KvIndexer),
216
217
218
219

    /// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
    /// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
    None,
220
221
222
223
224
225
226
227
228
}

impl Indexer {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
229
230
231
            Indexer::None => Ok(OverlapScores {
                scores: HashMap::new(),
                frequencies: Vec::new(),
232
                tree_sizes: HashMap::new(),
233
            }),
234
235
        }
    }
236
237
238
239

    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.dump_events().await,
240
241
242
243
244
            Indexer::None => {
                panic!(
                    "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
                );
            }
245
246
        }
    }
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262

    async fn process_routing_decision(
        &self,
        worker: WorkerWithDpRank,
        local_hashes: Vec<LocalBlockHash>,
        sequence_hashes: Vec<SequenceHash>,
    ) -> Result<(), KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => {
                indexer
                    .process_routing_decision(worker, local_hashes, sequence_hashes)
                    .await
            }
            Indexer::None => Ok(()),
        }
    }
263
264
}

265
266
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
267
pub struct KvRouter {
268
269
270
    indexer: Indexer,

    // How about a Box<dyn KvIndexerInterface>
271
    scheduler: KvScheduler,
272

273
    block_size: u32,
274
275

    kv_router_config: KvRouterConfig,
Yan Ru Pei's avatar
Yan Ru Pei committed
276
277

    cancellation_token: tokio_util::sync::CancellationToken,
278
279

    client: Client,
280
281

    worker_query_client: Option<WorkerQueryClient>,
282
283
284
285
}

impl KvRouter {
    pub async fn new(
286
287
        endpoint: Endpoint,
        client: Client,
288
        block_size: u32,
289
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
290
        kv_router_config: Option<KvRouterConfig>,
291
        consumer_id: String,
292
    ) -> Result<Self> {
293
        let kv_router_config = kv_router_config.unwrap_or_default();
294
        let component = endpoint.component();
295
        let cancellation_token = component.drt().primary_token();
296

297
        let instance_ids_rx = client.instance_avail_watcher();
298

299
300
        // Watch for runtime config updates via discovery interface
        let discovery = component.drt().discovery();
301
        let endpoint_id = endpoint.id();
302
        let discovery_key = DiscoveryQuery::EndpointModels {
303
304
305
            namespace: endpoint_id.namespace.clone(),
            component: endpoint_id.component.clone(),
            endpoint: endpoint_id.name.clone(),
306
307
        };
        let discovery_stream = discovery
308
            .list_and_watch(discovery_key.clone(), Some(cancellation_token.clone()))
309
310
311
312
313
            .await?;
        let runtime_configs_rx =
            watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
                card.runtime_config
            });
314

315
316
317
        let indexer = if kv_router_config.overlap_score_weight == 0.0 {
            // When overlap_score_weight is zero, we don't need to track prefixes
            Indexer::None
318
        } else {
319
            let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);
320
321
322
323
324
325
326
327
328
329
330
331
332

            // If use_kv_events is false, enable TTL and pruning for approximate behavior
            let prune_config = if !kv_router_config.use_kv_events {
                Some(PruneConfig {
                    ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
                    max_tree_size: kv_router_config.router_max_tree_size,
                    prune_target_ratio: kv_router_config.router_prune_target_ratio,
                })
            } else {
                None
            };

            Indexer::KvIndexer(KvIndexer::new_with_frequency(
333
                cancellation_token.clone(),
334
                None, // expiration_duration for frequency tracking
335
336
                block_size,
                kv_indexer_metrics,
337
                prune_config,
338
339
            ))
        };
340

341
        let scheduler = KvScheduler::start(
342
            component.clone(),
343
            block_size,
344
            instance_ids_rx,
345
            runtime_configs_rx.clone(),
346
            selector,
347
            kv_router_config.router_replica_sync,
348
            consumer_id.clone(),
349
350
        )
        .await?;
351

352
353
354
355
356
357
        // Initialize worker query client using namespace abstraction
        // (created before background task so we can use it for startup recovery)
        let worker_query_client =
            worker_query::WorkerQueryClient::new(component.clone(), runtime_configs_rx.clone());
        tracing::info!("Worker query client initialized");

358
        // Start KV event subscriber background process (only when use_kv_events is enabled)
359
360
        // This is spawned as a background task to avoid blocking router startup.
        // The task waits for runtime_configs to determine whether to use NATS Core or JetStream.
361
362
363
        if kv_router_config.use_kv_events
            && let Indexer::KvIndexer(ref kv_indexer) = indexer
        {
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
            // Clone everything needed for the background task
            let component_clone = component.clone();
            let kv_indexer_clone = kv_indexer.clone();
            let cancellation_token_clone = cancellation_token.clone();
            let mut runtime_configs_rx_clone = runtime_configs_rx.clone();
            let worker_query_client_clone =
                worker_query::WorkerQueryClient::new(component.clone(), runtime_configs_rx.clone());

            tokio::spawn(async move {
                // Wait for runtime_configs to have at least one entry
                let (all_local_indexer, count) = loop {
                    {
                        let configs = runtime_configs_rx_clone.borrow();
                        if !configs.is_empty() {
                            let all_local_indexer =
                                configs.values().all(|c| c.enable_local_indexer);
                            break (all_local_indexer, configs.len());
                        }
                    }

                    // Wait for changes to runtime_configs
                    tokio::select! {
                        _ = cancellation_token_clone.cancelled() => {
                            tracing::debug!("Subscriber selection task cancelled");
                            return;
                        }
                        result = runtime_configs_rx_clone.changed() => {
                            if result.is_err() {
                                tracing::debug!("Runtime configs channel closed");
                                return;
                            }
                        }
                    }
                };

                if all_local_indexer {
                    // All workers have local_indexer enabled - use NATS Core
401
                    tracing::info!(
402
                        "All {count} workers have local_indexer enabled, using NATS Core subscription"
403
                    );
404
405
406
407
408
409
410
411
412
413
414
415

                    if let Err(e) = start_kv_router_background_nats_core(
                        component_clone.clone(),
                        kv_indexer_clone.event_sender(),
                        kv_indexer_clone.remove_worker_sender(),
                        cancellation_token_clone.clone(),
                        worker_query_client_clone,
                    )
                    .await
                    {
                        tracing::error!("Failed to start NATS Core subscriber: {e}");
                    }
416
                } else {
417
                    // Not all workers have local_indexer - use JetStream
418
                    tracing::info!(
419
                        "Not all workers have local_indexer enabled, using JetStream subscription"
420
                    );
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440

                    if let Err(e) = start_kv_router_background(
                        component_clone.clone(),
                        consumer_id,
                        kv_indexer_clone.event_sender(),
                        kv_indexer_clone.remove_worker_sender(),
                        kv_router_config
                            .router_snapshot_threshold
                            .map(|_| kv_indexer_clone.get_workers_sender()),
                        kv_router_config
                            .router_snapshot_threshold
                            .map(|_| kv_indexer_clone.snapshot_event_sender()),
                        cancellation_token_clone.clone(),
                        kv_router_config.router_snapshot_threshold,
                        kv_router_config.router_reset_states,
                    )
                    .await
                    {
                        tracing::error!("Failed to start JetStream subscriber: {e}");
                    }
441
                }
442
            });
443
        }
444

445
        tracing::info!("KV Routing initialized");
446
        Ok(Self {
447
            indexer,
448
            scheduler,
449
            block_size,
450
            kv_router_config,
Yan Ru Pei's avatar
Yan Ru Pei committed
451
            cancellation_token,
452
            client,
453
            worker_query_client: Some(worker_query_client),
454
        })
455
456
    }

457
458
459
460
461
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

462
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
463
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
Yan Ru Pei's avatar
Yan Ru Pei committed
464
465
    /// Now also takes optional context_id for request tracking
    pub async fn find_best_match(
466
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
467
        context_id: Option<&str>,
468
        tokens: &[u32],
469
        router_config_override: Option<&RouterConfigOverride>,
470
        update_states: bool,
Yan Ru Pei's avatar
Yan Ru Pei committed
471
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
Yan Ru Pei's avatar
Yan Ru Pei committed
472
473
474
475
476
        // Validate that context_id is provided when update_states is true
        if update_states && context_id.is_none() {
            panic!("context_id must be provided if update_states is true");
        }

477
        let isl_tokens = tokens.len();
478

479
480
481
482
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let seq_hashes = compute_seq_hash_for_block(&block_hashes);

        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
483

484
        // Determine who needs seq_hashes
485
        let needs_process_routing = !self.kv_router_config.use_kv_events;
486
487
488
489
        let scheduler_needs_it = self.kv_router_config.router_track_active_blocks;

        // Optimize cloning: only clone if both need it, otherwise move
        let (maybe_seq_hashes_1, maybe_seq_hashes_2) =
490
            match (needs_process_routing, scheduler_needs_it) {
491
492
493
494
495
496
                (true, true) => (Some(seq_hashes.clone()), Some(seq_hashes)),
                (true, false) => (Some(seq_hashes), None),
                (false, true) => (None, Some(seq_hashes)),
                (false, false) => (None, None),
            };

Yan Ru Pei's avatar
Yan Ru Pei committed
497
        let best_worker = self
498
            .scheduler
499
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
500
                context_id.map(|s| s.to_string()),
501
                isl_tokens,
502
                maybe_seq_hashes_2,
503
                overlap_scores.clone(),
504
                router_config_override,
505
                update_states,
506
            )
507
            .await?;
508

509
510
511
        // Process routing decision when not using KV events (approximate mode with TTL/pruning)
        if needs_process_routing {
            self.indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
512
                .process_routing_decision(best_worker, block_hashes, maybe_seq_hashes_1.unwrap())
513
514
                .await?;
        }
515

516
517
        let overlap_amount = overlap_scores
            .scores
Yan Ru Pei's avatar
Yan Ru Pei committed
518
            .get(&best_worker)
519
520
            .copied()
            .unwrap_or(0);
Yan Ru Pei's avatar
Yan Ru Pei committed
521
        Ok((best_worker, overlap_amount))
522
523
    }

524
525
526
527
528
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
        overlap_blocks: u32,
Yan Ru Pei's avatar
Yan Ru Pei committed
529
        worker: WorkerWithDpRank,
530
531
    ) {
        let isl_tokens = tokens.len();
532
533
534
535
536

        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
            compute_seq_hash_for_block(&block_hashes)
        });
537

538
539
        if let Err(e) = self
            .scheduler
540
            .add_request(
541
                request_id.clone(),
542
                maybe_seq_hashes,
543
544
                isl_tokens,
                overlap_blocks,
Yan Ru Pei's avatar
Yan Ru Pei committed
545
                worker,
546
            )
547
548
549
550
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
551
552
    }

553
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
554
        self.scheduler.mark_prefill_completed(request_id).await
555
556
    }

557
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
558
        self.scheduler.free(request_id).await
559
    }
560

561
    pub fn block_size(&self) -> u32 {
562
563
        self.block_size
    }
564

565
566
567
568
569
570
571
572
573
    /// Get the disaggregated endpoint for a worker, if available.
    /// Used to look up bootstrap host/port for prefill workers.
    pub async fn get_disaggregated_endpoint(
        &self,
        worker_id: u64,
    ) -> Option<crate::local_model::runtime_config::DisaggregatedEndpoint> {
        self.scheduler.get_disaggregated_endpoint(worker_id).await
    }

574
575
576
577
578
579
    /// Get potential prefill and decode loads for all workers
    pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
        let isl_tokens = tokens.len();
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;

580
581
582
583
584
        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
            compute_seq_hash_for_block(&block_hashes)
        });

585
586
        Ok(self
            .scheduler
587
            .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
588
589
590
            .await)
    }

591
592
593
594
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650

    /// Query a specific worker's local KV indexer for its events
    /// (See docstring for `WorkerQueryClient.query_worker()`)
    pub async fn query_worker_local_kv(
        &self,
        worker_id: WorkerId,
        start_event_id: Option<u64>,
        end_event_id: Option<u64>,
    ) -> Result<WorkerKvQueryResponse> {
        let query_client = self
            .worker_query_client
            .as_ref()
            .ok_or_else(|| anyhow::anyhow!("Worker query client not available (NATS required)"))?;

        query_client
            .query_worker(worker_id, start_event_id, end_event_id)
            .await
    }

    /// Recover missed KV events from a specific worker.
    ///
    /// Queries the worker's local KV indexer for events starting from
    /// `start_event_id` and applies them to the router's indexer.
    ///
    /// # Arguments
    ///
    /// * `worker_id` - The worker to recover from
    /// * `start_event_id` - First event ID to fetch (inclusive), or None to start from beginning
    /// * `end_event_id` - Last event ID to fetch (inclusive), or None for all
    pub async fn recover_from_worker(
        &self,
        worker_id: WorkerId,
        start_event_id: Option<u64>,
        end_event_id: Option<u64>,
    ) -> Result<usize> {
        let query_client = self
            .worker_query_client
            .as_ref()
            .ok_or_else(|| anyhow::anyhow!("Worker query client not available"))?;

        let event_tx = match &self.indexer {
            Indexer::KvIndexer(kv_indexer) => kv_indexer.event_sender(),
            Indexer::None => {
                anyhow::bail!("Cannot recover: indexer is disabled (--overlap_score_weight is 0)")
            }
        };

        subscriber::recover_from_worker(
            query_client,
            worker_id,
            start_event_id,
            end_event_id,
            &event_tx,
        )
        .await
    }
651
652
}

Michael Feil's avatar
Michael Feil committed
653
654
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
655
656
657
658
659
660
661
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
662
663
664
665
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
            RouterRequest::New { tokens } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
666
                let (best_worker, overlap_blocks) = self
Yan Ru Pei's avatar
Yan Ru Pei committed
667
                    .find_best_match(Some(&context_id), &tokens, None, true)
Michael Feil's avatar
Michael Feil committed
668
669
670
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
671
672
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
673
674
675
                    overlap_blocks,
                }
            }
676
677
678
679
680
681
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
            RouterRequest::MarkFree => RouterResponse::FreeMarked {
                success: self.free(&context_id).await.is_ok(),
            },
Michael Feil's avatar
Michael Feil committed
682
        };
683
684
685
686
687
688

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
689
690

pub struct KvPushRouter {
691
    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
692
    pub chooser: Arc<KvRouter>,
693
694
695
696
}

impl KvPushRouter {
    pub fn new(
697
        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
698
699
700
701
702
703
704
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
}

#[async_trait]
705
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
706
707
    for KvPushRouter
{
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
    /// Generate method that handles KV-aware routing with three distinct behaviors:
    ///
    /// 1. **If `query_instance_id` annotation is set**:
    ///    - Returns the best matching worker ID without routing the request
    ///    - Does NOT update any router local states
    ///    - Response includes worker_instance_id and token_data annotations
    ///
    /// 2. **If `backend_instance_id` is set in the request**:
    ///    - Routes directly to the specified backend instance
    ///    - DOES update router states to track this request (unless query_instance_id is also set)
    ///    - Bypasses the normal KV matching logic
    ///
    /// 3. **If neither are set (default behavior)**:
    ///    - Finds the best worker based on KV cache overlap
    ///    - Updates router states to track the request
    ///    - Routes to the selected worker
    ///
    /// The router state updates include tracking active sequences and managing
    /// prefill/completion lifecycle for proper KV cache management.
727
728
    async fn generate(
        &self,
729
        request: SingleIn<PreprocessedRequest>,
730
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
        // Extract context ID for request tracking
        let context_id = request.context().id().to_string();

        // Check if this is a query_instance_id request first
        let query_instance_id = request.has_annotation("query_instance_id");

        let (instance_id, dp_rank, overlap_amount) = if let Some(id) = request.backend_instance_id {
            // If instance_id is set, use it and compute actual overlap
            let dp_rank = request.dp_rank.unwrap_or(0);
            if query_instance_id {
                tracing::debug!(
                    "backend_instance_id is set, routing to instance {id} with dp_rank {dp_rank} and ignoring query_instance_id annotation"
                );
            }

            // Compute actual overlap blocks by querying the indexer
            let block_hashes =
                compute_block_hash_for_seq(&request.token_ids, self.chooser.block_size());
            let overlap_scores = self.chooser.indexer.find_matches(block_hashes).await?;
            let worker = WorkerWithDpRank::new(id, dp_rank);
            let overlap_blocks = overlap_scores.scores.get(&worker).copied().unwrap_or(0);

            self.chooser
                .add_request(
                    context_id.clone(),
                    &request.token_ids,
                    overlap_blocks,
                    worker,
                )
                .await;
            (id, dp_rank, overlap_blocks)
        } else {
            // Otherwise, find the best match
            let (best_worker, overlap_amount) = self
                .chooser
                .find_best_match(
                    Some(&context_id),
                    &request.token_ids,
                    request.router_config_override.as_ref(),
                    !query_instance_id, // Don't update states if query_instance_id
                )
                .await?;
            (best_worker.worker_id, best_worker.dp_rank, overlap_amount)
        };

        // if request has the annotation "query_instance_id",
        // then the request will not be routed to the worker,
        // and instead the worker_instance_id will be returned.
        let stream_context = request.context().clone();
        if query_instance_id {
            let instance_id_str = instance_id.to_string();
            let response = Annotated::from_annotation("worker_instance_id", &instance_id_str)?;

            // Return the tokens in nvext.token_data format
            let response_tokens = Annotated::from_annotation("token_data", &request.token_ids)?;
            tracing::trace!(
                "Tokens requested in the response through the query_instance_id annotation: {:?}",
                response_tokens
            );
            let stream = stream::iter(vec![response, response_tokens]);
            return Ok(ResponseStream::new(Box::pin(stream), stream_context));
        }
        let (mut backend_input, context) = request.into_parts();
        backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
        backend_input.dp_rank = Some(dp_rank);
796

797
798
        // Get prefill worker ID from prefill_result if available
        // In aggregated mode, prefill_result is None, so we use decode_worker_id for both
799
        let decode_worker_id = instance_id;
800
801
802
803
804
805
806
807
808
809
        let prefill_worker_id = backend_input
            .prefill_result
            .as_ref()
            .and_then(|prefill_result| {
                prefill_result
                    .disaggregated_params
                    .get("worker_id")
                    .and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok())
                    .and_then(|info| info.prefill_worker_id)
            })
810
811
            .or(Some(decode_worker_id)); // Use decode_worker_id if no separate prefill worker

812
813
814
815
816
817
818
819
820
        let updated_request = context.map(|_| backend_input);

        let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
        let stream_context = response_stream.context();
        let chooser = self.chooser.clone();
        let context_for_monitoring = stream_context.clone();

        let wrapped_stream = Box::pin(async_stream::stream! {
            let mut prefill_marked = false;
821
            let mut first_item = true;
822
823
824
825
826
827
828
829

            loop {
                tokio::select! {
                    biased;

                    _ = context_for_monitoring.stopped() => {
                        tracing::debug!("Request {context_id} cancelled, ending stream");
                        break;
830
                    }
Yan Ru Pei's avatar
Yan Ru Pei committed
831

832
                    item = response_stream.next() => {
833
                        let Some(mut item) = item else {
834
835
                            break;
                        };
836

837
838
                        if !prefill_marked {
                            if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
839
                                tracing::warn!("Failed to mark prefill completed for request {context_id}: {e}");
840
                            }
841
                            prefill_marked = true;
842
                        }
843

844
845
846
847
848
                        // Always inject worker_id in first item's disaggregated_params
                        // This is needed for:
                        // 1. PrefillRouter to know which prefill worker was chosen
                        // 2. Client response when extra_fields contains "worker_id"
                        if first_item {
849
                            first_item = false;
850
851
852
853
854
855

                            let Some(ref mut data) = item.data else {
                                yield item;
                                continue;
                            };

856
                            // prefill_worker_id comes from prefill_result.disaggregated_params or falls back to instance_id
857
                            // decode_worker_id is always the current instance_id
858
859
860
861
862
863
                            let worker_id_info = WorkerIdInfo {
                                prefill_worker_id,
                                decode_worker_id: Some(decode_worker_id),
                            };
                            let worker_id_json = serde_json::to_value(&worker_id_info)
                                .expect("WorkerIdInfo serialization should not fail");
864
865
866
867
868
869

                            if let Some(obj) = data.disaggregated_params.as_mut().and_then(|p| p.as_object_mut()) {
                                obj.insert("worker_id".to_string(), worker_id_json);
                            } else {
                                data.disaggregated_params = Some(json!({"worker_id": worker_id_json}));
                            }
870
                        }
871
872

                        yield item;
873
                    }
874
875
                }
            }
876

877
            if let Err(e) = chooser.free(&context_id).await {
878
                tracing::warn!("Failed to free request {context_id}: {e}");
879
            }
880
881
        });
        Ok(ResponseStream::new(wrapped_stream, stream_context))
882
883
    }
}
Yan Ru Pei's avatar
Yan Ru Pei committed
884
885
886
887
888
889
890

impl Drop for KvRouter {
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}