kv_router.rs 35.8 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
use std::sync::Arc;
6
use std::time::Duration;
7

8
use anyhow::Result;
9
use derive_builder::Builder;
10
use dynamo_runtime::{
11
    component::{Client, Endpoint},
12
    discovery::{DiscoveryQuery, watch_and_extract_field},
13
    pipeline::{
14
15
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
        SingleIn, async_trait,
16
    },
17
    protocols::EndpointId,
18
    protocols::annotated::Annotated,
19
    traits::DistributedRuntimeProvider,
20
21
};
use futures::stream::{self, StreamExt};
22
use rand::Rng;
23
use serde::{Deserialize, Serialize};
24
use serde_json::json;
25

26
pub mod approx;
27
pub mod indexer;
28
pub mod prefill_router;
29
30
pub mod protocols;
pub mod publisher;
31
pub mod recorder;
32
pub mod scheduler;
33
pub mod sequence;
34
pub mod subscriber;
35
pub mod worker_query;
36

37
use indexer::WorkerKvQueryResponse;
38
pub use prefill_router::PrefillRouter;
39
use worker_query::WorkerQueryClient;
40

41
use crate::{
42
    discovery::RuntimeConfigsWithNotify,
43
    kv_router::{
44
        approx::PruneConfig,
45
        indexer::{KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent},
Yan Ru Pei's avatar
Yan Ru Pei committed
46
        protocols::{
47
48
49
            LocalBlockHash, RouterRequest, RouterResponse, TokensWithHashes, WorkerId,
            WorkerSelectionResult, WorkerWithDpRank, compute_block_hash_for_seq,
            compute_seq_hash_for_block,
Yan Ru Pei's avatar
Yan Ru Pei committed
50
        },
51
        scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
52
        sequence::SequenceError,
53
        subscriber::{start_kv_router_background, start_kv_router_background_nats_core},
54
    },
55
    local_model::runtime_config::ModelRuntimeConfig,
56
    model_card::ModelDeploymentCard,
57
    preprocessor::PreprocessedRequest,
58
    protocols::common::llm_backend::LLMEngineOutput,
59
    protocols::common::timing::RequestPhase,
60
61
};

62
63
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
64
65
66
67
68

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
69
pub const KV_EVENT_SUBJECT: &str = "kv-events";
70
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
71
72
73
74
75
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
76

77
78
79
80
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

81
// for worker-local kvindexer query
82
pub const WORKER_KV_INDEXER_QUERY_ENDPOINT: &str = "worker_kv_indexer_query";
83
84
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer

85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
// for router discovery registration
pub const KV_ROUTER_COMPONENT: &str = "kv-router";
pub const KV_ROUTER_ENDPOINT: &str = "generate";

/// Creates an EndpointId for the KV router in the given namespace.
pub fn router_endpoint_id(namespace: String) -> EndpointId {
    EndpointId {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        name: KV_ROUTER_ENDPOINT.to_string(),
    }
}

/// Creates a DiscoveryQuery for the KV router in the given namespace.
pub fn router_discovery_query(namespace: String) -> DiscoveryQuery {
    DiscoveryQuery::Endpoint {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        endpoint: KV_ROUTER_ENDPOINT.to_string(),
    }
}

107
108
109
110
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
111
        workers: &HashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>,
112
        request: &SchedulingRequest,
113
        block_size: u32,
114
115
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
116

117
118
119
120
121
122
123
124
125
126
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize)]
pub struct RouterConfigOverride {
    #[builder(default)]
    pub overlap_score_weight: Option<f64>,

    #[builder(default)]
    pub router_temperature: Option<f64>,
}

