kv_router.rs 27.3 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
use std::sync::Arc;
6
use std::time::Duration;
7

8
use anyhow::Result;
9
use derive_builder::Builder;
10
use dynamo_runtime::{
11
    component::{Client, Endpoint},
12
    discovery::{DiscoveryQuery, watch_and_extract_field},
13
    pipeline::{
14
15
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
        SingleIn, async_trait,
16
    },
17
    protocols::EndpointId,
18
    protocols::annotated::Annotated,
19
    traits::DistributedRuntimeProvider,
20
21
};
use futures::stream::{self, StreamExt};
22
use serde::{Deserialize, Serialize};
23
use serde_json::json;
24

25
26
use crate::protocols::openai::nvext::WorkerIdInfo;

27
pub mod approx;
28
pub mod indexer;
29
pub mod prefill_router;
30
31
pub mod protocols;
pub mod publisher;
32
pub mod recorder;
33
34
pub mod scheduler;
pub mod scoring;
35
pub mod sequence;
36
pub mod subscriber;
37

38
39
pub use prefill_router::PrefillRouter;

40
41
use crate::{
    kv_router::{
42
        approx::PruneConfig,
43
        indexer::{
44
45
            KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
            compute_block_hash_for_seq, compute_seq_hash_for_block,
46
        },
Yan Ru Pei's avatar
Yan Ru Pei committed
47
48
49
        protocols::{
            LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult, WorkerWithDpRank,
        },
50
        scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
51
        sequence::SequenceError,
52
        subscriber::start_kv_router_background,
53
    },
54
    local_model::runtime_config::ModelRuntimeConfig,
55
    model_card::ModelDeploymentCard,
56
    preprocessor::PreprocessedRequest,
57
    protocols::common::llm_backend::LLMEngineOutput,
58
    tokens::SequenceHash,
59
60
};

61
62
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
63
64
65
66
67

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
68
pub const KV_EVENT_SUBJECT: &str = "kv_events";
69
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
70
71
72
73
74
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
75

76
77
78
79
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
// for router discovery registration
pub const KV_ROUTER_COMPONENT: &str = "kv-router";
pub const KV_ROUTER_ENDPOINT: &str = "generate";

/// Creates an EndpointId for the KV router in the given namespace.
pub fn router_endpoint_id(namespace: String) -> EndpointId {
    EndpointId {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        name: KV_ROUTER_ENDPOINT.to_string(),
    }
}

/// Creates a DiscoveryQuery for the KV router in the given namespace.
pub fn router_discovery_query(namespace: String) -> DiscoveryQuery {
    DiscoveryQuery::Endpoint {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        endpoint: KV_ROUTER_ENDPOINT.to_string(),
    }
}

102
103
104
105
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
106
        workers: &HashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>,
107
        request: &SchedulingRequest,
108
        block_size: u32,
109
110
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
111

112
113
114
115
116
117
118
119
120
121
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize)]
pub struct RouterConfigOverride {
    #[builder(default)]
    pub overlap_score_weight: Option<f64>,

    #[builder(default)]
    pub router_temperature: Option<f64>,
}

122
/// KV Router configuration parameters
123
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
124
125
126
pub struct KvRouterConfig {
    pub overlap_score_weight: f64,

127
    pub router_temperature: f64,
128

129
130
    pub use_kv_events: bool,

131
132
    pub router_replica_sync: bool,

133
134
135
    /// Whether to track active blocks in the router (default: true)
    pub router_track_active_blocks: bool,

136
137
138
    /// Threshold for triggering snapshots. If None, no snapshots will be performed.
    pub router_snapshot_threshold: Option<u32>,

139
    /// Whether to reset the router state on startup (default: false)
140
    pub router_reset_states: bool,
141
142
143
144
145
146
147
148
149

    /// TTL for blocks in seconds (only used when use_kv_events is false, default: 120.0)
    pub router_ttl_secs: f64,

    /// Maximum tree size before pruning (only used when use_kv_events is false, default: 1024)
    pub router_max_tree_size: usize,

    /// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
    pub router_prune_target_ratio: f64,
150
151
152
153
154
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
155
            overlap_score_weight: 1.0,
156
            router_temperature: 0.0,
157
            use_kv_events: true,
158
            router_replica_sync: false,
159
            router_track_active_blocks: true,
160
            router_snapshot_threshold: Some(1000000),
161
            router_reset_states: false,
162
163
164
            router_ttl_secs: 120.0,
            router_max_tree_size: 1024,
            router_prune_target_ratio: 0.8,
165
166
167
168
169
170
171
        }
    }
}

