kv_router.rs 31.6 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
use std::sync::Arc;
6
use std::time::Duration;
7

8
use anyhow::Result;
9
use derive_builder::Builder;
10
use dynamo_runtime::{
11
    component::{Client, Endpoint},
12
    discovery::{DiscoveryQuery, watch_and_extract_field},
13
    pipeline::{
14
15
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
        SingleIn, async_trait,
16
    },
17
    protocols::EndpointId,
18
    protocols::annotated::Annotated,
19
    traits::DistributedRuntimeProvider,
20
21
};
use futures::stream::{self, StreamExt};
22
use serde::{Deserialize, Serialize};
23
use serde_json::json;
24

25
26
use crate::protocols::openai::nvext::WorkerIdInfo;

27
pub mod approx;
28
pub mod indexer;
29
pub mod prefill_router;
30
31
pub mod protocols;
pub mod publisher;
32
pub mod recorder;
33
34
pub mod scheduler;
pub mod scoring;
35
pub mod sequence;
36
pub mod subscriber;
37
pub mod worker_query;
38

39
use indexer::WorkerKvQueryResponse;
40
pub use prefill_router::PrefillRouter;
41
use worker_query::WorkerQueryClient;
42

43
44
use crate::{
    kv_router::{
45
        approx::PruneConfig,
46
        indexer::{
47
48
            KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
            compute_block_hash_for_seq, compute_seq_hash_for_block,
49
        },
Yan Ru Pei's avatar
Yan Ru Pei committed
50
        protocols::{
51
52
            LocalBlockHash, RouterRequest, RouterResponse, WorkerId, WorkerSelectionResult,
            WorkerWithDpRank,
Yan Ru Pei's avatar
Yan Ru Pei committed
53
        },
54
        scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
55
        sequence::SequenceError,
56
        subscriber::{recover_from_all_workers, start_kv_router_background},
57
    },
58
    local_model::runtime_config::ModelRuntimeConfig,
59
    model_card::ModelDeploymentCard,
60
    preprocessor::PreprocessedRequest,
61
    protocols::common::llm_backend::LLMEngineOutput,
62
    tokens::SequenceHash,
63
64
};

65
66
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
67
68
69
70
71

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
72
pub const KV_EVENT_SUBJECT: &str = "kv_events";
73
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
74
75
76
77
78
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
79

80
81
82
83
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

84
85
86
87
// for worker-local kvindexer query
pub const WORKER_KV_INDEXER_QUERY_SUBJECT: &str = "worker_kv_indexer_query";
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
// for router discovery registration
pub const KV_ROUTER_COMPONENT: &str = "kv-router";
pub const KV_ROUTER_ENDPOINT: &str = "generate";

/// Creates an EndpointId for the KV router in the given namespace.
pub fn router_endpoint_id(namespace: String) -> EndpointId {
    EndpointId {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        name: KV_ROUTER_ENDPOINT.to_string(),
    }
}

/// Creates a DiscoveryQuery for the KV router in the given namespace.
pub fn router_discovery_query(namespace: String) -> DiscoveryQuery {
    DiscoveryQuery::Endpoint {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        endpoint: KV_ROUTER_ENDPOINT.to_string(),
    }
}

110
111
112
113
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
114
        workers: &HashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>,
115
        request: &SchedulingRequest,
116
        block_size: u32,
117
118
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
119

120
121
122
123
124
125
126
127
128
129
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize)]
pub struct RouterConfigOverride {
    #[builder(default)]
    pub overlap_score_weight: Option<f64>,

    #[builder(default)]
    pub router_temperature: Option<f64>,
}

130
/// KV Router configuration parameters
131
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
132
133
134
pub struct KvRouterConfig {
    pub overlap_score_weight: f64,

135
    pub router_temperature: f64,
136

137
138
    pub use_kv_events: bool,

139
140
    pub router_replica_sync: bool,

141
142
143
    /// Whether to track active blocks in the router (default: true)
    pub router_track_active_blocks: bool,

144
145
146
    /// Threshold for triggering snapshots. If None, no snapshots will be performed.
    pub router_snapshot_threshold: Option<u32>,

147
    /// Whether to reset the router state on startup (default: false)
148
    pub router_reset_states: bool,
149
150
151
152
153
154
155
156
157

    /// TTL for blocks in seconds (only used when use_kv_events is false, default: 120.0)
    pub router_ttl_secs: f64,

    /// Maximum tree size before pruning (only used when use_kv_events is false, default: 1024)
    pub router_max_tree_size: usize,