127
/// KV Router configuration parameters
128
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
129
130
131
pub struct KvRouterConfig {
    pub overlap_score_weight: f64,

132
    pub router_temperature: f64,
133

134
135
    pub use_kv_events: bool,

136
137
    pub router_replica_sync: bool,

138
139
140
    /// Whether to track active blocks in the router (default: true)
    pub router_track_active_blocks: bool,

141
142
143
144
145
    /// Whether to assume KV cache reuse when tracking active blocks (default: true).
    /// When true, computes actual block hashes for sequence tracking.
    /// When false, generates random hashes (assuming no KV cache reuse).
    pub router_assume_kv_reuse: bool,

146
147
148
    /// Threshold for triggering snapshots. If None, no snapshots will be performed.
    pub router_snapshot_threshold: Option<u32>,

149
    /// Whether to reset the router state on startup (default: false)
150
    pub router_reset_states: bool,
151
152
153
154

    /// TTL for blocks in seconds (only used when use_kv_events is false, default: 120.0)
    pub router_ttl_secs: f64,

155
    /// Maximum tree size before pruning (only used when use_kv_events is false, default: 2^20 = 1048576)
156
157
158
159
    pub router_max_tree_size: usize,

    /// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
    pub router_prune_target_ratio: f64,
160
161
162
163
164
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
165
            overlap_score_weight: 1.0,
166
            router_temperature: 0.0,
167
            use_kv_events: true,
168
            router_replica_sync: false,
169
            router_track_active_blocks: true,
170
            router_assume_kv_reuse: true,
171
            router_snapshot_threshold: Some(1000000),
172
            router_reset_states: false,
173
            router_ttl_secs: 120.0,
174
            router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default()
175
            router_prune_target_ratio: 0.8,
176
177
178
179
180
181
182
        }
    }
}

impl KvRouterConfig {
    /// Create a new KvRouterConfig with optional weight values.
    /// If a weight is None, the default value will be used.
183
    #[allow(clippy::too_many_arguments)]
184
185
    pub fn new(
        overlap_score_weight: Option<f64>,
186
        temperature: Option<f64>,
187
        use_kv_events: Option<bool>,
188
        replica_sync: Option<bool>,
189
        track_active_blocks: Option<bool>,
190
        assume_kv_reuse: Option<bool>,
191
192
        router_snapshot_threshold: Option<Option<u32>>,
        router_reset_states: Option<bool>,
193
194
195
        router_ttl_secs: Option<f64>,
        router_max_tree_size: Option<usize>,
        router_prune_target_ratio: Option<f64>,
196
197
198
199
    ) -> Self {
        let default = Self::default();
        Self {
            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
200
            router_temperature: temperature.unwrap_or(default.router_temperature),
201
            use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
202
            router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
203
204
            router_track_active_blocks: track_active_blocks
                .unwrap_or(default.router_track_active_blocks),
205
            router_assume_kv_reuse: assume_kv_reuse.unwrap_or(default.router_assume_kv_reuse),
206
207
208
            router_snapshot_threshold: router_snapshot_threshold
                .unwrap_or(default.router_snapshot_threshold),
            router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
209
210
211
212
            router_ttl_secs: router_ttl_secs.unwrap_or(default.router_ttl_secs),
            router_max_tree_size: router_max_tree_size.unwrap_or(default.router_max_tree_size),
            router_prune_target_ratio: router_prune_target_ratio
                .unwrap_or(default.router_prune_target_ratio),
213
214
        }
    }
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245

    /// Compute sequence hashes for active block tracking based on configuration.
    ///
    /// Returns:
    /// - `None` if `router_track_active_blocks` is false
    /// - Random hashes if `router_track_active_blocks` is true but `router_assume_kv_reuse` is false
    /// - Actual sequence hashes if both are true
    pub fn compute_seq_hashes_for_tracking(
        &self,
        tokens: &[u32],
        block_size: u32,
    ) -> Option<Vec<u64>> {
        if !self.router_track_active_blocks {
            return None;
        }

        let num_blocks = tokens.len() / block_size as usize;
        if num_blocks == 0 {
            return Some(Vec::new());
        }

        if self.router_assume_kv_reuse {
            // Compute actual block hashes and sequence hashes
            let block_hashes = compute_block_hash_for_seq(tokens, block_size, None);
            Some(compute_seq_hash_for_block(&block_hashes))
        } else {
            // Generate random hashes (no KV reuse assumed)
            let mut rng = rand::rng();
            Some((0..num_blocks).map(|_| rng.random::<u64>()).collect())
        }
    }
246
247
}

