kv_router.rs 40.8 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
use std::sync::Arc;
6
use std::time::Duration;
7
8
#[cfg(feature = "bench")]
use std::time::Instant;
9

10
use anyhow::Result;
11
use derive_builder::Builder;
12
use dynamo_runtime::{
13
    component::{Client, Endpoint},
14
    discovery::{DiscoveryQuery, EventTransportKind},
15
    pipeline::{
16
17
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
        SingleIn, async_trait,
18
    },
19
    protocols::EndpointId,
20
    protocols::annotated::Annotated,
21
    traits::DistributedRuntimeProvider,
22
23
};
use futures::stream::{self, StreamExt};
24
use rand::Rng;
25
use serde::{Deserialize, Serialize};
26
use serde_json::json;
27

28
29
30
31
32
// Re-export from dynamo-kv-router crate
pub use dynamo_kv_router::approx;
pub use dynamo_kv_router::indexer;
pub use dynamo_kv_router::protocols;

33
pub mod prefill_router;
34
pub mod publisher;
35
pub mod recorder;
36
pub mod scheduler;
37
pub mod sequence;
38
pub mod subscriber;
39
pub mod worker_query;
40

41
use indexer::WorkerKvQueryResponse;
42
pub use prefill_router::PrefillRouter;
43
use worker_query::WorkerQueryClient;
44

45
use crate::{
46
    discovery::RuntimeConfigs,
47
    kv_router::{
48
        approx::PruneConfig,
49
        indexer::{KvIndexer, KvIndexerInterface, KvRouterError},
Yan Ru Pei's avatar
Yan Ru Pei committed
50
        protocols::{
51
            DpRank, LocalBlockHash, OverlapScores, RouterEvent, RouterRequest, RouterResponse,
52
53
            TokensWithHashes, WorkerId, WorkerSelectionResult, WorkerWithDpRank,
            compute_block_hash_for_seq, compute_seq_hash_for_block,
Yan Ru Pei's avatar
Yan Ru Pei committed
54
        },
55
        scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
56
        sequence::SequenceError,
57
        subscriber::{start_kv_router_background, start_kv_router_background_event_plane},
58
    },
59
    local_model::runtime_config::ModelRuntimeConfig,
60
    preprocessor::PreprocessedRequest,
61
    protocols::common::llm_backend::LLMEngineOutput,
62
    protocols::common::timing::RequestPhase,
63
64
};

65
66
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
67
68
69
70
71

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
72
pub const KV_EVENT_SUBJECT: &str = "kv-events";
73
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
74
75
76
77
78
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
79

80
81
82
83
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

84
85
86
// for worker-local kvindexer query
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer

87
88
89
90
91
92
/// Generates a dp_rank-specific endpoint name for the worker KV indexer query service.
/// Each dp_rank has its own LocalKvIndexer and query endpoint to ensure per-dp_rank monotonicity.
pub fn worker_kv_indexer_query_endpoint(dp_rank: DpRank) -> String {
    format!("worker_kv_indexer_query_dp{dp_rank}")
}

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
// for router discovery registration
pub const KV_ROUTER_COMPONENT: &str = "kv-router";
pub const KV_ROUTER_ENDPOINT: &str = "generate";

/// Creates an EndpointId for the KV router in the given namespace.
pub fn router_endpoint_id(namespace: String) -> EndpointId {
    EndpointId {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        name: KV_ROUTER_ENDPOINT.to_string(),
    }
}

/// Creates a DiscoveryQuery for the KV router in the given namespace.
pub fn router_discovery_query(namespace: String) -> DiscoveryQuery {
    DiscoveryQuery::Endpoint {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        endpoint: KV_ROUTER_ENDPOINT.to_string(),
    }
}

115
116
117
118
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
119
        workers: &HashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>,
120
        request: &SchedulingRequest,
121
        block_size: u32,
122
123
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
124

125
126
127
128
129
130
131
132
133
134
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize)]
pub struct RouterConfigOverride {
    #[builder(default)]
    pub overlap_score_weight: Option<f64>,

    #[builder(default)]
    pub router_temperature: Option<f64>,
}

