kv_router.rs 38.8 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
use std::sync::Arc;
6
use std::time::Duration;
7

8
use anyhow::Result;
9
use derive_builder::Builder;
10
use dynamo_runtime::{
11
    component::{Client, Endpoint},
12
    discovery::{DiscoveryQuery, EventTransportKind},
13
    pipeline::{
14
15
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
        SingleIn, async_trait,
16
    },
17
    protocols::EndpointId,
18
    protocols::annotated::Annotated,
19
    traits::DistributedRuntimeProvider,
20
21
};
use futures::stream::{self, StreamExt};
22
use rand::Rng;
23
use serde::{Deserialize, Serialize};
24
use serde_json::json;
25

26
27
28
29
30
// Re-export from dynamo-kv-router crate
pub use dynamo_kv_router::approx;
pub use dynamo_kv_router::indexer;
pub use dynamo_kv_router::protocols;

31
pub mod prefill_router;
32
pub mod publisher;
33
pub mod recorder;
34
pub mod scheduler;
35
pub mod sequence;
36
pub mod subscriber;
37
pub mod worker_query;
38

39
use indexer::WorkerKvQueryResponse;
40
pub use prefill_router::PrefillRouter;
41
use worker_query::WorkerQueryClient;
42

43
use crate::{
44
    discovery::RuntimeConfigs,
45
    kv_router::{
46
        approx::PruneConfig,
47
        indexer::{KvIndexer, KvIndexerInterface, KvRouterError},
Yan Ru Pei's avatar
Yan Ru Pei committed
48
        protocols::{
49
50
51
            LocalBlockHash, OverlapScores, RouterEvent, RouterRequest, RouterResponse,
            TokensWithHashes, WorkerId, WorkerSelectionResult, WorkerWithDpRank,
            compute_block_hash_for_seq, compute_seq_hash_for_block,
Yan Ru Pei's avatar
Yan Ru Pei committed
52
        },
53
        scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
54
        sequence::SequenceError,
55
        subscriber::{start_kv_router_background, start_kv_router_background_event_plane},
56
    },
57
    local_model::runtime_config::ModelRuntimeConfig,
58
    preprocessor::PreprocessedRequest,
59
    protocols::common::llm_backend::LLMEngineOutput,
60
    protocols::common::timing::RequestPhase,
61
62
};

63
64
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
65
66
67
68
69

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
70
pub const KV_EVENT_SUBJECT: &str = "kv-events";
71
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
72
73
74
75
76
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
77

78
79
80
81
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

82
// for worker-local kvindexer query
83
pub const WORKER_KV_INDEXER_QUERY_ENDPOINT: &str = "worker_kv_indexer_query";
84
85
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
// for router discovery registration
pub const KV_ROUTER_COMPONENT: &str = "kv-router";
pub const KV_ROUTER_ENDPOINT: &str = "generate";

/// Creates an EndpointId for the KV router in the given namespace.
pub fn router_endpoint_id(namespace: String) -> EndpointId {
    EndpointId {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        name: KV_ROUTER_ENDPOINT.to_string(),
    }
}

/// Creates a DiscoveryQuery for the KV router in the given namespace.
pub fn router_discovery_query(namespace: String) -> DiscoveryQuery {
    DiscoveryQuery::Endpoint {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        endpoint: KV_ROUTER_ENDPOINT.to_string(),
    }
}

108
109
110
111
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
112
        workers: &HashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>,
113
        request: &SchedulingRequest,
114
        block_size: u32,
115
116
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
117

118
119
120
121
122
123
124
125
126
127
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize)]
pub struct RouterConfigOverride {
    #[builder(default)]
    pub overlap_score_weight: Option<f64>,

    #[builder(default)]
    pub router_temperature: Option<f64>,
}