    /// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
    pub router_prune_target_ratio: f64,
158
159
160
161
162
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
163
            overlap_score_weight: 1.0,
164
            router_temperature: 0.0,
165
            use_kv_events: true,
166
            router_replica_sync: false,
167
            router_track_active_blocks: true,
168
            router_snapshot_threshold: Some(1000000),
169
            router_reset_states: false,
170
171
172
            router_ttl_secs: 120.0,
            router_max_tree_size: 1024,
            router_prune_target_ratio: 0.8,
173
174
175
176
177
178
179
        }
    }
}

impl KvRouterConfig {
    /// Create a new KvRouterConfig with optional weight values.
    /// If a weight is None, the default value will be used.
180
    #[allow(clippy::too_many_arguments)]
181
182
    pub fn new(
        overlap_score_weight: Option<f64>,
183
        temperature: Option<f64>,
184
        use_kv_events: Option<bool>,
185
        replica_sync: Option<bool>,
186
        track_active_blocks: Option<bool>,
187
188
        router_snapshot_threshold: Option<Option<u32>>,
        router_reset_states: Option<bool>,
189
190
191
        router_ttl_secs: Option<f64>,
        router_max_tree_size: Option<usize>,
        router_prune_target_ratio: Option<f64>,
192
193
194
195
    ) -> Self {
        let default = Self::default();
        Self {
            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
196
            router_temperature: temperature.unwrap_or(default.router_temperature),
197
            use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
198
            router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
199
200
            router_track_active_blocks: track_active_blocks
                .unwrap_or(default.router_track_active_blocks),
201
202
203
            router_snapshot_threshold: router_snapshot_threshold
                .unwrap_or(default.router_snapshot_threshold),
            router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
204
205
206
207
            router_ttl_secs: router_ttl_secs.unwrap_or(default.router_ttl_secs),
            router_max_tree_size: router_max_tree_size.unwrap_or(default.router_max_tree_size),
            router_prune_target_ratio: router_prune_target_ratio
                .unwrap_or(default.router_prune_target_ratio),
208
209
210
211
        }
    }
}

212
pub enum Indexer {
213
214
    /// Updates itself based on KV events emitted by backend workers or routing decisions.
    /// Supports TTL-based expiration and size-based pruning.
215
    /// Has the ability to persist and snapshot states.
216
    KvIndexer(KvIndexer),
217
218
219
220

    /// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
    /// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
    None,
221
222
223
224
225
226
227
228
229
}

impl Indexer {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
230
231
232
            Indexer::None => Ok(OverlapScores {
                scores: HashMap::new(),
                frequencies: Vec::new(),
233
                tree_sizes: HashMap::new(),
234
            }),
235
236
        }
    }
237
238
239
240

    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.dump_events().await,
241
242
243
244
245
            Indexer::None => {
                panic!(
                    "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
                );
            }
246
247
        }
    }
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263

    async fn process_routing_decision(
        &self,
        worker: WorkerWithDpRank,
        local_hashes: Vec<LocalBlockHash>,
        sequence_hashes: Vec<SequenceHash>,
    ) -> Result<(), KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => {
                indexer
                    .process_routing_decision(worker, local_hashes, sequence_hashes)
                    .await
            }
            Indexer::None => Ok(()),
        }
    }
264
265
}

266
267
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
268
pub struct KvRouter {
269
270
271
    indexer: Indexer,

    // How about a Box<dyn KvIndexerInterface>
272
    scheduler: KvScheduler,
273

274
    block_size: u32,
275
276

    kv_router_config: KvRouterConfig,
Yan Ru Pei's avatar
Yan Ru Pei committed
277
278

    cancellation_token: tokio_util::sync::CancellationToken,
279
280

    client: Client,
281
282

    worker_query_client: Option<WorkerQueryClient>,
283
284
285
286
}

impl KvRouter {
    pub async fn new(
287
288
        endpoint: Endpoint,
        client: Client,
289
        block_size: u32,
290
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
291
        kv_router_config: Option<KvRouterConfig>,
292
        consumer_id: String,
293
    ) -> Result<Self> {
294
        let kv_router_config = kv_router_config.unwrap_or_default();
295
        let component = endpoint.component();
296
        let cancellation_token = component.drt().primary_token();
297

298
        let instance_ids_rx = client.instance_avail_watcher();
299

300
301
        // Watch for runtime config updates via discovery interface
        let discovery = component.drt().discovery();
302
        let endpoint_id = endpoint.id();
303
        let discovery_key = DiscoveryQuery::EndpointModels {
304
305
306
            namespace: endpoint_id.namespace.clone(),
            component: endpoint_id.component.clone(),
            endpoint: endpoint_id.name.clone(),
307
308
        };
        let discovery_stream = discovery
309
            .list_and_watch(discovery_key.clone(), Some(cancellation_token.clone()))
310
311
312
313
314
            .await?;
        let runtime_configs_rx =
            watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
                card.runtime_config
            });
