kv_router.rs 36.2 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
use std::sync::Arc;
6
use std::time::Duration;
7

8
use anyhow::Result;
9
use dashmap::DashMap;
10
use derive_builder::Builder;
11
use dynamo_runtime::{
12
    component::{Client, Endpoint},
13
    discovery::{DiscoveryQuery, watch_and_extract_field},
14
    pipeline::{
15
16
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
        SingleIn, async_trait,
17
    },
18
    protocols::EndpointId,
19
    protocols::annotated::Annotated,
20
    traits::DistributedRuntimeProvider,
21
22
};
use futures::stream::{self, StreamExt};
23
use serde::{Deserialize, Serialize};
24
use serde_json::json;
25

26
pub mod approx;
27
pub mod indexer;
28
pub mod prefill_router;
29
30
pub mod protocols;
pub mod publisher;
31
pub mod recorder;
32
pub mod scheduler;
33
pub mod sequence;
34
pub mod subscriber;
35
pub mod worker_query;
36

37
use indexer::WorkerKvQueryResponse;
38
pub use prefill_router::PrefillRouter;
39
use worker_query::WorkerQueryClient;
40

41
42
use crate::{
    kv_router::{
43
        approx::PruneConfig,
44
        indexer::{KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent},
Yan Ru Pei's avatar
Yan Ru Pei committed
45
        protocols::{
46
47
48
            LocalBlockHash, RouterRequest, RouterResponse, TokensWithHashes, WorkerId,
            WorkerSelectionResult, WorkerWithDpRank, compute_block_hash_for_seq,
            compute_seq_hash_for_block,
Yan Ru Pei's avatar
Yan Ru Pei committed
49
        },
50
        scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
51
        sequence::SequenceError,
52
        subscriber::{start_kv_router_background, start_kv_router_background_nats_core},
53
    },
54
    local_model::runtime_config::ModelRuntimeConfig,
55
    model_card::ModelDeploymentCard,
56
    preprocessor::PreprocessedRequest,
57
    protocols::common::llm_backend::LLMEngineOutput,
58
    protocols::common::timing::RequestPhase,
59
60
};

61
62
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
63
64
65
66
67

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
68
pub const KV_EVENT_SUBJECT: &str = "kv-events";
69
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
70
71
72
73
74
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
75

76
77
78
79
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

80
81
82
83
// for worker-local kvindexer query
pub const WORKER_KV_INDEXER_QUERY_SUBJECT: &str = "worker_kv_indexer_query";
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
// for router discovery registration
pub const KV_ROUTER_COMPONENT: &str = "kv-router";
pub const KV_ROUTER_ENDPOINT: &str = "generate";

/// Creates an EndpointId for the KV router in the given namespace.
pub fn router_endpoint_id(namespace: String) -> EndpointId {
    EndpointId {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        name: KV_ROUTER_ENDPOINT.to_string(),
    }
}

/// Creates a DiscoveryQuery for the KV router in the given namespace.
pub fn router_discovery_query(namespace: String) -> DiscoveryQuery {
    DiscoveryQuery::Endpoint {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        endpoint: KV_ROUTER_ENDPOINT.to_string(),
    }
}

106
107
108
109
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
110
        workers: &HashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>,
111
        request: &SchedulingRequest,
112
        block_size: u32,
113
114
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
115

116
117
118
119
120
121
122
123
124
125
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize)]
pub struct RouterConfigOverride {
    #[builder(default)]
    pub overlap_score_weight: Option<f64>,

    #[builder(default)]
    pub router_temperature: Option<f64>,
}

126
/// KV Router configuration parameters
127
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
128
129
130
pub struct KvRouterConfig {
    pub overlap_score_weight: f64,

131
    pub router_temperature: f64,
132

133
134
    pub use_kv_events: bool,

135
136
    pub router_replica_sync: bool,

137
138
139
    /// Whether to track active blocks in the router (default: true)
    pub router_track_active_blocks: bool,

140
141
142
    /// Threshold for triggering snapshots. If None, no snapshots will be performed.
    pub router_snapshot_threshold: Option<u32>,

143
    /// Whether to reset the router state on startup (default: false)
144
    pub router_reset_states: bool,
145
146
147
148