248
pub enum Indexer {
249
250
    /// Updates itself based on KV events emitted by backend workers or routing decisions.
    /// Supports TTL-based expiration and size-based pruning.
251
    /// Has the ability to persist and snapshot states.
252
    KvIndexer(KvIndexer),
253
254
255
256

    /// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
    /// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
    None,
257
258
259
260
261
262
263
264
265
}

impl Indexer {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
266
267
268
            Indexer::None => Ok(OverlapScores {
                scores: HashMap::new(),
                frequencies: Vec::new(),
269
                tree_sizes: HashMap::new(),
270
            }),
271
272
        }
    }
273
274
275
276

    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.dump_events().await,
277
278
279
280
281
            Indexer::None => {
                panic!(
                    "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
                );
            }
282
283
        }
    }
284

285
    async fn process_routing_decision_for_request(
286
        &self,
287
        tokens_with_hashes: &mut TokensWithHashes,
288
289
290
291
292
        worker: WorkerWithDpRank,
    ) -> Result<(), KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => {
                indexer
293
                    .process_routing_decision_for_request(tokens_with_hashes, worker)
294
295
296
297
298
                    .await
            }
            Indexer::None => Ok(()),
        }
    }
299
300
}

301
302
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
303
pub struct KvRouter {
304
305
306
    indexer: Indexer,

    // How about a Box<dyn KvIndexerInterface>
307
    scheduler: KvScheduler,
308

309
    block_size: u32,
310
311

    kv_router_config: KvRouterConfig,
Yan Ru Pei's avatar
Yan Ru Pei committed
312
313

    cancellation_token: tokio_util::sync::CancellationToken,
314
315

    client: Client,
316
317

    worker_query_client: Option<WorkerQueryClient>,
318
319
320
321
}

impl KvRouter {
    pub async fn new(
322
323
        endpoint: Endpoint,
        client: Client,
324
        workers_with_configs: Arc<RuntimeConfigsWithNotify>,
325
        block_size: u32,
326
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
327
        kv_router_config: Option<KvRouterConfig>,
328
        consumer_id: String,
329
    ) -> Result<Self> {
330
        let kv_router_config = kv_router_config.unwrap_or_default();
331
        let component = endpoint.component();
332
        let cancellation_token = component.drt().primary_token();
333

334
        // Watch for runtime config updates via discovery interface
335
        // (still needed for WorkerQueryClient and background tasks)
336
        let discovery = component.drt().discovery();
337
        let endpoint_id = endpoint.id();
338
        let discovery_key = DiscoveryQuery::EndpointModels {
339
340
341
            namespace: endpoint_id.namespace.clone(),
            component: endpoint_id.component.clone(),
            endpoint: endpoint_id.name.clone(),
342
343
        };
        let discovery_stream = discovery
344
            .list_and_watch(discovery_key.clone(), Some(cancellation_token.clone()))
345
346
347
348
349
            .await?;
        let runtime_configs_rx =
            watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
                card.runtime_config
            });
350

351
352
353
        let indexer = if kv_router_config.overlap_score_weight == 0.0 {
            // When overlap_score_weight is zero, we don't need to track prefixes
            Indexer::None
354
        } else {
355
            let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);