135
/// KV Router configuration parameters
136
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
137
138
139
pub struct KvRouterConfig {
    pub overlap_score_weight: f64,

140
    pub router_temperature: f64,
141

142
143
    pub use_kv_events: bool,

144
145
    pub router_replica_sync: bool,

146
147
148
    /// Whether to track active blocks in the router (default: true)
    pub router_track_active_blocks: bool,

149
150
151
152
153
    /// Whether to track output blocks during generation (default: false)
    /// When enabled, the router adds placeholder blocks as tokens are generated
    /// and applies fractional decay based on progress toward expected_output_tokens.
    pub router_track_output_blocks: bool,

154
155
156
157
158
    /// Whether to assume KV cache reuse when tracking active blocks (default: true).
    /// When true, computes actual block hashes for sequence tracking.
    /// When false, generates random hashes (assuming no KV cache reuse).
    pub router_assume_kv_reuse: bool,

159
160
161
    /// Threshold for triggering snapshots. If None, no snapshots will be performed.
    pub router_snapshot_threshold: Option<u32>,

162
    /// Whether to reset the router state on startup (default: false)
163
    pub router_reset_states: bool,
164
165
166
167

    /// TTL for blocks in seconds (only used when use_kv_events is false, default: 120.0)
    pub router_ttl_secs: f64,

168
    /// Maximum tree size before pruning (only used when use_kv_events is false, default: 2^20 = 1048576)
169
170
171
172
    pub router_max_tree_size: usize,

    /// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
    pub router_prune_target_ratio: f64,
173
174
175
176
177
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
178
            overlap_score_weight: 1.0,
179
            router_temperature: 0.0,
180
            use_kv_events: true,
181
            router_replica_sync: false,
182
            router_track_active_blocks: true,
183
            router_track_output_blocks: false,
184
            router_assume_kv_reuse: true,
185
            router_snapshot_threshold: Some(1000000),
186
            router_reset_states: false,
187
            router_ttl_secs: 120.0,
188
            router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default()
189
            router_prune_target_ratio: 0.8,
190
191
192
193
194
195
196
        }
    }
}

impl KvRouterConfig {
    /// Create a new KvRouterConfig with optional weight values.
    /// If a weight is None, the default value will be used.
197
    #[allow(clippy::too_many_arguments)]
198
199
    pub fn new(
        overlap_score_weight: Option<f64>,
200
        temperature: Option<f64>,
201
        use_kv_events: Option<bool>,
202
        replica_sync: Option<bool>,
203
        track_active_blocks: Option<bool>,
204
        track_output_blocks: Option<bool>,
205
        assume_kv_reuse: Option<bool>,
206
207
        router_snapshot_threshold: Option<Option<u32>>,
        router_reset_states: Option<bool>,
208
209
210
        router_ttl_secs: Option<f64>,
        router_max_tree_size: Option<usize>,
        router_prune_target_ratio: Option<f64>,
211
212
213
214
    ) -> Self {
        let default = Self::default();
        Self {
            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
215
            router_temperature: temperature.unwrap_or(default.router_temperature),
216
            use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
217
            router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
218
219
            router_track_active_blocks: track_active_blocks
                .unwrap_or(default.router_track_active_blocks),
220
221
            router_track_output_blocks: track_output_blocks
                .unwrap_or(default.router_track_output_blocks),
222
            router_assume_kv_reuse: assume_kv_reuse.unwrap_or(default.router_assume_kv_reuse),
223
224
225
            router_snapshot_threshold: router_snapshot_threshold
                .unwrap_or(default.router_snapshot_threshold),
            router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
226
227
228
229
            router_ttl_secs: router_ttl_secs.unwrap_or(default.router_ttl_secs),
            router_max_tree_size: router_max_tree_size.unwrap_or(default.router_max_tree_size),
            router_prune_target_ratio: router_prune_target_ratio
                .unwrap_or(default.router_prune_target_ratio),
230
231
        }
    }
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262

    /// Compute sequence hashes for active block tracking based on configuration.
    ///
    /// Returns:
    /// - `None` if `router_track_active_blocks` is false
    /// - Random hashes if `router_track_active_blocks` is true but `router_assume_kv_reuse` is false
    /// - Actual sequence hashes if both are true
    pub fn compute_seq_hashes_for_tracking(
        &self,
        tokens: &[u32],
        block_size: u32,
    ) -> Option<Vec<u64>> {
        if !self.router_track_active_blocks {
            return None;
        }

        let num_blocks = tokens.len() / block_size as usize;
        if num_blocks == 0 {
            return Some(Vec::new());
        }

        if self.router_assume_kv_reuse {
            // Compute actual block hashes and sequence hashes
            let block_hashes = compute_block_hash_for_seq(tokens, block_size, None);
            Some(compute_seq_hash_for_block(&block_hashes))
        } else {
            // Generate random hashes (no KV reuse assumed)
            let mut rng = rand::rng();
            Some((0..num_blocks).map(|_| rng.random::<u64>()).collect())
        }
    }
263
264
}

