kv_router.rs 32.6 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
use std::sync::Arc;
6
use std::time::Duration;
7

8
use anyhow::Result;
9
use derive_builder::Builder;
10
use dynamo_runtime::{
11
    component::{Client, Endpoint},
12
    discovery::{DiscoveryQuery, watch_and_extract_field},
13
    pipeline::{
14
15
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
        SingleIn, async_trait,
16
    },
17
    protocols::EndpointId,
18
    protocols::annotated::Annotated,
19
    traits::DistributedRuntimeProvider,
20
21
};
use futures::stream::{self, StreamExt};
22
use serde::{Deserialize, Serialize};
23
use serde_json::json;
24

25
pub mod approx;
26
pub mod indexer;
27
pub mod prefill_router;
28
29
pub mod protocols;
pub mod publisher;
30
pub mod recorder;
31
pub mod scheduler;
32
pub mod sequence;
33
pub mod subscriber;
34
pub mod worker_query;
35

36
use indexer::WorkerKvQueryResponse;
37
pub use prefill_router::PrefillRouter;
38
use worker_query::WorkerQueryClient;
39

40
41
use crate::{
    kv_router::{
42
        approx::PruneConfig,
43
        indexer::{
44
45
            KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
            compute_block_hash_for_seq, compute_seq_hash_for_block,
46
        },
Yan Ru Pei's avatar
Yan Ru Pei committed
47
        protocols::{
48
49
            LocalBlockHash, RouterRequest, RouterResponse, WorkerId, WorkerSelectionResult,
            WorkerWithDpRank,
Yan Ru Pei's avatar
Yan Ru Pei committed
50
        },
51
        scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
52
        sequence::SequenceError,
53
        subscriber::{start_kv_router_background, start_kv_router_background_nats_core},
54
    },
55
    local_model::runtime_config::ModelRuntimeConfig,
56
    model_card::ModelDeploymentCard,
57
    preprocessor::PreprocessedRequest,
58
    protocols::common::llm_backend::LLMEngineOutput,
59
    protocols::common::timing::RequestPhase,
60
    tokens::SequenceHash,
61
62
};

63
64
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
65
66
67
68
69

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
70
pub const KV_EVENT_SUBJECT: &str = "kv-events";
71
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
72
73
74
75
76
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
77

78
79
80
81
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

82
83
84
85
// for worker-local kvindexer query
pub const WORKER_KV_INDEXER_QUERY_SUBJECT: &str = "worker_kv_indexer_query";
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
// for router discovery registration
pub const KV_ROUTER_COMPONENT: &str = "kv-router";
pub const KV_ROUTER_ENDPOINT: &str = "generate";

/// Creates an EndpointId for the KV router in the given namespace.
pub fn router_endpoint_id(namespace: String) -> EndpointId {
    EndpointId {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        name: KV_ROUTER_ENDPOINT.to_string(),
    }
}

/// Creates a DiscoveryQuery for the KV router in the given namespace.
pub fn router_discovery_query(namespace: String) -> DiscoveryQuery {
    DiscoveryQuery::Endpoint {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        endpoint: KV_ROUTER_ENDPOINT.to_string(),
    }
}

108
109
110
111
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
112
        workers: &HashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>,
113
        request: &SchedulingRequest,
114
        block_size: u32,
115
116
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
117

118
119
120
121
122
123
124
125
126
127
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize)]
pub struct RouterConfigOverride {
    #[builder(default)]
    pub overlap_score_weight: Option<f64>,

    #[builder(default)]
    pub router_temperature: Option<f64>,
}

128
/// KV Router configuration parameters
129
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
130
131
132
pub struct KvRouterConfig {
    pub overlap_score_weight: f64,

133
    pub router_temperature: f64,
134

135
136
    pub use_kv_events: bool,

137
138
    pub router_replica_sync: bool,

139
140
141
    /// Whether to track active blocks in the router (default: true)
    pub router_track_active_blocks: bool,

142
143
144
    /// Threshold for triggering snapshots. If None, no snapshots will be performed.
    pub router_snapshot_threshold: Option<u32>,

145
    /// Whether to reset the router state on startup (default: false)
146
    pub router_reset_states: bool,
147
148
149
150
151
152
153
154
155