    /// TTL for blocks in seconds (only used when use_kv_events is false, default: 120.0)
    pub router_ttl_secs: f64,

149
    /// Maximum tree size before pruning (only used when use_kv_events is false, default: 2^20 = 1048576)
150
151
152
153
    pub router_max_tree_size: usize,

    /// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
    pub router_prune_target_ratio: f64,
154
155
156
157
158
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
159
            overlap_score_weight: 1.0,
160
            router_temperature: 0.0,
161
            use_kv_events: true,
162
            router_replica_sync: false,
163
            router_track_active_blocks: true,
164
            router_snapshot_threshold: Some(1000000),
165
            router_reset_states: false,
166
            router_ttl_secs: 120.0,
167
            router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default()
168
            router_prune_target_ratio: 0.8,
169
170
171
172
173
174
175
        }
    }
}

impl KvRouterConfig {
    /// Create a new KvRouterConfig with optional weight values.
    /// If a weight is None, the default value will be used.
176
    #[allow(clippy::too_many_arguments)]
177
178
    pub fn new(
        overlap_score_weight: Option<f64>,
179
        temperature: Option<f64>,
180
        use_kv_events: Option<bool>,
181
        replica_sync: Option<bool>,
182
        track_active_blocks: Option<bool>,
183
184
        router_snapshot_threshold: Option<Option<u32>>,
        router_reset_states: Option<bool>,
185
186
187
        router_ttl_secs: Option<f64>,
        router_max_tree_size: Option<usize>,
        router_prune_target_ratio: Option<f64>,
188
189
190
191
    ) -> Self {
        let default = Self::default();
        Self {
            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
192
            router_temperature: temperature.unwrap_or(default.router_temperature),
193
            use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
194
            router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
195
196
            router_track_active_blocks: track_active_blocks
                .unwrap_or(default.router_track_active_blocks),
197
198
199
            router_snapshot_threshold: router_snapshot_threshold
                .unwrap_or(default.router_snapshot_threshold),
            router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
200
201
202
203
            router_ttl_secs: router_ttl_secs.unwrap_or(default.router_ttl_secs),
            router_max_tree_size: router_max_tree_size.unwrap_or(default.router_max_tree_size),
            router_prune_target_ratio: router_prune_target_ratio
                .unwrap_or(default.router_prune_target_ratio),
204
205
206
207
        }
    }
}

208
pub enum Indexer {
209
210
    /// Updates itself based on KV events emitted by backend workers or routing decisions.
    /// Supports TTL-based expiration and size-based pruning.
211
    /// Has the ability to persist and snapshot states.
212
    KvIndexer(KvIndexer),
213
214
215
216

    /// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
    /// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
    None,
217
218
219
220
221
222
223
224
225
}

impl Indexer {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
226
227
228
            Indexer::None => Ok(OverlapScores {
                scores: HashMap::new(),
                frequencies: Vec::new(),
229
                tree_sizes: HashMap::new(),
230
            }),
231
232
        }
    }
233
234
235
236

    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.dump_events().await,
237
238
239
240
241
            Indexer::None => {
                panic!(
                    "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
                );
            }
242
243
        }
    }
244

245
    async fn process_routing_decision_for_request(
246
        &self,
247
        tokens_with_hashes: &mut TokensWithHashes,
248
249
250
251
252
        worker: WorkerWithDpRank,
    ) -> Result<(), KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => {
                indexer
253
                    .process_routing_decision_for_request(tokens_with_hashes, worker)
254
255
256
257
258
                    .await
            }
            Indexer::None => Ok(()),
        }
    }
