kv_router.rs 38.7 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
use std::sync::Arc;
6
use std::time::Duration;
7
8
#[cfg(feature = "bench")]
use std::time::Instant;
9

10
use anyhow::Result;
11
use derive_builder::Builder;
12
use dynamo_runtime::{
13
    component::{Client, Endpoint},
14
    discovery::{DiscoveryQuery, EventTransportKind},
15
    pipeline::{
16
17
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
        SingleIn, async_trait,
18
    },
19
    protocols::EndpointId,
20
    protocols::annotated::Annotated,
21
    traits::DistributedRuntimeProvider,
22
23
};
use futures::stream::{self, StreamExt};
24
use rand::Rng;
25
use serde::{Deserialize, Serialize};
26
use serde_json::json;
27
use validator::Validate;
28

29
30
31
32
33
// Re-export from dynamo-kv-router crate
pub use dynamo_kv_router::approx;
pub use dynamo_kv_router::indexer;
pub use dynamo_kv_router::protocols;

34
pub mod prefill_router;
35
pub mod publisher;
36
pub mod recorder;
37
pub mod scheduler;
38
pub mod sequence;
39
pub mod subscriber;
40
pub mod worker_query;
41

42
use indexer::WorkerKvQueryResponse;
43
pub use prefill_router::PrefillRouter;
44
use worker_query::WorkerQueryClient;
45

46
use crate::{
47
    discovery::RuntimeConfigs,
48
    kv_router::{
49
        approx::PruneConfig,
50
        indexer::{KvIndexer, KvIndexerInterface, KvRouterError},
Yan Ru Pei's avatar
Yan Ru Pei committed
51
        protocols::{
52
            DpRank, LocalBlockHash, OverlapScores, RouterEvent, RouterRequest, RouterResponse,
53
54
            TokensWithHashes, WorkerId, WorkerSelectionResult, WorkerWithDpRank,
            compute_block_hash_for_seq, compute_seq_hash_for_block,
Yan Ru Pei's avatar
Yan Ru Pei committed
55
        },
56
        scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
57
        sequence::SequenceError,
58
        subscriber::{start_kv_router_background, start_kv_router_background_event_plane},
59
    },
60
    local_model::runtime_config::ModelRuntimeConfig,
61
    preprocessor::PreprocessedRequest,
62
    protocols::common::llm_backend::LLMEngineOutput,
63
    protocols::common::timing::RequestPhase,
64
65
};

66
67
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
68
69
70
71
72

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
73
pub const KV_EVENT_SUBJECT: &str = "kv-events";
74
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
75
76
77
78
79
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
80

81
82
83
84
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

85
86
87
// for worker-local kvindexer query
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer

88
89
90
91
92
93
/// Generates a dp_rank-specific endpoint name for the worker KV indexer query service.
/// Each dp_rank has its own LocalKvIndexer and query endpoint to ensure per-dp_rank monotonicity.
pub fn worker_kv_indexer_query_endpoint(dp_rank: DpRank) -> String {
    format!("worker_kv_indexer_query_dp{dp_rank}")
}

94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
// for router discovery registration
pub const KV_ROUTER_COMPONENT: &str = "kv-router";
pub const KV_ROUTER_ENDPOINT: &str = "generate";

/// Creates an EndpointId for the KV router in the given namespace.
pub fn router_endpoint_id(namespace: String) -> EndpointId {
    EndpointId {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        name: KV_ROUTER_ENDPOINT.to_string(),
    }
}

/// Creates a DiscoveryQuery for the KV router in the given namespace.
pub fn router_discovery_query(namespace: String) -> DiscoveryQuery {
    DiscoveryQuery::Endpoint {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        endpoint: KV_ROUTER_ENDPOINT.to_string(),
    }
}

116
117
118
119
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
120
        workers: &HashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>,
121
        request: &SchedulingRequest,
122
        block_size: u32,
123
124
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
125

126
/// Override configuration for router settings that can be specified per-request
127
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize, Validate)]
128
129
130
131
132
pub struct RouterConfigOverride {
    #[builder(default)]
    pub overlap_score_weight: Option<f64>,

    #[builder(default)]
133
    #[validate(range(min = 0.0))]
134
135
136
    pub router_temperature: Option<f64>,
}