impl KvRouterConfig {
    /// Create a new KvRouterConfig with optional weight values.
    /// If a weight is None, the default value will be used.
172
    #[allow(clippy::too_many_arguments)]
173
174
    pub fn new(
        overlap_score_weight: Option<f64>,
175
        temperature: Option<f64>,
176
        use_kv_events: Option<bool>,
177
        replica_sync: Option<bool>,
178
        track_active_blocks: Option<bool>,
179
180
        router_snapshot_threshold: Option<Option<u32>>,
        router_reset_states: Option<bool>,
181
182
183
        router_ttl_secs: Option<f64>,
        router_max_tree_size: Option<usize>,
        router_prune_target_ratio: Option<f64>,
184
185
186
187
    ) -> Self {
        let default = Self::default();
        Self {
            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
188
            router_temperature: temperature.unwrap_or(default.router_temperature),
189
            use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
190
            router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
191
192
            router_track_active_blocks: track_active_blocks
                .unwrap_or(default.router_track_active_blocks),
193
194
195
            router_snapshot_threshold: router_snapshot_threshold
                .unwrap_or(default.router_snapshot_threshold),
            router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
196
197
198
199
            router_ttl_secs: router_ttl_secs.unwrap_or(default.router_ttl_secs),
            router_max_tree_size: router_max_tree_size.unwrap_or(default.router_max_tree_size),
            router_prune_target_ratio: router_prune_target_ratio
                .unwrap_or(default.router_prune_target_ratio),
200
201
202
203
        }
    }
}

204
pub enum Indexer {
205
206
    /// Updates itself based on KV events emitted by backend workers or routing decisions.
    /// Supports TTL-based expiration and size-based pruning.
207
    /// Has the ability to persist and snapshot states.
208
    KvIndexer(KvIndexer),
209
210
211
212

    /// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
    /// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
    None,
213
214
215
216
217
218
219
220
221
}

impl Indexer {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
222
223
224
            Indexer::None => Ok(OverlapScores {
                scores: HashMap::new(),
                frequencies: Vec::new(),
225
                tree_sizes: HashMap::new(),
226
            }),
227
228
        }
    }
229
230
231
232

    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.dump_events().await,
233
234
235
236
237
            Indexer::None => {
                panic!(
                    "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
                );
            }
238
239
        }
    }
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255

    async fn process_routing_decision(
        &self,
        worker: WorkerWithDpRank,
        local_hashes: Vec<LocalBlockHash>,
        sequence_hashes: Vec<SequenceHash>,
    ) -> Result<(), KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => {
                indexer
                    .process_routing_decision(worker, local_hashes, sequence_hashes)
                    .await
            }
            Indexer::None => Ok(()),
        }
    }
256
257
}

258
259
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
260
pub struct KvRouter {
261
262
263
    indexer: Indexer,

    // How about a Box<dyn KvIndexerInterface>
264
    scheduler: KvScheduler,
265

266
    block_size: u32,
267
268

    kv_router_config: KvRouterConfig,
Yan Ru Pei's avatar
Yan Ru Pei committed
269
270

    cancellation_token: tokio_util::sync::CancellationToken,
271
272

    client: Client,
273
274
275
276
}

impl KvRouter {
    pub async fn new(
277
278
        endpoint: Endpoint,
        client: Client,
279
        block_size: u32,
280
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
281
        kv_router_config: Option<KvRouterConfig>,
282
        consumer_id: String,
283
    ) -> Result<Self> {
284
        let kv_router_config = kv_router_config.unwrap_or_default();
285
        let component = endpoint.component();
286
        let cancellation_token = component.drt().primary_token();
287

288
        let instance_ids_rx = client.instance_avail_watcher();
289

290
291
        // Watch for runtime config updates via discovery interface
        let discovery = component.drt().discovery();
292
        let endpoint_id = endpoint.id();
293
        let discovery_key = DiscoveryQuery::EndpointModels {
294
295
296
            namespace: endpoint_id.namespace.clone(),
            component: endpoint_id.component.clone(),
            endpoint: endpoint_id.name.clone(),
297
298
299
300
301
302
303
304
        };
        let discovery_stream = discovery
            .list_and_watch(discovery_key, Some(cancellation_token.clone()))
            .await?;
        let runtime_configs_rx =
            watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
                card.runtime_config
            });
305

306
307
308
        let indexer = if kv_router_config.overlap_score_weight == 0.0 {
            // When overlap_score_weight is zero, we don't need to track prefixes
            Indexer::None
309
        } else {
310
            let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);