259
260
}

261
262
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
263
pub struct KvRouter {
264
265
266
    indexer: Indexer,

    // How about a Box<dyn KvIndexerInterface>
267
    scheduler: KvScheduler,
268

269
    block_size: u32,
270
271

    kv_router_config: KvRouterConfig,
Yan Ru Pei's avatar
Yan Ru Pei committed
272
273

    cancellation_token: tokio_util::sync::CancellationToken,
274
275

    client: Client,
276
277

    worker_query_client: Option<WorkerQueryClient>,
278
279
280
281
}

impl KvRouter {
    pub async fn new(
282
283
        endpoint: Endpoint,
        client: Client,
284
        workers_with_configs: Arc<DashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>>,
285
        block_size: u32,
286
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
287
        kv_router_config: Option<KvRouterConfig>,
288
        consumer_id: String,
289
    ) -> Result<Self> {
290
        let kv_router_config = kv_router_config.unwrap_or_default();
291
        let component = endpoint.component();
292
        let cancellation_token = component.drt().primary_token();
293

294
        let instance_ids_rx = client.instance_avail_watcher();
295

296
        // Watch for runtime config updates via discovery interface
297
        // (still needed for WorkerQueryClient and background tasks)
298
        let discovery = component.drt().discovery();
299
        let endpoint_id = endpoint.id();
300
        let discovery_key = DiscoveryQuery::EndpointModels {
301
302
303
            namespace: endpoint_id.namespace.clone(),
            component: endpoint_id.component.clone(),
            endpoint: endpoint_id.name.clone(),
304
305
        };
        let discovery_stream = discovery
306
            .list_and_watch(discovery_key.clone(), Some(cancellation_token.clone()))
307
308
309
310
311
            .await?;
        let runtime_configs_rx =
            watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
                card.runtime_config
            });
312

313
314
315
        let indexer = if kv_router_config.overlap_score_weight == 0.0 {
            // When overlap_score_weight is zero, we don't need to track prefixes
            Indexer::None
316
        } else {
317
            let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);
318
319
320
321
322
323
324
325
326
327
328
329
330

            // If use_kv_events is false, enable TTL and pruning for approximate behavior
            let prune_config = if !kv_router_config.use_kv_events {
                Some(PruneConfig {
                    ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
                    max_tree_size: kv_router_config.router_max_tree_size,
                    prune_target_ratio: kv_router_config.router_prune_target_ratio,
                })
            } else {
                None
            };

            Indexer::KvIndexer(KvIndexer::new_with_frequency(
331
                cancellation_token.clone(),
332
                None, // expiration_duration for frequency tracking
333
334
                block_size,
                kv_indexer_metrics,
335
                prune_config,
336
337
            ))
        };
338

339
        let scheduler = KvScheduler::start(
340
            component.clone(),
341
            block_size,
342
            instance_ids_rx,
343
            workers_with_configs,
344
            selector,
345
            kv_router_config.router_replica_sync,
346
            consumer_id.clone(),
347
348
        )
        .await?;
349

350
351
352
353
354
355
        // Initialize worker query client using namespace abstraction
        // (created before background task so we can use it for startup recovery)
        let worker_query_client =
            worker_query::WorkerQueryClient::new(component.clone(), runtime_configs_rx.clone());
        tracing::info!("Worker query client initialized");

356
        // Start KV event subscriber background process (only when use_kv_events is enabled)
357
358
        // We block here until at least one worker runtime config is registered,
        // then spawn the subscriber. This ensures the router is ready before accepting requests.