356
357
358
359
360
361
362
363
364
365
366
367
368

            // If use_kv_events is false, enable TTL and pruning for approximate behavior
            let prune_config = if !kv_router_config.use_kv_events {
                Some(PruneConfig {
                    ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
                    max_tree_size: kv_router_config.router_max_tree_size,
                    prune_target_ratio: kv_router_config.router_prune_target_ratio,
                })
            } else {
                None
            };

            Indexer::KvIndexer(KvIndexer::new_with_frequency(
369
                cancellation_token.clone(),
370
                None, // expiration_duration for frequency tracking
371
372
                block_size,
                kv_indexer_metrics,
373
                prune_config,
374
375
            ))
        };
376

377
        let scheduler = KvScheduler::start(
378
            component.clone(),
379
            block_size,
380
            workers_with_configs.clone(),
381
            selector,
382
            kv_router_config.router_replica_sync,
383
            consumer_id.clone(),
384
385
        )
        .await?;
386

387
388
389
390
391
392
        // Initialize worker query client using namespace abstraction
        // (created before background task so we can use it for startup recovery)
        let worker_query_client =
            worker_query::WorkerQueryClient::new(component.clone(), runtime_configs_rx.clone());
        tracing::info!("Worker query client initialized");

393
        // Start KV event subscriber background process (only when use_kv_events is enabled)
394
        // model_manager.get_or_create_runtime_config_watcher() guarantees at least one worker exists.
395
396
397
        if kv_router_config.use_kv_events
            && let Indexer::KvIndexer(ref kv_indexer) = indexer
        {
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
            // model_manager guarantees workers_with_configs is populated
            // Wait for at least one worker before starting the subscriber
            while workers_with_configs.configs.is_empty() {
                tracing::info!("KV router waiting for at least one worker...");
                workers_with_configs.notify.notified().await;
            }

            let count = workers_with_configs.configs.len();
            let all_local_indexer = workers_with_configs
                .configs
                .iter()
                .filter_map(|r| r.value().as_ref().map(|c| c.enable_local_indexer))
                .all(|b| b);

            tracing::info!("Found {count} worker(s), starting KV event subscriber");
413

414
            // Start subscriber - setup runs synchronously, then spawns background loop internally
415
416
417
418
419
            if all_local_indexer {
                tracing::info!(
                    "All {count} workers have local_indexer enabled, using NATS Core subscription"
                );

420
421
422
423
424
425
426
427
428
429
430
                start_kv_router_background_nats_core(
                    component.clone(),
                    kv_indexer.event_sender(),
                    kv_indexer.remove_worker_sender(),
                    cancellation_token.clone(),
                    worker_query::WorkerQueryClient::new(
                        component.clone(),
                        runtime_configs_rx.clone(),
                    ),
                )
                .await?;
431
432
433
434
            } else {
                tracing::info!(
                    "Not all workers have local_indexer enabled, using JetStream subscription"
                );
435

436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
                start_kv_router_background(
                    component.clone(),
                    consumer_id,
                    kv_indexer.event_sender(),
                    kv_indexer.remove_worker_sender(),
                    kv_router_config
                        .router_snapshot_threshold
                        .map(|_| kv_indexer.get_workers_sender()),
                    kv_router_config
                        .router_snapshot_threshold
                        .map(|_| kv_indexer.snapshot_event_sender()),
                    cancellation_token.clone(),
                    kv_router_config.router_snapshot_threshold,
                    kv_router_config.router_reset_states,
                )
                .await?;
452
            }
453
        }
454

455
        tracing::info!("KV Routing initialized");
456
        Ok(Self {
457
            indexer,
458
            scheduler,
459
            block_size,
460
            kv_router_config,
Yan Ru Pei's avatar
Yan Ru Pei committed
461
            cancellation_token,
462
            client,
463
            worker_query_client: Some(worker_query_client),
464
        })
465
466
    }

467
468
469
470
471
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