311
312
313
314
315
316
317
318
319
320
321
322
323

            // If use_kv_events is false, enable TTL and pruning for approximate behavior
            let prune_config = if !kv_router_config.use_kv_events {
                Some(PruneConfig {
                    ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
                    max_tree_size: kv_router_config.router_max_tree_size,
                    prune_target_ratio: kv_router_config.router_prune_target_ratio,
                })
            } else {
                None
            };

            Indexer::KvIndexer(KvIndexer::new_with_frequency(
324
                cancellation_token.clone(),
325
                None, // expiration_duration for frequency tracking
326
327
                block_size,
                kv_indexer_metrics,
328
                prune_config,
329
330
            ))
        };
331

332
        let scheduler = KvScheduler::start(
333
            component.clone(),
334
            block_size,
335
            instance_ids_rx,
336
            runtime_configs_rx,
337
            selector,
338
            kv_router_config.router_replica_sync,
339
            consumer_id.clone(),
340
341
        )
        .await?;
342

343
344
345
346
        // Start KV event subscriber background process (only when use_kv_events is enabled)
        if kv_router_config.use_kv_events
            && let Indexer::KvIndexer(ref kv_indexer) = indexer
        {
347
348
            start_kv_router_background(
                component.clone(),
349
                consumer_id,
350
                kv_indexer.event_sender(),
351
                kv_indexer.remove_worker_sender(),
352
353
354
                kv_router_config
                    .router_snapshot_threshold
                    .map(|_| kv_indexer.get_workers_sender()),
355
356
357
358
359
360
361
362
                kv_router_config
                    .router_snapshot_threshold
                    .map(|_| kv_indexer.snapshot_event_sender()),
                cancellation_token.clone(),
                kv_router_config.router_snapshot_threshold,
                kv_router_config.router_reset_states,
            )
            .await?;
363
        }
364

365
        tracing::info!("KV Routing initialized");
366
        Ok(Self {
367
            indexer,
368
            scheduler,
369
            block_size,
370
            kv_router_config,
Yan Ru Pei's avatar
Yan Ru Pei committed
371
            cancellation_token,
372
            client,
373
        })
374
375
    }

376
377
378
379
380
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

381
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
382
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
Yan Ru Pei's avatar
Yan Ru Pei committed
383
384
    /// Now also takes optional context_id for request tracking
    pub async fn find_best_match(
385
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
386
        context_id: Option<&str>,
387
        tokens: &[u32],
388
        router_config_override: Option<&RouterConfigOverride>,
389
        update_states: bool,
Yan Ru Pei's avatar
Yan Ru Pei committed
390
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
Yan Ru Pei's avatar
Yan Ru Pei committed
391
392
393
394
395
        // Validate that context_id is provided when update_states is true
        if update_states && context_id.is_none() {
            panic!("context_id must be provided if update_states is true");
        }

396
        let isl_tokens = tokens.len();
397

398
399
400
401
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let seq_hashes = compute_seq_hash_for_block(&block_hashes);

        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
402

403
        // Determine who needs seq_hashes
404
        let needs_process_routing = !self.kv_router_config.use_kv_events;
405
406
407
408
        let scheduler_needs_it = self.kv_router_config.router_track_active_blocks;

        // Optimize cloning: only clone if both need it, otherwise move
        let (maybe_seq_hashes_1, maybe_seq_hashes_2) =
409
            match (needs_process_routing, scheduler_needs_it) {
410
411
412
413
414
415
                (true, true) => (Some(seq_hashes.clone()), Some(seq_hashes)),
                (true, false) => (Some(seq_hashes), None),
                (false, true) => (None, Some(seq_hashes)),
                (false, false) => (None, None),
            };

Yan Ru Pei's avatar
Yan Ru Pei committed
416
        let best_worker = self
417
            .scheduler
418
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
419
                context_id.map(|s| s.to_string()),
420
                isl_tokens,
421
                maybe_seq_hashes_2,
422
                overlap_scores.clone(),
423
                router_config_override,
424
                update_states,
425
            )
426
            .await?;
427

428
429
430
        // Process routing decision when not using KV events (approximate mode with TTL/pruning)
        if needs_process_routing {
            self.indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
431
                .process_routing_decision(best_worker, block_hashes, maybe_seq_hashes_1.unwrap())
432
433
                .await?;
        }
434

435
436
        let overlap_amount = overlap_scores
            .scores
Yan Ru Pei's avatar
Yan Ru Pei committed
437
            .get(&best_worker)
438
439
            .copied()
            .unwrap_or(0);
Yan Ru Pei's avatar
Yan Ru Pei committed
440
        Ok((best_worker, overlap_amount))