128
/// KV Router configuration parameters
129
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
130
131
132
pub struct KvRouterConfig {
    pub overlap_score_weight: f64,

133
    pub router_temperature: f64,
134

135
136
    pub use_kv_events: bool,

137
138
    pub router_replica_sync: bool,

139
140
141
    /// Whether to track active blocks in the router (default: true)
    pub router_track_active_blocks: bool,

142
143
144
145
146
    /// Whether to track output blocks during generation (default: false)
    /// When enabled, the router adds placeholder blocks as tokens are generated
    /// and applies fractional decay based on progress toward expected_output_tokens.
    pub router_track_output_blocks: bool,

147
148
149
150
151
    /// Whether to assume KV cache reuse when tracking active blocks (default: true).
    /// When true, computes actual block hashes for sequence tracking.
    /// When false, generates random hashes (assuming no KV cache reuse).
    pub router_assume_kv_reuse: bool,

152
153
154
    /// Threshold for triggering snapshots. If None, no snapshots will be performed.
    pub router_snapshot_threshold: Option<u32>,

155
    /// Whether to reset the router state on startup (default: false)
156
    pub router_reset_states: bool,
157
158
159
160

    /// TTL for blocks in seconds (only used when use_kv_events is false, default: 120.0)
    pub router_ttl_secs: f64,

161
    /// Maximum tree size before pruning (only used when use_kv_events is false, default: 2^20 = 1048576)
162
163
164
165
    pub router_max_tree_size: usize,

    /// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
    pub router_prune_target_ratio: f64,
166
167
168
169
170
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
171
            overlap_score_weight: 1.0,
172
            router_temperature: 0.0,
173
            use_kv_events: true,
174
            router_replica_sync: false,
175
            router_track_active_blocks: true,
176
            router_track_output_blocks: false,
177
            router_assume_kv_reuse: true,
178
            router_snapshot_threshold: Some(1000000),
179
            router_reset_states: false,
180
            router_ttl_secs: 120.0,
181
            router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default()
182
            router_prune_target_ratio: 0.8,
183
184
185
186
187
188
189
        }
    }
}

impl KvRouterConfig {
    /// Create a new KvRouterConfig with optional weight values.
    /// If a weight is None, the default value will be used.
190
    #[allow(clippy::too_many_arguments)]
191
192
    pub fn new(
        overlap_score_weight: Option<f64>,
193
        temperature: Option<f64>,
194
        use_kv_events: Option<bool>,
195
        replica_sync: Option<bool>,
196
        track_active_blocks: Option<bool>,
197
        track_output_blocks: Option<bool>,
198
        assume_kv_reuse: Option<bool>,
199
200
        router_snapshot_threshold: Option<Option<u32>>,
        router_reset_states: Option<bool>,
201
202
203
        router_ttl_secs: Option<f64>,
        router_max_tree_size: Option<usize>,
        router_prune_target_ratio: Option<f64>,
204
205
206
207
    ) -> Self {
        let default = Self::default();
        Self {
            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
208
            router_temperature: temperature.unwrap_or(default.router_temperature),
209
            use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
210
            router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
211
212
            router_track_active_blocks: track_active_blocks
                .unwrap_or(default.router_track_active_blocks),
213
214
            router_track_output_blocks: track_output_blocks
                .unwrap_or(default.router_track_output_blocks),
215
            router_assume_kv_reuse: assume_kv_reuse.unwrap_or(default.router_assume_kv_reuse),
216
217
218
            router_snapshot_threshold: router_snapshot_threshold
                .unwrap_or(default.router_snapshot_threshold),
            router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
219
220
221
222
            router_ttl_secs: router_ttl_secs.unwrap_or(default.router_ttl_secs),
            router_max_tree_size: router_max_tree_size.unwrap_or(default.router_max_tree_size),
            router_prune_target_ratio: router_prune_target_ratio
                .unwrap_or(default.router_prune_target_ratio),
223
224
        }
    }
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255

    /// Compute sequence hashes for active block tracking based on configuration.
    ///
    /// Returns:
    /// - `None` if `router_track_active_blocks` is false
    /// - Random hashes if `router_track_active_blocks` is true but `router_assume_kv_reuse` is false
    /// - Actual sequence hashes if both are true
    pub fn compute_seq_hashes_for_tracking(
        &self,
        tokens: &[u32],
        block_size: u32,
    ) -> Option<Vec<u64>> {
        if !self.router_track_active_blocks {
            return None;
        }

        let num_blocks = tokens.len() / block_size as usize;
        if num_blocks == 0 {
            return Some(Vec::new());
        }

        if self.router_assume_kv_reuse {
            // Compute actual block hashes and sequence hashes
            let block_hashes = compute_block_hash_for_seq(tokens, block_size, None);
            Some(compute_seq_hash_for_block(&block_hashes))
        } else {
            // Generate random hashes (no KV reuse assumed)
            let mut rng = rand::rng();
            Some((0..num_blocks).map(|_| rng.random::<u64>()).collect())
        }
    }
256
257
}