137
/// KV Router configuration parameters
138
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Validate)]
139
pub struct KvRouterConfig {
140
    #[validate(range(min = 0.0))]
141
142
    pub overlap_score_weight: f64,

143
    #[validate(range(min = 0.0))]
144
    pub router_temperature: f64,
145

146
147
    pub use_kv_events: bool,

148
149
    pub router_replica_sync: bool,

150
151
152
    /// Whether to track active blocks in the router (default: true)
    pub router_track_active_blocks: bool,

153
154
155
156
157
    /// Whether to track output blocks during generation (default: false)
    /// When enabled, the router adds placeholder blocks as tokens are generated
    /// and applies fractional decay based on progress toward expected_output_tokens.
    pub router_track_output_blocks: bool,

158
159
160
161
162
    /// Whether to assume KV cache reuse when tracking active blocks (default: true).
    /// When true, computes actual block hashes for sequence tracking.
    /// When false, generates random hashes (assuming no KV cache reuse).
    pub router_assume_kv_reuse: bool,

163
    /// Threshold for triggering snapshots. If None, no snapshots will be performed.
164
    #[validate(range(min = 1))]
165
166
    pub router_snapshot_threshold: Option<u32>,

167
    /// Whether to reset the router state on startup (default: false)
168
    pub router_reset_states: bool,
169
170

    /// TTL for blocks in seconds (only used when use_kv_events is false, default: 120.0)
171
    #[validate(range(min = 0.0))]
172
173
    pub router_ttl_secs: f64,

174
    /// Maximum tree size before pruning (only used when use_kv_events is false, default: 2^20 = 1048576)
175
    #[validate(range(min = 1))]
176
177
178
    pub router_max_tree_size: usize,

    /// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
179
    #[validate(range(min = 0.0, max = 1.0))]
180
    pub router_prune_target_ratio: f64,
181
182
183
184
185
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
186
            overlap_score_weight: 1.0,
187
            router_temperature: 0.0,
188
            use_kv_events: true,
189
            router_replica_sync: false,
190
            router_track_active_blocks: true,
191
            router_track_output_blocks: false,
192
            router_assume_kv_reuse: true,
193
            router_snapshot_threshold: Some(1000000),
194
            router_reset_states: false,
195
            router_ttl_secs: 120.0,
196
            router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default()
197
            router_prune_target_ratio: 0.8,
198
199
200
201
202
        }
    }
}

impl KvRouterConfig {
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    /// Compute sequence hashes for active block tracking based on configuration.
    ///
    /// Returns:
    /// - `None` if `router_track_active_blocks` is false
    /// - Random hashes if `router_track_active_blocks` is true but `router_assume_kv_reuse` is false
    /// - Actual sequence hashes if both are true
    pub fn compute_seq_hashes_for_tracking(
        &self,
        tokens: &[u32],
        block_size: u32,
    ) -> Option<Vec<u64>> {
        if !self.router_track_active_blocks {
            return None;
        }

        let num_blocks = tokens.len() / block_size as usize;
        if num_blocks == 0 {
            return Some(Vec::new());
        }

        if self.router_assume_kv_reuse {
            // Compute actual block hashes and sequence hashes
            let block_hashes = compute_block_hash_for_seq(tokens, block_size, None);
            Some(compute_seq_hash_for_block(&block_hashes))
        } else {
            // Generate random hashes (no KV reuse assumed)
            let mut rng = rand::rng();
            Some((0..num_blocks).map(|_| rng.random::<u64>()).collect())
        }
    }
233
234
}

235
pub enum Indexer {
236
237
    /// Updates itself based on KV events emitted by backend workers or routing decisions.
    /// Supports TTL-based expiration and size-based pruning.
238
    /// Has the ability to persist and snapshot states.
239
    KvIndexer(KvIndexer),
240
241
242
243

    /// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
    /// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
    None,
244
245
246
247
248
249
250
251
252
}

impl Indexer {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
253
254
255
            Indexer::None => Ok(OverlapScores {
                scores: HashMap::new(),
                frequencies: Vec::new(),
256
                tree_sizes: HashMap::new(),
257
            }),
258
259
        }
    }
260
261
262
263

    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.dump_events().await,
264
265
266
267
268
            Indexer::None => {
                panic!(
                    "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
                );
            }
269
270
        }
    }
271

