kv_router.rs 26.8 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
use std::sync::Arc;
6
use std::time::Duration;
7

8
use anyhow::Result;
9
use derive_builder::Builder;
10
use dynamo_runtime::{
11
    component::{Client, Endpoint},
12
    discovery::{DiscoveryQuery, watch_and_extract_field},
13
    pipeline::{
14
15
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
        SingleIn, async_trait,
16
    },
17
    protocols::EndpointId,
18
    protocols::annotated::Annotated,
19
    traits::DistributedRuntimeProvider,
20
21
};
use futures::stream::{self, StreamExt};
22
use serde::{Deserialize, Serialize};
23
use serde_json::json;
24

25
pub mod approx;
26
pub mod indexer;
27
pub mod prefill_router;
28
29
pub mod protocols;
pub mod publisher;
30
pub mod recorder;
31
32
pub mod scheduler;
pub mod scoring;
33
pub mod sequence;
34
pub mod subscriber;
35

36
37
pub use prefill_router::PrefillRouter;

38
39
use crate::{
    kv_router::{
40
        approx::PruneConfig,
41
        indexer::{
42
43
            KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
            compute_block_hash_for_seq, compute_seq_hash_for_block,
44
        },
Yan Ru Pei's avatar
Yan Ru Pei committed
45
46
47
        protocols::{
            LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult, WorkerWithDpRank,
        },
48
        scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
49
        sequence::SequenceError,
50
        subscriber::start_kv_router_background,
51
    },
52
    local_model::runtime_config::ModelRuntimeConfig,
53
    model_card::ModelDeploymentCard,
54
    preprocessor::PreprocessedRequest,
55
    protocols::common::llm_backend::LLMEngineOutput,
56
    tokens::SequenceHash,
57
58
};

59
60
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
61
62
63
64
65

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
66
pub const KV_EVENT_SUBJECT: &str = "kv_events";
67
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
68
69
70
71
72
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
73

74
75
76
77
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
// for router discovery registration
pub const KV_ROUTER_COMPONENT: &str = "kv-router";
pub const KV_ROUTER_ENDPOINT: &str = "generate";

/// Creates an EndpointId for the KV router in the given namespace.
pub fn router_endpoint_id(namespace: String) -> EndpointId {
    EndpointId {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        name: KV_ROUTER_ENDPOINT.to_string(),
    }
}

/// Creates a DiscoveryQuery for the KV router in the given namespace.
pub fn router_discovery_query(namespace: String) -> DiscoveryQuery {
    DiscoveryQuery::Endpoint {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        endpoint: KV_ROUTER_ENDPOINT.to_string(),
    }
}

100
101
102
103
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
104
        workers: &HashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>,
105
        request: &SchedulingRequest,
106
        block_size: u32,
107
108
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
109

110
111
112
113
114
115
116
117
118
119
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize)]
pub struct RouterConfigOverride {
    #[builder(default)]
    pub overlap_score_weight: Option<f64>,

    #[builder(default)]
    pub router_temperature: Option<f64>,
}

120
/// KV Router configuration parameters
121
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
122
123
124
pub struct KvRouterConfig {
    pub overlap_score_weight: f64,

125
    pub router_temperature: f64,
126

127
128
    pub use_kv_events: bool,

129
130
    pub router_replica_sync: bool,

131
132
133
    /// Whether to track active blocks in the router (default: true)
    pub router_track_active_blocks: bool,

134
135
136
    /// Threshold for triggering snapshots. If None, no snapshots will be performed.
    pub router_snapshot_threshold: Option<u32>,

137
    /// Whether to reset the router state on startup (default: false)
138
    pub router_reset_states: bool,
139
140
141
142
143
144
145
146
147

    /// TTL for blocks in seconds (only used when use_kv_events is false, default: 120.0)
    pub router_ttl_secs: f64,

    /// Maximum tree size before pruning (only used when use_kv_events is false, default: 1024)
    pub router_max_tree_size: usize,

    /// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
    pub router_prune_target_ratio: f64,
148
149
150
151
152
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
153
            overlap_score_weight: 1.0,
154
            router_temperature: 0.0,
155
            use_kv_events: true,
156
            router_replica_sync: false,
157
            router_track_active_blocks: true,
158
            router_snapshot_threshold: Some(1000000),
159
            router_reset_states: false,
160
161
162
            router_ttl_secs: 120.0,
            router_max_tree_size: 1024,
            router_prune_target_ratio: 0.8,
163
164
165
166
167
168
169
        }
    }
}