441
442
    }

443
444
445
446
447
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
        overlap_blocks: u32,
Yan Ru Pei's avatar
Yan Ru Pei committed
448
        worker: WorkerWithDpRank,
449
450
    ) {
        let isl_tokens = tokens.len();
451
452
453
454
455

        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
            compute_seq_hash_for_block(&block_hashes)
        });
456

457
458
        if let Err(e) = self
            .scheduler
459
            .add_request(
460
                request_id.clone(),
461
                maybe_seq_hashes,
462
463
                isl_tokens,
                overlap_blocks,
Yan Ru Pei's avatar
Yan Ru Pei committed
464
                worker,
465
            )
466
467
468
469
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
470
471
    }

472
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
473
        self.scheduler.mark_prefill_completed(request_id).await
474
475
    }

476
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
477
        self.scheduler.free(request_id).await
478
    }
479

480
    pub fn block_size(&self) -> u32 {
481
482
        self.block_size
    }
483

484
485
486
487
488
489
    /// Get potential prefill and decode loads for all workers
    pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
        let isl_tokens = tokens.len();
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;

490
491
492
493
494
        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
            compute_seq_hash_for_block(&block_hashes)
        });

495
496
        Ok(self
            .scheduler
497
            .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
498
499
500
            .await)
    }

501
502
503
504
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
505
506
}

Michael Feil's avatar
Michael Feil committed
507
508
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
509
510
511
512
513
514
515
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
516
517
518
519
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
            RouterRequest::New { tokens } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
520
                let (best_worker, overlap_blocks) = self
Yan Ru Pei's avatar
Yan Ru Pei committed
521
                    .find_best_match(Some(&context_id), &tokens, None, true)
Michael Feil's avatar
Michael Feil committed
522
523
524
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
525
526
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
527
528
529
                    overlap_blocks,
                }
            }
530
531
532
533
534
535
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
            RouterRequest::MarkFree => RouterResponse::FreeMarked {
                success: self.free(&context_id).await.is_ok(),
            },
Michael Feil's avatar
Michael Feil committed
536
        };
537
538
539
540
541
542

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
543
544

pub struct KvPushRouter {
545
    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
546
    pub chooser: Arc<KvRouter>,
547
548
549
550
}