258
pub enum Indexer {
259
260
    /// Updates itself based on KV events emitted by backend workers or routing decisions.
    /// Supports TTL-based expiration and size-based pruning.
261
    /// Has the ability to persist and snapshot states.
262
    KvIndexer(KvIndexer),
263
264
265
266

    /// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
    /// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
    None,
267
268
269
270
271
272
273
274
275
}

impl Indexer {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
276
277
278
            Indexer::None => Ok(OverlapScores {
                scores: HashMap::new(),
                frequencies: Vec::new(),
279
                tree_sizes: HashMap::new(),
280
            }),
281
282
        }
    }
283
284
285
286

    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.dump_events().await,
287
288
289
290
291
            Indexer::None => {
                panic!(
                    "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
                );
            }
292
293
        }
    }
294

295
    async fn process_routing_decision_for_request(
296
        &self,
297
        tokens_with_hashes: &mut TokensWithHashes,
298
299
300
301
302
        worker: WorkerWithDpRank,
    ) -> Result<(), KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => {
                indexer
303
                    .process_routing_decision_for_request(tokens_with_hashes, worker)
304
305
306
307
308
                    .await
            }
            Indexer::None => Ok(()),
        }
    }
309
310
}

311
312
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
313
pub struct KvRouter {
314
315
316
    indexer: Indexer,

    // How about a Box<dyn KvIndexerInterface>
317
    scheduler: KvScheduler,
318

319
    block_size: u32,
320
321

    kv_router_config: KvRouterConfig,
Yan Ru Pei's avatar
Yan Ru Pei committed
322
323

    cancellation_token: tokio_util::sync::CancellationToken,
324
325

    client: Client,
326
327

    worker_query_client: Option<WorkerQueryClient>,
328
329
330
331
}