impl KvRouterConfig {
    /// Create a new KvRouterConfig with optional weight values.
    /// If a weight is None, the default value will be used.
170
    #[allow(clippy::too_many_arguments)]
171
172
    pub fn new(
        overlap_score_weight: Option<f64>,
173
        temperature: Option<f64>,
174
        use_kv_events: Option<bool>,
175
        replica_sync: Option<bool>,
176
        track_active_blocks: Option<bool>,
177
178
        router_snapshot_threshold: Option<Option<u32>>,
        router_reset_states: Option<bool>,
179
180
181
        router_ttl_secs: Option<f64>,
        router_max_tree_size: Option<usize>,
        router_prune_target_ratio: Option<f64>,
182
183
184
185
    ) -> Self {
        let default = Self::default();
        Self {
            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
186
            router_temperature: temperature.unwrap_or(default.router_temperature),
187
            use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
188
            router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
189
190
            router_track_active_blocks: track_active_blocks
                .unwrap_or(default.router_track_active_blocks),
191
192
193
            router_snapshot_threshold: router_snapshot_threshold
                .unwrap_or(default.router_snapshot_threshold),
            router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
194
195
196
197
            router_ttl_secs: router_ttl_secs.unwrap_or(default.router_ttl_secs),
            router_max_tree_size: router_max_tree_size.unwrap_or(default.router_max_tree_size),
            router_prune_target_ratio: router_prune_target_ratio
                .unwrap_or(default.router_prune_target_ratio),
198
199
200
201
        }
    }
}

202
pub enum Indexer {
203
204
    /// Updates itself based on KV events emitted by backend workers or routing decisions.
    /// Supports TTL-based expiration and size-based pruning.
205
    /// Has the ability to persist and snapshot states.
206
    KvIndexer(KvIndexer),
207
208
209
210

    /// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
    /// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
    None,
211
212
213
214
215
216
217
218
219
}

impl Indexer {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
220
221
222
            Indexer::None => Ok(OverlapScores {
                scores: HashMap::new(),
                frequencies: Vec::new(),
223
                tree_sizes: HashMap::new(),
224
            }),
225
226
        }
    }
227
228
229
230

    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.dump_events().await,
231
232
233
234
235
            Indexer::None => {
                panic!(
                    "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
                );
            }
236
237
        }
    }
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253

    async fn process_routing_decision(
        &self,
        worker: WorkerWithDpRank,
        local_hashes: Vec<LocalBlockHash>,
        sequence_hashes: Vec<SequenceHash>,
    ) -> Result<(), KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => {
                indexer
                    .process_routing_decision(worker, local_hashes, sequence_hashes)
                    .await
            }
            Indexer::None => Ok(()),
        }
    }
254
255
}

256
257
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
258
pub struct KvRouter {
259
260
261
    indexer: Indexer,

    // How about a Box<dyn KvIndexerInterface>
262
    scheduler: KvScheduler,
263

264
    block_size: u32,
265
266

    kv_router_config: KvRouterConfig,
Yan Ru Pei's avatar
Yan Ru Pei committed
267
268

    cancellation_token: tokio_util::sync::CancellationToken,
269
270

    client: Client,
271
272
273
274
}

impl KvRouter {
    pub async fn new(
275
276
        endpoint: Endpoint,
        client: Client,
277
        block_size: u32,
278
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
279
        kv_router_config: Option<KvRouterConfig>,
280
        consumer_id: String,
281
    ) -> Result<Self> {
282
        let kv_router_config = kv_router_config.unwrap_or_default();
283
        let component = endpoint.component();
284
        let cancellation_token = component.drt().primary_token();
285

286
        let instance_ids_rx = client.instance_avail_watcher();
287

288
289
        // Watch for runtime config updates via discovery interface
        let discovery = component.drt().discovery();
290
        let endpoint_id = endpoint.id();
291
        let discovery_key = DiscoveryQuery::EndpointModels {
292
293
294
            namespace: endpoint_id.namespace.clone(),
            component: endpoint_id.component.clone(),
            endpoint: endpoint_id.name.clone(),
295
296
297
298
299
300
301
302
        };
        let discovery_stream = discovery
            .list_and_watch(discovery_key, Some(cancellation_token.clone()))
            .await?;
        let runtime_configs_rx =
            watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
                card.runtime_config
            });
303

304
305
306
        let indexer = if kv_router_config.overlap_score_weight == 0.0 {
            // When overlap_score_weight is zero, we don't need to track prefixes
            Indexer::None
307
        } else {
308
            let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);