    /// TTL for blocks in seconds (only used when use_kv_events is false, default: 120.0)
    pub router_ttl_secs: f64,

    /// Maximum tree size before pruning (only used when use_kv_events is false, default: 1024)
    pub router_max_tree_size: usize,

    /// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
    pub router_prune_target_ratio: f64,
156
157
158
159
160
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
161
            overlap_score_weight: 1.0,
162
            router_temperature: 0.0,
163
            use_kv_events: true,
164
            router_replica_sync: false,
165
            router_track_active_blocks: true,
166
            router_snapshot_threshold: Some(1000000),
167
            router_reset_states: false,
168
169
170
            router_ttl_secs: 120.0,
            router_max_tree_size: 1024,
            router_prune_target_ratio: 0.8,
171
172
173
174
175
176
177
        }
    }
}

impl KvRouterConfig {
    /// Create a new KvRouterConfig with optional weight values.
    /// If a weight is None, the default value will be used.
178
    #[allow(clippy::too_many_arguments)]
179
180
    pub fn new(
        overlap_score_weight: Option<f64>,
181
        temperature: Option<f64>,
182
        use_kv_events: Option<bool>,
183
        replica_sync: Option<bool>,
184
        track_active_blocks: Option<bool>,
185
186
        router_snapshot_threshold: Option<Option<u32>>,
        router_reset_states: Option<bool>,
187
188
189
        router_ttl_secs: Option<f64>,
        router_max_tree_size: Option<usize>,
        router_prune_target_ratio: Option<f64>,
190
191
192
193
    ) -> Self {
        let default = Self::default();
        Self {
            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
194
            router_temperature: temperature.unwrap_or(default.router_temperature),
195
            use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
196
            router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
197
198
            router_track_active_blocks: track_active_blocks
                .unwrap_or(default.router_track_active_blocks),
199
200
201
            router_snapshot_threshold: router_snapshot_threshold
                .unwrap_or(default.router_snapshot_threshold),
            router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
202
203
204
205
            router_ttl_secs: router_ttl_secs.unwrap_or(default.router_ttl_secs),
            router_max_tree_size: router_max_tree_size.unwrap_or(default.router_max_tree_size),
            router_prune_target_ratio: router_prune_target_ratio
                .unwrap_or(default.router_prune_target_ratio),
206
207
208
209
        }
    }
}

210
pub enum Indexer {
211
212
    /// Updates itself based on KV events emitted by backend workers or routing decisions.
    /// Supports TTL-based expiration and size-based pruning.
213
    /// Has the ability to persist and snapshot states.
214
    KvIndexer(KvIndexer),
215
216
217
218

    /// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
    /// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
    None,
219
220
221
222
223
224
225
226
227
}

impl Indexer {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
228
229
230
            Indexer::None => Ok(OverlapScores {
                scores: HashMap::new(),
                frequencies: Vec::new(),
231
                tree_sizes: HashMap::new(),
232
            }),
233
234
        }
    }
235
236
237
238

    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.dump_events().await,
239
240
241
242
243
            Indexer::None => {
                panic!(
                    "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
                );
            }
244
245
        }
    }
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261

    async fn process_routing_decision(
        &self,
        worker: WorkerWithDpRank,
        local_hashes: Vec<LocalBlockHash>,
        sequence_hashes: Vec<SequenceHash>,
    ) -> Result<(), KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => {
                indexer
                    .process_routing_decision(worker, local_hashes, sequence_hashes)
                    .await
            }
            Indexer::None => Ok(()),
        }
    }
262
263
}

264
265
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
266
pub struct KvRouter {
267
268
269
    indexer: Indexer,

    // How about a Box<dyn KvIndexerInterface>
270
    scheduler: KvScheduler,
271

272
    block_size: u32,
273
274