265
pub enum Indexer {
266
267
    /// Updates itself based on KV events emitted by backend workers or routing decisions.
    /// Supports TTL-based expiration and size-based pruning.
268
    /// Has the ability to persist and snapshot states.
269
    KvIndexer(KvIndexer),
270
271
272
273

    /// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
    /// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
    None,
274
275
276
277
278
279
280
281
282
}

impl Indexer {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
283
284
285
            Indexer::None => Ok(OverlapScores {
                scores: HashMap::new(),
                frequencies: Vec::new(),
286
                tree_sizes: HashMap::new(),
287
            }),
288
289
        }
    }
290
291
292
293

    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.dump_events().await,
294
295
296
297
298
            Indexer::None => {
                panic!(
                    "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
                );
            }
299
300
        }
    }
301

302
    async fn process_routing_decision_for_request(
303
        &self,
304
        tokens_with_hashes: &mut TokensWithHashes,
305
306
307
308
309
        worker: WorkerWithDpRank,
    ) -> Result<(), KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => {
                indexer
310
                    .process_routing_decision_for_request(tokens_with_hashes, worker)
311
312
313
314
315
                    .await
            }
            Indexer::None => Ok(()),
        }
    }
316
317
}

318
319
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
320
pub struct KvRouter {
321
322
323
    indexer: Indexer,

    // How about a Box<dyn KvIndexerInterface>
324
    scheduler: KvScheduler,
325

326
    block_size: u32,
327
328

    kv_router_config: KvRouterConfig,
Yan Ru Pei's avatar
Yan Ru Pei committed
329
330

    cancellation_token: tokio_util::sync::CancellationToken,
331
332

    client: Client,
333
334

    worker_query_client: Option<WorkerQueryClient>,
335
336
337
}

impl KvRouter {
338
    #[allow(clippy::too_many_arguments)]
339
    pub async fn new(
340
341
        endpoint: Endpoint,
        client: Client,
342
        workers_with_configs: Arc<RuntimeConfigs>,
343
        block_size: u32,
344
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
345
        kv_router_config: Option<KvRouterConfig>,
346
        router_id: u64,
347
        worker_type: &'static str,
348
    ) -> Result<Self> {
349
        let kv_router_config = kv_router_config.unwrap_or_default();
350
        let component = endpoint.component();
351
        let cancellation_token = component.drt().primary_token();
352

353
354
355
        let indexer = if kv_router_config.overlap_score_weight == 0.0 {
            // When overlap_score_weight is zero, we don't need to track prefixes
            Indexer::None
356
        } else {
357
            let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);
358
359
360
361
362
363
364
365
366
367
368
369
370

            // If use_kv_events is false, enable TTL and pruning for approximate behavior
            let prune_config = if !kv_router_config.use_kv_events {
                Some(PruneConfig {
                    ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
                    max_tree_size: kv_router_config.router_max_tree_size,
                    prune_target_ratio: kv_router_config.router_prune_target_ratio,
                })
            } else {
                None
            };

            Indexer::KvIndexer(KvIndexer::new_with_frequency(
371
                cancellation_token.clone(),
372
                None, // expiration_duration for frequency tracking
373
374
                block_size,
                kv_indexer_metrics,
375
                prune_config,
376
377
            ))
        };
378

379
380
381
        // Wait for at least one worker with a known runtime config before starting scheduler
        workers_with_configs.subscribe().wait_for_some().await;

382
        let scheduler = KvScheduler::start(
383
            component.clone(),
384
            block_size,
385
            workers_with_configs.clone(),
386
            selector,
387
            kv_router_config.router_replica_sync,
388
            router_id,
389
            worker_type,
390
391
        )
        .await?;
392

393
        // Initialize worker query client using namespace abstraction
394
        // (for query/recovery API methods - no lifecycle tracking needed)
395
396
397
398
        // Uses a subscriber from workers_with_configs
        let worker_query_client = worker_query::WorkerQueryClient::new(
            component.clone(),
            workers_with_configs.subscribe(),
399
            None, // No removal channel - query only
400
        );
401
402
        tracing::info!("Worker query client initialized");

403
404
405
406
        // Start KV event subscriber background process (only when use_kv_events is enabled)
        if kv_router_config.use_kv_events
            && let Indexer::KvIndexer(ref kv_indexer) = indexer
        {
407
408
409
410
411
412
            let all_local_indexer = workers_with_configs
                .configs
                .iter()
                .filter_map(|r| r.value().as_ref().map(|c| c.enable_local_indexer))
                .all(|b| b);

413
414
415
416
            tracing::info!(
                "Found {} worker(s), starting KV event subscriber",
                workers_with_configs.num_workers()
            );
417

418
419
            let transport_kind = EventTransportKind::from_env_or_default();

420
            // Start subscriber - setup runs synchronously, then spawns background loop internally
421
            if all_local_indexer {
422
423
424
425
426
427
428
429
430
431
                if transport_kind == EventTransportKind::Zmq {
                    if kv_router_config.router_snapshot_threshold.is_some()
                        || kv_router_config.router_reset_states
                    {
                        tracing::warn!(
                            "ZMQ event plane does not support KV snapshots or state reset; ignoring snapshot/reset settings"
                        );
                    }
                } else {
                    tracing::info!(
432
433
                        "All {} workers have local_indexer enabled, using NATS Core subscription",
                        workers_with_configs.num_workers()
434
435
436
437
                    );
                }

                start_kv_router_background_event_plane(
438
439
440
441
442
                    component.clone(),
                    kv_indexer.event_sender(),
                    cancellation_token.clone(),
                    worker_query::WorkerQueryClient::new(
                        component.clone(),
443
                        workers_with_configs.subscribe(),
444
                        Some(kv_indexer.remove_worker_sender()),
445
                    ),
446
                    transport_kind,
447
448
                )
                .await?;
449
            } else {
450
451
452
453
454
                if transport_kind == EventTransportKind::Zmq {
                    tracing::warn!(
                        "Not all workers have local_indexer enabled; falling back to JetStream for durability"
                    );
                }
455
456
457
                tracing::info!(
                    "Not all workers have local_indexer enabled, using JetStream subscription"
                );
458

459
460
                // Convert router_id to string for NATS consumer naming
                let consumer_id = router_id.to_string();
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
                start_kv_router_background(
                    component.clone(),
                    consumer_id,
                    kv_indexer.event_sender(),
                    kv_indexer.remove_worker_sender(),
                    kv_router_config
                        .router_snapshot_threshold
                        .map(|_| kv_indexer.get_workers_sender()),
                    kv_router_config
                        .router_snapshot_threshold
                        .map(|_| kv_indexer.snapshot_event_sender()),
                    cancellation_token.clone(),
                    kv_router_config.router_snapshot_threshold,
                    kv_router_config.router_reset_states,
                )
                .await?;
477
            }
478
        }
479

480
        tracing::info!("KV Routing initialized");
481
        Ok(Self {
482
            indexer,
483
            scheduler,
484
            block_size,
485
            kv_router_config,
Yan Ru Pei's avatar
Yan Ru Pei committed
486
            cancellation_token,
487
            client,
488
            worker_query_client: Some(worker_query_client),
489
        })
490
491
    }

