kv_router.rs 40 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
use std::sync::Arc;
6
use std::time::Duration;
7

8
use anyhow::Result;
9
use derive_builder::Builder;
10
use dynamo_runtime::{
11
    component::{Client, Endpoint},
12
    discovery::{DiscoveryQuery, EventTransportKind},
13
    pipeline::{
14
15
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
        SingleIn, async_trait,
16
    },
17
    protocols::EndpointId,
18
    protocols::annotated::Annotated,
19
    traits::DistributedRuntimeProvider,
20
21
};
use futures::stream::{self, StreamExt};
22
use rand::Rng;
23
use serde::{Deserialize, Serialize};
24
use serde_json::json;
25

26
27
28
29
30
// Re-export from dynamo-kv-router crate
pub use dynamo_kv_router::approx;
pub use dynamo_kv_router::indexer;
pub use dynamo_kv_router::protocols;

31
pub mod prefill_router;
32
pub mod publisher;
33
pub mod recorder;
34
pub mod scheduler;
35
pub mod sequence;
36
pub mod subscriber;
37
pub mod worker_query;
38

39
use indexer::WorkerKvQueryResponse;
40
pub use prefill_router::PrefillRouter;
41
use worker_query::WorkerQueryClient;
42

43
use crate::{
44
    discovery::RuntimeConfigs,
45
    kv_router::{
46
        approx::PruneConfig,
47
        indexer::{KvIndexer, KvIndexerInterface, KvRouterError},
Yan Ru Pei's avatar
Yan Ru Pei committed
48
        protocols::{
49
            DpRank, LocalBlockHash, OverlapScores, RouterEvent, RouterRequest, RouterResponse,
50
51
            TokensWithHashes, WorkerId, WorkerSelectionResult, WorkerWithDpRank,
            compute_block_hash_for_seq, compute_seq_hash_for_block,
Yan Ru Pei's avatar
Yan Ru Pei committed
52
        },
53
        scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
54
        sequence::SequenceError,
55
        subscriber::{start_kv_router_background, start_kv_router_background_event_plane},
56
    },
57
    local_model::runtime_config::ModelRuntimeConfig,
58
    preprocessor::PreprocessedRequest,
59
    protocols::common::llm_backend::LLMEngineOutput,
60
    protocols::common::timing::RequestPhase,
61
62
};

63
64
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
65
66
67
68
69

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
70
pub const KV_EVENT_SUBJECT: &str = "kv-events";
71
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
72
73
74
75
76
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
77

78
79
80
81
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

82
83
84
// for worker-local kvindexer query
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer

85
86
87
88
89
90
/// Generates a dp_rank-specific endpoint name for the worker KV indexer query service.
/// Each dp_rank has its own LocalKvIndexer and query endpoint to ensure per-dp_rank monotonicity.
pub fn worker_kv_indexer_query_endpoint(dp_rank: DpRank) -> String {
    format!("worker_kv_indexer_query_dp{dp_rank}")
}

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
// for router discovery registration
pub const KV_ROUTER_COMPONENT: &str = "kv-router";
pub const KV_ROUTER_ENDPOINT: &str = "generate";

/// Creates an EndpointId for the KV router in the given namespace.
pub fn router_endpoint_id(namespace: String) -> EndpointId {
    EndpointId {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        name: KV_ROUTER_ENDPOINT.to_string(),
    }
}

/// Creates a DiscoveryQuery for the KV router in the given namespace.
pub fn router_discovery_query(namespace: String) -> DiscoveryQuery {
    DiscoveryQuery::Endpoint {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        endpoint: KV_ROUTER_ENDPOINT.to_string(),
    }
}

113
114
115
116
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
117
        workers: &HashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>,
118
        request: &SchedulingRequest,
119
        block_size: u32,
120
121
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
122

123
124
125
126
127
128
129
130
131
132
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize)]
pub struct RouterConfigOverride {
    #[builder(default)]
    pub overlap_score_weight: Option<f64>,

    #[builder(default)]
    pub router_temperature: Option<f64>,
}