472
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
473
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
Yan Ru Pei's avatar
Yan Ru Pei committed
474
475
    /// Now also takes optional context_id for request tracking
    pub async fn find_best_match(
476
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
477
        context_id: Option<&str>,
478
        tokens: &[u32],
479
        router_config_override: Option<&RouterConfigOverride>,
480
        update_states: bool,
Yan Ru Pei's avatar
Yan Ru Pei committed
481
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
Yan Ru Pei's avatar
Yan Ru Pei committed
482
483
484
485
486
        // Validate that context_id is provided when update_states is true
        if update_states && context_id.is_none() {
            panic!("context_id must be provided if update_states is true");
        }

487
        let isl_tokens = tokens.len();
488

489
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
490
        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
491

492
493
494
        // Compute seq_hashes only if scheduler needs it for active blocks tracking
        let maybe_seq_hashes = self
            .kv_router_config
495
            .compute_seq_hashes_for_tracking(tokens, self.block_size);
496

Yan Ru Pei's avatar
Yan Ru Pei committed
497
        let best_worker = self
498
            .scheduler
499
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
500
                context_id.map(|s| s.to_string()),
501
                isl_tokens,
502
                maybe_seq_hashes,
503
                overlap_scores.clone(),
504
                router_config_override,
505
                update_states,
506
            )
507
            .await?;
508

509
510
        // Note: Routing decision recording (for approximate mode) is now handled
        // by KvPushRouter::generate after select_worker returns.
511

512
513
        let overlap_amount = overlap_scores
            .scores
Yan Ru Pei's avatar
Yan Ru Pei committed
514
            .get(&best_worker)
515
516
            .copied()
            .unwrap_or(0);
Yan Ru Pei's avatar
Yan Ru Pei committed
517
        Ok((best_worker, overlap_amount))
518
519
    }

520
521
522
523
524
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
        overlap_blocks: u32,
Yan Ru Pei's avatar
Yan Ru Pei committed
525
        worker: WorkerWithDpRank,
526
527
    ) {
        let isl_tokens = tokens.len();
528

529
530
531
        let maybe_seq_hashes = self
            .kv_router_config
            .compute_seq_hashes_for_tracking(tokens, self.block_size);
532

533
534
        if let Err(e) = self
            .scheduler
535
            .add_request(
536
                request_id.clone(),
537
                maybe_seq_hashes,
538
539
                isl_tokens,
                overlap_blocks,
Yan Ru Pei's avatar
Yan Ru Pei committed
540
                worker,
541
            )
542
543
544
545
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
546
547
    }

548
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
549
        self.scheduler.mark_prefill_completed(request_id).await
550
551
    }

552
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
553
        self.scheduler.free(request_id).await
554
    }
555

556
    pub fn block_size(&self) -> u32 {
557
558
        self.block_size
    }
559

560
561
562
563
564
565
566
567
568
569
570
571
    /// Compute the overlap blocks for a given token sequence and worker.
    /// This queries the indexer to find how many blocks are already cached.
    pub async fn get_overlap_blocks(
        &self,
        tokens: &[u32],
        worker: WorkerWithDpRank,
    ) -> Result<u32, KvRouterError> {
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;
        Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0))
    }

572
573
574
    /// Get potential prefill and decode loads for all workers
    pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
        let isl_tokens = tokens.len();
575
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
576
        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
577

578
579
        let maybe_seq_hashes = self
            .kv_router_config
580
            .compute_seq_hashes_for_tracking(tokens, self.block_size);
581

582
583
        Ok(self
            .scheduler
584
            .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
585
586
587
            .await)
    }