impl KvRouter {
    pub async fn new(
332
333
        endpoint: Endpoint,
        client: Client,
334
        workers_with_configs: Arc<RuntimeConfigs>,
335
        block_size: u32,
336
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
337
        kv_router_config: Option<KvRouterConfig>,
338
        router_id: u64,
339
    ) -> Result<Self> {
340
        let kv_router_config = kv_router_config.unwrap_or_default();
341
        let component = endpoint.component();
342
        let cancellation_token = component.drt().primary_token();
343

344
345
346
        let indexer = if kv_router_config.overlap_score_weight == 0.0 {
            // When overlap_score_weight is zero, we don't need to track prefixes
            Indexer::None
347
        } else {
348
            let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);
349
350
351
352
353
354
355
356
357
358
359
360
361

            // If use_kv_events is false, enable TTL and pruning for approximate behavior
            let prune_config = if !kv_router_config.use_kv_events {
                Some(PruneConfig {
                    ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
                    max_tree_size: kv_router_config.router_max_tree_size,
                    prune_target_ratio: kv_router_config.router_prune_target_ratio,
                })
            } else {
                None
            };

            Indexer::KvIndexer(KvIndexer::new_with_frequency(
362
                cancellation_token.clone(),
363
                None, // expiration_duration for frequency tracking
364
365
                block_size,
                kv_indexer_metrics,
366
                prune_config,
367
368
            ))
        };
369

370
371
372
        // Wait for at least one worker with a known runtime config before starting scheduler
        workers_with_configs.subscribe().wait_for_some().await;

373
        let scheduler = KvScheduler::start(
374
            component.clone(),
375
            block_size,
376
            workers_with_configs.clone(),
377
            selector,
378
            kv_router_config.router_replica_sync,
379
            router_id,
380
381
        )
        .await?;
382

383
384
        // Initialize worker query client using namespace abstraction
        // (created before background task so we can use it for startup recovery)
385
386
387
388
389
        // Uses a subscriber from workers_with_configs
        let worker_query_client = worker_query::WorkerQueryClient::new(
            component.clone(),
            workers_with_configs.subscribe(),
        );
390
391
        tracing::info!("Worker query client initialized");

392
393
394
395
        // Start KV event subscriber background process (only when use_kv_events is enabled)
        if kv_router_config.use_kv_events
            && let Indexer::KvIndexer(ref kv_indexer) = indexer
        {
396
397
398
399
400
401
            let all_local_indexer = workers_with_configs
                .configs
                .iter()
                .filter_map(|r| r.value().as_ref().map(|c| c.enable_local_indexer))
                .all(|b| b);

402
403
404
405
            tracing::info!(
                "Found {} worker(s), starting KV event subscriber",
                workers_with_configs.num_workers()
            );
406

407
408
            let transport_kind = EventTransportKind::from_env_or_default();

409
            // Start subscriber - setup runs synchronously, then spawns background loop internally
410
            if all_local_indexer {
411
412
413
414
415
416
417
418
419
420
                if transport_kind == EventTransportKind::Zmq {
                    if kv_router_config.router_snapshot_threshold.is_some()
                        || kv_router_config.router_reset_states
                    {
                        tracing::warn!(
                            "ZMQ event plane does not support KV snapshots or state reset; ignoring snapshot/reset settings"
                        );
                    }
                } else {
                    tracing::info!(
421
422
                        "All {} workers have local_indexer enabled, using NATS Core subscription",
                        workers_with_configs.num_workers()
423
424
425
426
                    );
                }

                start_kv_router_background_event_plane(
427
428
429
430
431
432
                    component.clone(),
                    kv_indexer.event_sender(),
                    kv_indexer.remove_worker_sender(),
                    cancellation_token.clone(),
                    worker_query::WorkerQueryClient::new(
                        component.clone(),
433
                        workers_with_configs.subscribe(),
434
                    ),
435
                    transport_kind,
436
437
                )
                .await?;
438
            } else {
439
440
441
442
443
                if transport_kind == EventTransportKind::Zmq {
                    tracing::warn!(
                        "Not all workers have local_indexer enabled; falling back to JetStream for durability"
                    );
                }
444
445
446
                tracing::info!(
                    "Not all workers have local_indexer enabled, using JetStream subscription"
                );
447

448
449
                // Convert router_id to string for NATS consumer naming
                let consumer_id = router_id.to_string();
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
                start_kv_router_background(
                    component.clone(),
                    consumer_id,
                    kv_indexer.event_sender(),
                    kv_indexer.remove_worker_sender(),
                    kv_router_config
                        .router_snapshot_threshold
                        .map(|_| kv_indexer.get_workers_sender()),
                    kv_router_config
                        .router_snapshot_threshold
                        .map(|_| kv_indexer.snapshot_event_sender()),
                    cancellation_token.clone(),
                    kv_router_config.router_snapshot_threshold,
                    kv_router_config.router_reset_states,
                )
                .await?;
466
            }
467
        }
468

469
        tracing::info!("KV Routing initialized");
470
        Ok(Self {
471
            indexer,
472
            scheduler,
473
            block_size,
474
            kv_router_config,
Yan Ru Pei's avatar
Yan Ru Pei committed
475
            cancellation_token,
476
            client,
477
            worker_query_client: Some(worker_query_client),
478
        })
479
480
    }

481
482
483
484
485
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

486
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
487
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
Yan Ru Pei's avatar
Yan Ru Pei committed
488
489
    /// Now also takes optional context_id for request tracking
    pub async fn find_best_match(
490
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
491
        context_id: Option<&str>,
492
        tokens: &[u32],
493
        router_config_override: Option<&RouterConfigOverride>,
494
        update_states: bool,
Yan Ru Pei's avatar
Yan Ru Pei committed
495
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
Yan Ru Pei's avatar
Yan Ru Pei committed
496
497
498
499
500
        // Validate that context_id is provided when update_states is true
        if update_states && context_id.is_none() {
            panic!("context_id must be provided if update_states is true");
        }

501
        let isl_tokens = tokens.len();
502

503
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
504
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;
505

506
507
508
        // Compute seq_hashes only if scheduler needs it for active blocks tracking
        let maybe_seq_hashes = self
            .kv_router_config
509
            .compute_seq_hashes_for_tracking(tokens, self.block_size);
510

Yan Ru Pei's avatar
Yan Ru Pei committed
511
        let best_worker = self
512
            .scheduler
513
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
514
                context_id.map(|s| s.to_string()),
515
                isl_tokens,
516
                maybe_seq_hashes,
517
                overlap_scores.clone(),
518
                router_config_override,
519
                update_states,
520
            )