492
493
494
495
496
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

497
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
498
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
Yan Ru Pei's avatar
Yan Ru Pei committed
499
    /// Now also takes optional context_id for request tracking
500
    #[allow(clippy::too_many_arguments)]
Yan Ru Pei's avatar
Yan Ru Pei committed
501
    pub async fn find_best_match(
502
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
503
        context_id: Option<&str>,
504
        tokens: &[u32],
505
        router_config_override: Option<&RouterConfigOverride>,
506
        update_states: bool,
507
        lora_name: Option<String>,
Yan Ru Pei's avatar
Yan Ru Pei committed
508
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
509
510
511
        #[cfg(feature = "bench")]
        let start = Instant::now();

Yan Ru Pei's avatar
Yan Ru Pei committed
512
513
514
515
516
        // Validate that context_id is provided when update_states is true
        if update_states && context_id.is_none() {
            panic!("context_id must be provided if update_states is true");
        }

517
        let isl_tokens = tokens.len();
518

519
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
520
521
        #[cfg(feature = "bench")]
        let hash_elapsed = start.elapsed();
522
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;
523
524
        #[cfg(feature = "bench")]
        let find_matches_elapsed = start.elapsed();
525

526
527
528
        // Compute seq_hashes only if scheduler needs it for active blocks tracking
        let maybe_seq_hashes = self
            .kv_router_config
529
            .compute_seq_hashes_for_tracking(tokens, self.block_size);
530

Yan Ru Pei's avatar
Yan Ru Pei committed
531
        let best_worker = self
532
            .scheduler
533
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
534
                context_id.map(|s| s.to_string()),
535
                isl_tokens,
536
                maybe_seq_hashes,
537
                overlap_scores.clone(),
538
                router_config_override,
539
                update_states,
540
                lora_name,
541
            )
542
            .await?;
543