315

316
317
318
        let indexer = if kv_router_config.overlap_score_weight == 0.0 {
            // When overlap_score_weight is zero, we don't need to track prefixes
            Indexer::None
319
        } else {
320
            let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);
321
322
323
324
325
326
327
328
329
330
331
332
333

            // If use_kv_events is false, enable TTL and pruning for approximate behavior
            let prune_config = if !kv_router_config.use_kv_events {
                Some(PruneConfig {
                    ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
                    max_tree_size: kv_router_config.router_max_tree_size,
                    prune_target_ratio: kv_router_config.router_prune_target_ratio,
                })
            } else {
                None
            };

            Indexer::KvIndexer(KvIndexer::new_with_frequency(
334
                cancellation_token.clone(),
335
                None, // expiration_duration for frequency tracking
336
337
                block_size,
                kv_indexer_metrics,
338
                prune_config,
339
340
            ))
        };
341

342
        let scheduler = KvScheduler::start(
343
            component.clone(),
344
            block_size,
345
            instance_ids_rx,
346
            runtime_configs_rx.clone(),
347
            selector,
348
            kv_router_config.router_replica_sync,
349
            consumer_id.clone(),
350
351
        )
        .await?;
352

353
354
355
356
357
358
        // Initialize worker query client using namespace abstraction
        // (created before background task so we can use it for startup recovery)
        let worker_query_client =
            worker_query::WorkerQueryClient::new(component.clone(), runtime_configs_rx.clone());
        tracing::info!("Worker query client initialized");

359
360
361
362
        // Start KV event subscriber background process (only when use_kv_events is enabled)
        if kv_router_config.use_kv_events
            && let Indexer::KvIndexer(ref kv_indexer) = indexer
        {
363
364
            start_kv_router_background(
                component.clone(),
365
                consumer_id,
366
                kv_indexer.event_sender(),
367
                kv_indexer.remove_worker_sender(),
368
369
370
                kv_router_config
                    .router_snapshot_threshold
                    .map(|_| kv_indexer.get_workers_sender()),
371
372
373
374
375
376
377
378
                kv_router_config
                    .router_snapshot_threshold
                    .map(|_| kv_indexer.snapshot_event_sender()),
                cancellation_token.clone(),
                kv_router_config.router_snapshot_threshold,
                kv_router_config.router_reset_states,
            )
            .await?;
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419

            // Perform startup recovery from workers with local indexers
            // This catches up on any events missed while the router was offline
            let last_event_ids = kv_indexer
                .get_last_received_event_ids()
                .await
                .unwrap_or_default();
            let instances = client.instance_source.as_ref().borrow().clone();
            let worker_ids: Vec<WorkerId> = instances.iter().map(|i| i.instance_id).collect();

            if !worker_ids.is_empty() {
                tracing::info!(
                    worker_count = worker_ids.len(),
                    "Starting recovery from workers with local indexers"
                );

                // NOTE: recover_from_all_workers() is a no-op if
                // Worker with worker_id is not associated with a
                // local indexer instance.
                let recovered = recover_from_all_workers(
                    &worker_query_client,
                    &last_event_ids,
                    &worker_ids,
                    &kv_indexer.event_sender(),
                )
                .await;

                if recovered > 0 {
                    tracing::info!(
                        recovered_events = recovered,
                        "KV Router startup: Recovered {} KV events from workers {:?}",
                        recovered,
                        worker_ids
                    );
                } else {
                    tracing::info!(
                        "KV Router startup: No KV events recovered from workers {:?}",
                        worker_ids
                    );
                }
            }
420
        }
421

422
        tracing::info!("KV Routing initialized");
423
        Ok(Self {
424
            indexer,
425
            scheduler,
426
            block_size,
427
            kv_router_config,
Yan Ru Pei's avatar
Yan Ru Pei committed
428
            cancellation_token,
429
            client,
430
            worker_query_client: Some(worker_query_client),
431
        })
432
433
    }

