kv_router.rs 38.2 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
use std::sync::Arc;
6
use std::time::Duration;
7

8
use anyhow::Result;
9
use derive_builder::Builder;
10
use dynamo_runtime::{
11
    component::{Client, Endpoint},
12
    discovery::{DiscoveryQuery, watch_and_extract_field},
13
    pipeline::{
14
15
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
        SingleIn, async_trait,
16
    },
17
    protocols::EndpointId,
18
    protocols::annotated::Annotated,
19
    traits::DistributedRuntimeProvider,
20
21
};
use futures::stream::{self, StreamExt};
22
use serde::{Deserialize, Serialize};
23
use serde_json::json;
24

25
26
use crate::protocols::openai::nvext::WorkerIdInfo;

27
pub mod approx;
28
pub mod indexer;
29
pub mod prefill_router;
30
31
pub mod protocols;
pub mod publisher;
32
pub mod recorder;
33
pub mod scheduler;
34
pub mod sequence;
35
pub mod subscriber;
36
pub mod worker_query;
37

38
use indexer::WorkerKvQueryResponse;
39
pub use prefill_router::PrefillRouter;
40
use worker_query::WorkerQueryClient;
41

42
43
use crate::{
    kv_router::{
44
        approx::PruneConfig,
45
        indexer::{
46
47
            KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
            compute_block_hash_for_seq, compute_seq_hash_for_block,
48
        },
Yan Ru Pei's avatar
Yan Ru Pei committed
49
        protocols::{
50
51
            LocalBlockHash, RouterRequest, RouterResponse, WorkerId, WorkerSelectionResult,
            WorkerWithDpRank,
Yan Ru Pei's avatar
Yan Ru Pei committed
52
        },
53
        scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
54
        sequence::SequenceError,
55
        subscriber::{start_kv_router_background, start_kv_router_background_nats_core},
56
    },
57
    local_model::runtime_config::ModelRuntimeConfig,
58
    model_card::ModelDeploymentCard,
59
    preprocessor::PreprocessedRequest,
60
    protocols::common::llm_backend::LLMEngineOutput,
61
    tokens::SequenceHash,
62
63
};

64
65
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
66
67
68
69
70

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
71
pub const KV_EVENT_SUBJECT: &str = "kv-events";
72
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
73
74
75
76
77
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
78

79
80
81
82
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

83
84
85
86
// for worker-local kvindexer query
pub const WORKER_KV_INDEXER_QUERY_SUBJECT: &str = "worker_kv_indexer_query";
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer

87
88
89
90
91
92
93
94
95
96
97
98
99
// for router discovery registration
pub const KV_ROUTER_COMPONENT: &str = "kv-router";
pub const KV_ROUTER_ENDPOINT: &str = "generate";

/// Creates an EndpointId for the KV router in the given namespace.
pub fn router_endpoint_id(namespace: String) -> EndpointId {
    EndpointId {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        name: KV_ROUTER_ENDPOINT.to_string(),
    }
}

100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
/// Specifies the type of worker being queried when using the `query_instance_id` annotation.
/// This tells the router which worker pool to select from and what type of operation is intended.
///
/// Query instance types for worker selection
/// - "prefill" → select a prefill worker (disaggregated serving)
/// - "decode" → select a decode worker (disaggregated serving)
///
/// Note: Empty value ("query_instance_id:") is handled by PrefillRouter for disagg orchestration
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum QueryInstanceType {
    /// Query for a prefill worker (disaggregated serving)
    Prefill,
    /// Query for a decode worker (disaggregated serving)
    Decode,
}

impl std::fmt::Display for QueryInstanceType {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            QueryInstanceType::Prefill => write!(f, "prefill"),
            QueryInstanceType::Decode => write!(f, "decode"),
        }
    }
}

impl std::str::FromStr for QueryInstanceType {
    type Err = String;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        match s.to_lowercase().as_str() {
            "prefill" => Ok(QueryInstanceType::Prefill),
            "decode" => Ok(QueryInstanceType::Decode),
            _ => Err(format!(
                "Invalid QueryInstanceType: '{s}'. Expected 'prefill' or 'decode'"
            )),
        }
    }
}