544
545
546
547
548
549
550
551
552
553
554
555
556
        #[cfg(feature = "bench")]
        {
            let total_elapsed = start.elapsed();
            tracing::info!(
                isl_tokens,
                hash_us = hash_elapsed.as_micros() as u64,
                find_matches_us = (find_matches_elapsed - hash_elapsed).as_micros() as u64,
                schedule_us = (total_elapsed - find_matches_elapsed).as_micros() as u64,
                total_us = total_elapsed.as_micros() as u64,
                "find_best_match completed"
            );
        }

557
558
        // Note: Routing decision recording (for approximate mode) is now handled
        // by KvPushRouter::generate after select_worker returns.
559

560
561
        let overlap_amount = overlap_scores
            .scores
Yan Ru Pei's avatar
Yan Ru Pei committed
562
            .get(&best_worker)
563
564
            .copied()
            .unwrap_or(0);
Yan Ru Pei's avatar
Yan Ru Pei committed
565
        Ok((best_worker, overlap_amount))
566
567
    }

568
    #[allow(clippy::too_many_arguments)]
569
570
571
572
573
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
        overlap_blocks: u32,
574
        expected_output_tokens: Option<u32>,
Yan Ru Pei's avatar
Yan Ru Pei committed
575
        worker: WorkerWithDpRank,
576
        lora_name: Option<String>,
577
578
    ) {
        let isl_tokens = tokens.len();
579

580
581
582
        let maybe_seq_hashes = self
            .kv_router_config
            .compute_seq_hashes_for_tracking(tokens, self.block_size);
583

584
585
        if let Err(e) = self
            .scheduler
586
            .add_request(
587
                request_id.clone(),
588
                maybe_seq_hashes,
589
590
                isl_tokens,
                overlap_blocks,
591
                expected_output_tokens,
Yan Ru Pei's avatar
Yan Ru Pei committed
592
                worker,
593
                lora_name,
594
            )
595
596
597
598
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
599
600
    }

601
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
602
        self.scheduler.mark_prefill_completed(request_id).await
603
604
    }

605
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
606
        self.scheduler.free(request_id).await
607
    }
608

609
610
611
612
613
614
    /// Get the worker type for this router ("prefill" or "decode").
    /// Used for Prometheus metric labeling.
    pub fn worker_type(&self) -> &'static str {
        self.scheduler.worker_type()
    }

615
616
617
618
619
620
621
622
623
624
    pub async fn add_output_block(
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
        self.scheduler
            .add_output_block(request_id, decay_fraction)
            .await
    }

625
    pub fn block_size(&self) -> u32 {
626
627
        self.block_size
    }
628

629
630
631
632
633
634
635
636
637
638
639
640
    /// Compute the overlap blocks for a given token sequence and worker.
    /// This queries the indexer to find how many blocks are already cached.
    pub async fn get_overlap_blocks(
        &self,
        tokens: &[u32],
        worker: WorkerWithDpRank,
    ) -> Result<u32, KvRouterError> {
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;
        Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0))
    }

641
642
643
    /// Get potential prefill and decode loads for all workers
    pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
        let isl_tokens = tokens.len();
644
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
645
        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
646

647
648
        let maybe_seq_hashes = self
            .kv_router_config
649
            .compute_seq_hashes_for_tracking(tokens, self.block_size);
650

651
652
        Ok(self
            .scheduler
653
            .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
654
655
656
            .await)
    }

657
658
659
660
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
661
662
663
664
665
666

    /// Query a specific worker's local KV indexer for its events
    /// (See docstring for `WorkerQueryClient.query_worker()`)
    pub async fn query_worker_local_kv(
        &self,
        worker_id: WorkerId,
667
        dp_rank: DpRank,
668
669
670
671
672
673
674
675
676
        start_event_id: Option<u64>,
        end_event_id: Option<u64>,
    ) -> Result<WorkerKvQueryResponse> {
        let query_client = self
            .worker_query_client
            .as_ref()
            .ok_or_else(|| anyhow::anyhow!("Worker query client not available (NATS required)"))?;

        query_client
677
            .query_worker(worker_id, dp_rank, start_event_id, end_event_id)
678
679
680
            .await
    }

681
    /// Recover missed KV events from a specific worker's dp_rank.
682
683
684
685
686
687
688
    ///
    /// Queries the worker's local KV indexer for events starting from
    /// `start_event_id` and applies them to the router's indexer.
    ///
    /// # Arguments
    ///
    /// * `worker_id` - The worker to recover from
689
    /// * `dp_rank` - The data parallel rank to recover from