133
/// KV Router configuration parameters
134
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
135
136
137
pub struct KvRouterConfig {
    pub overlap_score_weight: f64,

138
    pub router_temperature: f64,
139

140
141
    pub use_kv_events: bool,

142
143
    pub router_replica_sync: bool,

144
145
146
    /// Whether to track active blocks in the router (default: true)
    pub router_track_active_blocks: bool,

147
148
149
150
151
    /// Whether to track output blocks during generation (default: false)
    /// When enabled, the router adds placeholder blocks as tokens are generated
    /// and applies fractional decay based on progress toward expected_output_tokens.
    pub router_track_output_blocks: bool,

152
153
154
155
156
    /// Whether to assume KV cache reuse when tracking active blocks (default: true).
    /// When true, computes actual block hashes for sequence tracking.
    /// When false, generates random hashes (assuming no KV cache reuse).
    pub router_assume_kv_reuse: bool,

157
158
159
    /// Threshold for triggering snapshots. If None, no snapshots will be performed.
    pub router_snapshot_threshold: Option<u32>,

160
    /// Whether to reset the router state on startup (default: false)
161
    pub router_reset_states: bool,
162
163
164
165

    /// TTL for blocks in seconds (only used when use_kv_events is false, default: 120.0)
    pub router_ttl_secs: f64,

166
    /// Maximum tree size before pruning (only used when use_kv_events is false, default: 2^20 = 1048576)
167
168
169
170
    pub router_max_tree_size: usize,

    /// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
    pub router_prune_target_ratio: f64,
171
172
173
174
175
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
176
            overlap_score_weight: 1.0,
177
            router_temperature: 0.0,
178
            use_kv_events: true,
179
            router_replica_sync: false,
180
            router_track_active_blocks: true,
181
            router_track_output_blocks: false,
182
            router_assume_kv_reuse: true,
183
            router_snapshot_threshold: Some(1000000),
184
            router_reset_states: false,
185
            router_ttl_secs: 120.0,
186
            router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default()
187
            router_prune_target_ratio: 0.8,
188
189
190
191
192
193
194
        }
    }
}

impl KvRouterConfig {
    /// Create a new KvRouterConfig with optional weight values.
    /// If a weight is None, the default value will be used.
195
    #[allow(clippy::too_many_arguments)]
196
197
    pub fn new(
        overlap_score_weight: Option<f64>,
198
        temperature: Option<f64>,
199
        use_kv_events: Option<bool>,
200
        replica_sync: Option<bool>,
201
        track_active_blocks: Option<bool>,
202
        track_output_blocks: Option<bool>,
203
        assume_kv_reuse: Option<bool>,
204
205
        router_snapshot_threshold: Option<Option<u32>>,
        router_reset_states: Option<bool>,
206
207
208
        router_ttl_secs: Option<f64>,
        router_max_tree_size: Option<usize>,
        router_prune_target_ratio: Option<f64>,
209
210
211
212
    ) -> Self {
        let default = Self::default();
        Self {
            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
213
            router_temperature: temperature.unwrap_or(default.router_temperature),
214
            use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
215
            router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
216
217
            router_track_active_blocks: track_active_blocks
                .unwrap_or(default.router_track_active_blocks),
218
219
            router_track_output_blocks: track_output_blocks
                .unwrap_or(default.router_track_output_blocks),
220
            router_assume_kv_reuse: assume_kv_reuse.unwrap_or(default.router_assume_kv_reuse),
221
222
223
            router_snapshot_threshold: router_snapshot_threshold
                .unwrap_or(default.router_snapshot_threshold),
            router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
224
225
226
227
            router_ttl_secs: router_ttl_secs.unwrap_or(default.router_ttl_secs),
            router_max_tree_size: router_max_tree_size.unwrap_or(default.router_max_tree_size),
            router_prune_target_ratio: router_prune_target_ratio
                .unwrap_or(default.router_prune_target_ratio),
228
229
        }
    }
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260

    /// Compute sequence hashes for active block tracking based on configuration.
    ///
    /// Returns:
    /// - `None` if `router_track_active_blocks` is false
    /// - Random hashes if `router_track_active_blocks` is true but `router_assume_kv_reuse` is false
    /// - Actual sequence hashes if both are true
    pub fn compute_seq_hashes_for_tracking(
        &self,
        tokens: &[u32],
        block_size: u32,
    ) -> Option<Vec<u64>> {
        if !self.router_track_active_blocks {
            return None;
        }

        let num_blocks = tokens.len() / block_size as usize;
        if num_blocks == 0 {
            return Some(Vec::new());
        }

        if self.router_assume_kv_reuse {
            // Compute actual block hashes and sequence hashes
            let block_hashes = compute_block_hash_for_seq(tokens, block_size, None);
            Some(compute_seq_hash_for_block(&block_hashes))
        } else {
            // Generate random hashes (no KV reuse assumed)
            let mut rng = rand::rng();
            Some((0..num_blocks).map(|_| rng.random::<u64>()).collect())
        }
    }
261
262
}