434
435
436
437
438
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

439
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
440
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
Yan Ru Pei's avatar
Yan Ru Pei committed
441
442
    /// Now also takes optional context_id for request tracking
    pub async fn find_best_match(
443
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
444
        context_id: Option<&str>,
445
        tokens: &[u32],
446
        router_config_override: Option<&RouterConfigOverride>,
447
        update_states: bool,
Yan Ru Pei's avatar
Yan Ru Pei committed
448
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
Yan Ru Pei's avatar
Yan Ru Pei committed
449
450
451
452
453
        // Validate that context_id is provided when update_states is true
        if update_states && context_id.is_none() {
            panic!("context_id must be provided if update_states is true");
        }

454
        let isl_tokens = tokens.len();
455

456
457
458
459
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let seq_hashes = compute_seq_hash_for_block(&block_hashes);

        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
460

461
        // Determine who needs seq_hashes
462
        let needs_process_routing = !self.kv_router_config.use_kv_events;
463
464
465
466
        let scheduler_needs_it = self.kv_router_config.router_track_active_blocks;

        // Optimize cloning: only clone if both need it, otherwise move
        let (maybe_seq_hashes_1, maybe_seq_hashes_2) =
467
            match (needs_process_routing, scheduler_needs_it) {
468
469
470
471
472
473
                (true, true) => (Some(seq_hashes.clone()), Some(seq_hashes)),
                (true, false) => (Some(seq_hashes), None),
                (false, true) => (None, Some(seq_hashes)),
                (false, false) => (None, None),
            };

Yan Ru Pei's avatar
Yan Ru Pei committed
474
        let best_worker = self
475
            .scheduler
476
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
477
                context_id.map(|s| s.to_string()),
478
                isl_tokens,
479
                maybe_seq_hashes_2,
480
                overlap_scores.clone(),
481
                router_config_override,
482
                update_states,
483
            )
484
            .await?;
485

486
487
488
        // Process routing decision when not using KV events (approximate mode with TTL/pruning)
        if needs_process_routing {
            self.indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
489
                .process_routing_decision(best_worker, block_hashes, maybe_seq_hashes_1.unwrap())
490
491
                .await?;
        }
492

493
494
        let overlap_amount = overlap_scores
            .scores
Yan Ru Pei's avatar
Yan Ru Pei committed
495
            .get(&best_worker)
496
497
            .copied()
            .unwrap_or(0);
Yan Ru Pei's avatar
Yan Ru Pei committed
498
        Ok((best_worker, overlap_amount))
499
500
    }

501
502
503
504
505
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
        overlap_blocks: u32,
Yan Ru Pei's avatar
Yan Ru Pei committed
506
        worker: WorkerWithDpRank,
507
508
    ) {
        let isl_tokens = tokens.len();
509
510
511
512
513

        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
            compute_seq_hash_for_block(&block_hashes)
        });
514

515
516
        if let Err(e) = self
            .scheduler
517
            .add_request(
518
                request_id.clone(),
519
                maybe_seq_hashes,
520
521
                isl_tokens,
                overlap_blocks,
Yan Ru Pei's avatar
Yan Ru Pei committed
522
                worker,
523
            )
524
525
526
527
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
528
529
    }

530
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
531
        self.scheduler.mark_prefill_completed(request_id).await
532
533
    }

534
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
535
        self.scheduler.free(request_id).await
536
    }
537

538
    pub fn block_size(&self) -> u32 {
539
540
        self.block_size
    }
541

542
543
544
545
546
547
    /// Get potential prefill and decode loads for all workers
    pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
        let isl_tokens = tokens.len();
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;

548
549
550
551
552
        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
            compute_seq_hash_for_block(&block_hashes)
        });

553
554
        Ok(self
            .scheduler
555
            .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
556
557
558
            .await)
    }