359
360
361
        if kv_router_config.use_kv_events
            && let Indexer::KvIndexer(ref kv_indexer) = indexer
        {
362
363
            let mut runtime_configs_rx_clone = runtime_configs_rx.clone();

364
365
366
367
368
369
370
371
            // Wait for at least one worker runtime config to be registered
            tracing::info!("Waiting for at least one worker runtime config to be registered...");
            let (all_local_indexer, count) = loop {
                {
                    let configs = runtime_configs_rx_clone.borrow();
                    if !configs.is_empty() {
                        let all_local_indexer = configs.values().all(|c| c.enable_local_indexer);
                        break (all_local_indexer, configs.len());
372
                    }
373
                }
374

375
376
377
378
379
380
381
382
383
384
                // Wait for changes to runtime_configs
                tokio::select! {
                    _ = cancellation_token.cancelled() => {
                        tracing::debug!("KvRouter startup cancelled while waiting for workers");
                        anyhow::bail!("KvRouter startup cancelled");
                    }
                    result = runtime_configs_rx_clone.changed() => {
                        if result.is_err() {
                            tracing::debug!("Runtime configs channel closed");
                            anyhow::bail!("Runtime configs channel closed before any workers registered");
385
386
                        }
                    }
387
388
389
                }
            };
            tracing::info!("Found {count} worker runtime config(s), starting KV event subscriber");
390

391
392
393
394
395
396
            // Clone everything needed for the background subscriber task
            let component_clone = component.clone();
            let kv_indexer_clone = kv_indexer.clone();
            let cancellation_token_clone = cancellation_token.clone();
            let worker_query_client_clone =
                worker_query::WorkerQueryClient::new(component.clone(), runtime_configs_rx.clone());
397

398
399
400
401
402
403
404
405
            // Spawn subscriber as background task (long-running)
            if all_local_indexer {
                // All workers have local_indexer enabled - use NATS Core
                tracing::info!(
                    "All {count} workers have local_indexer enabled, using NATS Core subscription"
                );

                tokio::spawn(async move {
406
                    if let Err(e) = start_kv_router_background_nats_core(
407
                        component_clone,
408
409
                        kv_indexer_clone.event_sender(),
                        kv_indexer_clone.remove_worker_sender(),
410
                        cancellation_token_clone,
411
412
413
414
415
416
                        worker_query_client_clone,
                    )
                    .await
                    {
                        tracing::error!("Failed to start NATS Core subscriber: {e}");
                    }
417
418
419
420
421
422
                });
            } else {
                // Not all workers have local_indexer - use JetStream
                tracing::info!(
                    "Not all workers have local_indexer enabled, using JetStream subscription"
                );
423

424
                tokio::spawn(async move {
425
                    if let Err(e) = start_kv_router_background(
426
                        component_clone,
427
428
429
430
431
432
433
434
435
                        consumer_id,
                        kv_indexer_clone.event_sender(),
                        kv_indexer_clone.remove_worker_sender(),
                        kv_router_config
                            .router_snapshot_threshold
                            .map(|_| kv_indexer_clone.get_workers_sender()),
                        kv_router_config
                            .router_snapshot_threshold
                            .map(|_| kv_indexer_clone.snapshot_event_sender()),
436
                        cancellation_token_clone,
437
438
439
440
441
442
443
                        kv_router_config.router_snapshot_threshold,
                        kv_router_config.router_reset_states,
                    )
                    .await
                    {
                        tracing::error!("Failed to start JetStream subscriber: {e}");
                    }
444
445
                });
            }
446
        }
447

448
        tracing::info!("KV Routing initialized");
449
        Ok(Self {
450
            indexer,
451
            scheduler,
452
            block_size,
453
            kv_router_config,
Yan Ru Pei's avatar
Yan Ru Pei committed
454
            cancellation_token,
455
            client,
456
            worker_query_client: Some(worker_query_client),
457
        })
458
459
    }

460
461
462
463
464
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