    kv_router_config: KvRouterConfig,
Yan Ru Pei's avatar
Yan Ru Pei committed
275
276

    cancellation_token: tokio_util::sync::CancellationToken,
277
278

    client: Client,
279
280

    worker_query_client: Option<WorkerQueryClient>,
281
282
283
284
}

impl KvRouter {
    pub async fn new(
285
286
        endpoint: Endpoint,
        client: Client,
287
        block_size: u32,
288
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
289
        kv_router_config: Option<KvRouterConfig>,
290
        consumer_id: String,
291
    ) -> Result<Self> {
292
        let kv_router_config = kv_router_config.unwrap_or_default();
293
        let component = endpoint.component();
294
        let cancellation_token = component.drt().primary_token();
295

296
        let instance_ids_rx = client.instance_avail_watcher();
297

298
299
        // Watch for runtime config updates via discovery interface
        let discovery = component.drt().discovery();
300
        let endpoint_id = endpoint.id();
301
        let discovery_key = DiscoveryQuery::EndpointModels {
302
303
304
            namespace: endpoint_id.namespace.clone(),
            component: endpoint_id.component.clone(),
            endpoint: endpoint_id.name.clone(),
305
306
        };
        let discovery_stream = discovery
307
            .list_and_watch(discovery_key.clone(), Some(cancellation_token.clone()))
308
309
310
311
312
            .await?;
        let runtime_configs_rx =
            watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
                card.runtime_config
            });
313

314
315
316
        let indexer = if kv_router_config.overlap_score_weight == 0.0 {
            // When overlap_score_weight is zero, we don't need to track prefixes
            Indexer::None
317
        } else {
318
            let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);
319
320
321
322
323
324
325
326
327
328
329
330
331

            // If use_kv_events is false, enable TTL and pruning for approximate behavior
            let prune_config = if !kv_router_config.use_kv_events {
                Some(PruneConfig {
                    ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
                    max_tree_size: kv_router_config.router_max_tree_size,
                    prune_target_ratio: kv_router_config.router_prune_target_ratio,
                })
            } else {
                None
            };

            Indexer::KvIndexer(KvIndexer::new_with_frequency(
332
                cancellation_token.clone(),
333
                None, // expiration_duration for frequency tracking
334
335
                block_size,
                kv_indexer_metrics,
336
                prune_config,
337
338
            ))
        };
339

340
        let scheduler = KvScheduler::start(
341
            component.clone(),
342
            block_size,
343
            instance_ids_rx,
344
            runtime_configs_rx.clone(),
345
            selector,
346
            kv_router_config.router_replica_sync,
347
            consumer_id.clone(),
348
349
        )
        .await?;
350

351
352
353
354
355
356
        // Initialize worker query client using namespace abstraction
        // (created before background task so we can use it for startup recovery)
        let worker_query_client =
            worker_query::WorkerQueryClient::new(component.clone(), runtime_configs_rx.clone());
        tracing::info!("Worker query client initialized");

357
        // Start KV event subscriber background process (only when use_kv_events is enabled)
358
359
        // This is spawned as a background task to avoid blocking router startup.
        // The task waits for runtime_configs to determine whether to use NATS Core or JetStream.