521
            .await?;
522

523
524
        // Note: Routing decision recording (for approximate mode) is now handled
        // by KvPushRouter::generate after select_worker returns.
525

526
527
        let overlap_amount = overlap_scores
            .scores
Yan Ru Pei's avatar
Yan Ru Pei committed
528
            .get(&best_worker)
529
530
            .copied()
            .unwrap_or(0);
Yan Ru Pei's avatar
Yan Ru Pei committed
531
        Ok((best_worker, overlap_amount))
532
533
    }

534
535
536
537
538
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
        overlap_blocks: u32,
539
        expected_output_tokens: Option<u32>,
Yan Ru Pei's avatar
Yan Ru Pei committed
540
        worker: WorkerWithDpRank,
541
542
    ) {
        let isl_tokens = tokens.len();
543

544
545
546
        let maybe_seq_hashes = self
            .kv_router_config
            .compute_seq_hashes_for_tracking(tokens, self.block_size);
547

548
549
        if let Err(e) = self
            .scheduler
550
            .add_request(
551
                request_id.clone(),
552
                maybe_seq_hashes,
553
554
                isl_tokens,
                overlap_blocks,
555
                expected_output_tokens,
Yan Ru Pei's avatar
Yan Ru Pei committed
556
                worker,
557
            )
558
559
560
561
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
562
563
    }

564
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
565
        self.scheduler.mark_prefill_completed(request_id).await
566
567
    }

568
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
569
        self.scheduler.free(request_id).await
570
    }
571

572
573
574
575
576
577
578
579
580
581
    pub async fn add_output_block(
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
        self.scheduler
            .add_output_block(request_id, decay_fraction)
            .await
    }

582
    pub fn block_size(&self) -> u32 {
583
584
        self.block_size
    }
585

586
587
588
589
590
591
592
593
594
595
596
597
    /// Compute the overlap blocks for a given token sequence and worker.
    /// This queries the indexer to find how many blocks are already cached.
    pub async fn get_overlap_blocks(
        &self,
        tokens: &[u32],
        worker: WorkerWithDpRank,
    ) -> Result<u32, KvRouterError> {
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;
        Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0))
    }

598
599
600
    /// Get potential prefill and decode loads for all workers
    pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
        let isl_tokens = tokens.len();
601
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
602
        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
603

604
605
        let maybe_seq_hashes = self
            .kv_router_config
606
            .compute_seq_hashes_for_tracking(tokens, self.block_size);
607

608
609
        Ok(self
            .scheduler
610
            .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
611
612
613
            .await)
    }

614
615
616
617
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673

    /// Query a specific worker's local KV indexer for its events
    /// (See docstring for `WorkerQueryClient.query_worker()`)
    pub async fn query_worker_local_kv(
        &self,
        worker_id: WorkerId,
        start_event_id: Option<u64>,
        end_event_id: Option<u64>,
    ) -> Result<WorkerKvQueryResponse> {
        let query_client = self
            .worker_query_client
            .as_ref()
            .ok_or_else(|| anyhow::anyhow!("Worker query client not available (NATS required)"))?;

        query_client
            .query_worker(worker_id, start_event_id, end_event_id)
            .await
    }

    /// Recover missed KV events from a specific worker.
    ///
    /// Queries the worker's local KV indexer for events starting from
    /// `start_event_id` and applies them to the router's indexer.
    ///
    /// # Arguments
    ///
    /// * `worker_id` - The worker to recover from
    /// * `start_event_id` - First event ID to fetch (inclusive), or None to start from beginning
    /// * `end_event_id` - Last event ID to fetch (inclusive), or None for all
    pub async fn recover_from_worker(
        &self,
        worker_id: WorkerId,
        start_event_id: Option<u64>,
        end_event_id: Option<u64>,
    ) -> Result<usize> {
        let query_client = self
            .worker_query_client
            .as_ref()
            .ok_or_else(|| anyhow::anyhow!("Worker query client not available"))?;

        let event_tx = match &self.indexer {
            Indexer::KvIndexer(kv_indexer) => kv_indexer.event_sender(),
            Indexer::None => {
                anyhow::bail!("Cannot recover: indexer is disabled (--overlap_score_weight is 0)")
            }
        };

        subscriber::recover_from_worker(
            query_client,
            worker_id,
            start_event_id,
            end_event_id,
            &event_tx,
        )
        .await
    }