309
310
311
312
313
314
315
316
317
318
319
320
321

            // If use_kv_events is false, enable TTL and pruning for approximate behavior
            let prune_config = if !kv_router_config.use_kv_events {
                Some(PruneConfig {
                    ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
                    max_tree_size: kv_router_config.router_max_tree_size,
                    prune_target_ratio: kv_router_config.router_prune_target_ratio,
                })
            } else {
                None
            };

            Indexer::KvIndexer(KvIndexer::new_with_frequency(
322
                cancellation_token.clone(),
323
                None, // expiration_duration for frequency tracking
324
325
                block_size,
                kv_indexer_metrics,
326
                prune_config,
327
328
            ))
        };
329

330
        let scheduler = KvScheduler::start(
331
            component.clone(),
332
            block_size,
333
            instance_ids_rx,
334
            runtime_configs_rx,
335
            selector,
336
            kv_router_config.router_replica_sync,
337
            consumer_id.clone(),
338
339
        )
        .await?;
340

341
342
343
344
        // Start KV event subscriber background process (only when use_kv_events is enabled)
        if kv_router_config.use_kv_events
            && let Indexer::KvIndexer(ref kv_indexer) = indexer
        {
345
346
            start_kv_router_background(
                component.clone(),
347
                consumer_id,
348
                kv_indexer.event_sender(),
349
                kv_indexer.remove_worker_sender(),
350
351
352
                kv_router_config
                    .router_snapshot_threshold
                    .map(|_| kv_indexer.get_workers_sender()),
353
354
355
356
357
358
359
360
                kv_router_config
                    .router_snapshot_threshold
                    .map(|_| kv_indexer.snapshot_event_sender()),
                cancellation_token.clone(),
                kv_router_config.router_snapshot_threshold,
                kv_router_config.router_reset_states,
            )
            .await?;
361
        }
362

363
        tracing::info!("KV Routing initialized");
364
        Ok(Self {
365
            indexer,
366
            scheduler,
367
            block_size,
368
            kv_router_config,
Yan Ru Pei's avatar
Yan Ru Pei committed
369
            cancellation_token,
370
            client,
371
        })
372
373
    }

374
375
376
377
378
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

379
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
380
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
Yan Ru Pei's avatar
Yan Ru Pei committed
381
382
    /// Now also takes optional context_id for request tracking
    pub async fn find_best_match(
383
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
384
        context_id: Option<&str>,
385
        tokens: &[u32],
386
        router_config_override: Option<&RouterConfigOverride>,
387
        update_states: bool,
Yan Ru Pei's avatar
Yan Ru Pei committed
388
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
Yan Ru Pei's avatar
Yan Ru Pei committed
389
390
391
392
393
        // Validate that context_id is provided when update_states is true
        if update_states && context_id.is_none() {
            panic!("context_id must be provided if update_states is true");
        }

394
        let isl_tokens = tokens.len();
395

396
397
398
399
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let seq_hashes = compute_seq_hash_for_block(&block_hashes);

        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
400

401
        // Determine who needs seq_hashes
402
        let needs_process_routing = !self.kv_router_config.use_kv_events;
403
404
405
406
        let scheduler_needs_it = self.kv_router_config.router_track_active_blocks;

        // Optimize cloning: only clone if both need it, otherwise move
        let (maybe_seq_hashes_1, maybe_seq_hashes_2) =
407
            match (needs_process_routing, scheduler_needs_it) {
408
409
410
411
412
413
                (true, true) => (Some(seq_hashes.clone()), Some(seq_hashes)),
                (true, false) => (Some(seq_hashes), None),
                (false, true) => (None, Some(seq_hashes)),
                (false, false) => (None, None),
            };

Yan Ru Pei's avatar
Yan Ru Pei committed
414
        let best_worker = self
415
            .scheduler
416
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
417
                context_id.map(|s| s.to_string()),
418
                isl_tokens,
419
                maybe_seq_hashes_2,
420
                overlap_scores.clone(),
421
                router_config_override,
422
                update_states,
423
            )
424
            .await?;
425

426
427
428
        // Process routing decision when not using KV events (approximate mode with TTL/pruning)
        if needs_process_routing {
            self.indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
429
                .process_routing_decision(best_worker, block_hashes, maybe_seq_hashes_1.unwrap())
430
431
                .await?;
        }
432

433
434
        let overlap_amount = overlap_scores
            .scores