690
691
692
693
694
    /// * `start_event_id` - First event ID to fetch (inclusive), or None to start from beginning
    /// * `end_event_id` - Last event ID to fetch (inclusive), or None for all
    pub async fn recover_from_worker(
        &self,
        worker_id: WorkerId,
695
        dp_rank: DpRank,
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
        start_event_id: Option<u64>,
        end_event_id: Option<u64>,
    ) -> Result<usize> {
        let query_client = self
            .worker_query_client
            .as_ref()
            .ok_or_else(|| anyhow::anyhow!("Worker query client not available"))?;

        let event_tx = match &self.indexer {
            Indexer::KvIndexer(kv_indexer) => kv_indexer.event_sender(),
            Indexer::None => {
                anyhow::bail!("Cannot recover: indexer is disabled (--overlap_score_weight is 0)")
            }
        };

711
712
713
        query_client
            .recover_from_worker(worker_id, dp_rank, start_event_id, end_event_id, &event_tx)
            .await
714
    }
715
716
}

Michael Feil's avatar
Michael Feil committed
717
718
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
719
720
721
722
723
724
725
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
726
727
728
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
729
            RouterRequest::New { tokens } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
730
                let (best_worker, overlap_blocks) = self
731
                    .find_best_match(Some(&context_id), &tokens, None, true, None)
Michael Feil's avatar
Michael Feil committed
732
733
734
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
735
736
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
737
738
739
                    overlap_blocks,
                }
            }
740
741
742
743
744
745
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
            RouterRequest::MarkFree => RouterResponse::FreeMarked {
                success: self.free(&context_id).await.is_ok(),
            },
Michael Feil's avatar
Michael Feil committed
746
        };
747
748
749
750
751
752

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
753
754

pub struct KvPushRouter {
755
    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
756
    pub chooser: Arc<KvRouter>,
757
758
}

759
760
761
762
763
764
765
/// Result of worker selection containing instance ID, dp_rank, and overlap amount.
struct WorkerSelection {
    instance_id: u64,
    dp_rank: u32,
    overlap_amount: u32,
}

766
767
impl KvPushRouter {
    pub fn new(
768
        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
769
770
771
772
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787

    /// Select a worker for the request, either using a preselected worker or finding the best match.
    ///
    /// When `is_query_only` is false and `handle_local_updates` is true, this also registers
    /// the request with the scheduler via `add_request`.
    async fn select_worker(
        &self,
        context_id: &str,
        request: &PreprocessedRequest,
        phase: RequestPhase,
        is_query_only: bool,
        handle_local_updates: bool,
    ) -> Result<WorkerSelection, Error> {
        let routing = request.routing.as_ref();

788
789
790
        // Extract LORA name from routing hints
        let lora_name = routing.and_then(|r| r.lora_name.clone());

791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
        // Get pre-selected worker based on phase, with backend_instance_id as fallback
        let Some(id) = (match phase {
            RequestPhase::Prefill => {
                routing.and_then(|r| r.prefill_worker_id.or(r.backend_instance_id))
            }
            RequestPhase::Decode => {
                routing.and_then(|r| r.decode_worker_id.or(r.backend_instance_id))
            }
            RequestPhase::Aggregated => routing.and_then(|r| r.backend_instance_id),
        }) else {
            // No preselected worker - find the best match
            // Don't update states if this is a query-only request
            let (best_worker, overlap_amount) = self
                .chooser
                .find_best_match(
                    Some(context_id),
                    &request.token_ids,
                    request.router_config_override.as_ref(),
                    !is_query_only,
810
                    lora_name,
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
                )
                .await?;

            return Ok(WorkerSelection {
                instance_id: best_worker.worker_id,
                dp_rank: best_worker.dp_rank,
                overlap_amount,
            });
        };

        // Route to pre-selected or explicitly specified worker
        let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0);
        tracing::debug!(
            worker_id = id,
            dp_rank = dp_rank,
            ?phase,
            "Routing to specified worker"
        );

        // Compute actual overlap blocks by querying the indexer
        let worker = WorkerWithDpRank::new(id, dp_rank);
        let overlap_blocks = self
            .chooser
            .get_overlap_blocks(&request.token_ids, worker)
            .await?;

837
838
839
840
841
842
        // Extract expected_output_tokens from routing hints
        let expected_output_tokens = request
            .routing
            .as_ref()
            .and_then(|r| r.expected_output_tokens);

843
844
845
846
847
848
849
        // Perform add_request if this router handles local updates
        if !is_query_only && handle_local_updates {
            self.chooser
                .add_request(
                    context_id.to_string(),
                    &request.token_ids,
                    overlap_blocks,
850
                    expected_output_tokens,
851
                    worker,
852
                    lora_name,
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
                )
                .await;
        } else {
            tracing::debug!(
                request_id = %context_id,
                worker_id = id,
                dp_rank = dp_rank,
                "Skipping add_request - query or handled externally"
            );
        }

        Ok(WorkerSelection {
            instance_id: id,
            dp_rank,
            overlap_amount: overlap_blocks,
        })
    }
870
871
872
}