674
675
}

Michael Feil's avatar
Michael Feil committed
676
677
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
678
679
680
681
682
683
684
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
685
686
687
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
688
            RouterRequest::New { tokens } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
689
                let (best_worker, overlap_blocks) = self
Yan Ru Pei's avatar
Yan Ru Pei committed
690
                    .find_best_match(Some(&context_id), &tokens, None, true)
Michael Feil's avatar
Michael Feil committed
691
692
693
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
694
695
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
696
697
698
                    overlap_blocks,
                }
            }
699
700
701
702
703
704
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
            RouterRequest::MarkFree => RouterResponse::FreeMarked {
                success: self.free(&context_id).await.is_ok(),
            },
Michael Feil's avatar
Michael Feil committed
705
        };
706
707
708
709
710
711

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
712
713

pub struct KvPushRouter {
714
    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
715
    pub chooser: Arc<KvRouter>,
716
717
}

718
719
720
721
722
723
724
/// Result of worker selection containing instance ID, dp_rank, and overlap amount.
struct WorkerSelection {
    instance_id: u64,
    dp_rank: u32,
    overlap_amount: u32,
}

725
726
impl KvPushRouter {
    pub fn new(
727
        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
728
729
730
731
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791

    /// Select a worker for the request, either using a preselected worker or finding the best match.
    ///
    /// When `is_query_only` is false and `handle_local_updates` is true, this also registers
    /// the request with the scheduler via `add_request`.
    async fn select_worker(
        &self,
        context_id: &str,
        request: &PreprocessedRequest,
        phase: RequestPhase,
        is_query_only: bool,
        handle_local_updates: bool,
    ) -> Result<WorkerSelection, Error> {
        let routing = request.routing.as_ref();

        // Get pre-selected worker based on phase, with backend_instance_id as fallback
        let Some(id) = (match phase {
            RequestPhase::Prefill => {
                routing.and_then(|r| r.prefill_worker_id.or(r.backend_instance_id))
            }
            RequestPhase::Decode => {
                routing.and_then(|r| r.decode_worker_id.or(r.backend_instance_id))
            }
            RequestPhase::Aggregated => routing.and_then(|r| r.backend_instance_id),
        }) else {
            // No preselected worker - find the best match
            // Don't update states if this is a query-only request
            let (best_worker, overlap_amount) = self
                .chooser
                .find_best_match(
                    Some(context_id),
                    &request.token_ids,
                    request.router_config_override.as_ref(),
                    !is_query_only,
                )
                .await?;

            return Ok(WorkerSelection {
                instance_id: best_worker.worker_id,
                dp_rank: best_worker.dp_rank,
                overlap_amount,
            });
        };

        // Route to pre-selected or explicitly specified worker
        let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0);
        tracing::debug!(
            worker_id = id,
            dp_rank = dp_rank,
            ?phase,
            "Routing to specified worker"
        );

        // Compute actual overlap blocks by querying the indexer
        let worker = WorkerWithDpRank::new(id, dp_rank);
        let overlap_blocks = self
            .chooser
            .get_overlap_blocks(&request.token_ids, worker)
            .await?;

792
793
794
795
796
797
        // Extract expected_output_tokens from routing hints
        let expected_output_tokens = request
            .routing
            .as_ref()
            .and_then(|r| r.expected_output_tokens);

798
799
800
801
802
803
804
        // Perform add_request if this router handles local updates
        if !is_query_only && handle_local_updates {
            self.chooser
                .add_request(
                    context_id.to_string(),
                    &request.token_ids,
                    overlap_blocks,
805
                    expected_output_tokens,
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
                    worker,
                )
                .await;
        } else {
            tracing::debug!(
                request_id = %context_id,
                worker_id = id,
                dp_rank = dp_rank,
                "Skipping add_request - query or handled externally"
            );
        }

        Ok(WorkerSelection {
            instance_id: id,
            dp_rank,
            overlap_amount: overlap_blocks,
        })
    }
824
825
826
}