588
589
590
591
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647

    /// Query a specific worker's local KV indexer for its events
    /// (See docstring for `WorkerQueryClient.query_worker()`)
    pub async fn query_worker_local_kv(
        &self,
        worker_id: WorkerId,
        start_event_id: Option<u64>,
        end_event_id: Option<u64>,
    ) -> Result<WorkerKvQueryResponse> {
        let query_client = self
            .worker_query_client
            .as_ref()
            .ok_or_else(|| anyhow::anyhow!("Worker query client not available (NATS required)"))?;

        query_client
            .query_worker(worker_id, start_event_id, end_event_id)
            .await
    }

    /// Recover missed KV events from a specific worker.
    ///
    /// Queries the worker's local KV indexer for events starting from
    /// `start_event_id` and applies them to the router's indexer.
    ///
    /// # Arguments
    ///
    /// * `worker_id` - The worker to recover from
    /// * `start_event_id` - First event ID to fetch (inclusive), or None to start from beginning
    /// * `end_event_id` - Last event ID to fetch (inclusive), or None for all
    pub async fn recover_from_worker(
        &self,
        worker_id: WorkerId,
        start_event_id: Option<u64>,
        end_event_id: Option<u64>,
    ) -> Result<usize> {
        let query_client = self
            .worker_query_client
            .as_ref()
            .ok_or_else(|| anyhow::anyhow!("Worker query client not available"))?;

        let event_tx = match &self.indexer {
            Indexer::KvIndexer(kv_indexer) => kv_indexer.event_sender(),
            Indexer::None => {
                anyhow::bail!("Cannot recover: indexer is disabled (--overlap_score_weight is 0)")
            }
        };

        subscriber::recover_from_worker(
            query_client,
            worker_id,
            start_event_id,
            end_event_id,
            &event_tx,
        )
        .await
    }
648
649
}

Michael Feil's avatar
Michael Feil committed
650
651
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
652
653
654
655
656
657
658
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
659
660
661
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
662
            RouterRequest::New { tokens } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
663
                let (best_worker, overlap_blocks) = self
Yan Ru Pei's avatar
Yan Ru Pei committed
664
                    .find_best_match(Some(&context_id), &tokens, None, true)
Michael Feil's avatar
Michael Feil committed
665
666
667
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
668
669
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
670
671
672
                    overlap_blocks,
                }
            }
673
674
675
676
677
678
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
            RouterRequest::MarkFree => RouterResponse::FreeMarked {
                success: self.free(&context_id).await.is_ok(),
            },
Michael Feil's avatar
Michael Feil committed
679
        };
680
681
682
683
684
685

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
686
687

pub struct KvPushRouter {
688
    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
689
    pub chooser: Arc<KvRouter>,
690
691
}

692
693
694
695
696
697
698
/// Result of worker selection containing instance ID, dp_rank, and overlap amount.
struct WorkerSelection {
    instance_id: u64,
    dp_rank: u32,
    overlap_amount: u32,
}

699
700
impl KvPushRouter {
    pub fn new(
701
        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
702
703
704
705
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790

    /// Select a worker for the request, either using a preselected worker or finding the best match.
    ///
    /// When `is_query_only` is false and `handle_local_updates` is true, this also registers
    /// the request with the scheduler via `add_request`.
    async fn select_worker(
        &self,
        context_id: &str,
        request: &PreprocessedRequest,
        phase: RequestPhase,
        is_query_only: bool,
        handle_local_updates: bool,
    ) -> Result<WorkerSelection, Error> {
        let routing = request.routing.as_ref();

        // Get pre-selected worker based on phase, with backend_instance_id as fallback
        let Some(id) = (match phase {
            RequestPhase::Prefill => {
                routing.and_then(|r| r.prefill_worker_id.or(r.backend_instance_id))
            }
            RequestPhase::Decode => {
                routing.and_then(|r| r.decode_worker_id.or(r.backend_instance_id))
            }
            RequestPhase::Aggregated => routing.and_then(|r| r.backend_instance_id),
        }) else {
            // No preselected worker - find the best match
            // Don't update states if this is a query-only request
            let (best_worker, overlap_amount) = self
                .chooser
                .find_best_match(
                    Some(context_id),
                    &request.token_ids,
                    request.router_config_override.as_ref(),
                    !is_query_only,
                )
                .await?;

            return Ok(WorkerSelection {
                instance_id: best_worker.worker_id,
                dp_rank: best_worker.dp_rank,
                overlap_amount,
            });
        };

        // Route to pre-selected or explicitly specified worker
        let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0);
        tracing::debug!(
            worker_id = id,
            dp_rank = dp_rank,
            ?phase,
            "Routing to specified worker"
        );

        // Compute actual overlap blocks by querying the indexer
        let worker = WorkerWithDpRank::new(id, dp_rank);
        let overlap_blocks = self
            .chooser
            .get_overlap_blocks(&request.token_ids, worker)
            .await?;

        // Perform add_request if this router handles local updates
        if !is_query_only && handle_local_updates {
            self.chooser
                .add_request(
                    context_id.to_string(),
                    &request.token_ids,
                    overlap_blocks,
                    worker,
                )
                .await;
        } else {
            tracing::debug!(
                request_id = %context_id,
                worker_id = id,
                dp_rank = dp_rank,
                "Skipping add_request - query or handled externally"
            );
        }

        Ok(WorkerSelection {
            instance_id: id,
            dp_rank,
            overlap_amount: overlap_blocks,
        })
    }
791
792
793
}