140
141
142
143
144
145
146
147
148
/// Creates a DiscoveryQuery for the KV router in the given namespace.
pub fn router_discovery_query(namespace: String) -> DiscoveryQuery {
    DiscoveryQuery::Endpoint {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        endpoint: KV_ROUTER_ENDPOINT.to_string(),
    }
}

149
150
151
152
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
153
        workers: &HashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>,
154
        request: &SchedulingRequest,
155
        block_size: u32,
156
157
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
158

159
160
161
162
163
164
165
166
167
168
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize)]
pub struct RouterConfigOverride {
    #[builder(default)]
    pub overlap_score_weight: Option<f64>,

    #[builder(default)]
    pub router_temperature: Option<f64>,
}

169
/// KV Router configuration parameters
170
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
171
172
173
pub struct KvRouterConfig {
    pub overlap_score_weight: f64,

174
    pub router_temperature: f64,
175

176
177
    pub use_kv_events: bool,

178
179
    pub router_replica_sync: bool,

180
181
182
    /// Whether to track active blocks in the router (default: true)
    pub router_track_active_blocks: bool,

183
184
185
    /// Threshold for triggering snapshots. If None, no snapshots will be performed.
    pub router_snapshot_threshold: Option<u32>,

186
    /// Whether to reset the router state on startup (default: false)
187
    pub router_reset_states: bool,
188
189
190
191
192
193
194
195
196

    /// TTL for blocks in seconds (only used when use_kv_events is false, default: 120.0)
    pub router_ttl_secs: f64,

    /// Maximum tree size before pruning (only used when use_kv_events is false, default: 1024)
    pub router_max_tree_size: usize,

    /// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
    pub router_prune_target_ratio: f64,
197
198
199
200
201
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
202
            overlap_score_weight: 1.0,
203
            router_temperature: 0.0,
204
            use_kv_events: true,
205
            router_replica_sync: false,
206
            router_track_active_blocks: true,
207
            router_snapshot_threshold: Some(1000000),
208
            router_reset_states: false,
209
210
211
            router_ttl_secs: 120.0,
            router_max_tree_size: 1024,
            router_prune_target_ratio: 0.8,
212
213
214
215
216
217
218
        }
    }
}

impl KvRouterConfig {
    /// Create a new KvRouterConfig with optional weight values.
    /// If a weight is None, the default value will be used.
219
    #[allow(clippy::too_many_arguments)]
220
221
    pub fn new(
        overlap_score_weight: Option<f64>,
222
        temperature: Option<f64>,
223
        use_kv_events: Option<bool>,
224
        replica_sync: Option<bool>,
225
        track_active_blocks: Option<bool>,
226
227
        router_snapshot_threshold: Option<Option<u32>>,
        router_reset_states: Option<bool>,
228
229
230
        router_ttl_secs: Option<f64>,
        router_max_tree_size: Option<usize>,
        router_prune_target_ratio: Option<f64>,
231
232
233
234
    ) -> Self {
        let default = Self::default();
        Self {
            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
235
            router_temperature: temperature.unwrap_or(default.router_temperature),
236
            use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
237
            router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
238
239
            router_track_active_blocks: track_active_blocks
                .unwrap_or(default.router_track_active_blocks),
240
241
242
            router_snapshot_threshold: router_snapshot_threshold
                .unwrap_or(default.router_snapshot_threshold),
            router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
243
244
245
246
            router_ttl_secs: router_ttl_secs.unwrap_or(default.router_ttl_secs),
            router_max_tree_size: router_max_tree_size.unwrap_or(default.router_max_tree_size),
            router_prune_target_ratio: router_prune_target_ratio
                .unwrap_or(default.router_prune_target_ratio),
247
248
249
250
        }
    }
}

251
pub enum Indexer {
252
253
    /// Updates itself based on KV events emitted by backend workers or routing decisions.
    /// Supports TTL-based expiration and size-based pruning.
254
    /// Has the ability to persist and snapshot states.
255
    KvIndexer(KvIndexer),
256
257
258
259