263
pub enum Indexer {
264
265
    /// Updates itself based on KV events emitted by backend workers or routing decisions.
    /// Supports TTL-based expiration and size-based pruning.
266
    /// Has the ability to persist and snapshot states.
267
    KvIndexer(KvIndexer),
268
269
270
271

    /// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
    /// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
    None,
272
273
274
275
276
277
278
279
280
}

impl Indexer {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
281
282
283
            Indexer::None => Ok(OverlapScores {
                scores: HashMap::new(),
                frequencies: Vec::new(),
284
                tree_sizes: HashMap::new(),
285
            }),
286
287
        }
    }
288
289
290
291

    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.dump_events().await,
292
293
294
295
296
            Indexer::None => {
                panic!(
                    "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
                );
            }
297
298
        }
    }
299

300
    async fn process_routing_decision_for_request(
301
        &self,
302
        tokens_with_hashes: &mut TokensWithHashes,
303
304
305
306
307
        worker: WorkerWithDpRank,
    ) -> Result<(), KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => {
                indexer
308
                    .process_routing_decision_for_request(tokens_with_hashes, worker)
309
310
311
312
313
                    .await
            }
            Indexer::None => Ok(()),
        }
    }
314
315
}

316
317
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
318
pub struct KvRouter {
319
320
321
    indexer: Indexer,

    // How about a Box<dyn KvIndexerInterface>
322
    scheduler: KvScheduler,
323

324
    block_size: u32,
325
326

    kv_router_config: KvRouterConfig,
Yan Ru Pei's avatar
Yan Ru Pei committed
327
328

    cancellation_token: tokio_util::sync::CancellationToken,
329
330

    client: Client,
331
332

    worker_query_client: Option<WorkerQueryClient>,
333
334
335
}