465
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
466
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
Yan Ru Pei's avatar
Yan Ru Pei committed
467
468
    /// Now also takes optional context_id for request tracking
    pub async fn find_best_match(
469
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
470
        context_id: Option<&str>,
471
        tokens: &[u32],
472
        router_config_override: Option<&RouterConfigOverride>,
473
        update_states: bool,
Yan Ru Pei's avatar
Yan Ru Pei committed
474
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
Yan Ru Pei's avatar
Yan Ru Pei committed
475
476
477
478
479
        // Validate that context_id is provided when update_states is true
        if update_states && context_id.is_none() {
            panic!("context_id must be provided if update_states is true");
        }

480
        let isl_tokens = tokens.len();
481

482
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
483
        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
484

485
486
487
488
489
        // Compute seq_hashes only if scheduler needs it for active blocks tracking
        let maybe_seq_hashes = self
            .kv_router_config
            .router_track_active_blocks
            .then(|| compute_seq_hash_for_block(&block_hashes));
490

Yan Ru Pei's avatar
Yan Ru Pei committed
491
        let best_worker = self
492
            .scheduler
493
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
494
                context_id.map(|s| s.to_string()),
495
                isl_tokens,
496
                maybe_seq_hashes,
497
                overlap_scores.clone(),
498
                router_config_override,
499
                update_states,
500
            )
501
            .await?;
502

503
504
        // Note: Routing decision recording (for approximate mode) is now handled
        // by KvPushRouter::generate after select_worker returns.
505

506
507
        let overlap_amount = overlap_scores
            .scores
Yan Ru Pei's avatar
Yan Ru Pei committed
508
            .get(&best_worker)
509
510
            .copied()
            .unwrap_or(0);
Yan Ru Pei's avatar
Yan Ru Pei committed
511
        Ok((best_worker, overlap_amount))
512
513
    }

514
515
516
517
518
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
        overlap_blocks: u32,
Yan Ru Pei's avatar
Yan Ru Pei committed
519
        worker: WorkerWithDpRank,
520
521
    ) {
        let isl_tokens = tokens.len();
522
523

        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
524
            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
525
526
            compute_seq_hash_for_block(&block_hashes)
        });
527

528
529
        if let Err(e) = self
            .scheduler
530
            .add_request(
531
                request_id.clone(),
532
                maybe_seq_hashes,
533
534
                isl_tokens,
                overlap_blocks,
Yan Ru Pei's avatar
Yan Ru Pei committed
535
                worker,
536
            )
537
538
539
540
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
541
542
    }

543
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
544
        self.scheduler.mark_prefill_completed(request_id).await
545
546
    }

547
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
548
        self.scheduler.free(request_id).await
549
    }
550

551
    pub fn block_size(&self) -> u32 {
552
553
        self.block_size
    }
554

555
556
557
558
559
560
561
562
563
564
565
566
    /// Compute the overlap blocks for a given token sequence and worker.
    /// This queries the indexer to find how many blocks are already cached.
    pub async fn get_overlap_blocks(
        &self,
        tokens: &[u32],
        worker: WorkerWithDpRank,
    ) -> Result<u32, KvRouterError> {
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;
        Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0))
    }

567
568
569
    /// Get potential prefill and decode loads for all workers
    pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
        let isl_tokens = tokens.len();
570
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
571
        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
572

573
574
575
576
        let maybe_seq_hashes = self
            .kv_router_config
            .router_track_active_blocks
            .then(|| compute_seq_hash_for_block(&block_hashes));
577

578
579
        Ok(self
            .scheduler
580
            .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
581
582
583
            .await)
    }