272
    async fn process_routing_decision_for_request(
273
        &self,
274
        tokens_with_hashes: &mut TokensWithHashes,
275
276
277
278
279
        worker: WorkerWithDpRank,
    ) -> Result<(), KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => {
                indexer
280
                    .process_routing_decision_for_request(tokens_with_hashes, worker)
281
282
283
284
285
                    .await
            }
            Indexer::None => Ok(()),
        }
    }
286
287
}

288
289
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
290
pub struct KvRouter {
291
292
293
    indexer: Indexer,

    // How about a Box<dyn KvIndexerInterface>
294
    scheduler: KvScheduler,
295

296
    block_size: u32,
297
298

    kv_router_config: KvRouterConfig,
Yan Ru Pei's avatar
Yan Ru Pei committed
299
300

    cancellation_token: tokio_util::sync::CancellationToken,
301
302

    client: Client,
303
304

    worker_query_client: Option<WorkerQueryClient>,
305
306
307
}

impl KvRouter {
308
    #[allow(clippy::too_many_arguments)]
309
    pub async fn new(
310
311
        endpoint: Endpoint,
        client: Client,
312
        workers_with_configs: Arc<RuntimeConfigs>,
313
        block_size: u32,
314
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
315
        kv_router_config: Option<KvRouterConfig>,
316
        router_id: u64,
317
        worker_type: &'static str,
318
    ) -> Result<Self> {
319
        let kv_router_config = kv_router_config.unwrap_or_default();
320
        kv_router_config.validate()?;
321
        let component = endpoint.component();
322
        let cancellation_token = component.drt().primary_token();
323

324
325
326
        let indexer = if kv_router_config.overlap_score_weight == 0.0 {
            // When overlap_score_weight is zero, we don't need to track prefixes
            Indexer::None
327
        } else {
328
            let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);
329
330
331
332
333
334
335
336
337
338
339
340
341

            // If use_kv_events is false, enable TTL and pruning for approximate behavior
            let prune_config = if !kv_router_config.use_kv_events {
                Some(PruneConfig {
                    ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
                    max_tree_size: kv_router_config.router_max_tree_size,
                    prune_target_ratio: kv_router_config.router_prune_target_ratio,
                })
            } else {
                None
            };

            Indexer::KvIndexer(KvIndexer::new_with_frequency(
342
                cancellation_token.clone(),
343
                None, // expiration_duration for frequency tracking
344
345
                block_size,
                kv_indexer_metrics,
346
                prune_config,
347
348
            ))
        };
349

350
351
352
        // Wait for at least one worker with a known runtime config before starting scheduler
        workers_with_configs.subscribe().wait_for_some().await;

353
        let scheduler = KvScheduler::start(
354
            component.clone(),
355
            block_size,
356
            workers_with_configs.clone(),
357
            selector,
358
            kv_router_config.router_replica_sync,
359
            router_id,
360
            worker_type,
361
362
        )
        .await?;
363

364
        // Initialize worker query client using namespace abstraction
365
        // (for query/recovery API methods - no lifecycle tracking needed)
366
367
368
369
        // Uses a subscriber from workers_with_configs
        let worker_query_client = worker_query::WorkerQueryClient::new(
            component.clone(),
            workers_with_configs.subscribe(),
370
            None, // No removal channel - query only
371
        );
372
373
        tracing::info!("Worker query client initialized");

374
375
376
377
        // Start KV event subscriber background process (only when use_kv_events is enabled)
        if kv_router_config.use_kv_events
            && let Indexer::KvIndexer(ref kv_indexer) = indexer
        {
378
379
380
381
382
383
            let all_local_indexer = workers_with_configs
                .configs
                .iter()
                .filter_map(|r| r.value().as_ref().map(|c| c.enable_local_indexer))
                .all(|b| b);

384
385
386
387
            tracing::info!(
                "Found {} worker(s), starting KV event subscriber",
                workers_with_configs.num_workers()
            );
388

389
390
            let transport_kind = EventTransportKind::from_env_or_default();

391
            // Start subscriber - setup runs synchronously, then spawns background loop internally
392
            if all_local_indexer {
393
394
395
396
397
398
399
400
401
402
                if transport_kind == EventTransportKind::Zmq {
                    if kv_router_config.router_snapshot_threshold.is_some()
                        || kv_router_config.router_reset_states
                    {
                        tracing::warn!(
                            "ZMQ event plane does not support KV snapshots or state reset; ignoring snapshot/reset settings"
                        );
                    }
                } else {
                    tracing::info!(
403
404
                        "All {} workers have local_indexer enabled, using NATS Core subscription",
                        workers_with_configs.num_workers()
405
406
407
408
                    );
                }

                start_kv_router_background_event_plane(
409
410
411
412
413
                    component.clone(),
                    kv_indexer.event_sender(),
                    cancellation_token.clone(),
                    worker_query::WorkerQueryClient::new(
                        component.clone(),
414
                        workers_with_configs.subscribe(),
415
                        Some(kv_indexer.remove_worker_sender()),
416
                    ),
417
                    transport_kind,
418
419
                )
                .await?;
420
            } else {
421
422
423
424
425
                if transport_kind == EventTransportKind::Zmq {
                    tracing::warn!(
                        "Not all workers have local_indexer enabled; falling back to JetStream for durability"
                    );
                }
426
427
428
                tracing::info!(
                    "Not all workers have local_indexer enabled, using JetStream subscription"
                );
429

430
431
                // Convert router_id to string for NATS consumer naming
                let consumer_id = router_id.to_string();
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
                start_kv_router_background(
                    component.clone(),
                    consumer_id,
                    kv_indexer.event_sender(),
                    kv_indexer.remove_worker_sender(),
                    kv_router_config
                        .router_snapshot_threshold
                        .map(|_| kv_indexer.get_workers_sender()),
                    kv_router_config
                        .router_snapshot_threshold
                        .map(|_| kv_indexer.snapshot_event_sender()),
                    cancellation_token.clone(),
                    kv_router_config.router_snapshot_threshold,
                    kv_router_config.router_reset_states,
                )
                .await?;
448
            }
449
        }
450

451
        tracing::info!("KV Routing initialized");
452
        Ok(Self {
453
            indexer,
454
            scheduler,
455
            block_size,
456
            kv_router_config,
Yan Ru Pei's avatar
Yan Ru Pei committed
457
            cancellation_token,
458
            client,
459
            worker_query_client: Some(worker_query_client),
460
        })
461
462
    }