impl KvRouter {
336
    #[allow(clippy::too_many_arguments)]
337
    pub async fn new(
338
339
        endpoint: Endpoint,
        client: Client,
340
        workers_with_configs: Arc<RuntimeConfigs>,
341
        block_size: u32,
342
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
343
        kv_router_config: Option<KvRouterConfig>,
344
        router_id: u64,
345
        worker_type: &'static str,
346
    ) -> Result<Self> {
347
        let kv_router_config = kv_router_config.unwrap_or_default();
348
        let component = endpoint.component();
349
        let cancellation_token = component.drt().primary_token();
350

351
352
353
        let indexer = if kv_router_config.overlap_score_weight == 0.0 {
            // When overlap_score_weight is zero, we don't need to track prefixes
            Indexer::None
354
        } else {
355
            let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);
356
357
358
359
360
361
362
363
364
365
366
367
368

            // If use_kv_events is false, enable TTL and pruning for approximate behavior
            let prune_config = if !kv_router_config.use_kv_events {
                Some(PruneConfig {
                    ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
                    max_tree_size: kv_router_config.router_max_tree_size,
                    prune_target_ratio: kv_router_config.router_prune_target_ratio,
                })
            } else {
                None
            };

            Indexer::KvIndexer(KvIndexer::new_with_frequency(
369
                cancellation_token.clone(),
370
                None, // expiration_duration for frequency tracking
371
372
                block_size,
                kv_indexer_metrics,
373
                prune_config,
374
375
            ))
        };
376

377
378
379
        // Wait for at least one worker with a known runtime config before starting scheduler
        workers_with_configs.subscribe().wait_for_some().await;

380
        let scheduler = KvScheduler::start(
381
            component.clone(),
382
            block_size,
383
            workers_with_configs.clone(),
384
            selector,
385
            kv_router_config.router_replica_sync,
386
            router_id,
387
            worker_type,
388
389
        )
        .await?;
390

391
        // Initialize worker query client using namespace abstraction
392
        // (for query/recovery API methods - no lifecycle tracking needed)
393
394
395
396
        // Uses a subscriber from workers_with_configs
        let worker_query_client = worker_query::WorkerQueryClient::new(
            component.clone(),
            workers_with_configs.subscribe(),
397
            None, // No removal channel - query only
398
        );
399
400
        tracing::info!("Worker query client initialized");

401
402
403
404
        // Start KV event subscriber background process (only when use_kv_events is enabled)
        if kv_router_config.use_kv_events
            && let Indexer::KvIndexer(ref kv_indexer) = indexer
        {
405
406
407
408
409
410
            let all_local_indexer = workers_with_configs
                .configs
                .iter()
                .filter_map(|r| r.value().as_ref().map(|c| c.enable_local_indexer))
                .all(|b| b);

411
412
413
414
            tracing::info!(
                "Found {} worker(s), starting KV event subscriber",
                workers_with_configs.num_workers()
            );
415

416
417
            let transport_kind = EventTransportKind::from_env_or_default();

418
            // Start subscriber - setup runs synchronously, then spawns background loop internally
419
            if all_local_indexer {
420
421
422
423
424
425
426
427
428
429
                if transport_kind == EventTransportKind::Zmq {
                    if kv_router_config.router_snapshot_threshold.is_some()
                        || kv_router_config.router_reset_states
                    {
                        tracing::warn!(
                            "ZMQ event plane does not support KV snapshots or state reset; ignoring snapshot/reset settings"
                        );
                    }
                } else {
                    tracing::info!(
430
431
                        "All {} workers have local_indexer enabled, using NATS Core subscription",
                        workers_with_configs.num_workers()
432
433
434
435
                    );
                }

                start_kv_router_background_event_plane(
436
437
438
439
440
                    component.clone(),
                    kv_indexer.event_sender(),
                    cancellation_token.clone(),
                    worker_query::WorkerQueryClient::new(
                        component.clone(),
441
                        workers_with_configs.subscribe(),
442
                        Some(kv_indexer.remove_worker_sender()),
443
                    ),
444
                    transport_kind,
445
446
                )
                .await?;
447
            } else {
448
449
450
451
452
                if transport_kind == EventTransportKind::Zmq {
                    tracing::warn!(
                        "Not all workers have local_indexer enabled; falling back to JetStream for durability"
                    );
                }
453
454
455
                tracing::info!(
                    "Not all workers have local_indexer enabled, using JetStream subscription"
                );
456

457
458
                // Convert router_id to string for NATS consumer naming
                let consumer_id = router_id.to_string();
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
                start_kv_router_background(
                    component.clone(),
                    consumer_id,
                    kv_indexer.event_sender(),
                    kv_indexer.remove_worker_sender(),
                    kv_router_config
                        .router_snapshot_threshold
                        .map(|_| kv_indexer.get_workers_sender()),
                    kv_router_config
                        .router_snapshot_threshold
                        .map(|_| kv_indexer.snapshot_event_sender()),
                    cancellation_token.clone(),
                    kv_router_config.router_snapshot_threshold,
                    kv_router_config.router_reset_states,
                )
                .await?;
475
            }
476
        }
477

478
        tracing::info!("KV Routing initialized");
479
        Ok(Self {
480
            indexer,
481
            scheduler,
482
            block_size,
483
            kv_router_config,
Yan Ru Pei's avatar
Yan Ru Pei committed
484
            cancellation_token,
485
            client,
486
            worker_query_client: Some(worker_query_client),
487
        })
488
489
    }

490
491
492
493
494
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

495
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
496
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
Yan Ru Pei's avatar
Yan Ru Pei committed
497
    /// Now also takes optional context_id for request tracking
498
    #[allow(clippy::too_many_arguments)]
Yan Ru Pei's avatar
Yan Ru Pei committed
499
    pub async fn find_best_match(
500
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
501
        context_id: Option<&str>,
502
        tokens: &[u32],
503
        router_config_override: Option<&RouterConfigOverride>,
504
        update_states: bool,
505
        lora_name: Option<String>,
Yan Ru Pei's avatar
Yan Ru Pei committed
506
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
Yan Ru Pei's avatar
Yan Ru Pei committed
507
508
509
510
511
        // Validate that context_id is provided when update_states is true
        if update_states && context_id.is_none() {
            panic!("context_id must be provided if update_states is true");
        }

512
        let isl_tokens = tokens.len();
513

514
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
515
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;
516

517
518
519
        // Compute seq_hashes only if scheduler needs it for active blocks tracking
        let maybe_seq_hashes = self
            .kv_router_config
520
            .compute_seq_hashes_for_tracking(tokens, self.block_size);
521

Yan Ru Pei's avatar
Yan Ru Pei committed
522
        let best_worker = self
523
            .scheduler
524
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
525
                context_id.map(|s| s.to_string()),
526
                isl_tokens,
527
                maybe_seq_hashes,
528
                overlap_scores.clone(),
529
                router_config_override,
530
                update_states,
531
                lora_name,
532
            )