360
361
362
        if kv_router_config.use_kv_events
            && let Indexer::KvIndexer(ref kv_indexer) = indexer
        {
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
            // Clone everything needed for the background task
            let component_clone = component.clone();
            let kv_indexer_clone = kv_indexer.clone();
            let cancellation_token_clone = cancellation_token.clone();
            let mut runtime_configs_rx_clone = runtime_configs_rx.clone();
            let worker_query_client_clone =
                worker_query::WorkerQueryClient::new(component.clone(), runtime_configs_rx.clone());

            tokio::spawn(async move {
                // Wait for runtime_configs to have at least one entry
                let (all_local_indexer, count) = loop {
                    {
                        let configs = runtime_configs_rx_clone.borrow();
                        if !configs.is_empty() {
                            let all_local_indexer =
                                configs.values().all(|c| c.enable_local_indexer);
                            break (all_local_indexer, configs.len());
                        }
                    }

                    // Wait for changes to runtime_configs
                    tokio::select! {
                        _ = cancellation_token_clone.cancelled() => {
                            tracing::debug!("Subscriber selection task cancelled");
                            return;
                        }
                        result = runtime_configs_rx_clone.changed() => {
                            if result.is_err() {
                                tracing::debug!("Runtime configs channel closed");
                                return;
                            }
                        }
                    }
                };

                if all_local_indexer {
                    // All workers have local_indexer enabled - use NATS Core
400
                    tracing::info!(
401
                        "All {count} workers have local_indexer enabled, using NATS Core subscription"
402
                    );
403
404
405
406
407
408
409
410
411
412
413
414

                    if let Err(e) = start_kv_router_background_nats_core(
                        component_clone.clone(),
                        kv_indexer_clone.event_sender(),
                        kv_indexer_clone.remove_worker_sender(),
                        cancellation_token_clone.clone(),
                        worker_query_client_clone,
                    )
                    .await
                    {
                        tracing::error!("Failed to start NATS Core subscriber: {e}");
                    }
415
                } else {
416
                    // Not all workers have local_indexer - use JetStream
417
                    tracing::info!(
418
                        "Not all workers have local_indexer enabled, using JetStream subscription"
419
                    );
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439

                    if let Err(e) = start_kv_router_background(
                        component_clone.clone(),
                        consumer_id,
                        kv_indexer_clone.event_sender(),
                        kv_indexer_clone.remove_worker_sender(),
                        kv_router_config
                            .router_snapshot_threshold
                            .map(|_| kv_indexer_clone.get_workers_sender()),
                        kv_router_config
                            .router_snapshot_threshold
                            .map(|_| kv_indexer_clone.snapshot_event_sender()),
                        cancellation_token_clone.clone(),
                        kv_router_config.router_snapshot_threshold,
                        kv_router_config.router_reset_states,
                    )
                    .await
                    {
                        tracing::error!("Failed to start JetStream subscriber: {e}");
                    }
440
                }
441
            });
442
        }
443

444
        tracing::info!("KV Routing initialized");
445
        Ok(Self {
446
            indexer,
447
            scheduler,
448
            block_size,
449
            kv_router_config,
Yan Ru Pei's avatar
Yan Ru Pei committed
450
            cancellation_token,
451
            client,
452
            worker_query_client: Some(worker_query_client),
453
        })
454
455
    }

456
457
458
459
460
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

461
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
462
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
Yan Ru Pei's avatar
Yan Ru Pei committed
463
464
    /// Now also takes optional context_id for request tracking
    pub async fn find_best_match(
465
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
466
        context_id: Option<&str>,
467
        tokens: &[u32],
468
        router_config_override: Option<&RouterConfigOverride>,
469
        update_states: bool,
Yan Ru Pei's avatar
Yan Ru Pei committed
470
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
Yan Ru Pei's avatar
Yan Ru Pei committed
471
472
473
474
475
        // Validate that context_id is provided when update_states is true
        if update_states && context_id.is_none() {
            panic!("context_id must be provided if update_states is true");
        }

476
        let isl_tokens = tokens.len();
477

478
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
479
480
481
        let seq_hashes = compute_seq_hash_for_block(&block_hashes);

        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
482

483
        // Determine who needs seq_hashes
484
        let needs_process_routing = !self.kv_router_config.use_kv_events;
485
486
487
488
        let scheduler_needs_it = self.kv_router_config.router_track_active_blocks;

        // Optimize cloning: only clone if both need it, otherwise move
        let (maybe_seq_hashes_1, maybe_seq_hashes_2) =
489
            match (needs_process_routing, scheduler_needs_it) {
490
491
492
493
494
495
                (true, true) => (Some(seq_hashes.clone()), Some(seq_hashes)),
                (true, false) => (Some(seq_hashes), None),
                (false, true) => (None, Some(seq_hashes)),
                (false, false) => (None, None),
            };

Yan Ru Pei's avatar
Yan Ru Pei committed
496
        let best_worker = self
497
            .scheduler
498
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
499
                context_id.map(|s| s.to_string()),
500
                isl_tokens,
501
                maybe_seq_hashes_2,
502
                overlap_scores.clone(),
503
                router_config_override,
504
                update_states,
505
            )