463
464
465
466
467
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

468
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
469
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
Yan Ru Pei's avatar
Yan Ru Pei committed
470
    /// Now also takes optional context_id for request tracking
471
    #[allow(clippy::too_many_arguments)]
Yan Ru Pei's avatar
Yan Ru Pei committed
472
    pub async fn find_best_match(
473
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
474
        context_id: Option<&str>,
475
        tokens: &[u32],
476
        router_config_override: Option<&RouterConfigOverride>,
477
        update_states: bool,
478
        lora_name: Option<String>,
Yan Ru Pei's avatar
Yan Ru Pei committed
479
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
480
481
482
        #[cfg(feature = "bench")]
        let start = Instant::now();

Yan Ru Pei's avatar
Yan Ru Pei committed
483
        if update_states && context_id.is_none() {
484
            anyhow::bail!("context_id must be provided when update_states is true");
Yan Ru Pei's avatar
Yan Ru Pei committed
485
486
        }

487
        let isl_tokens = tokens.len();
488

489
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
490
491
        #[cfg(feature = "bench")]
        let hash_elapsed = start.elapsed();
492
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;
493
494
        #[cfg(feature = "bench")]
        let find_matches_elapsed = start.elapsed();
495

496
497
498
        // Compute seq_hashes only if scheduler needs it for active blocks tracking
        let maybe_seq_hashes = self
            .kv_router_config
499
            .compute_seq_hashes_for_tracking(tokens, self.block_size);
500

Yan Ru Pei's avatar
Yan Ru Pei committed
501
        let best_worker = self
502
            .scheduler
503
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
504
                context_id.map(|s| s.to_string()),
505
                isl_tokens,
506
                maybe_seq_hashes,
507
                overlap_scores.clone(),
508
                router_config_override,
509
                update_states,
510
                lora_name,
511
            )
512
            .await?;
513

514
515
516
517
518
519
520
521
522
523
524
525
526
        #[cfg(feature = "bench")]
        {
            let total_elapsed = start.elapsed();
            tracing::info!(
                isl_tokens,
                hash_us = hash_elapsed.as_micros() as u64,
                find_matches_us = (find_matches_elapsed - hash_elapsed).as_micros() as u64,
                schedule_us = (total_elapsed - find_matches_elapsed).as_micros() as u64,
                total_us = total_elapsed.as_micros() as u64,
                "find_best_match completed"
            );
        }