    /// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
    /// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
    None,
260
261
262
263
264
265
266
267
268
}

impl Indexer {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
269
270
271
            Indexer::None => Ok(OverlapScores {
                scores: HashMap::new(),
                frequencies: Vec::new(),
272
                tree_sizes: HashMap::new(),
273
            }),
274
275
        }
    }
276
277
278
279

    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.dump_events().await,
280
281
282
283
284
            Indexer::None => {
                panic!(
                    "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
                );
            }
285
286
        }
    }
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302

    async fn process_routing_decision(
        &self,
        worker: WorkerWithDpRank,
        local_hashes: Vec<LocalBlockHash>,
        sequence_hashes: Vec<SequenceHash>,
    ) -> Result<(), KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => {
                indexer
                    .process_routing_decision(worker, local_hashes, sequence_hashes)
                    .await
            }
            Indexer::None => Ok(()),
        }
    }
303
304
}

305
306
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
307
pub struct KvRouter {
308
309
310
    indexer: Indexer,

    // How about a Box<dyn KvIndexerInterface>
311
    scheduler: KvScheduler,
312

313
    block_size: u32,
314
315

    kv_router_config: KvRouterConfig,
Yan Ru Pei's avatar
Yan Ru Pei committed
316
317

    cancellation_token: tokio_util::sync::CancellationToken,
318
319

    client: Client,
320
321

    worker_query_client: Option<WorkerQueryClient>,
322
323
324
325
}

impl KvRouter {
    pub async fn new(
326
327
        endpoint: Endpoint,
        client: Client,
328
        block_size: u32,
329
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
330
        kv_router_config: Option<KvRouterConfig>,
331
        consumer_id: String,
332
    ) -> Result<Self> {
333
        let kv_router_config = kv_router_config.unwrap_or_default();
334
        let component = endpoint.component();
335
        let cancellation_token = component.drt().primary_token();
336

337
        let instance_ids_rx = client.instance_avail_watcher();
338

339
340
        // Watch for runtime config updates via discovery interface
        let discovery = component.drt().discovery();
341
        let endpoint_id = endpoint.id();
342
        let discovery_key = DiscoveryQuery::EndpointModels {
343
344
345
            namespace: endpoint_id.namespace.clone(),
            component: endpoint_id.component.clone(),
            endpoint: endpoint_id.name.clone(),
346
347
        };
        let discovery_stream = discovery
348
            .list_and_watch(discovery_key.clone(), Some(cancellation_token.clone()))
349
350
351
352
353
            .await?;
        let runtime_configs_rx =
            watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
                card.runtime_config
            });
354

355
356
357
        let indexer = if kv_router_config.overlap_score_weight == 0.0 {
            // When overlap_score_weight is zero, we don't need to track prefixes
            Indexer::None
358
        } else {
359
            let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);
360
361
362
363
364
365
366
367
368
369
370
371
372

            // If use_kv_events is false, enable TTL and pruning for approximate behavior
            let prune_config = if !kv_router_config.use_kv_events {
                Some(PruneConfig {
                    ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
                    max_tree_size: kv_router_config.router_max_tree_size,
                    prune_target_ratio: kv_router_config.router_prune_target_ratio,
                })
            } else {
                None
            };

            Indexer::KvIndexer(KvIndexer::new_with_frequency(
373
                cancellation_token.clone(),
374
                None, // expiration_duration for frequency tracking
375
376
                block_size,
                kv_indexer_metrics,
377
                prune_config,
378
379
            ))
        };
380

381
        let scheduler = KvScheduler::start(
382
            component.clone(),
383
            block_size,
384
            instance_ids_rx,
385
            runtime_configs_rx.clone(),
386
            selector,
387
            kv_router_config.router_replica_sync,
388
            consumer_id.clone(),
389
390
        )
        .await?;
391