506
            .await?;
507

508
509
510
        // Process routing decision when not using KV events (approximate mode with TTL/pruning)
        if needs_process_routing {
            self.indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
511
                .process_routing_decision(best_worker, block_hashes, maybe_seq_hashes_1.unwrap())
512
513
                .await?;
        }
514

515
516
        let overlap_amount = overlap_scores
            .scores
Yan Ru Pei's avatar
Yan Ru Pei committed
517
            .get(&best_worker)
518
519
            .copied()
            .unwrap_or(0);
Yan Ru Pei's avatar
Yan Ru Pei committed
520
        Ok((best_worker, overlap_amount))
521
522
    }

523
524
525
526
527
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
        overlap_blocks: u32,
Yan Ru Pei's avatar
Yan Ru Pei committed
528
        worker: WorkerWithDpRank,
529
530
    ) {
        let isl_tokens = tokens.len();
531
532

        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
533
            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
534
535
            compute_seq_hash_for_block(&block_hashes)
        });
536

537
538
        if let Err(e) = self
            .scheduler
539
            .add_request(
540
                request_id.clone(),
541
                maybe_seq_hashes,
542
543
                isl_tokens,
                overlap_blocks,
Yan Ru Pei's avatar
Yan Ru Pei committed
544
                worker,
545
            )
546
547
548
549
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
550
551
    }

552
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
553
        self.scheduler.mark_prefill_completed(request_id).await
554
555
    }

556
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
557
        self.scheduler.free(request_id).await
558
    }
559

560
    pub fn block_size(&self) -> u32 {
561
562
        self.block_size
    }
563

564
565
566
567
568
569
570
571
572
    /// Get the disaggregated endpoint for a worker, if available.
    /// Used to look up bootstrap host/port for prefill workers.
    pub async fn get_disaggregated_endpoint(
        &self,
        worker_id: u64,
    ) -> Option<crate::local_model::runtime_config::DisaggregatedEndpoint> {
        self.scheduler.get_disaggregated_endpoint(worker_id).await
    }

573
574
575
    /// Get potential prefill and decode loads for all workers
    pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
        let isl_tokens = tokens.len();
576
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
577
578
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;

579
        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
580
            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
581
582
583
            compute_seq_hash_for_block(&block_hashes)
        });

584
585
        Ok(self
            .scheduler
586
            .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
587
588
589
            .await)
    }

590
591
592
593
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649

    /// Query a specific worker's local KV indexer for its events
    /// (See docstring for `WorkerQueryClient.query_worker()`)
    pub async fn query_worker_local_kv(
        &self,
        worker_id: WorkerId,
        start_event_id: Option<u64>,
        end_event_id: Option<u64>,
    ) -> Result<WorkerKvQueryResponse> {
        let query_client = self
            .worker_query_client
            .as_ref()
            .ok_or_else(|| anyhow::anyhow!("Worker query client not available (NATS required)"))?;

        query_client
            .query_worker(worker_id, start_event_id, end_event_id)
            .await
    }

    /// Recover missed KV events from a specific worker.
    ///
    /// Queries the worker's local KV indexer for events starting from
    /// `start_event_id` and applies them to the router's indexer.
    ///
    /// # Arguments
    ///
    /// * `worker_id` - The worker to recover from
    /// * `start_event_id` - First event ID to fetch (inclusive), or None to start from beginning
    /// * `end_event_id` - Last event ID to fetch (inclusive), or None for all
    pub async fn recover_from_worker(
        &self,
        worker_id: WorkerId,
        start_event_id: Option<u64>,
        end_event_id: Option<u64>,
    ) -> Result<usize> {
        let query_client = self
            .worker_query_client
            .as_ref()
            .ok_or_else(|| anyhow::anyhow!("Worker query client not available"))?;

        let event_tx = match &self.indexer {
            Indexer::KvIndexer(kv_indexer) => kv_indexer.event_sender(),
            Indexer::None => {
                anyhow::bail!("Cannot recover: indexer is disabled (--overlap_score_weight is 0)")
            }
        };

        subscriber::recover_from_worker(
            query_client,
            worker_id,
            start_event_id,
            end_event_id,
            &event_tx,
        )
        .await
    }