559
560
561
562
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618

    /// Query a specific worker's local KV indexer for its events
    /// (See docstring for `WorkerQueryClient.query_worker()`)
    pub async fn query_worker_local_kv(
        &self,
        worker_id: WorkerId,
        start_event_id: Option<u64>,
        end_event_id: Option<u64>,
    ) -> Result<WorkerKvQueryResponse> {
        let query_client = self
            .worker_query_client
            .as_ref()
            .ok_or_else(|| anyhow::anyhow!("Worker query client not available (NATS required)"))?;

        query_client
            .query_worker(worker_id, start_event_id, end_event_id)
            .await
    }

    /// Recover missed KV events from a specific worker.
    ///
    /// Queries the worker's local KV indexer for events starting from
    /// `start_event_id` and applies them to the router's indexer.
    ///
    /// # Arguments
    ///
    /// * `worker_id` - The worker to recover from
    /// * `start_event_id` - First event ID to fetch (inclusive), or None to start from beginning
    /// * `end_event_id` - Last event ID to fetch (inclusive), or None for all
    pub async fn recover_from_worker(
        &self,
        worker_id: WorkerId,
        start_event_id: Option<u64>,
        end_event_id: Option<u64>,
    ) -> Result<usize> {
        let query_client = self
            .worker_query_client
            .as_ref()
            .ok_or_else(|| anyhow::anyhow!("Worker query client not available"))?;

        let event_tx = match &self.indexer {
            Indexer::KvIndexer(kv_indexer) => kv_indexer.event_sender(),
            Indexer::None => {
                anyhow::bail!("Cannot recover: indexer is disabled (--overlap_score_weight is 0)")
            }
        };

        subscriber::recover_from_worker(
            query_client,
            worker_id,
            start_event_id,
            end_event_id,
            &event_tx,
        )
        .await
    }
619
620
}

Michael Feil's avatar
Michael Feil committed
621
622
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
623
624
625
626
627
628
629
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
630
631
632
633
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
            RouterRequest::New { tokens } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
634
                let (best_worker, overlap_blocks) = self
Yan Ru Pei's avatar
Yan Ru Pei committed
635
                    .find_best_match(Some(&context_id), &tokens, None, true)
Michael Feil's avatar
Michael Feil committed
636
637
638
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
639
640
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
641
642
643
                    overlap_blocks,
                }
            }
644
645
646
647
648
649
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
            RouterRequest::MarkFree => RouterResponse::FreeMarked {
                success: self.free(&context_id).await.is_ok(),
            },
Michael Feil's avatar
Michael Feil committed
650
        };
651
652
653
654
655
656

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
657
658

pub struct KvPushRouter {
659
    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
660
    pub chooser: Arc<KvRouter>,
661
662
663
664
}