#[async_trait]
794
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
795
796
    for KvPushRouter
{
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
    /// Generate method that handles KV-aware routing with three distinct behaviors:
    ///
    /// 1. **If `query_instance_id` annotation is set**:
    ///    - Returns the best matching worker ID without routing the request
    ///    - Does NOT update any router local states
    ///    - Response includes worker_instance_id and token_data annotations
    ///
    /// 2. **If `backend_instance_id` is set in the request**:
    ///    - Routes directly to the specified backend instance
    ///    - DOES update router states to track this request (unless query_instance_id is also set)
    ///    - Bypasses the normal KV matching logic
    ///
    /// 3. **If neither are set (default behavior)**:
    ///    - Finds the best worker based on KV cache overlap
    ///    - Updates router states to track the request
    ///    - Routes to the selected worker
    ///
    /// The router state updates include tracking active sequences and managing
    /// prefill/completion lifecycle for proper KV cache management.
816
817
    async fn generate(
        &self,
818
        request: SingleIn<PreprocessedRequest>,
819
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
820
821
822
        // Extract context ID for request tracking
        let context_id = request.context().id().to_string();

823
824
825
        // Simple query-only detection: presence of query_instance_id annotation means query-only mode
        let is_query_only = request.get_annotation_value("query_instance_id").is_some();

826
        // Determine if this router should handle local state updates (add_request, free, etc.)
827
828
829
830
831
832
        // Default is true (router handles bookkeeping). Set to false for GAIE Stage 2 where
        // an external orchestrator (e.g., EPP sidecar) handles bookkeeping via C FFI.
        let handle_local_updates = request
            .routing
            .as_ref()
            .and_then(|r| r.enable_local_updates)
833
834
            .unwrap_or(true);

835
836
837
        // Get phase from tracker (defaults to Aggregated if no tracker or phase not set)
        let phase = request
            .tracker
838
            .as_ref()
839
840
841
842
            .map(|t| t.phase())
            .unwrap_or(RequestPhase::Aggregated);

        let block_size = self.chooser.block_size() as usize;
843
844
845
846
847
848
849
850
851
852
853
854
855
856
        let selection = self
            .select_worker(
                &context_id,
                &request,
                phase,
                is_query_only,
                handle_local_updates,
            )
            .await?;
        let WorkerSelection {
            instance_id,
            dp_rank,
            overlap_amount,
        } = selection;
857

858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
        // In approximate mode (use_kv_events=false), record the routing decision
        // so the indexer can track cache state based on routing decisions.
        // This covers both pre-selected workers and find_best_match selections.
        if !is_query_only && !self.chooser.kv_router_config.use_kv_events {
            let worker = WorkerWithDpRank::new(instance_id, dp_rank);
            let mut tokens_with_hashes =
                TokensWithHashes::new(request.token_ids.clone(), self.chooser.block_size);
            if let Err(e) = self
                .chooser
                .indexer
                .process_routing_decision_for_request(&mut tokens_with_hashes, worker)
                .await
            {
                tracing::warn!(
                    request_id = %context_id,
                    worker_id = instance_id,
                    dp_rank = dp_rank,
                    error = %e,
                    "Failed to record routing decision in approximate mode"
                );
            }
        }

881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
        // Record metrics in tracker: KV hit rate and worker ID based on phase
        if let Some(ref tracker) = request.tracker {
            let isl_blocks = request.token_ids.len().div_ceil(block_size);
            tracker.record_kv_hit(overlap_amount, isl_blocks);
            tracker.record_worker(instance_id);
        }

        // Handle query-only requests: early return with worker info
        if is_query_only {
            let stream_context = request.context().clone();
            // Tracker is always created for query-only requests (delta generator enables tracking
            // when query_instance_id annotation is present)
            let worker_id_info = request.tracker.as_ref().and_then(|t| t.get_worker_info());

            tracing::trace!(
                ?phase,
                worker_id = instance_id,
                ?worker_id_info,
                "Returning worker selection (query-only mode)"
            );

902
903
904
905
906
907
908
909
910
            let output = LLMEngineOutput {
                disaggregated_params: Some(json!({
                    "worker_id": worker_id_info,
                    "token_ids": request.token_ids
                })),
                ..Default::default()
            };
            let response = Annotated::from_data(output);
            let stream = stream::iter(vec![response]);
911
912
            return Ok(ResponseStream::new(Box::pin(stream), stream_context));
        }
913
914

        // Route to worker
915
        let (mut backend_input, context) = request.into_parts();
916
        backend_input.routing_mut().dp_rank = Some(dp_rank);
917
918
        let updated_request = context.map(|_| backend_input);

919
        let chooser = self.chooser.clone();
920
921
922
923
        let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
        let stream_context = response_stream.context();
        let context_for_monitoring = stream_context.clone();

924
925
926
        // Wrap stream with lifecycle management (mark_prefill_completed, free)
        // Only perform these operations if handle_local_updates is true.
        // When false, an external caller (e.g., GAIE sidecar) handles bookkeeping via C FFI.
927
928
929
930
931
932
933
934
935
936
        let wrapped_stream = Box::pin(async_stream::stream! {
            let mut prefill_marked = false;

            loop {
                tokio::select! {
                    biased;

                    _ = context_for_monitoring.stopped() => {
                        tracing::debug!("Request {context_id} cancelled, ending stream");
                        break;
937
                    }
Yan Ru Pei's avatar
Yan Ru Pei committed
938

939
                    item = response_stream.next() => {
940
                        let Some(item) = item else {
941
942
                            break;
                        };
943

944
                        if handle_local_updates && !prefill_marked {
945
946
947
948
949
950
951
952
953
954
                            // Only mark prefill completed when we receive actual tokens,
                            // not empty bootstrap info (token_ids: []) from disaggregated prefill
                            let has_tokens = item.data.as_ref()
                                .map(|d| !d.token_ids.is_empty())
                                .unwrap_or(false);
                            if has_tokens {
                                if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
                                    tracing::warn!("Failed to mark prefill completed for request {context_id}: {e}");
                                }
                                prefill_marked = true;
955
                            }
956
                        }
957

958
                        yield item;
959
                    }
960
961
                }
            }
962

963
964
965
966
967
            // Only call free() if we handle local updates.
            // When handle_local_updates=false, external caller handles cleanup via C FFI.
            if handle_local_updates
                && let Err(e) = chooser.free(&context_id).await
            {
968
                tracing::warn!("Failed to free request {context_id}: {e}");
969
            }
970
971
        });
        Ok(ResponseStream::new(wrapped_stream, stream_context))
972
973
    }
}
Yan Ru Pei's avatar
Yan Ru Pei committed
974
975
976
977
978
979
980

impl Drop for KvRouter {
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}