584
585
586
587
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643

    /// Query a specific worker's local KV indexer for its events
    /// (See docstring for `WorkerQueryClient.query_worker()`)
    pub async fn query_worker_local_kv(
        &self,
        worker_id: WorkerId,
        start_event_id: Option<u64>,
        end_event_id: Option<u64>,
    ) -> Result<WorkerKvQueryResponse> {
        let query_client = self
            .worker_query_client
            .as_ref()
            .ok_or_else(|| anyhow::anyhow!("Worker query client not available (NATS required)"))?;

        query_client
            .query_worker(worker_id, start_event_id, end_event_id)
            .await
    }

    /// Recover missed KV events from a specific worker.
    ///
    /// Queries the worker's local KV indexer for events starting from
    /// `start_event_id` and applies them to the router's indexer.
    ///
    /// # Arguments
    ///
    /// * `worker_id` - The worker to recover from
    /// * `start_event_id` - First event ID to fetch (inclusive), or None to start from beginning
    /// * `end_event_id` - Last event ID to fetch (inclusive), or None for all
    pub async fn recover_from_worker(
        &self,
        worker_id: WorkerId,
        start_event_id: Option<u64>,
        end_event_id: Option<u64>,
    ) -> Result<usize> {
        let query_client = self
            .worker_query_client
            .as_ref()
            .ok_or_else(|| anyhow::anyhow!("Worker query client not available"))?;

        let event_tx = match &self.indexer {
            Indexer::KvIndexer(kv_indexer) => kv_indexer.event_sender(),
            Indexer::None => {
                anyhow::bail!("Cannot recover: indexer is disabled (--overlap_score_weight is 0)")
            }
        };

        subscriber::recover_from_worker(
            query_client,
            worker_id,
            start_event_id,
            end_event_id,
            &event_tx,
        )
        .await
    }
644
645
}

Michael Feil's avatar
Michael Feil committed
646
647
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
648
649
650
651
652
653
654
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
655
656
657
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
658
            RouterRequest::New { tokens } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
659
                let (best_worker, overlap_blocks) = self
Yan Ru Pei's avatar
Yan Ru Pei committed
660
                    .find_best_match(Some(&context_id), &tokens, None, true)
Michael Feil's avatar
Michael Feil committed
661
662
663
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
664
665
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
666
667
668
                    overlap_blocks,
                }
            }
669
670
671
672
673
674
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
            RouterRequest::MarkFree => RouterResponse::FreeMarked {
                success: self.free(&context_id).await.is_ok(),
            },
Michael Feil's avatar
Michael Feil committed
675
        };
676
677
678
679
680
681

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
682
683

pub struct KvPushRouter {
684
    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
685
    pub chooser: Arc<KvRouter>,
686
687
}

688
689
690
691
692
693
694
/// Result of worker selection containing instance ID, dp_rank, and overlap amount.
struct WorkerSelection {
    instance_id: u64,
    dp_rank: u32,
    overlap_amount: u32,
}

695
696
impl KvPushRouter {
    pub fn new(
697
        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
698
699
700
701
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786

    /// Select a worker for the request, either using a preselected worker or finding the best match.
    ///
    /// When `is_query_only` is false and `handle_local_updates` is true, this also registers
    /// the request with the scheduler via `add_request`.
    async fn select_worker(
        &self,
        context_id: &str,
        request: &PreprocessedRequest,
        phase: RequestPhase,
        is_query_only: bool,
        handle_local_updates: bool,
    ) -> Result<WorkerSelection, Error> {
        let routing = request.routing.as_ref();

        // Get pre-selected worker based on phase, with backend_instance_id as fallback
        let Some(id) = (match phase {
            RequestPhase::Prefill => {
                routing.and_then(|r| r.prefill_worker_id.or(r.backend_instance_id))
            }
            RequestPhase::Decode => {
                routing.and_then(|r| r.decode_worker_id.or(r.backend_instance_id))
            }
            RequestPhase::Aggregated => routing.and_then(|r| r.backend_instance_id),
        }) else {
            // No preselected worker - find the best match
            // Don't update states if this is a query-only request
            let (best_worker, overlap_amount) = self
                .chooser
                .find_best_match(
                    Some(context_id),
                    &request.token_ids,
                    request.router_config_override.as_ref(),
                    !is_query_only,
                )
                .await?;

            return Ok(WorkerSelection {
                instance_id: best_worker.worker_id,
                dp_rank: best_worker.dp_rank,
                overlap_amount,
            });
        };

        // Route to pre-selected or explicitly specified worker
        let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0);
        tracing::debug!(
            worker_id = id,
            dp_rank = dp_rank,
            ?phase,
            "Routing to specified worker"
        );

        // Compute actual overlap blocks by querying the indexer
        let worker = WorkerWithDpRank::new(id, dp_rank);
        let overlap_blocks = self
            .chooser
            .get_overlap_blocks(&request.token_ids, worker)
            .await?;

        // Perform add_request if this router handles local updates
        if !is_query_only && handle_local_updates {
            self.chooser
                .add_request(
                    context_id.to_string(),
                    &request.token_ids,
                    overlap_blocks,
                    worker,
                )
                .await;
        } else {
            tracing::debug!(
                request_id = %context_id,
                worker_id = id,
                dp_rank = dp_rank,
                "Skipping add_request - query or handled externally"
            );
        }

        Ok(WorkerSelection {
            instance_id: id,
            dp_rank,
            overlap_amount: overlap_blocks,
        })
    }
787
788
789
}