impl KvPushRouter {
    pub fn new(
551
        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
552
553
554
555
556
557
558
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
}

#[async_trait]
559
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
560
561
    for KvPushRouter
{
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
    /// Generate method that handles KV-aware routing with three distinct behaviors:
    ///
    /// 1. **If `query_instance_id` annotation is set**:
    ///    - Returns the best matching worker ID without routing the request
    ///    - Does NOT update any router local states
    ///    - Response includes worker_instance_id and token_data annotations
    ///
    /// 2. **If `backend_instance_id` is set in the request**:
    ///    - Routes directly to the specified backend instance
    ///    - DOES update router states to track this request (unless query_instance_id is also set)
    ///    - Bypasses the normal KV matching logic
    ///
    /// 3. **If neither are set (default behavior)**:
    ///    - Finds the best worker based on KV cache overlap
    ///    - Updates router states to track the request
    ///    - Routes to the selected worker
    ///
    /// The router state updates include tracking active sequences and managing
    /// prefill/completion lifecycle for proper KV cache management.
581
582
    async fn generate(
        &self,
583
        request: SingleIn<PreprocessedRequest>,
584
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
        // Extract context ID for request tracking
        let context_id = request.context().id().to_string();

        // Check if this is a query_instance_id request first
        let query_instance_id = request.has_annotation("query_instance_id");

        let (instance_id, dp_rank, overlap_amount) = if let Some(id) = request.backend_instance_id {
            // If instance_id is set, use it and compute actual overlap
            let dp_rank = request.dp_rank.unwrap_or(0);
            if query_instance_id {
                tracing::debug!(
                    "backend_instance_id is set, routing to instance {id} with dp_rank {dp_rank} and ignoring query_instance_id annotation"
                );
            }

            // Compute actual overlap blocks by querying the indexer
            let block_hashes =
                compute_block_hash_for_seq(&request.token_ids, self.chooser.block_size());
            let overlap_scores = self.chooser.indexer.find_matches(block_hashes).await?;
            let worker = WorkerWithDpRank::new(id, dp_rank);
            let overlap_blocks = overlap_scores.scores.get(&worker).copied().unwrap_or(0);

            self.chooser
                .add_request(
                    context_id.clone(),
                    &request.token_ids,
                    overlap_blocks,
                    worker,
                )
                .await;
            (id, dp_rank, overlap_blocks)
        } else {
            // Otherwise, find the best match
            let (best_worker, overlap_amount) = self
                .chooser
                .find_best_match(
                    Some(&context_id),
                    &request.token_ids,
                    request.router_config_override.as_ref(),
                    !query_instance_id, // Don't update states if query_instance_id
                )
                .await?;
            (best_worker.worker_id, best_worker.dp_rank, overlap_amount)
        };

        // if request has the annotation "query_instance_id",
        // then the request will not be routed to the worker,
        // and instead the worker_instance_id will be returned.
        let stream_context = request.context().clone();
        if query_instance_id {
            let instance_id_str = instance_id.to_string();
            let response = Annotated::from_annotation("worker_instance_id", &instance_id_str)?;

            // Return the tokens in nvext.token_data format
            let response_tokens = Annotated::from_annotation("token_data", &request.token_ids)?;
            tracing::trace!(
                "Tokens requested in the response through the query_instance_id annotation: {:?}",
                response_tokens
            );
            let stream = stream::iter(vec![response, response_tokens]);
            return Ok(ResponseStream::new(Box::pin(stream), stream_context));
        }
        let (mut backend_input, context) = request.into_parts();
        backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
        backend_input.dp_rank = Some(dp_rank);
650

651
652
        // Get prefill worker ID from prefill_result if available
        // In aggregated mode, prefill_result is None, so we use decode_worker_id for both
653
        let decode_worker_id = instance_id;
654
655
656
657
658
659
660
661
662
663
        let prefill_worker_id = backend_input
            .prefill_result
            .as_ref()
            .and_then(|prefill_result| {
                prefill_result
                    .disaggregated_params
                    .get("worker_id")
                    .and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok())
                    .and_then(|info| info.prefill_worker_id)
            })
664
665
            .or(Some(decode_worker_id)); // Use decode_worker_id if no separate prefill worker

666
667
668
669
670
671
672
673
674
        let updated_request = context.map(|_| backend_input);

        let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
        let stream_context = response_stream.context();
        let chooser = self.chooser.clone();
        let context_for_monitoring = stream_context.clone();

        let wrapped_stream = Box::pin(async_stream::stream! {
            let mut prefill_marked = false;
675
            let mut first_item = true;
676
677
678
679
680
681
682
683

            loop {
                tokio::select! {
                    biased;

                    _ = context_for_monitoring.stopped() => {
                        tracing::debug!("Request {context_id} cancelled, ending stream");
                        break;
684
                    }
Yan Ru Pei's avatar
Yan Ru Pei committed
685

686
                    item = response_stream.next() => {
687
                        let Some(mut item) = item else {
688
689
                            break;
                        };
690

691
692
                        if !prefill_marked {
                            if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
693
                                tracing::warn!("Failed to mark prefill completed for request {context_id}: {e}");
694
                            }
695
                            prefill_marked = true;
696
                        }
697

698
699
700
701
702
                        // Always inject worker_id in first item's disaggregated_params
                        // This is needed for:
                        // 1. PrefillRouter to know which prefill worker was chosen
                        // 2. Client response when extra_fields contains "worker_id"
                        if first_item {
703
                            first_item = false;
704
705
706
707
708
709

                            let Some(ref mut data) = item.data else {
                                yield item;
                                continue;
                            };

710
                            // prefill_worker_id comes from prefill_result.disaggregated_params or falls back to instance_id
711
                            // decode_worker_id is always the current instance_id
712
713
714
715
716
717
                            let worker_id_info = WorkerIdInfo {
                                prefill_worker_id,
                                decode_worker_id: Some(decode_worker_id),
                            };
                            let worker_id_json = serde_json::to_value(&worker_id_info)
                                .expect("WorkerIdInfo serialization should not fail");
718
719
720
721
722
723

                            if let Some(obj) = data.disaggregated_params.as_mut().and_then(|p| p.as_object_mut()) {
                                obj.insert("worker_id".to_string(), worker_id_json);
                            } else {
                                data.disaggregated_params = Some(json!({"worker_id": worker_id_json}));
                            }
724
                        }
725
726

                        yield item;
727
                    }
728
729
                }
            }
730

731
            if let Err(e) = chooser.free(&context_id).await {
732
                tracing::warn!("Failed to free request {context_id}: {e}");
733
            }
734
735
        });
        Ok(ResponseStream::new(wrapped_stream, stream_context))
736
737
    }
}
Yan Ru Pei's avatar
Yan Ru Pei committed
738
739
740
741
742
743
744

impl Drop for KvRouter {
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}