392
393
394
395
396
397
        // Initialize worker query client using namespace abstraction
        // (created before background task so we can use it for startup recovery)
        let worker_query_client =
            worker_query::WorkerQueryClient::new(component.clone(), runtime_configs_rx.clone());
        tracing::info!("Worker query client initialized");

398
        // Start KV event subscriber background process (only when use_kv_events is enabled)
399
400
        // This is spawned as a background task to avoid blocking router startup.
        // The task waits for runtime_configs to determine whether to use NATS Core or JetStream.
401
402
403
        if kv_router_config.use_kv_events
            && let Indexer::KvIndexer(ref kv_indexer) = indexer
        {
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
            // Clone everything needed for the background task
            let component_clone = component.clone();
            let kv_indexer_clone = kv_indexer.clone();
            let cancellation_token_clone = cancellation_token.clone();
            let mut runtime_configs_rx_clone = runtime_configs_rx.clone();
            let worker_query_client_clone =
                worker_query::WorkerQueryClient::new(component.clone(), runtime_configs_rx.clone());

            tokio::spawn(async move {
                // Wait for runtime_configs to have at least one entry
                let (all_local_indexer, count) = loop {
                    {
                        let configs = runtime_configs_rx_clone.borrow();
                        if !configs.is_empty() {
                            let all_local_indexer =
                                configs.values().all(|c| c.enable_local_indexer);
                            break (all_local_indexer, configs.len());
                        }
                    }

                    // Wait for changes to runtime_configs
                    tokio::select! {
                        _ = cancellation_token_clone.cancelled() => {
                            tracing::debug!("Subscriber selection task cancelled");
                            return;
                        }
                        result = runtime_configs_rx_clone.changed() => {
                            if result.is_err() {
                                tracing::debug!("Runtime configs channel closed");
                                return;
                            }
                        }
                    }
                };

                if all_local_indexer {
                    // All workers have local_indexer enabled - use NATS Core
441
                    tracing::info!(
442
                        "All {count} workers have local_indexer enabled, using NATS Core subscription"
443
                    );
444
445
446
447
448
449
450
451
452
453
454
455

                    if let Err(e) = start_kv_router_background_nats_core(
                        component_clone.clone(),
                        kv_indexer_clone.event_sender(),
                        kv_indexer_clone.remove_worker_sender(),
                        cancellation_token_clone.clone(),
                        worker_query_client_clone,
                    )
                    .await
                    {
                        tracing::error!("Failed to start NATS Core subscriber: {e}");
                    }
456
                } else {
457
                    // Not all workers have local_indexer - use JetStream
458
                    tracing::info!(
459
                        "Not all workers have local_indexer enabled, using JetStream subscription"
460
                    );
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480

                    if let Err(e) = start_kv_router_background(
                        component_clone.clone(),
                        consumer_id,
                        kv_indexer_clone.event_sender(),
                        kv_indexer_clone.remove_worker_sender(),
                        kv_router_config
                            .router_snapshot_threshold
                            .map(|_| kv_indexer_clone.get_workers_sender()),
                        kv_router_config
                            .router_snapshot_threshold
                            .map(|_| kv_indexer_clone.snapshot_event_sender()),
                        cancellation_token_clone.clone(),
                        kv_router_config.router_snapshot_threshold,
                        kv_router_config.router_reset_states,
                    )
                    .await
                    {
                        tracing::error!("Failed to start JetStream subscriber: {e}");
                    }
481
                }
482
            });
483
        }
484

485
        tracing::info!("KV Routing initialized");
486
        Ok(Self {
487
            indexer,
488
            scheduler,
489
            block_size,
490
            kv_router_config,
Yan Ru Pei's avatar
Yan Ru Pei committed
491
            cancellation_token,
492
            client,
493
            worker_query_client: Some(worker_query_client),
494
        })
495
496
    }

497
498
499
500
501
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