650
651
}

Michael Feil's avatar
Michael Feil committed
652
653
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
654
655
656
657
658
659
660
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
661
662
663
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
664
            RouterRequest::New { tokens } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
665
                let (best_worker, overlap_blocks) = self
Yan Ru Pei's avatar
Yan Ru Pei committed
666
                    .find_best_match(Some(&context_id), &tokens, None, true)
Michael Feil's avatar
Michael Feil committed
667
668
669
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
670
671
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
672
673
674
                    overlap_blocks,
                }
            }
675
676
677
678
679
680
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
            RouterRequest::MarkFree => RouterResponse::FreeMarked {
                success: self.free(&context_id).await.is_ok(),
            },
Michael Feil's avatar
Michael Feil committed
681
        };
682
683
684
685
686
687

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
688
689

pub struct KvPushRouter {
690
    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
691
    pub chooser: Arc<KvRouter>,
692
693
694
695
}

impl KvPushRouter {
    pub fn new(
696
        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
697
698
699
700
701
702
703
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
}

#[async_trait]
704
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
705
706
    for KvPushRouter
{
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
    /// Generate method that handles KV-aware routing with three distinct behaviors:
    ///
    /// 1. **If `query_instance_id` annotation is set**:
    ///    - Returns the best matching worker ID without routing the request
    ///    - Does NOT update any router local states
    ///    - Response includes worker_instance_id and token_data annotations
    ///
    /// 2. **If `backend_instance_id` is set in the request**:
    ///    - Routes directly to the specified backend instance
    ///    - DOES update router states to track this request (unless query_instance_id is also set)
    ///    - Bypasses the normal KV matching logic
    ///
    /// 3. **If neither are set (default behavior)**:
    ///    - Finds the best worker based on KV cache overlap
    ///    - Updates router states to track the request
    ///    - Routes to the selected worker
    ///
    /// The router state updates include tracking active sequences and managing
    /// prefill/completion lifecycle for proper KV cache management.
726
727
    async fn generate(
        &self,
728
        request: SingleIn<PreprocessedRequest>,
729
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
730
731
732
        // Extract context ID for request tracking
        let context_id = request.context().id().to_string();

733
734
735
736
737
738
        // Simple query-only detection: presence of query_instance_id annotation means query-only mode
        let is_query_only = request.get_annotation_value("query_instance_id").is_some();

        // Get phase from tracker (defaults to Aggregated if no tracker or phase not set)
        let phase = request
            .tracker
739
            .as_ref()
740
741
742
            .map(|t| t.phase())
            .unwrap_or(RequestPhase::Aggregated);

743
744
        // Get pre-selected worker based on phase, with backend_instance_id as fallback
        let routing = request.routing.as_ref();
745
        let preselected = match phase {
746
747
748
749
750
751
752
            RequestPhase::Prefill => {
                routing.and_then(|r| r.prefill_worker_id.or(r.backend_instance_id))
            }
            RequestPhase::Decode => {
                routing.and_then(|r| r.decode_worker_id.or(r.backend_instance_id))
            }
            RequestPhase::Aggregated => routing.and_then(|r| r.backend_instance_id),
753
        };
754

755
        let block_size = self.chooser.block_size() as usize;
756
757
758
759
760
761
762
763
764
        let (instance_id, dp_rank, overlap_amount) = if let Some(id) = preselected {
            // Route to pre-selected or explicitly specified worker
            let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0);
            tracing::debug!(
                worker_id = id,
                dp_rank = dp_rank,
                ?phase,
                "Routing to specified worker"
            );
765

766
767
768
769
770
771
772
773
774
775
776
            // Compute actual overlap blocks by querying the indexer
            let block_hashes =
                compute_block_hash_for_seq(&request.token_ids, self.chooser.block_size(), None);
            let overlap_scores = self.chooser.indexer.find_matches(block_hashes).await?;
            let worker = WorkerWithDpRank::new(id, dp_rank);
            let overlap_blocks = overlap_scores.scores.get(&worker).copied().unwrap_or(0);

            if !is_query_only {
                self.chooser
                    .add_request(
                        context_id.clone(),
777
                        &request.token_ids,
778
779
                        overlap_blocks,
                        worker,
780
                    )
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
                    .await;
            }
            (id, dp_rank, overlap_blocks)
        } else {
            // Find the best worker match
            // Don't update states if this is a query-only request
            let (best_worker, overlap_amount) = self
                .chooser
                .find_best_match(
                    Some(&context_id),
                    &request.token_ids,
                    request.router_config_override.as_ref(),
                    !is_query_only,
                )
                .await?;
            (best_worker.worker_id, best_worker.dp_rank, overlap_amount)
        };
798

799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
        // Record metrics in tracker: KV hit rate and worker ID based on phase
        if let Some(ref tracker) = request.tracker {
            let isl_blocks = request.token_ids.len().div_ceil(block_size);
            tracker.record_kv_hit(overlap_amount, isl_blocks);
            tracker.record_worker(instance_id);
        }

        // Handle query-only requests: early return with worker info
        if is_query_only {
            let stream_context = request.context().clone();
            // Tracker is always created for query-only requests (delta generator enables tracking
            // when query_instance_id annotation is present)
            let worker_id_info = request.tracker.as_ref().and_then(|t| t.get_worker_info());

            tracing::trace!(
                ?phase,
                worker_id = instance_id,
                ?worker_id_info,
                "Returning worker selection (query-only mode)"
            );

820
821
822
823
824
825
826
827
828
            let output = LLMEngineOutput {
                disaggregated_params: Some(json!({
                    "worker_id": worker_id_info,
                    "token_ids": request.token_ids
                })),
                ..Default::default()
            };
            let response = Annotated::from_data(output);
            let stream = stream::iter(vec![response]);
829
830
            return Ok(ResponseStream::new(Box::pin(stream), stream_context));
        }
831
832

        // Route to worker
833
        let (mut backend_input, context) = request.into_parts();
834
        backend_input.routing_mut().dp_rank = Some(dp_rank);
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
        let updated_request = context.map(|_| backend_input);

        let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
        let stream_context = response_stream.context();
        let chooser = self.chooser.clone();
        let context_for_monitoring = stream_context.clone();

        let wrapped_stream = Box::pin(async_stream::stream! {
            let mut prefill_marked = false;

            loop {
                tokio::select! {
                    biased;

                    _ = context_for_monitoring.stopped() => {
                        tracing::debug!("Request {context_id} cancelled, ending stream");
                        break;
852
                    }
Yan Ru Pei's avatar
Yan Ru Pei committed
853

854
                    item = response_stream.next() => {
855
                        let Some(item) = item else {
856
857
                            break;
                        };
858

859
860
                        if !prefill_marked {
                            if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
861
                                tracing::warn!("Failed to mark prefill completed for request {context_id}: {e}");
862
                            }
863
                            prefill_marked = true;
864
                        }
865

866
                        yield item;
867
                    }
868
869
                }
            }
870

871
            if let Err(e) = chooser.free(&context_id).await {
872
                tracing::warn!("Failed to free request {context_id}: {e}");
873
            }
874
875
        });
        Ok(ResponseStream::new(wrapped_stream, stream_context))
876
877
    }
}
Yan Ru Pei's avatar
Yan Ru Pei committed
878
879
880
881
882
883
884

impl Drop for KvRouter {
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}