#[async_trait]
790
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
791
792
    for KvPushRouter
{
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
    /// Generate method that handles KV-aware routing with three distinct behaviors:
    ///
    /// 1. **If `query_instance_id` annotation is set**:
    ///    - Returns the best matching worker ID without routing the request
    ///    - Does NOT update any router local states
    ///    - Response includes worker_instance_id and token_data annotations
    ///
    /// 2. **If `backend_instance_id` is set in the request**:
    ///    - Routes directly to the specified backend instance
    ///    - DOES update router states to track this request (unless query_instance_id is also set)
    ///    - Bypasses the normal KV matching logic
    ///
    /// 3. **If neither are set (default behavior)**:
    ///    - Finds the best worker based on KV cache overlap
    ///    - Updates router states to track the request
    ///    - Routes to the selected worker
    ///
    /// The router state updates include tracking active sequences and managing
    /// prefill/completion lifecycle for proper KV cache management.
812
813
    async fn generate(
        &self,
814
        request: SingleIn<PreprocessedRequest>,
815
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
816
817
818
        // Extract context ID for request tracking
        let context_id = request.context().id().to_string();

819
820
821
        // Simple query-only detection: presence of query_instance_id annotation means query-only mode
        let is_query_only = request.get_annotation_value("query_instance_id").is_some();

822
        // Determine if this router should handle local state updates (add_request, free, etc.)
823
824
825
826
827
828
        // Default is true (router handles bookkeeping). Set to false for GAIE Stage 2 where
        // an external orchestrator (e.g., EPP sidecar) handles bookkeeping via C FFI.
        let handle_local_updates = request
            .routing
            .as_ref()
            .and_then(|r| r.enable_local_updates)
829
830
            .unwrap_or(true);

831
832
833
        // Get phase from tracker (defaults to Aggregated if no tracker or phase not set)
        let phase = request
            .tracker
834
            .as_ref()
835
836
837
838
            .map(|t| t.phase())
            .unwrap_or(RequestPhase::Aggregated);

        let block_size = self.chooser.block_size() as usize;
839
840
841
842
843
844
845
846
847
848
849
850
851
852
        let selection = self
            .select_worker(
                &context_id,
                &request,
                phase,
                is_query_only,
                handle_local_updates,
            )
            .await?;
        let WorkerSelection {
            instance_id,
            dp_rank,
            overlap_amount,
        } = selection;
853

854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
        // In approximate mode (use_kv_events=false), record the routing decision
        // so the indexer can track cache state based on routing decisions.
        // This covers both pre-selected workers and find_best_match selections.
        if !is_query_only && !self.chooser.kv_router_config.use_kv_events {
            let worker = WorkerWithDpRank::new(instance_id, dp_rank);
            let mut tokens_with_hashes =
                TokensWithHashes::new(request.token_ids.clone(), self.chooser.block_size);
            if let Err(e) = self
                .chooser
                .indexer
                .process_routing_decision_for_request(&mut tokens_with_hashes, worker)
                .await
            {
                tracing::warn!(
                    request_id = %context_id,
                    worker_id = instance_id,
                    dp_rank = dp_rank,
                    error = %e,
                    "Failed to record routing decision in approximate mode"
                );
            }
        }

877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
        // Record metrics in tracker: KV hit rate and worker ID based on phase
        if let Some(ref tracker) = request.tracker {
            let isl_blocks = request.token_ids.len().div_ceil(block_size);
            tracker.record_kv_hit(overlap_amount, isl_blocks);
            tracker.record_worker(instance_id);
        }

        // Handle query-only requests: early return with worker info
        if is_query_only {
            let stream_context = request.context().clone();
            // Tracker is always created for query-only requests (delta generator enables tracking
            // when query_instance_id annotation is present)
            let worker_id_info = request.tracker.as_ref().and_then(|t| t.get_worker_info());

            tracing::trace!(
                ?phase,
                worker_id = instance_id,
                ?worker_id_info,
                "Returning worker selection (query-only mode)"
            );

898
899
900
901
902
903
904
905
906
            let output = LLMEngineOutput {
                disaggregated_params: Some(json!({
                    "worker_id": worker_id_info,
                    "token_ids": request.token_ids
                })),
                ..Default::default()
            };
            let response = Annotated::from_data(output);
            let stream = stream::iter(vec![response]);
907
908
            return Ok(ResponseStream::new(Box::pin(stream), stream_context));
        }
909
910

        // Route to worker
911
        let (mut backend_input, context) = request.into_parts();
912
        backend_input.routing_mut().dp_rank = Some(dp_rank);
913
914
        let updated_request = context.map(|_| backend_input);

915
        let chooser = self.chooser.clone();
916
917
918
919
        let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
        let stream_context = response_stream.context();
        let context_for_monitoring = stream_context.clone();

920
921
922
        // Wrap stream with lifecycle management (mark_prefill_completed, free)
        // Only perform these operations if handle_local_updates is true.
        // When false, an external caller (e.g., GAIE sidecar) handles bookkeeping via C FFI.
923
924
925
926
927
928
929
930
931
932
        let wrapped_stream = Box::pin(async_stream::stream! {
            let mut prefill_marked = false;

            loop {
                tokio::select! {
                    biased;

                    _ = context_for_monitoring.stopped() => {
                        tracing::debug!("Request {context_id} cancelled, ending stream");
                        break;
933
                    }
Yan Ru Pei's avatar
Yan Ru Pei committed
934

935
                    item = response_stream.next() => {
936
                        let Some(item) = item else {
937
938
                            break;
                        };
939

940
                        if handle_local_updates && !prefill_marked {
941
942
943
944
945
946
947
948
949
950
                            // Only mark prefill completed when we receive actual tokens,
                            // not empty bootstrap info (token_ids: []) from disaggregated prefill
                            let has_tokens = item.data.as_ref()
                                .map(|d| !d.token_ids.is_empty())
                                .unwrap_or(false);
                            if has_tokens {
                                if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
                                    tracing::warn!("Failed to mark prefill completed for request {context_id}: {e}");
                                }
                                prefill_marked = true;
951
                            }
952
                        }
953

954
                        yield item;
955
                    }
956
957
                }
            }
958

959
960
961
962
963
            // Only call free() if we handle local updates.
            // When handle_local_updates=false, external caller handles cleanup via C FFI.
            if handle_local_updates
                && let Err(e) = chooser.free(&context_id).await
            {
964
                tracing::warn!("Failed to free request {context_id}: {e}");
965
            }
966
967
        });
        Ok(ResponseStream::new(wrapped_stream, stream_context))
968
969
    }
}
Yan Ru Pei's avatar
Yan Ru Pei committed
970
971
972
973
974
975
976

impl Drop for KvRouter {
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}