527
528
        // Note: Routing decision recording (for approximate mode) is now handled
        // by KvPushRouter::generate after select_worker returns.
529

530
531
        let overlap_amount = overlap_scores
            .scores
Yan Ru Pei's avatar
Yan Ru Pei committed
532
            .get(&best_worker)
533
534
            .copied()
            .unwrap_or(0);
Yan Ru Pei's avatar
Yan Ru Pei committed
535
        Ok((best_worker, overlap_amount))
536
537
    }

538
    #[allow(clippy::too_many_arguments)]
539
540
541
542
543
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
        overlap_blocks: u32,
544
        expected_output_tokens: Option<u32>,
Yan Ru Pei's avatar
Yan Ru Pei committed
545
        worker: WorkerWithDpRank,
546
        lora_name: Option<String>,
547
548
    ) {
        let isl_tokens = tokens.len();
549

550
551
552
        let maybe_seq_hashes = self
            .kv_router_config
            .compute_seq_hashes_for_tracking(tokens, self.block_size);
553

554
555
        if let Err(e) = self
            .scheduler
556
            .add_request(
557
                request_id.clone(),
558
                maybe_seq_hashes,
559
560
                isl_tokens,
                overlap_blocks,
561
                expected_output_tokens,
Yan Ru Pei's avatar
Yan Ru Pei committed
562
                worker,
563
                lora_name,
564
            )
565
566
567
568
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
569
570
    }

571
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
572
        self.scheduler.mark_prefill_completed(request_id).await
573
574
    }

575
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
576
        self.scheduler.free(request_id).await
577
    }
578

579
580
581
582
583
584
    /// Get the worker type for this router ("prefill" or "decode").
    /// Used for Prometheus metric labeling.
    pub fn worker_type(&self) -> &'static str {
        self.scheduler.worker_type()
    }

585
586
587
588
589
590
591
592
593
594
    pub async fn add_output_block(
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
        self.scheduler
            .add_output_block(request_id, decay_fraction)
            .await
    }

595
    pub fn block_size(&self) -> u32 {
596
597
        self.block_size
    }
598

599
600
601
602
603
604
605
606
607
608
609
610
    /// Compute the overlap blocks for a given token sequence and worker.
    /// This queries the indexer to find how many blocks are already cached.
    pub async fn get_overlap_blocks(
        &self,
        tokens: &[u32],
        worker: WorkerWithDpRank,
    ) -> Result<u32, KvRouterError> {
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;
        Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0))
    }

611
612
613
    /// Get potential prefill and decode loads for all workers
    pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
        let isl_tokens = tokens.len();
614
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
615
        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
616

617
618
        let maybe_seq_hashes = self
            .kv_router_config
619
            .compute_seq_hashes_for_tracking(tokens, self.block_size);
620

621
622
        Ok(self
            .scheduler
623
            .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
624
625
626
            .await)
    }

627
628
629
630
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
631
632
633
634
635
636

    /// Query a specific worker's local KV indexer for its events
    /// (See docstring for `WorkerQueryClient.query_worker()`)
    pub async fn query_worker_local_kv(
        &self,
        worker_id: WorkerId,
637
        dp_rank: DpRank,
638
639
640
641
642
643
644
645
646
        start_event_id: Option<u64>,
        end_event_id: Option<u64>,
    ) -> Result<WorkerKvQueryResponse> {
        let query_client = self
            .worker_query_client
            .as_ref()
            .ok_or_else(|| anyhow::anyhow!("Worker query client not available (NATS required)"))?;

        query_client
647
            .query_worker(worker_id, dp_rank, start_event_id, end_event_id)
648
649
650
            .await
    }

651
    /// Recover missed KV events from a specific worker's dp_rank.
652
653
654
655
656
657
658
    ///
    /// Queries the worker's local KV indexer for events starting from
    /// `start_event_id` and applies them to the router's indexer.
    ///
    /// # Arguments
    ///
    /// * `worker_id` - The worker to recover from
659
    /// * `dp_rank` - The data parallel rank to recover from