impl KvPushRouter {
    pub fn new(
665
        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
666
667
668
669
670
671
672
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
}

#[async_trait]
673
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
674
675
    for KvPushRouter
{
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
    /// Generate method that handles KV-aware routing with three distinct behaviors:
    ///
    /// 1. **If `query_instance_id` annotation is set**:
    ///    - Returns the best matching worker ID without routing the request
    ///    - Does NOT update any router local states
    ///    - Response includes worker_instance_id and token_data annotations
    ///
    /// 2. **If `backend_instance_id` is set in the request**:
    ///    - Routes directly to the specified backend instance
    ///    - DOES update router states to track this request (unless query_instance_id is also set)
    ///    - Bypasses the normal KV matching logic
    ///
    /// 3. **If neither are set (default behavior)**:
    ///    - Finds the best worker based on KV cache overlap
    ///    - Updates router states to track the request
    ///    - Routes to the selected worker
    ///
    /// The router state updates include tracking active sequences and managing
    /// prefill/completion lifecycle for proper KV cache management.
695
696
    async fn generate(
        &self,
697
        request: SingleIn<PreprocessedRequest>,
698
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
        // Extract context ID for request tracking
        let context_id = request.context().id().to_string();

        // Check if this is a query_instance_id request first
        let query_instance_id = request.has_annotation("query_instance_id");

        let (instance_id, dp_rank, overlap_amount) = if let Some(id) = request.backend_instance_id {
            // If instance_id is set, use it and compute actual overlap
            let dp_rank = request.dp_rank.unwrap_or(0);
            if query_instance_id {
                tracing::debug!(
                    "backend_instance_id is set, routing to instance {id} with dp_rank {dp_rank} and ignoring query_instance_id annotation"
                );
            }

            // Compute actual overlap blocks by querying the indexer
            let block_hashes =
                compute_block_hash_for_seq(&request.token_ids, self.chooser.block_size());
            let overlap_scores = self.chooser.indexer.find_matches(block_hashes).await?;
            let worker = WorkerWithDpRank::new(id, dp_rank);
            let overlap_blocks = overlap_scores.scores.get(&worker).copied().unwrap_or(0);

            self.chooser
                .add_request(
                    context_id.clone(),
                    &request.token_ids,
                    overlap_blocks,
                    worker,
                )
                .await;
            (id, dp_rank, overlap_blocks)
        } else {
            // Otherwise, find the best match
            let (best_worker, overlap_amount) = self
                .chooser
                .find_best_match(
                    Some(&context_id),
                    &request.token_ids,
                    request.router_config_override.as_ref(),
                    !query_instance_id, // Don't update states if query_instance_id
                )
                .await?;
            (best_worker.worker_id, best_worker.dp_rank, overlap_amount)
        };

        // if request has the annotation "query_instance_id",
        // then the request will not be routed to the worker,
        // and instead the worker_instance_id will be returned.
        let stream_context = request.context().clone();
        if query_instance_id {
            let instance_id_str = instance_id.to_string();
            let response = Annotated::from_annotation("worker_instance_id", &instance_id_str)?;

            // Return the tokens in nvext.token_data format
            let response_tokens = Annotated::from_annotation("token_data", &request.token_ids)?;
            tracing::trace!(
                "Tokens requested in the response through the query_instance_id annotation: {:?}",
                response_tokens
            );
            let stream = stream::iter(vec![response, response_tokens]);
            return Ok(ResponseStream::new(Box::pin(stream), stream_context));
        }
        let (mut backend_input, context) = request.into_parts();
        backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
        backend_input.dp_rank = Some(dp_rank);
764

765
766
        // Get prefill worker ID from prefill_result if available
        // In aggregated mode, prefill_result is None, so we use decode_worker_id for both
767
        let decode_worker_id = instance_id;
768
769
770
771
772
773
774
775
776
777
        let prefill_worker_id = backend_input
            .prefill_result
            .as_ref()
            .and_then(|prefill_result| {
                prefill_result
                    .disaggregated_params
                    .get("worker_id")
                    .and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok())
                    .and_then(|info| info.prefill_worker_id)
            })
778
779
            .or(Some(decode_worker_id)); // Use decode_worker_id if no separate prefill worker

780
781
782
783
784
785
786
787
788
        let updated_request = context.map(|_| backend_input);

        let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
        let stream_context = response_stream.context();
        let chooser = self.chooser.clone();
        let context_for_monitoring = stream_context.clone();

        let wrapped_stream = Box::pin(async_stream::stream! {
            let mut prefill_marked = false;
789
            let mut first_item = true;
790
791
792
793
794
795
796
797

            loop {
                tokio::select! {
                    biased;

                    _ = context_for_monitoring.stopped() => {
                        tracing::debug!("Request {context_id} cancelled, ending stream");
                        break;
798
                    }
Yan Ru Pei's avatar
Yan Ru Pei committed
799

800
                    item = response_stream.next() => {
801
                        let Some(mut item) = item else {
802
803
                            break;
                        };
804

805
806
                        if !prefill_marked {
                            if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
807
                                tracing::warn!("Failed to mark prefill completed for request {context_id}: {e}");
808
                            }
809
                            prefill_marked = true;
810
                        }
811

812
813
814
815
816
                        // Always inject worker_id in first item's disaggregated_params
                        // This is needed for:
                        // 1. PrefillRouter to know which prefill worker was chosen
                        // 2. Client response when extra_fields contains "worker_id"
                        if first_item {
817
                            first_item = false;
818
819
820
821
822
823

                            let Some(ref mut data) = item.data else {
                                yield item;
                                continue;
                            };

824
                            // prefill_worker_id comes from prefill_result.disaggregated_params or falls back to instance_id
825
                            // decode_worker_id is always the current instance_id
826
827
828
829
830
831
                            let worker_id_info = WorkerIdInfo {
                                prefill_worker_id,
                                decode_worker_id: Some(decode_worker_id),
                            };
                            let worker_id_json = serde_json::to_value(&worker_id_info)
                                .expect("WorkerIdInfo serialization should not fail");
832
833
834
835
836
837

                            if let Some(obj) = data.disaggregated_params.as_mut().and_then(|p| p.as_object_mut()) {
                                obj.insert("worker_id".to_string(), worker_id_json);
                            } else {
                                data.disaggregated_params = Some(json!({"worker_id": worker_id_json}));
                            }
838
                        }
839
840

                        yield item;
841
                    }
842
843
                }
            }
844

845
            if let Err(e) = chooser.free(&context_id).await {
846
                tracing::warn!("Failed to free request {context_id}: {e}");
847
            }
848
849
        });
        Ok(ResponseStream::new(wrapped_stream, stream_context))
850
851
    }
}
Yan Ru Pei's avatar
Yan Ru Pei committed
852
853
854
855
856
857
858

impl Drop for KvRouter {
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}