#[async_trait]
873
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
874
875
    for KvPushRouter
{
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
    /// Generate method that handles KV-aware routing with three distinct behaviors:
    ///
    /// 1. **If `query_instance_id` annotation is set**:
    ///    - Returns the best matching worker ID without routing the request
    ///    - Does NOT update any router local states
    ///    - Response includes worker_instance_id and token_data annotations
    ///
    /// 2. **If `backend_instance_id` is set in the request**:
    ///    - Routes directly to the specified backend instance
    ///    - DOES update router states to track this request (unless query_instance_id is also set)
    ///    - Bypasses the normal KV matching logic
    ///
    /// 3. **If neither are set (default behavior)**:
    ///    - Finds the best worker based on KV cache overlap
    ///    - Updates router states to track the request
    ///    - Routes to the selected worker
    ///
    /// The router state updates include tracking active sequences and managing
    /// prefill/completion lifecycle for proper KV cache management.
895
896
    async fn generate(
        &self,
897
        request: SingleIn<PreprocessedRequest>,
898
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
899
900
901
        // Extract context ID for request tracking
        let context_id = request.context().id().to_string();

902
903
904
        // Simple query-only detection: presence of query_instance_id annotation means query-only mode
        let is_query_only = request.get_annotation_value("query_instance_id").is_some();

905
        // Determine if this router should handle local state updates (add_request, free, etc.)
906
907
908
909
910
911
        // Default is true (router handles bookkeeping). Set to false for GAIE Stage 2 where
        // an external orchestrator (e.g., EPP sidecar) handles bookkeeping via C FFI.
        let handle_local_updates = request
            .routing
            .as_ref()
            .and_then(|r| r.enable_local_updates)
912
913
            .unwrap_or(true);

914
915
916
        // Get phase from tracker (defaults to Aggregated if no tracker or phase not set)
        let phase = request
            .tracker
917
            .as_ref()
918
919
920
921
            .map(|t| t.phase())
            .unwrap_or(RequestPhase::Aggregated);

        let block_size = self.chooser.block_size() as usize;
922
923
924
925
926
927
928
929
930
931
932
933
934
935
        let selection = self
            .select_worker(
                &context_id,
                &request,
                phase,
                is_query_only,
                handle_local_updates,
            )
            .await?;
        let WorkerSelection {
            instance_id,
            dp_rank,
            overlap_amount,
        } = selection;
936

937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
        // In approximate mode (use_kv_events=false), record the routing decision
        // so the indexer can track cache state based on routing decisions.
        // This covers both pre-selected workers and find_best_match selections.
        if !is_query_only && !self.chooser.kv_router_config.use_kv_events {
            let worker = WorkerWithDpRank::new(instance_id, dp_rank);
            let mut tokens_with_hashes =
                TokensWithHashes::new(request.token_ids.clone(), self.chooser.block_size);
            if let Err(e) = self
                .chooser
                .indexer
                .process_routing_decision_for_request(&mut tokens_with_hashes, worker)
                .await
            {
                tracing::warn!(
                    request_id = %context_id,
                    worker_id = instance_id,
                    dp_rank = dp_rank,
                    error = %e,
                    "Failed to record routing decision in approximate mode"
                );
            }
        }

960
961
962
        // Record metrics in tracker: KV hit rate, worker ID, and worker type based on phase.
        // Worker type is stored at routing time to avoid expensive MDC lookups when
        // updating Prometheus metrics (TTFT/ITL) later in the response stream.
963
964
965
        if let Some(ref tracker) = request.tracker {
            let isl_blocks = request.token_ids.len().div_ceil(block_size);
            tracker.record_kv_hit(overlap_amount, isl_blocks);
966
            tracker.record_worker_full(instance_id, dp_rank, self.chooser.worker_type());
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
        }

        // Handle query-only requests: early return with worker info
        if is_query_only {
            let stream_context = request.context().clone();
            // Tracker is always created for query-only requests (delta generator enables tracking
            // when query_instance_id annotation is present)
            let worker_id_info = request.tracker.as_ref().and_then(|t| t.get_worker_info());

            tracing::trace!(
                ?phase,
                worker_id = instance_id,
                ?worker_id_info,
                "Returning worker selection (query-only mode)"
            );

983
984
985
986
987
988
989
990
991
            let output = LLMEngineOutput {
                disaggregated_params: Some(json!({
                    "worker_id": worker_id_info,
                    "token_ids": request.token_ids
                })),
                ..Default::default()
            };
            let response = Annotated::from_data(output);
            let stream = stream::iter(vec![response]);
992
993
            return Ok(ResponseStream::new(Box::pin(stream), stream_context));
        }
994
995

        // Route to worker
996
997
998
999
1000
1001
1002
1003
        let isl_tokens = request.token_ids.len();
        let expected_output_tokens = request
            .routing
            .as_ref()
            .and_then(|r| r.expected_output_tokens);
        let track_output_blocks =
            self.chooser.kv_router_config.router_track_output_blocks && handle_local_updates;

1004
        let (mut backend_input, context) = request.into_parts();
1005
        backend_input.routing_mut().dp_rank = Some(dp_rank);
1006
1007
        let updated_request = context.map(|_| backend_input);

1008
        let chooser = self.chooser.clone();
1009
1010
1011
1012
        let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
        let stream_context = response_stream.context();
        let context_for_monitoring = stream_context.clone();

1013
1014
1015
        // Wrap stream with lifecycle management (mark_prefill_completed, free)
        // Only perform these operations if handle_local_updates is true.
        // When false, an external caller (e.g., GAIE sidecar) handles bookkeeping via C FFI.
1016
1017
1018
        let wrapped_stream = Box::pin(async_stream::stream! {
            let mut prefill_marked = false;

1019
1020
1021
1022
            // Output block tracking state
            let mut cumulative_osl: usize = 0;
            let mut current_total_blocks = isl_tokens.div_ceil(block_size);

1023
1024
1025
1026
1027
1028
1029
            loop {
                tokio::select! {
                    biased;

                    _ = context_for_monitoring.stopped() => {
                        tracing::debug!("Request {context_id} cancelled, ending stream");
                        break;
1030
                    }
Yan Ru Pei's avatar
Yan Ru Pei committed
1031

1032
                    item = response_stream.next() => {
1033
                        let Some(item) = item else {
1034
1035
                            break;
                        };
1036

1037
                        if handle_local_updates && !prefill_marked {
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
                            // Only mark prefill completed when we receive actual tokens,
                            // not empty bootstrap info (token_ids: []) from disaggregated prefill
                            let has_tokens = item.data.as_ref()
                                .map(|d| !d.token_ids.is_empty())
                                .unwrap_or(false);
                            if has_tokens {
                                if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
                                    tracing::warn!("Failed to mark prefill completed for request {context_id}: {e}");
                                }
                                prefill_marked = true;
1048
                            }
1049
                        }
1050

1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
                        // Track output blocks if enabled
                        if track_output_blocks {
                            let new_tokens = item.data.as_ref()
                                .map(|d| d.token_ids.len())
                                .unwrap_or(0);
                            cumulative_osl += new_tokens;

                            let new_total_blocks = (isl_tokens + cumulative_osl).div_ceil(block_size);
                            if new_total_blocks > current_total_blocks {
                                // New block boundary crossed - add output block with decay
                                // Clamp eot to min 1 to avoid division by zero, and result to min 0.0
                                let decay_fraction = expected_output_tokens.map(|eot| {
                                    (1.0 - (cumulative_osl as f64 / eot.max(1) as f64)).max(0.0)
                                });
                                if let Err(e) = chooser.add_output_block(&context_id, decay_fraction).await {
                                    tracing::warn!(
                                        "Failed to add output block for request {context_id}: {e}"
                                    );
                                }
                                current_total_blocks = new_total_blocks;
                            }
                        }

1074
                        yield item;
1075
                    }
1076
1077
                }
            }
1078

1079
1080
1081
1082
1083
            // Only call free() if we handle local updates.
            // When handle_local_updates=false, external caller handles cleanup via C FFI.
            if handle_local_updates
                && let Err(e) = chooser.free(&context_id).await
            {
1084
                tracing::warn!("Failed to free request {context_id}: {e}");
1085
            }
1086
1087
        });
        Ok(ResponseStream::new(wrapped_stream, stream_context))
1088
1089
    }
}
Yan Ru Pei's avatar
Yan Ru Pei committed
1090
1091
1092
1093
1094
1095
1096

impl Drop for KvRouter {
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}