502
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
503
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
Yan Ru Pei's avatar
Yan Ru Pei committed
504
505
    /// Now also takes optional context_id for request tracking
    pub async fn find_best_match(
506
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
507
        context_id: Option<&str>,
508
        tokens: &[u32],
509
        router_config_override: Option<&RouterConfigOverride>,
510
        update_states: bool,
Yan Ru Pei's avatar
Yan Ru Pei committed
511
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
Yan Ru Pei's avatar
Yan Ru Pei committed
512
513
514
515
516
        // Validate that context_id is provided when update_states is true
        if update_states && context_id.is_none() {
            panic!("context_id must be provided if update_states is true");
        }

517
        let isl_tokens = tokens.len();
518

519
520
521
522
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let seq_hashes = compute_seq_hash_for_block(&block_hashes);

        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
523

524
        // Determine who needs seq_hashes
525
        let needs_process_routing = !self.kv_router_config.use_kv_events;
526
527
528
529
        let scheduler_needs_it = self.kv_router_config.router_track_active_blocks;

        // Optimize cloning: only clone if both need it, otherwise move
        let (maybe_seq_hashes_1, maybe_seq_hashes_2) =
530
            match (needs_process_routing, scheduler_needs_it) {
531
532
533
534
535
536
                (true, true) => (Some(seq_hashes.clone()), Some(seq_hashes)),
                (true, false) => (Some(seq_hashes), None),
                (false, true) => (None, Some(seq_hashes)),
                (false, false) => (None, None),
            };

Yan Ru Pei's avatar
Yan Ru Pei committed
537
        let best_worker = self
538
            .scheduler
539
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
540
                context_id.map(|s| s.to_string()),
541
                isl_tokens,
542
                maybe_seq_hashes_2,
543
                overlap_scores.clone(),
544
                router_config_override,
545
                update_states,
546
            )
547
            .await?;
548

549
550
551
        // Process routing decision when not using KV events (approximate mode with TTL/pruning)
        if needs_process_routing {
            self.indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
552
                .process_routing_decision(best_worker, block_hashes, maybe_seq_hashes_1.unwrap())
553
554
                .await?;
        }
555

556
557
        let overlap_amount = overlap_scores
            .scores
Yan Ru Pei's avatar
Yan Ru Pei committed
558
            .get(&best_worker)
559
560
            .copied()
            .unwrap_or(0);
Yan Ru Pei's avatar
Yan Ru Pei committed
561
        Ok((best_worker, overlap_amount))
562
563
    }

564
565
566
567
568
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
        overlap_blocks: u32,
Yan Ru Pei's avatar
Yan Ru Pei committed
569
        worker: WorkerWithDpRank,
570
571
    ) {
        let isl_tokens = tokens.len();
572
573
574
575
576

        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
            compute_seq_hash_for_block(&block_hashes)
        });
577

578
579
        if let Err(e) = self
            .scheduler
580
            .add_request(
581
                request_id.clone(),
582
                maybe_seq_hashes,
583
584
                isl_tokens,
                overlap_blocks,
Yan Ru Pei's avatar
Yan Ru Pei committed
585
                worker,
586
            )
587
588
589
590
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
591
592
    }

593
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
594
        self.scheduler.mark_prefill_completed(request_id).await
595
596
    }

597
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
598
        self.scheduler.free(request_id).await
599
    }
600

601
    pub fn block_size(&self) -> u32 {
602
603
        self.block_size
    }
604

605
606
607
608
609
610
611
612
613
    /// Get the disaggregated endpoint for a worker, if available.
    /// Used to look up bootstrap host/port for prefill workers.
    pub async fn get_disaggregated_endpoint(
        &self,
        worker_id: u64,
    ) -> Option<crate::local_model::runtime_config::DisaggregatedEndpoint> {
        self.scheduler.get_disaggregated_endpoint(worker_id).await
    }

614
615
616
617
618
619
    /// Get potential prefill and decode loads for all workers
    pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
        let isl_tokens = tokens.len();
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;

620
621
622
623
624
        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
            compute_seq_hash_for_block(&block_hashes)
        });

625
626
        Ok(self
            .scheduler
627
            .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
628
629
630
            .await)
    }