533
            .await?;
534

535
536
        // Note: Routing decision recording (for approximate mode) is now handled
        // by KvPushRouter::generate after select_worker returns.
537

538
539
        let overlap_amount = overlap_scores
            .scores
Yan Ru Pei's avatar
Yan Ru Pei committed
540
            .get(&best_worker)
541
542
            .copied()
            .unwrap_or(0);
Yan Ru Pei's avatar
Yan Ru Pei committed
543
        Ok((best_worker, overlap_amount))
544
545
    }

546
    #[allow(clippy::too_many_arguments)]
547
548
549
550
551
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
        overlap_blocks: u32,
552
        expected_output_tokens: Option<u32>,
Yan Ru Pei's avatar
Yan Ru Pei committed
553
        worker: WorkerWithDpRank,
554
        lora_name: Option<String>,
555
556
    ) {
        let isl_tokens = tokens.len();
557

558
559
560
        let maybe_seq_hashes = self
            .kv_router_config
            .compute_seq_hashes_for_tracking(tokens, self.block_size);
561

562
563
        if let Err(e) = self
            .scheduler
564
            .add_request(
565
                request_id.clone(),
566
                maybe_seq_hashes,
567
568
                isl_tokens,
                overlap_blocks,
569
                expected_output_tokens,
Yan Ru Pei's avatar
Yan Ru Pei committed
570
                worker,
571
                lora_name,
572
            )
573
574
575
576
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
577
578
    }

579
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
580
        self.scheduler.mark_prefill_completed(request_id).await
581
582
    }

583
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
584
        self.scheduler.free(request_id).await
585
    }
586

587
588
589
590
591
592
    /// Get the worker type for this router ("prefill" or "decode").
    /// Used for Prometheus metric labeling.
    pub fn worker_type(&self) -> &'static str {
        self.scheduler.worker_type()
    }

593
594
595
596
597
598
599
600
601
602
    pub async fn add_output_block(
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
        self.scheduler
            .add_output_block(request_id, decay_fraction)
            .await
    }

603
    pub fn block_size(&self) -> u32 {
604
605
        self.block_size
    }
606

607
608
609
610
611
612
613
614
615
616
617
618
    /// Compute the overlap blocks for a given token sequence and worker.
    /// This queries the indexer to find how many blocks are already cached.
    pub async fn get_overlap_blocks(
        &self,
        tokens: &[u32],
        worker: WorkerWithDpRank,
    ) -> Result<u32, KvRouterError> {
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;
        Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0))
    }

619
620
621
    /// Get potential prefill and decode loads for all workers
    pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
        let isl_tokens = tokens.len();
622
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
623
        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
624

625
626
        let maybe_seq_hashes = self
            .kv_router_config
627
            .compute_seq_hashes_for_tracking(tokens, self.block_size);
628

629
630
        Ok(self
            .scheduler
631
            .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
632
633
634
            .await)
    }

635
636
637
638
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
639
640
641
642
643
644

    /// Query a specific worker's local KV indexer for its events
    /// (See docstring for `WorkerQueryClient.query_worker()`)
    pub async fn query_worker_local_kv(
        &self,
        worker_id: WorkerId,
645
        dp_rank: DpRank,
646
647
648
649
650
651
652
653
654
        start_event_id: Option<u64>,
        end_event_id: Option<u64>,
    ) -> Result<WorkerKvQueryResponse> {
        let query_client = self
            .worker_query_client
            .as_ref()
            .ok_or_else(|| anyhow::anyhow!("Worker query client not available (NATS required)"))?;

        query_client
655
            .query_worker(worker_id, dp_rank, start_event_id, end_event_id)
656
657
658
            .await
    }

659
    /// Recover missed KV events from a specific worker's dp_rank.
660
661
662
663
664
665
666
    ///
    /// Queries the worker's local KV indexer for events starting from
    /// `start_event_id` and applies them to the router's indexer.
    ///
    /// # Arguments
    ///
    /// * `worker_id` - The worker to recover from
667
    /// * `dp_rank` - The data parallel rank to recover from