#[async_trait]
827
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
828
829
    for KvPushRouter
{
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
    /// Generate method that handles KV-aware routing with three distinct behaviors:
    ///
    /// 1. **If `query_instance_id` annotation is set**:
    ///    - Returns the best matching worker ID without routing the request
    ///    - Does NOT update any router local states
    ///    - Response includes worker_instance_id and token_data annotations
    ///
    /// 2. **If `backend_instance_id` is set in the request**:
    ///    - Routes directly to the specified backend instance
    ///    - DOES update router states to track this request (unless query_instance_id is also set)
    ///    - Bypasses the normal KV matching logic
    ///
    /// 3. **If neither are set (default behavior)**:
    ///    - Finds the best worker based on KV cache overlap
    ///    - Updates router states to track the request
    ///    - Routes to the selected worker
    ///
    /// The router state updates include tracking active sequences and managing
    /// prefill/completion lifecycle for proper KV cache management.
849
850
    async fn generate(
        &self,
851
        request: SingleIn<PreprocessedRequest>,
852
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
853
854
855
        // Extract context ID for request tracking
        let context_id = request.context().id().to_string();

856
857
858
        // Simple query-only detection: presence of query_instance_id annotation means query-only mode
        let is_query_only = request.get_annotation_value("query_instance_id").is_some();

859
        // Determine if this router should handle local state updates (add_request, free, etc.)
860
861
862
863
864
865
        // Default is true (router handles bookkeeping). Set to false for GAIE Stage 2 where
        // an external orchestrator (e.g., EPP sidecar) handles bookkeeping via C FFI.
        let handle_local_updates = request
            .routing
            .as_ref()
            .and_then(|r| r.enable_local_updates)
866
867
            .unwrap_or(true);

868
869
870
        // Get phase from tracker (defaults to Aggregated if no tracker or phase not set)
        let phase = request
            .tracker
871
            .as_ref()
872
873
874
875
            .map(|t| t.phase())
            .unwrap_or(RequestPhase::Aggregated);

        let block_size = self.chooser.block_size() as usize;
876
877
878
879
880
881
882
883
884
885
886
887
888
889
        let selection = self
            .select_worker(
                &context_id,
                &request,
                phase,
                is_query_only,
                handle_local_updates,
            )
            .await?;
        let WorkerSelection {
            instance_id,
            dp_rank,
            overlap_amount,
        } = selection;
890

891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
        // In approximate mode (use_kv_events=false), record the routing decision
        // so the indexer can track cache state based on routing decisions.
        // This covers both pre-selected workers and find_best_match selections.
        if !is_query_only && !self.chooser.kv_router_config.use_kv_events {
            let worker = WorkerWithDpRank::new(instance_id, dp_rank);
            let mut tokens_with_hashes =
                TokensWithHashes::new(request.token_ids.clone(), self.chooser.block_size);
            if let Err(e) = self
                .chooser
                .indexer
                .process_routing_decision_for_request(&mut tokens_with_hashes, worker)
                .await
            {
                tracing::warn!(
                    request_id = %context_id,
                    worker_id = instance_id,
                    dp_rank = dp_rank,
                    error = %e,
                    "Failed to record routing decision in approximate mode"
                );
            }
        }

914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
        // Record metrics in tracker: KV hit rate and worker ID based on phase
        if let Some(ref tracker) = request.tracker {
            let isl_blocks = request.token_ids.len().div_ceil(block_size);
            tracker.record_kv_hit(overlap_amount, isl_blocks);
            tracker.record_worker(instance_id);
        }

        // Handle query-only requests: early return with worker info
        if is_query_only {
            let stream_context = request.context().clone();
            // Tracker is always created for query-only requests (delta generator enables tracking
            // when query_instance_id annotation is present)
            let worker_id_info = request.tracker.as_ref().and_then(|t| t.get_worker_info());

            tracing::trace!(
                ?phase,
                worker_id = instance_id,
                ?worker_id_info,
                "Returning worker selection (query-only mode)"
            );

935
936
937
938
939
940
941
942
943
            let output = LLMEngineOutput {
                disaggregated_params: Some(json!({
                    "worker_id": worker_id_info,
                    "token_ids": request.token_ids
                })),
                ..Default::default()
            };
            let response = Annotated::from_data(output);
            let stream = stream::iter(vec![response]);
944
945
            return Ok(ResponseStream::new(Box::pin(stream), stream_context));
        }
946
947

        // Route to worker
948
949
950
951
952
953
954
955
        let isl_tokens = request.token_ids.len();
        let expected_output_tokens = request
            .routing
            .as_ref()
            .and_then(|r| r.expected_output_tokens);
        let track_output_blocks =
            self.chooser.kv_router_config.router_track_output_blocks && handle_local_updates;

956
        let (mut backend_input, context) = request.into_parts();
957
        backend_input.routing_mut().dp_rank = Some(dp_rank);
958
959
        let updated_request = context.map(|_| backend_input);

960
        let chooser = self.chooser.clone();
961
962
963
964
        let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
        let stream_context = response_stream.context();
        let context_for_monitoring = stream_context.clone();

965
966
967
        // Wrap stream with lifecycle management (mark_prefill_completed, free)
        // Only perform these operations if handle_local_updates is true.
        // When false, an external caller (e.g., GAIE sidecar) handles bookkeeping via C FFI.
968
969
970
        let wrapped_stream = Box::pin(async_stream::stream! {
            let mut prefill_marked = false;

971
972
973
974
            // Output block tracking state
            let mut cumulative_osl: usize = 0;
            let mut current_total_blocks = isl_tokens.div_ceil(block_size);

975
976
977
978
979
980
981
            loop {
                tokio::select! {
                    biased;

                    _ = context_for_monitoring.stopped() => {
                        tracing::debug!("Request {context_id} cancelled, ending stream");
                        break;
982
                    }
Yan Ru Pei's avatar
Yan Ru Pei committed
983

984
                    item = response_stream.next() => {
985
                        let Some(item) = item else {
986
987
                            break;
                        };
988

989
                        if handle_local_updates && !prefill_marked {
990
991
992
993
994
995
996
997
998
999
                            // Only mark prefill completed when we receive actual tokens,
                            // not empty bootstrap info (token_ids: []) from disaggregated prefill
                            let has_tokens = item.data.as_ref()
                                .map(|d| !d.token_ids.is_empty())
                                .unwrap_or(false);
                            if has_tokens {
                                if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
                                    tracing::warn!("Failed to mark prefill completed for request {context_id}: {e}");
                                }
                                prefill_marked = true;
1000
                            }
1001
                        }
1002

1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
                        // Track output blocks if enabled
                        if track_output_blocks {
                            let new_tokens = item.data.as_ref()
                                .map(|d| d.token_ids.len())
                                .unwrap_or(0);
                            cumulative_osl += new_tokens;

                            let new_total_blocks = (isl_tokens + cumulative_osl).div_ceil(block_size);
                            if new_total_blocks > current_total_blocks {
                                // New block boundary crossed - add output block with decay
                                // Clamp eot to min 1 to avoid division by zero, and result to min 0.0
                                let decay_fraction = expected_output_tokens.map(|eot| {
                                    (1.0 - (cumulative_osl as f64 / eot.max(1) as f64)).max(0.0)
                                });
                                if let Err(e) = chooser.add_output_block(&context_id, decay_fraction).await {
                                    tracing::warn!(
                                        "Failed to add output block for request {context_id}: {e}"
                                    );
                                }
                                current_total_blocks = new_total_blocks;
                            }
                        }

1026
                        yield item;
1027
                    }
1028
1029
                }
            }
1030

1031
1032
1033
1034
1035
            // Only call free() if we handle local updates.
            // When handle_local_updates=false, external caller handles cleanup via C FFI.
            if handle_local_updates
                && let Err(e) = chooser.free(&context_id).await
            {
1036
                tracing::warn!("Failed to free request {context_id}: {e}");
1037
            }
1038
1039
        });
        Ok(ResponseStream::new(wrapped_stream, stream_context))
1040
1041
    }
}
Yan Ru Pei's avatar
Yan Ru Pei committed
1042
1043
1044
1045
1046
1047
1048

impl Drop for KvRouter {
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}