631
632
633
634
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690

    /// Query a specific worker's local KV indexer for its events
    /// (See docstring for `WorkerQueryClient.query_worker()`)
    pub async fn query_worker_local_kv(
        &self,
        worker_id: WorkerId,
        start_event_id: Option<u64>,
        end_event_id: Option<u64>,
    ) -> Result<WorkerKvQueryResponse> {
        let query_client = self
            .worker_query_client
            .as_ref()
            .ok_or_else(|| anyhow::anyhow!("Worker query client not available (NATS required)"))?;

        query_client
            .query_worker(worker_id, start_event_id, end_event_id)
            .await
    }

    /// Recover missed KV events from a specific worker.
    ///
    /// Queries the worker's local KV indexer for events starting from
    /// `start_event_id` and applies them to the router's indexer.
    ///
    /// # Arguments
    ///
    /// * `worker_id` - The worker to recover from
    /// * `start_event_id` - First event ID to fetch (inclusive), or None to start from beginning
    /// * `end_event_id` - Last event ID to fetch (inclusive), or None for all
    pub async fn recover_from_worker(
        &self,
        worker_id: WorkerId,
        start_event_id: Option<u64>,
        end_event_id: Option<u64>,
    ) -> Result<usize> {
        let query_client = self
            .worker_query_client
            .as_ref()
            .ok_or_else(|| anyhow::anyhow!("Worker query client not available"))?;

        let event_tx = match &self.indexer {
            Indexer::KvIndexer(kv_indexer) => kv_indexer.event_sender(),
            Indexer::None => {
                anyhow::bail!("Cannot recover: indexer is disabled (--overlap_score_weight is 0)")
            }
        };

        subscriber::recover_from_worker(
            query_client,
            worker_id,
            start_event_id,
            end_event_id,
            &event_tx,
        )
        .await
    }
691
692
}

Michael Feil's avatar
Michael Feil committed
693
694
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
695
696
697
698
699
700
701
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
702
703
704
705
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
            RouterRequest::New { tokens } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
706
                let (best_worker, overlap_blocks) = self
Yan Ru Pei's avatar
Yan Ru Pei committed
707
                    .find_best_match(Some(&context_id), &tokens, None, true)
Michael Feil's avatar
Michael Feil committed
708
709
710
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
711
712
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
713
714
715
                    overlap_blocks,
                }
            }
716
717
718
719
720
721
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
            RouterRequest::MarkFree => RouterResponse::FreeMarked {
                success: self.free(&context_id).await.is_ok(),
            },
Michael Feil's avatar
Michael Feil committed
722
        };
723
724
725
726
727
728

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
729
730

pub struct KvPushRouter {
731
    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
732
    pub chooser: Arc<KvRouter>,
733
734
735
736
}