668
669
670
671
672
    /// * `start_event_id` - First event ID to fetch (inclusive), or None to start from beginning
    /// * `end_event_id` - Last event ID to fetch (inclusive), or None for all
    pub async fn recover_from_worker(
        &self,
        worker_id: WorkerId,
673
        dp_rank: DpRank,
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
        start_event_id: Option<u64>,
        end_event_id: Option<u64>,
    ) -> Result<usize> {
        let query_client = self
            .worker_query_client
            .as_ref()
            .ok_or_else(|| anyhow::anyhow!("Worker query client not available"))?;

        let event_tx = match &self.indexer {
            Indexer::KvIndexer(kv_indexer) => kv_indexer.event_sender(),
            Indexer::None => {
                anyhow::bail!("Cannot recover: indexer is disabled (--overlap_score_weight is 0)")
            }
        };

689
690
691
        query_client
            .recover_from_worker(worker_id, dp_rank, start_event_id, end_event_id, &event_tx)
            .await
692
    }
693
694
}

Michael Feil's avatar
Michael Feil committed
695
696
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
697
698
699
700
701
702
703
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
704
705
706
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
707
            RouterRequest::New { tokens } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
708
                let (best_worker, overlap_blocks) = self
709
                    .find_best_match(Some(&context_id), &tokens, None, true, None)
Michael Feil's avatar
Michael Feil committed
710
711
712
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
713
714
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
715
716
717
                    overlap_blocks,
                }
            }
718
719
720
721
722
723
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
            RouterRequest::MarkFree => RouterResponse::FreeMarked {
                success: self.free(&context_id).await.is_ok(),
            },
Michael Feil's avatar
Michael Feil committed
724
        };
725
726
727
728
729
730

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
731
732

pub struct KvPushRouter {
733
    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
734
    pub chooser: Arc<KvRouter>,
735
736
}

737
738
739
740
741
742
743
/// Result of worker selection containing instance ID, dp_rank, and overlap amount.
struct WorkerSelection {
    instance_id: u64,
    dp_rank: u32,
    overlap_amount: u32,
}

744
745
impl KvPushRouter {
    pub fn new(
746
        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
747
748
749
750
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765

    /// Select a worker for the request, either using a preselected worker or finding the best match.
    ///
    /// When `is_query_only` is false and `handle_local_updates` is true, this also registers
    /// the request with the scheduler via `add_request`.
    async fn select_worker(
        &self,
        context_id: &str,
        request: &PreprocessedRequest,
        phase: RequestPhase,
        is_query_only: bool,
        handle_local_updates: bool,
    ) -> Result<WorkerSelection, Error> {
        let routing = request.routing.as_ref();

766
767
768
        // Extract LORA name from routing hints
        let lora_name = routing.and_then(|r| r.lora_name.clone());

769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
        // Get pre-selected worker based on phase, with backend_instance_id as fallback
        let Some(id) = (match phase {
            RequestPhase::Prefill => {
                routing.and_then(|r| r.prefill_worker_id.or(r.backend_instance_id))
            }
            RequestPhase::Decode => {
                routing.and_then(|r| r.decode_worker_id.or(r.backend_instance_id))
            }
            RequestPhase::Aggregated => routing.and_then(|r| r.backend_instance_id),
        }) else {
            // No preselected worker - find the best match
            // Don't update states if this is a query-only request
            let (best_worker, overlap_amount) = self
                .chooser
                .find_best_match(
                    Some(context_id),
                    &request.token_ids,
                    request.router_config_override.as_ref(),
                    !is_query_only,
788
                    lora_name,
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
                )
                .await?;

            return Ok(WorkerSelection {
                instance_id: best_worker.worker_id,
                dp_rank: best_worker.dp_rank,
                overlap_amount,
            });
        };

        // Route to pre-selected or explicitly specified worker
        let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0);
        tracing::debug!(
            worker_id = id,
            dp_rank = dp_rank,
            ?phase,
            "Routing to specified worker"
        );

        // Compute actual overlap blocks by querying the indexer
        let worker = WorkerWithDpRank::new(id, dp_rank);
        let overlap_blocks = self
            .chooser
            .get_overlap_blocks(&request.token_ids, worker)
            .await?;

815
816
817
818
819
820
        // Extract expected_output_tokens from routing hints
        let expected_output_tokens = request
            .routing
            .as_ref()
            .and_then(|r| r.expected_output_tokens);

821
822
823
824
825
826
827
        // Perform add_request if this router handles local updates
        if !is_query_only && handle_local_updates {
            self.chooser
                .add_request(
                    context_id.to_string(),
                    &request.token_ids,
                    overlap_blocks,
828
                    expected_output_tokens,
829
                    worker,
830
                    lora_name,
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
                )
                .await;
        } else {
            tracing::debug!(
                request_id = %context_id,
                worker_id = id,
                dp_rank = dp_rank,
                "Skipping add_request - query or handled externally"
            );
        }

        Ok(WorkerSelection {
            instance_id: id,
            dp_rank,
            overlap_amount: overlap_blocks,
        })
    }
848
849
850
}