660
661
662
663
664
    /// * `start_event_id` - First event ID to fetch (inclusive), or None to start from beginning
    /// * `end_event_id` - Last event ID to fetch (inclusive), or None for all
    pub async fn recover_from_worker(
        &self,
        worker_id: WorkerId,
665
        dp_rank: DpRank,
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
        start_event_id: Option<u64>,
        end_event_id: Option<u64>,
    ) -> Result<usize> {
        let query_client = self
            .worker_query_client
            .as_ref()
            .ok_or_else(|| anyhow::anyhow!("Worker query client not available"))?;

        let event_tx = match &self.indexer {
            Indexer::KvIndexer(kv_indexer) => kv_indexer.event_sender(),
            Indexer::None => {
                anyhow::bail!("Cannot recover: indexer is disabled (--overlap_score_weight is 0)")
            }
        };

681
682
683
        query_client
            .recover_from_worker(worker_id, dp_rank, start_event_id, end_event_id, &event_tx)
            .await
684
    }
685
686
}

Michael Feil's avatar
Michael Feil committed
687
688
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
689
690
691
692
693
694
695
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
696
697
698
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
699
            RouterRequest::New { tokens } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
700
                let (best_worker, overlap_blocks) = self
701
                    .find_best_match(Some(&context_id), &tokens, None, true, None)
Michael Feil's avatar
Michael Feil committed
702
703
704
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
705
706
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
707
708
709
                    overlap_blocks,
                }
            }
710
711
712
713
714
715
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
            RouterRequest::MarkFree => RouterResponse::FreeMarked {
                success: self.free(&context_id).await.is_ok(),
            },
Michael Feil's avatar
Michael Feil committed
716
        };
717
718
719
720
721
722

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
723
724

pub struct KvPushRouter {
725
    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
726
    pub chooser: Arc<KvRouter>,
727
728
}

729
730
731
732
733
734
735
/// Result of worker selection containing instance ID, dp_rank, and overlap amount.
struct WorkerSelection {
    instance_id: u64,
    dp_rank: u32,
    overlap_amount: u32,
}

736
737
impl KvPushRouter {
    pub fn new(
738
        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
739
740
741
742
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
743
744
745
746
747
748
749
750
751
752
753
754
755
756

    /// Select a worker for the request, either using a preselected worker or finding the best match.
    ///
    /// When `is_query_only` is false and `handle_local_updates` is true, this also registers
    /// the request with the scheduler via `add_request`.
    async fn select_worker(
        &self,
        context_id: &str,
        request: &PreprocessedRequest,
        phase: RequestPhase,
        is_query_only: bool,
        handle_local_updates: bool,
    ) -> Result<WorkerSelection, Error> {
        let routing = request.routing.as_ref();
757
        let lora_name = routing.and_then(|r| r.lora_name.clone());
758
759
        let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0);
        let expected_output_tokens = routing.and_then(|r| r.expected_output_tokens);
760

761
        // Get pre-selected worker based on phase, with backend_instance_id as fallback
762
        let preselected_id = match phase {
763
764
765
766
767
768
769
            RequestPhase::Prefill => {
                routing.and_then(|r| r.prefill_worker_id.or(r.backend_instance_id))
            }
            RequestPhase::Decode => {
                routing.and_then(|r| r.decode_worker_id.or(r.backend_instance_id))
            }
            RequestPhase::Aggregated => routing.and_then(|r| r.backend_instance_id),
770
771
772
        };

        let Some(id) = preselected_id else {
773
774
775
776
777
778
779
            let (best_worker, overlap_amount) = self
                .chooser
                .find_best_match(
                    Some(context_id),
                    &request.token_ids,
                    request.router_config_override.as_ref(),
                    !is_query_only,
780
                    lora_name,
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
                )
                .await?;

            return Ok(WorkerSelection {
                instance_id: best_worker.worker_id,
                dp_rank: best_worker.dp_rank,
                overlap_amount,
            });
        };

        tracing::debug!(
            worker_id = id,
            dp_rank = dp_rank,
            ?phase,
            "Routing to specified worker"
        );

        let worker = WorkerWithDpRank::new(id, dp_rank);
        let overlap_blocks = self
            .chooser
            .get_overlap_blocks(&request.token_ids, worker)
            .await?;

        if !is_query_only && handle_local_updates {
            self.chooser
                .add_request(
                    context_id.to_string(),
                    &request.token_ids,
                    overlap_blocks,
810
                    expected_output_tokens,
811
                    worker,
812
                    lora_name,
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
                )
                .await;
        } else {
            tracing::debug!(
                request_id = %context_id,
                worker_id = id,
                dp_rank = dp_rank,
                "Skipping add_request - query or handled externally"
            );
        }

        Ok(WorkerSelection {
            instance_id: id,
            dp_rank,
            overlap_amount: overlap_blocks,
        })
    }
830
831
832
}