impl KvPushRouter {
    pub fn new(
737
        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
738
739
740
741
742
743
744
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
}

#[async_trait]
745
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
746
747
    for KvPushRouter
{
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
    /// Generate method that handles KV-aware routing with three distinct behaviors:
    ///
    /// 1. **If `query_instance_id` annotation is set**:
    ///    - Returns the best matching worker ID without routing the request
    ///    - Does NOT update any router local states
    ///    - Response includes worker_instance_id and token_data annotations
    ///
    /// 2. **If `backend_instance_id` is set in the request**:
    ///    - Routes directly to the specified backend instance
    ///    - DOES update router states to track this request (unless query_instance_id is also set)
    ///    - Bypasses the normal KV matching logic
    ///
    /// 3. **If neither are set (default behavior)**:
    ///    - Finds the best worker based on KV cache overlap
    ///    - Updates router states to track the request
    ///    - Routes to the selected worker
    ///
    /// The router state updates include tracking active sequences and managing
    /// prefill/completion lifecycle for proper KV cache management.
767
768
    async fn generate(
        &self,
769
        request: SingleIn<PreprocessedRequest>,
770
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
771
772
773
        // Extract context ID for request tracking
        let context_id = request.context().id().to_string();

774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
        // Check if this is a query_instance_id request and parse its type
        // Format: "query_instance_id:type" where type is "prefill", "decode", or "" (empty for aggregated)
        // Empty value ("query_instance_id:") means GAIE Aggregated mode - return same worker as both prefill and decode
        let query_instance_annotation = request.get_annotation_value("query_instance_id");
        let is_gaie_agg_query = query_instance_annotation
            .as_ref()
            .is_some_and(|s| s.is_empty());
        let query_instance_type: Option<QueryInstanceType> =
            if let Some(type_str) = &query_instance_annotation {
                match type_str.parse::<QueryInstanceType>() {
                    Ok(t) => Some(t),
                    Err(_) if type_str.is_empty() => {
                        // Empty value is valid for aggregated mode, not a warning
                        None
                    }
                    Err(e) => {
                        tracing::warn!("Invalid query_instance_id type '{type_str}': {e}");
                        None
                    }
                }
            } else {
                None
            };
797
798
799
800

        let (instance_id, dp_rank, overlap_amount) = if let Some(id) = request.backend_instance_id {
            // If instance_id is set, use it and compute actual overlap
            let dp_rank = request.dp_rank.unwrap_or(0);
801
            if query_instance_type.is_some() {
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
                tracing::debug!(
                    "backend_instance_id is set, routing to instance {id} with dp_rank {dp_rank} and ignoring query_instance_id annotation"
                );
            }

            // Compute actual overlap blocks by querying the indexer
            let block_hashes =
                compute_block_hash_for_seq(&request.token_ids, self.chooser.block_size());
            let overlap_scores = self.chooser.indexer.find_matches(block_hashes).await?;
            let worker = WorkerWithDpRank::new(id, dp_rank);
            let overlap_blocks = overlap_scores.scores.get(&worker).copied().unwrap_or(0);

            self.chooser
                .add_request(
                    context_id.clone(),
                    &request.token_ids,
                    overlap_blocks,
                    worker,
                )
                .await;
            (id, dp_rank, overlap_blocks)
        } else {
            // Otherwise, find the best match
825
826
            // Don't update states if this is a query-only request (any query_instance_id annotation)
            let should_update_states = query_instance_annotation.is_none();
827
828
829
830
831
832
            let (best_worker, overlap_amount) = self
                .chooser
                .find_best_match(
                    Some(&context_id),
                    &request.token_ids,
                    request.router_config_override.as_ref(),
833
                    should_update_states,
834
835
836
837
838
                )
                .await?;
            (best_worker.worker_id, best_worker.dp_rank, overlap_amount)
        };

839
840
841
        // If request has a query_instance_id annotation, return worker selection info
        // without routing to the actual worker. Returns LLMEngineOutput with disaggregated_params
        // containing worker_id info, same structure as normal execution for uniform extraction.
842
        let stream_context = request.context().clone();
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898

        // Handle query-only requests (GAIE Stage 1)
        if query_instance_type.is_some() || is_gaie_agg_query {
            let worker_id_info = if is_gaie_agg_query {
                // GAIE Aggregated mode: same worker serves both prefill and decode
                tracing::trace!(
                    query_type = "aggregated",
                    worker_id = instance_id,
                    "Returning aggregated worker selection (same worker for prefill and decode)"
                );
                WorkerIdInfo {
                    prefill_worker_id: Some(instance_id),
                    decode_worker_id: Some(instance_id),
                }
            } else {
                match query_instance_type.unwrap() {
                    QueryInstanceType::Prefill => {
                        tracing::trace!(
                            query_type = "prefill",
                            prefill_worker_id = instance_id,
                            "Returning prefill worker selection"
                        );
                        WorkerIdInfo {
                            prefill_worker_id: Some(instance_id),
                            decode_worker_id: None,
                        }
                    }
                    QueryInstanceType::Decode => {
                        // Get prefill_worker_id from annotation (set by caller after prefill selection)
                        let prefill_worker_id = request
                            .get_annotation_value("prefill_worker_id")
                            .and_then(|s| s.parse::<u64>().ok());
                        tracing::trace!(
                            query_type = "decode",
                            prefill_worker_id = ?prefill_worker_id,
                            decode_worker_id = instance_id,
                            "Returning decode worker selection"
                        );
                        WorkerIdInfo {
                            prefill_worker_id,
                            decode_worker_id: Some(instance_id),
                        }
                    }
                }
            };

            // Return as LLMEngineOutput with disaggregated_params (same structure as normal execution)
            let output = LLMEngineOutput {
                disaggregated_params: Some(json!({
                    "worker_id": worker_id_info,
                    "token_ids": request.token_ids
                })),
                ..Default::default()
            };
            let response = Annotated::from_data(output);
            let stream = stream::iter(vec![response]);
899
900
901
902
903
            return Ok(ResponseStream::new(Box::pin(stream), stream_context));
        }
        let (mut backend_input, context) = request.into_parts();
        backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
        backend_input.dp_rank = Some(dp_rank);
904

905
906
        // Get prefill worker ID from prefill_result if available
        // In aggregated mode, prefill_result is None, so we use decode_worker_id for both
907
        let decode_worker_id = instance_id;
908
909
910
911
912
913
914
915
916
917
        let prefill_worker_id = backend_input
            .prefill_result
            .as_ref()
            .and_then(|prefill_result| {
                prefill_result
                    .disaggregated_params
                    .get("worker_id")
                    .and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok())
                    .and_then(|info| info.prefill_worker_id)
            })
918
919
            .or(Some(decode_worker_id)); // Use decode_worker_id if no separate prefill worker

920
921
922
923
924
925
926
927
928
        let updated_request = context.map(|_| backend_input);

        let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
        let stream_context = response_stream.context();
        let chooser = self.chooser.clone();
        let context_for_monitoring = stream_context.clone();

        let wrapped_stream = Box::pin(async_stream::stream! {
            let mut prefill_marked = false;
929
            let mut first_item = true;
930
931
932
933
934
935
936
937

            loop {
                tokio::select! {
                    biased;

                    _ = context_for_monitoring.stopped() => {
                        tracing::debug!("Request {context_id} cancelled, ending stream");
                        break;
938
                    }
Yan Ru Pei's avatar
Yan Ru Pei committed
939

940
                    item = response_stream.next() => {
941
                        let Some(mut item) = item else {
942
943
                            break;
                        };
944

945
946
                        if !prefill_marked {
                            if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
947
                                tracing::warn!("Failed to mark prefill completed for request {context_id}: {e}");
948
                            }
949
                            prefill_marked = true;
950
                        }
951

952
953
954
955
956
                        // Always inject worker_id in first item's disaggregated_params
                        // This is needed for:
                        // 1. PrefillRouter to know which prefill worker was chosen
                        // 2. Client response when extra_fields contains "worker_id"
                        if first_item {
957
                            first_item = false;
958
959
960
961
962
963

                            let Some(ref mut data) = item.data else {
                                yield item;
                                continue;
                            };

964
                            // prefill_worker_id comes from prefill_result.disaggregated_params or falls back to instance_id
965
                            // decode_worker_id is always the current instance_id
966
967
968
969
970
971
                            let worker_id_info = WorkerIdInfo {
                                prefill_worker_id,
                                decode_worker_id: Some(decode_worker_id),
                            };
                            let worker_id_json = serde_json::to_value(&worker_id_info)
                                .expect("WorkerIdInfo serialization should not fail");
972
973
974
975
976
977

                            if let Some(obj) = data.disaggregated_params.as_mut().and_then(|p| p.as_object_mut()) {
                                obj.insert("worker_id".to_string(), worker_id_json);
                            } else {
                                data.disaggregated_params = Some(json!({"worker_id": worker_id_json}));
                            }
978
                        }
979
980

                        yield item;
981
                    }
982
983
                }
            }
984

985
            if let Err(e) = chooser.free(&context_id).await {
986
                tracing::warn!("Failed to free request {context_id}: {e}");
987
            }
988
989
        });
        Ok(ResponseStream::new(wrapped_stream, stream_context))
990
991
    }
}
Yan Ru Pei's avatar
Yan Ru Pei committed
992
993
994
995
996
997
998

impl Drop for KvRouter {
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}