Yan Ru Pei's avatar
Yan Ru Pei committed
435
            .get(&best_worker)
436
437
            .copied()
            .unwrap_or(0);
Yan Ru Pei's avatar
Yan Ru Pei committed
438
        Ok((best_worker, overlap_amount))
439
440
    }

441
442
443
444
445
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
        overlap_blocks: u32,
Yan Ru Pei's avatar
Yan Ru Pei committed
446
        worker: WorkerWithDpRank,
447
448
    ) {
        let isl_tokens = tokens.len();
449
450
451
452
453

        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
            compute_seq_hash_for_block(&block_hashes)
        });
454

455
456
        if let Err(e) = self
            .scheduler
457
            .add_request(
458
                request_id.clone(),
459
                maybe_seq_hashes,
460
461
                isl_tokens,
                overlap_blocks,
Yan Ru Pei's avatar
Yan Ru Pei committed
462
                worker,
463
            )
464
465
466
467
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
468
469
    }

470
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
471
        self.scheduler.mark_prefill_completed(request_id).await
472
473
    }

474
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
475
        self.scheduler.free(request_id).await
476
    }
477

478
    pub fn block_size(&self) -> u32 {
479
480
        self.block_size
    }
481

482
483
484
485
486
487
    /// Get potential prefill and decode loads for all workers
    pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
        let isl_tokens = tokens.len();
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;

488
489
490
491
492
        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
            compute_seq_hash_for_block(&block_hashes)
        });

493
494
        Ok(self
            .scheduler
495
            .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
496
497
498
            .await)
    }

499
500
501
502
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
503
504
}

Michael Feil's avatar
Michael Feil committed
505
506
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
507
508
509
510
511
512
513
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
514
515
516
517
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
            RouterRequest::New { tokens } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
518
                let (best_worker, overlap_blocks) = self
Yan Ru Pei's avatar
Yan Ru Pei committed
519
                    .find_best_match(Some(&context_id), &tokens, None, true)
Michael Feil's avatar
Michael Feil committed
520
521
522
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
523
524
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
525
526
527
                    overlap_blocks,
                }
            }
528
529
530
531
532
533
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
            RouterRequest::MarkFree => RouterResponse::FreeMarked {
                success: self.free(&context_id).await.is_ok(),
            },
Michael Feil's avatar
Michael Feil committed
534
        };
535
536
537
538
539
540

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
541
542

pub struct KvPushRouter {
543
    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
544
    pub chooser: Arc<KvRouter>,
545
546
547
548
}