#[async_trait]
851
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
852
853
    for KvPushRouter
{
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
    /// Generate method that handles KV-aware routing with three distinct behaviors:
    ///
    /// 1. **If `query_instance_id` annotation is set**:
    ///    - Returns the best matching worker ID without routing the request
    ///    - Does NOT update any router local states
    ///    - Response includes worker_instance_id and token_data annotations
    ///
    /// 2. **If `backend_instance_id` is set in the request**:
    ///    - Routes directly to the specified backend instance
    ///    - DOES update router states to track this request (unless query_instance_id is also set)
    ///    - Bypasses the normal KV matching logic
    ///
    /// 3. **If neither are set (default behavior)**:
    ///    - Finds the best worker based on KV cache overlap
    ///    - Updates router states to track the request
    ///    - Routes to the selected worker
    ///
    /// The router state updates include tracking active sequences and managing
    /// prefill/completion lifecycle for proper KV cache management.
873
874
    async fn generate(
        &self,
875
        request: SingleIn<PreprocessedRequest>,
876
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
877
878
879
        // Extract context ID for request tracking
        let context_id = request.context().id().to_string();

880
881
882
        // Simple query-only detection: presence of query_instance_id annotation means query-only mode
        let is_query_only = request.get_annotation_value("query_instance_id").is_some();

883
        // Determine if this router should handle local state updates (add_request, free, etc.)
884
885
886
887
888
889
        // Default is true (router handles bookkeeping). Set to false for GAIE Stage 2 where
        // an external orchestrator (e.g., EPP sidecar) handles bookkeeping via C FFI.
        let handle_local_updates = request
            .routing
            .as_ref()
            .and_then(|r| r.enable_local_updates)
890
891
            .unwrap_or(true);

892
893
894
        // Get phase from tracker (defaults to Aggregated if no tracker or phase not set)
        let phase = request
            .tracker
895
            .as_ref()
896
897
898
899
            .map(|t| t.phase())
            .unwrap_or(RequestPhase::Aggregated);

        let block_size = self.chooser.block_size() as usize;
900
901
902
903
904
905
906
907
908
909
910
911
912
913
        let selection = self
            .select_worker(
                &context_id,
                &request,
                phase,
                is_query_only,
                handle_local_updates,
            )
            .await?;
        let WorkerSelection {
            instance_id,
            dp_rank,
            overlap_amount,
        } = selection;
914

915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
        // In approximate mode (use_kv_events=false), record the routing decision
        // so the indexer can track cache state based on routing decisions.
        // This covers both pre-selected workers and find_best_match selections.
        if !is_query_only && !self.chooser.kv_router_config.use_kv_events {
            let worker = WorkerWithDpRank::new(instance_id, dp_rank);
            let mut tokens_with_hashes =
                TokensWithHashes::new(request.token_ids.clone(), self.chooser.block_size);
            if let Err(e) = self
                .chooser
                .indexer
                .process_routing_decision_for_request(&mut tokens_with_hashes, worker)
                .await
            {
                tracing::warn!(
                    request_id = %context_id,
                    worker_id = instance_id,
                    dp_rank = dp_rank,
                    error = %e,
                    "Failed to record routing decision in approximate mode"
                );
            }
        }

938
939
940
        // Record metrics in tracker: KV hit rate, worker ID, and worker type based on phase.
        // Worker type is stored at routing time to avoid expensive MDC lookups when
        // updating Prometheus metrics (TTFT/ITL) later in the response stream.
941
942
943
        if let Some(ref tracker) = request.tracker {
            let isl_blocks = request.token_ids.len().div_ceil(block_size);
            tracker.record_kv_hit(overlap_amount, isl_blocks);
944
            tracker.record_worker_full(instance_id, dp_rank, self.chooser.worker_type());
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
        }

        // Handle query-only requests: early return with worker info
        if is_query_only {
            let stream_context = request.context().clone();
            // Tracker is always created for query-only requests (delta generator enables tracking
            // when query_instance_id annotation is present)
            let worker_id_info = request.tracker.as_ref().and_then(|t| t.get_worker_info());

            tracing::trace!(
                ?phase,
                worker_id = instance_id,
                ?worker_id_info,
                "Returning worker selection (query-only mode)"
            );

961
962
963
964
965
966
967
968
969
            let output = LLMEngineOutput {
                disaggregated_params: Some(json!({
                    "worker_id": worker_id_info,
                    "token_ids": request.token_ids
                })),
                ..Default::default()
            };
            let response = Annotated::from_data(output);
            let stream = stream::iter(vec![response]);
970
971
            return Ok(ResponseStream::new(Box::pin(stream), stream_context));
        }
972
973

        // Route to worker
974
975
976
977
978
979
980
981
        let isl_tokens = request.token_ids.len();
        let expected_output_tokens = request
            .routing
            .as_ref()
            .and_then(|r| r.expected_output_tokens);
        let track_output_blocks =
            self.chooser.kv_router_config.router_track_output_blocks && handle_local_updates;

982
        let (mut backend_input, context) = request.into_parts();
983
        backend_input.routing_mut().dp_rank = Some(dp_rank);
984
985
        let updated_request = context.map(|_| backend_input);

986
        let chooser = self.chooser.clone();
987
988
989
990
        let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
        let stream_context = response_stream.context();
        let context_for_monitoring = stream_context.clone();

991
992
993
        // Wrap stream with lifecycle management (mark_prefill_completed, free)
        // Only perform these operations if handle_local_updates is true.
        // When false, an external caller (e.g., GAIE sidecar) handles bookkeeping via C FFI.
994
995
996
        let wrapped_stream = Box::pin(async_stream::stream! {
            let mut prefill_marked = false;

997
998
999
1000
            // Output block tracking state
            let mut cumulative_osl: usize = 0;
            let mut current_total_blocks = isl_tokens.div_ceil(block_size);

1001
1002
1003
1004
1005
1006
1007
            loop {
                tokio::select! {
                    biased;

                    _ = context_for_monitoring.stopped() => {
                        tracing::debug!("Request {context_id} cancelled, ending stream");
                        break;
1008
                    }
Yan Ru Pei's avatar
Yan Ru Pei committed
1009

1010
                    item = response_stream.next() => {
1011
                        let Some(item) = item else {
1012
1013
                            break;
                        };
1014

1015
                        if handle_local_updates && !prefill_marked {
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
                            // Only mark prefill completed when we receive actual tokens,
                            // not empty bootstrap info (token_ids: []) from disaggregated prefill
                            let has_tokens = item.data.as_ref()
                                .map(|d| !d.token_ids.is_empty())
                                .unwrap_or(false);
                            if has_tokens {
                                if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
                                    tracing::warn!("Failed to mark prefill completed for request {context_id}: {e}");
                                }
                                prefill_marked = true;
1026
                            }
1027
                        }
1028

1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
                        // Track output blocks if enabled
                        if track_output_blocks {
                            let new_tokens = item.data.as_ref()
                                .map(|d| d.token_ids.len())
                                .unwrap_or(0);
                            cumulative_osl += new_tokens;

                            let new_total_blocks = (isl_tokens + cumulative_osl).div_ceil(block_size);
                            if new_total_blocks > current_total_blocks {
                                // New block boundary crossed - add output block with decay
                                // Clamp eot to min 1 to avoid division by zero, and result to min 0.0
                                let decay_fraction = expected_output_tokens.map(|eot| {
                                    (1.0 - (cumulative_osl as f64 / eot.max(1) as f64)).max(0.0)
                                });
                                if let Err(e) = chooser.add_output_block(&context_id, decay_fraction).await {
                                    tracing::warn!(
                                        "Failed to add output block for request {context_id}: {e}"
                                    );
                                }
                                current_total_blocks = new_total_blocks;
                            }
                        }

1052
                        yield item;
1053
                    }
1054
1055
                }
            }
1056

1057
1058
1059
1060
1061
            // Only call free() if we handle local updates.
            // When handle_local_updates=false, external caller handles cleanup via C FFI.
            if handle_local_updates
                && let Err(e) = chooser.free(&context_id).await
            {
1062
                tracing::warn!("Failed to free request {context_id}: {e}");
1063
            }
1064
1065
        });
        Ok(ResponseStream::new(wrapped_stream, stream_context))
1066
1067
    }
}
Yan Ru Pei's avatar
Yan Ru Pei committed
1068
1069
1070
1071
1072
1073
1074

impl Drop for KvRouter {
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}