#[async_trait]
833
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
834
835
    for KvPushRouter
{
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
    /// Generate method that handles KV-aware routing with three distinct behaviors:
    ///
    /// 1. **If `query_instance_id` annotation is set**:
    ///    - Returns the best matching worker ID without routing the request
    ///    - Does NOT update any router local states
    ///    - Response includes worker_instance_id and token_data annotations
    ///
    /// 2. **If `backend_instance_id` is set in the request**:
    ///    - Routes directly to the specified backend instance
    ///    - DOES update router states to track this request (unless query_instance_id is also set)
    ///    - Bypasses the normal KV matching logic
    ///
    /// 3. **If neither are set (default behavior)**:
    ///    - Finds the best worker based on KV cache overlap
    ///    - Updates router states to track the request
    ///    - Routes to the selected worker
    ///
    /// The router state updates include tracking active sequences and managing
    /// prefill/completion lifecycle for proper KV cache management.
855
856
    async fn generate(
        &self,
857
        request: SingleIn<PreprocessedRequest>,
858
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
859
860
861
        // Extract context ID for request tracking
        let context_id = request.context().id().to_string();

862
863
864
        // Simple query-only detection: presence of query_instance_id annotation means query-only mode
        let is_query_only = request.get_annotation_value("query_instance_id").is_some();

865
        // Determine if this router should handle local state updates (add_request, free, etc.)
866
867
868
869
870
871
        // Default is true (router handles bookkeeping). Set to false for GAIE Stage 2 where
        // an external orchestrator (e.g., EPP sidecar) handles bookkeeping via C FFI.
        let handle_local_updates = request
            .routing
            .as_ref()
            .and_then(|r| r.enable_local_updates)
872
873
            .unwrap_or(true);

874
875
876
        // Get phase from tracker (defaults to Aggregated if no tracker or phase not set)
        let phase = request
            .tracker
877
            .as_ref()
878
879
880
881
            .map(|t| t.phase())
            .unwrap_or(RequestPhase::Aggregated);

        let block_size = self.chooser.block_size() as usize;
882
883
884
885
886
887
888
889
890
891
892
893
894
895
        let selection = self
            .select_worker(
                &context_id,
                &request,
                phase,
                is_query_only,
                handle_local_updates,
            )
            .await?;
        let WorkerSelection {
            instance_id,
            dp_rank,
            overlap_amount,
        } = selection;
896

897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
        // In approximate mode (use_kv_events=false), record the routing decision
        // so the indexer can track cache state based on routing decisions.
        // This covers both pre-selected workers and find_best_match selections.
        if !is_query_only && !self.chooser.kv_router_config.use_kv_events {
            let worker = WorkerWithDpRank::new(instance_id, dp_rank);
            let mut tokens_with_hashes =
                TokensWithHashes::new(request.token_ids.clone(), self.chooser.block_size);
            if let Err(e) = self
                .chooser
                .indexer
                .process_routing_decision_for_request(&mut tokens_with_hashes, worker)
                .await
            {
                tracing::warn!(
                    request_id = %context_id,
                    worker_id = instance_id,
                    dp_rank = dp_rank,
                    error = %e,
                    "Failed to record routing decision in approximate mode"
                );
            }
        }

920
921
922
        // Record metrics in tracker: KV hit rate, worker ID, and worker type based on phase.
        // Worker type is stored at routing time to avoid expensive MDC lookups when
        // updating Prometheus metrics (TTFT/ITL) later in the response stream.
923
924
925
        if let Some(ref tracker) = request.tracker {
            let isl_blocks = request.token_ids.len().div_ceil(block_size);
            tracker.record_kv_hit(overlap_amount, isl_blocks);
926
            tracker.record_worker_full(instance_id, dp_rank, self.chooser.worker_type());
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
        }

        // Handle query-only requests: early return with worker info
        if is_query_only {
            let stream_context = request.context().clone();
            // Tracker is always created for query-only requests (delta generator enables tracking
            // when query_instance_id annotation is present)
            let worker_id_info = request.tracker.as_ref().and_then(|t| t.get_worker_info());

            tracing::trace!(
                ?phase,
                worker_id = instance_id,
                ?worker_id_info,
                "Returning worker selection (query-only mode)"
            );

943
944
945
946
947
948
949
950
951
            let output = LLMEngineOutput {
                disaggregated_params: Some(json!({
                    "worker_id": worker_id_info,
                    "token_ids": request.token_ids
                })),
                ..Default::default()
            };
            let response = Annotated::from_data(output);
            let stream = stream::iter(vec![response]);
952
953
            return Ok(ResponseStream::new(Box::pin(stream), stream_context));
        }
954
955

        // Route to worker
956
957
958
959
960
961
962
963
        let isl_tokens = request.token_ids.len();
        let expected_output_tokens = request
            .routing
            .as_ref()
            .and_then(|r| r.expected_output_tokens);
        let track_output_blocks =
            self.chooser.kv_router_config.router_track_output_blocks && handle_local_updates;

964
        let (mut backend_input, context) = request.into_parts();
965
        backend_input.routing_mut().dp_rank = Some(dp_rank);
966
967
        let updated_request = context.map(|_| backend_input);

968
        let chooser = self.chooser.clone();
969
970
971
972
        let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
        let stream_context = response_stream.context();
        let context_for_monitoring = stream_context.clone();

973
974
975
        // Wrap stream with lifecycle management (mark_prefill_completed, free)
        // Only perform these operations if handle_local_updates is true.
        // When false, an external caller (e.g., GAIE sidecar) handles bookkeeping via C FFI.
976
977
978
        let wrapped_stream = Box::pin(async_stream::stream! {
            let mut prefill_marked = false;

979
980
981
982
            // Output block tracking state
            let mut cumulative_osl: usize = 0;
            let mut current_total_blocks = isl_tokens.div_ceil(block_size);

983
984
985
986
987
988
989
            loop {
                tokio::select! {
                    biased;

                    _ = context_for_monitoring.stopped() => {
                        tracing::debug!("Request {context_id} cancelled, ending stream");
                        break;
990
                    }
Yan Ru Pei's avatar
Yan Ru Pei committed
991

992
                    item = response_stream.next() => {
993
                        let Some(item) = item else {
994
995
                            break;
                        };
996

997
                        if handle_local_updates && !prefill_marked {
998
999
1000
1001
1002
1003
1004
1005
1006
1007
                            // Only mark prefill completed when we receive actual tokens,
                            // not empty bootstrap info (token_ids: []) from disaggregated prefill
                            let has_tokens = item.data.as_ref()
                                .map(|d| !d.token_ids.is_empty())
                                .unwrap_or(false);
                            if has_tokens {
                                if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
                                    tracing::warn!("Failed to mark prefill completed for request {context_id}: {e}");
                                }
                                prefill_marked = true;
1008
                            }
1009
                        }
1010

1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
                        // Track output blocks if enabled
                        if track_output_blocks {
                            let new_tokens = item.data.as_ref()
                                .map(|d| d.token_ids.len())
                                .unwrap_or(0);
                            cumulative_osl += new_tokens;

                            let new_total_blocks = (isl_tokens + cumulative_osl).div_ceil(block_size);
                            if new_total_blocks > current_total_blocks {
                                // New block boundary crossed - add output block with decay
                                // Clamp eot to min 1 to avoid division by zero, and result to min 0.0
                                let decay_fraction = expected_output_tokens.map(|eot| {
                                    (1.0 - (cumulative_osl as f64 / eot.max(1) as f64)).max(0.0)
                                });
                                if let Err(e) = chooser.add_output_block(&context_id, decay_fraction).await {
                                    tracing::warn!(
                                        "Failed to add output block for request {context_id}: {e}"
                                    );
                                }
                                current_total_blocks = new_total_blocks;
                            }
                        }

1034
                        yield item;
1035
                    }
1036
1037
                }
            }
1038

1039
1040
1041
1042
1043
            // Only call free() if we handle local updates.
            // When handle_local_updates=false, external caller handles cleanup via C FFI.
            if handle_local_updates
                && let Err(e) = chooser.free(&context_id).await
            {
1044
                tracing::warn!("Failed to free request {context_id}: {e}");
1045
            }
1046
1047
        });
        Ok(ResponseStream::new(wrapped_stream, stream_context))
1048
1049
    }
}
Yan Ru Pei's avatar
Yan Ru Pei committed
1050
1051
1052
1053
1054
1055
1056

impl Drop for KvRouter {
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}