impl KvPushRouter {
    pub fn new(
549
        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
550
551
552
553
554
555
556
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
}

#[async_trait]
557
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
558
559
    for KvPushRouter
{
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
    /// Generate method that handles KV-aware routing with three distinct behaviors:
    ///
    /// 1. **If `query_instance_id` annotation is set**:
    ///    - Returns the best matching worker ID without routing the request
    ///    - Does NOT update any router local states
    ///    - Response includes worker_instance_id and token_data annotations
    ///
    /// 2. **If `backend_instance_id` is set in the request**:
    ///    - Routes directly to the specified backend instance
    ///    - DOES update router states to track this request (unless query_instance_id is also set)
    ///    - Bypasses the normal KV matching logic
    ///
    /// 3. **If neither are set (default behavior)**:
    ///    - Finds the best worker based on KV cache overlap
    ///    - Updates router states to track the request
    ///    - Routes to the selected worker
    ///
    /// The router state updates include tracking active sequences and managing
    /// prefill/completion lifecycle for proper KV cache management.
579
580
    async fn generate(
        &self,
581
        request: SingleIn<PreprocessedRequest>,
582
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
        // Extract context ID for request tracking
        let context_id = request.context().id().to_string();

        // Check if this is a query_instance_id request first
        let query_instance_id = request.has_annotation("query_instance_id");

        let (instance_id, dp_rank, overlap_amount) = if let Some(id) = request.backend_instance_id {
            // If instance_id is set, use it and compute actual overlap
            let dp_rank = request.dp_rank.unwrap_or(0);
            if query_instance_id {
                tracing::debug!(
                    "backend_instance_id is set, routing to instance {id} with dp_rank {dp_rank} and ignoring query_instance_id annotation"
                );
            }

            // Compute actual overlap blocks by querying the indexer
            let block_hashes =
                compute_block_hash_for_seq(&request.token_ids, self.chooser.block_size());
            let overlap_scores = self.chooser.indexer.find_matches(block_hashes).await?;
            let worker = WorkerWithDpRank::new(id, dp_rank);
            let overlap_blocks = overlap_scores.scores.get(&worker).copied().unwrap_or(0);

            self.chooser
                .add_request(
                    context_id.clone(),
                    &request.token_ids,
                    overlap_blocks,
                    worker,
                )
                .await;
            (id, dp_rank, overlap_blocks)
        } else {
            // Otherwise, find the best match
            let (best_worker, overlap_amount) = self
                .chooser
                .find_best_match(
                    Some(&context_id),
                    &request.token_ids,
                    request.router_config_override.as_ref(),
                    !query_instance_id, // Don't update states if query_instance_id
                )
                .await?;
            (best_worker.worker_id, best_worker.dp_rank, overlap_amount)
        };

        // if request has the annotation "query_instance_id",
        // then the request will not be routed to the worker,
        // and instead the worker_instance_id will be returned.
        let stream_context = request.context().clone();
        if query_instance_id {
            let instance_id_str = instance_id.to_string();
            let response = Annotated::from_annotation("worker_instance_id", &instance_id_str)?;

            // Return the tokens in nvext.token_data format
            let response_tokens = Annotated::from_annotation("token_data", &request.token_ids)?;
            tracing::trace!(
                "Tokens requested in the response through the query_instance_id annotation: {:?}",
                response_tokens
            );
            let stream = stream::iter(vec![response, response_tokens]);
            return Ok(ResponseStream::new(Box::pin(stream), stream_context));
        }
        let (mut backend_input, context) = request.into_parts();
        backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
        backend_input.dp_rank = Some(dp_rank);
648
649
650
651
652
653
654
655
656
657

        // Get prefill worker ID if available (stored by PrefillRouter)
        // In aggregated mode, prefill_worker_id is None, so we use decode_worker_id for both
        let decode_worker_id = instance_id;
        let prefill_worker_id = context
            .get::<u64>("prefill_worker_id")
            .ok()
            .map(|arc| *arc)
            .or(Some(decode_worker_id)); // Use decode_worker_id if no separate prefill worker

658
659
660
661
662
663
664
665
666
        let updated_request = context.map(|_| backend_input);

        let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
        let stream_context = response_stream.context();
        let chooser = self.chooser.clone();
        let context_for_monitoring = stream_context.clone();

        let wrapped_stream = Box::pin(async_stream::stream! {
            let mut prefill_marked = false;
667
            let mut first_item = true;
668
669
670
671
672
673
674
675

            loop {
                tokio::select! {
                    biased;

                    _ = context_for_monitoring.stopped() => {
                        tracing::debug!("Request {context_id} cancelled, ending stream");
                        break;
676
                    }
Yan Ru Pei's avatar
Yan Ru Pei committed
677

678
                    item = response_stream.next() => {
679
                        let Some(mut item) = item else {
680
681
                            break;
                        };
682

683
684
                        if !prefill_marked {
                            if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
685
                                tracing::warn!("Failed to mark prefill completed for request {context_id}: {e}");
686
                            }
687
                            prefill_marked = true;
688
                        }
689

690
691
692
693
694
                        // Always inject worker_id in first item's disaggregated_params
                        // This is needed for:
                        // 1. PrefillRouter to know which prefill worker was chosen
                        // 2. Client response when extra_fields contains "worker_id"
                        if first_item {
695
                            first_item = false;
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713

                            let Some(ref mut data) = item.data else {
                                yield item;
                                continue;
                            };

                            // prefill_worker_id comes from context (set by PrefillRouter) or falls back to instance_id
                            // decode_worker_id is always the current instance_id
                            let worker_id_json = json!({
                                "prefill_worker_id": prefill_worker_id,
                                "decode_worker_id": decode_worker_id,
                            });

                            if let Some(obj) = data.disaggregated_params.as_mut().and_then(|p| p.as_object_mut()) {
                                obj.insert("worker_id".to_string(), worker_id_json);
                            } else {
                                data.disaggregated_params = Some(json!({"worker_id": worker_id_json}));
                            }
714
                        }
715
716

                        yield item;
717
                    }
718
719
                }
            }
720

721
            if let Err(e) = chooser.free(&context_id).await {
722
                tracing::warn!("Failed to free request {context_id}: {e}");
723
            }
724
725
        });
        Ok(ResponseStream::new(wrapped_stream, stream_context))
726
727
    }
}
Yan Ru Pei's avatar
Yan Ru Pei committed
728
729
730
731
732
733
734

impl Drop for KvRouter {
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}