kv_router.rs 26.1 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
use std::sync::Arc;
6
use std::time::Duration;
7

8
use anyhow::Result;
9
use derive_builder::Builder;
10
use dynamo_runtime::{
11
    component::{Client, Endpoint},
12
    discovery::{DiscoveryQuery, watch_and_extract_field},
13
    pipeline::{
14
15
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
        SingleIn, async_trait,
16
17
    },
    protocols::annotated::Annotated,
18
    traits::DistributedRuntimeProvider,
19
20
};
use futures::stream::{self, StreamExt};
21
use serde::{Deserialize, Serialize};
22
use serde_json::json;
23

24
pub mod approx;
25
pub mod indexer;
26
pub mod prefill_router;
27
28
pub mod protocols;
pub mod publisher;
29
pub mod recorder;
30
31
pub mod scheduler;
pub mod scoring;
32
pub mod sequence;
33
pub mod subscriber;
34

35
36
pub use prefill_router::PrefillRouter;

37
38
use crate::{
    kv_router::{
39
        approx::PruneConfig,
40
        indexer::{
41
42
            KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
            compute_block_hash_for_seq, compute_seq_hash_for_block,
43
        },
Yan Ru Pei's avatar
Yan Ru Pei committed
44
45
46
        protocols::{
            LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult, WorkerWithDpRank,
        },
47
        scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
48
        sequence::SequenceError,
49
        subscriber::start_kv_router_background,
50
    },
51
    local_model::runtime_config::ModelRuntimeConfig,
52
    model_card::ModelDeploymentCard,
53
    preprocessor::PreprocessedRequest,
54
    protocols::common::llm_backend::LLMEngineOutput,
55
    tokens::SequenceHash,
56
57
};

58
59
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
60
61
62
63
64

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
65
pub const KV_EVENT_SUBJECT: &str = "kv_events";
66
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
67
68
69
70
71
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
72

73
74
75
76
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

77
78
79
80
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
81
        workers: &HashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>,
82
        request: &SchedulingRequest,
83
        block_size: u32,
84
85
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
86

87
88
89
90
91
92
93
94
95
96
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize)]
pub struct RouterConfigOverride {
    #[builder(default)]
    pub overlap_score_weight: Option<f64>,

    #[builder(default)]
    pub router_temperature: Option<f64>,
}

97
/// KV Router configuration parameters
98
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
99
100
101
pub struct KvRouterConfig {
    pub overlap_score_weight: f64,

102
    pub router_temperature: f64,
103

104
105
    pub use_kv_events: bool,

106
107
    pub router_replica_sync: bool,

108
109
110
    /// Whether to track active blocks in the router (default: true)
    pub router_track_active_blocks: bool,

111
112
113
    /// Threshold for triggering snapshots. If None, no snapshots will be performed.
    pub router_snapshot_threshold: Option<u32>,

114
    /// Whether to reset the router state on startup (default: false)
115
    pub router_reset_states: bool,
116
117
118
119
120
121
122
123
124

    /// TTL for blocks in seconds (only used when use_kv_events is false, default: 120.0)
    pub router_ttl_secs: f64,

    /// Maximum tree size before pruning (only used when use_kv_events is false, default: 1024)
    pub router_max_tree_size: usize,

    /// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
    pub router_prune_target_ratio: f64,
125
126
127
128
129
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
130
            overlap_score_weight: 1.0,
131
            router_temperature: 0.0,
132
            use_kv_events: true,
133
            router_replica_sync: false,
134
            router_track_active_blocks: true,
135
            router_snapshot_threshold: Some(1000000),
136
            router_reset_states: false,
137
138
139
            router_ttl_secs: 120.0,
            router_max_tree_size: 1024,
            router_prune_target_ratio: 0.8,
140
141
142
143
144
145
146
        }
    }
}

impl KvRouterConfig {
    /// Create a new KvRouterConfig with optional weight values.
    /// If a weight is None, the default value will be used.
147
    #[allow(clippy::too_many_arguments)]
148
149
    pub fn new(
        overlap_score_weight: Option<f64>,
150
        temperature: Option<f64>,
151
        use_kv_events: Option<bool>,
152
        replica_sync: Option<bool>,
153
        track_active_blocks: Option<bool>,
154
155
        router_snapshot_threshold: Option<Option<u32>>,
        router_reset_states: Option<bool>,
156
157
158
        router_ttl_secs: Option<f64>,
        router_max_tree_size: Option<usize>,
        router_prune_target_ratio: Option<f64>,
159
160
161
162
    ) -> Self {
        let default = Self::default();
        Self {
            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
163
            router_temperature: temperature.unwrap_or(default.router_temperature),
164
            use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
165
            router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
166
167
            router_track_active_blocks: track_active_blocks
                .unwrap_or(default.router_track_active_blocks),
168
169
170
            router_snapshot_threshold: router_snapshot_threshold
                .unwrap_or(default.router_snapshot_threshold),
            router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
171
172
173
174
            router_ttl_secs: router_ttl_secs.unwrap_or(default.router_ttl_secs),
            router_max_tree_size: router_max_tree_size.unwrap_or(default.router_max_tree_size),
            router_prune_target_ratio: router_prune_target_ratio
                .unwrap_or(default.router_prune_target_ratio),
175
176
177
178
        }
    }
}

179
pub enum Indexer {
180
181
    /// Updates itself based on KV events emitted by backend workers or routing decisions.
    /// Supports TTL-based expiration and size-based pruning.
182
    /// Has the ability to persist and snapshot states.
183
    KvIndexer(KvIndexer),
184
185
186
187

    /// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
    /// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
    None,
188
189
190
191
192
193
194
195
196
}

impl Indexer {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
197
198
199
            Indexer::None => Ok(OverlapScores {
                scores: HashMap::new(),
                frequencies: Vec::new(),
200
                tree_sizes: HashMap::new(),
201
            }),
202
203
        }
    }
204
205
206
207

    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.dump_events().await,
208
209
210
211
212
            Indexer::None => {
                panic!(
                    "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
                );
            }
213
214
        }
    }
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230

    async fn process_routing_decision(
        &self,
        worker: WorkerWithDpRank,
        local_hashes: Vec<LocalBlockHash>,
        sequence_hashes: Vec<SequenceHash>,
    ) -> Result<(), KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => {
                indexer
                    .process_routing_decision(worker, local_hashes, sequence_hashes)
                    .await
            }
            Indexer::None => Ok(()),
        }
    }
231
232
}

233
234
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
235
pub struct KvRouter {
236
237
238
    indexer: Indexer,

    // How about a Box<dyn KvIndexerInterface>
239
    scheduler: KvScheduler,
240

241
    block_size: u32,
242
243

    kv_router_config: KvRouterConfig,
Yan Ru Pei's avatar
Yan Ru Pei committed
244
245

    cancellation_token: tokio_util::sync::CancellationToken,
246
247

    client: Client,
248
249
250
251
}

impl KvRouter {
    pub async fn new(
252
253
        endpoint: Endpoint,
        client: Client,
254
        block_size: u32,
255
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
256
        kv_router_config: Option<KvRouterConfig>,
257
        consumer_uuid: String,
258
    ) -> Result<Self> {
259
        let kv_router_config = kv_router_config.unwrap_or_default();
260
        let component = endpoint.component();
261
        let cancellation_token = component.drt().primary_token();
262

263
        let instance_ids_rx = client.instance_avail_watcher();
264

265
266
        // Watch for runtime config updates via discovery interface
        let discovery = component.drt().discovery();
267
        let endpoint_id = endpoint.id();
268
        let discovery_key = DiscoveryQuery::EndpointModels {
269
270
271
            namespace: endpoint_id.namespace.clone(),
            component: endpoint_id.component.clone(),
            endpoint: endpoint_id.name.clone(),
272
273
274
275
276
277
278
279
        };
        let discovery_stream = discovery
            .list_and_watch(discovery_key, Some(cancellation_token.clone()))
            .await?;
        let runtime_configs_rx =
            watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
                card.runtime_config
            });
280

281
282
283
        let indexer = if kv_router_config.overlap_score_weight == 0.0 {
            // When overlap_score_weight is zero, we don't need to track prefixes
            Indexer::None
284
        } else {
285
            let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);
286
287
288
289
290
291
292
293
294
295
296
297
298

            // If use_kv_events is false, enable TTL and pruning for approximate behavior
            let prune_config = if !kv_router_config.use_kv_events {
                Some(PruneConfig {
                    ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
                    max_tree_size: kv_router_config.router_max_tree_size,
                    prune_target_ratio: kv_router_config.router_prune_target_ratio,
                })
            } else {
                None
            };

            Indexer::KvIndexer(KvIndexer::new_with_frequency(
299
                cancellation_token.clone(),
300
                None, // expiration_duration for frequency tracking
301
302
                block_size,
                kv_indexer_metrics,
303
                prune_config,
304
305
            ))
        };
306

307
        let scheduler = KvScheduler::start(
308
            component.clone(),
309
            block_size,
310
            instance_ids_rx,
311
            runtime_configs_rx,
312
            selector,
313
            kv_router_config.router_replica_sync,
314
            consumer_uuid.clone(),
315
316
        )
        .await?;
317

318
319
320
321
        // Start KV event subscriber background process (only when use_kv_events is enabled)
        if kv_router_config.use_kv_events
            && let Indexer::KvIndexer(ref kv_indexer) = indexer
        {
322
323
324
325
            start_kv_router_background(
                component.clone(),
                consumer_uuid,
                kv_indexer.event_sender(),
326
                kv_indexer.remove_worker_sender(),
327
328
329
                kv_router_config
                    .router_snapshot_threshold
                    .map(|_| kv_indexer.get_workers_sender()),
330
331
332
333
334
335
336
337
                kv_router_config
                    .router_snapshot_threshold
                    .map(|_| kv_indexer.snapshot_event_sender()),
                cancellation_token.clone(),
                kv_router_config.router_snapshot_threshold,
                kv_router_config.router_reset_states,
            )
            .await?;
338
        }
339

340
        tracing::info!("KV Routing initialized");
341
        Ok(Self {
342
            indexer,
343
            scheduler,
344
            block_size,
345
            kv_router_config,
Yan Ru Pei's avatar
Yan Ru Pei committed
346
            cancellation_token,
347
            client,
348
        })
349
350
    }

351
352
353
354
355
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

356
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
357
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
Yan Ru Pei's avatar
Yan Ru Pei committed
358
359
    /// Now also takes optional context_id for request tracking
    pub async fn find_best_match(
360
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
361
        context_id: Option<&str>,
362
        tokens: &[u32],
363
        router_config_override: Option<&RouterConfigOverride>,
364
        update_states: bool,
Yan Ru Pei's avatar
Yan Ru Pei committed
365
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
Yan Ru Pei's avatar
Yan Ru Pei committed
366
367
368
369
370
        // Validate that context_id is provided when update_states is true
        if update_states && context_id.is_none() {
            panic!("context_id must be provided if update_states is true");
        }

371
        let isl_tokens = tokens.len();
372

373
374
375
376
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let seq_hashes = compute_seq_hash_for_block(&block_hashes);

        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
377

378
        // Determine who needs seq_hashes
379
        let needs_process_routing = !self.kv_router_config.use_kv_events;
380
381
382
383
        let scheduler_needs_it = self.kv_router_config.router_track_active_blocks;

        // Optimize cloning: only clone if both need it, otherwise move
        let (maybe_seq_hashes_1, maybe_seq_hashes_2) =
384
            match (needs_process_routing, scheduler_needs_it) {
385
386
387
388
389
390
                (true, true) => (Some(seq_hashes.clone()), Some(seq_hashes)),
                (true, false) => (Some(seq_hashes), None),
                (false, true) => (None, Some(seq_hashes)),
                (false, false) => (None, None),
            };

Yan Ru Pei's avatar
Yan Ru Pei committed
391
        let best_worker = self
392
            .scheduler
393
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
394
                context_id.map(|s| s.to_string()),
395
                isl_tokens,
396
                maybe_seq_hashes_2,
397
                overlap_scores.clone(),
398
                router_config_override,
399
                update_states,
400
            )
401
            .await?;
402

403
404
405
        // Process routing decision when not using KV events (approximate mode with TTL/pruning)
        if needs_process_routing {
            self.indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
406
                .process_routing_decision(best_worker, block_hashes, maybe_seq_hashes_1.unwrap())
407
408
                .await?;
        }
409

410
411
        let overlap_amount = overlap_scores
            .scores
Yan Ru Pei's avatar
Yan Ru Pei committed
412
            .get(&best_worker)
413
414
            .copied()
            .unwrap_or(0);
Yan Ru Pei's avatar
Yan Ru Pei committed
415
        Ok((best_worker, overlap_amount))
416
417
    }

418
419
420
421
422
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
        overlap_blocks: u32,
Yan Ru Pei's avatar
Yan Ru Pei committed
423
        worker: WorkerWithDpRank,
424
425
    ) {
        let isl_tokens = tokens.len();
426
427
428
429
430

        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
            compute_seq_hash_for_block(&block_hashes)
        });
431

432
433
        if let Err(e) = self
            .scheduler
434
            .add_request(
435
                request_id.clone(),
436
                maybe_seq_hashes,
437
438
                isl_tokens,
                overlap_blocks,
Yan Ru Pei's avatar
Yan Ru Pei committed
439
                worker,
440
            )
441
442
443
444
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
445
446
    }

447
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
448
        self.scheduler.mark_prefill_completed(request_id).await
449
450
    }

451
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
452
        self.scheduler.free(request_id).await
453
    }
454

455
    pub fn block_size(&self) -> u32 {
456
457
        self.block_size
    }
458

459
460
461
462
463
464
    /// Get potential prefill and decode loads for all workers
    pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
        let isl_tokens = tokens.len();
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;

465
466
467
468
469
        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
            compute_seq_hash_for_block(&block_hashes)
        });

470
471
        Ok(self
            .scheduler
472
            .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
473
474
475
            .await)
    }

476
477
478
479
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
480
481
}

Michael Feil's avatar
Michael Feil committed
482
483
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
484
485
486
487
488
489
490
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
491
492
493
494
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
            RouterRequest::New { tokens } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
495
                let (best_worker, overlap_blocks) = self
Yan Ru Pei's avatar
Yan Ru Pei committed
496
                    .find_best_match(Some(&context_id), &tokens, None, true)
Michael Feil's avatar
Michael Feil committed
497
498
499
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
500
501
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
502
503
504
                    overlap_blocks,
                }
            }
505
506
507
508
509
510
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
            RouterRequest::MarkFree => RouterResponse::FreeMarked {
                success: self.free(&context_id).await.is_ok(),
            },
Michael Feil's avatar
Michael Feil committed
511
        };
512
513
514
515
516
517

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
518
519

pub struct KvPushRouter {
520
    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
521
    pub chooser: Arc<KvRouter>,
522
523
524
525
}

impl KvPushRouter {
    pub fn new(
526
        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
527
528
529
530
531
532
533
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
}

#[async_trait]
534
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
535
536
    for KvPushRouter
{
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
    /// Generate method that handles KV-aware routing with three distinct behaviors:
    ///
    /// 1. **If `query_instance_id` annotation is set**:
    ///    - Returns the best matching worker ID without routing the request
    ///    - Does NOT update any router local states
    ///    - Response includes worker_instance_id and token_data annotations
    ///
    /// 2. **If `backend_instance_id` is set in the request**:
    ///    - Routes directly to the specified backend instance
    ///    - DOES update router states to track this request (unless query_instance_id is also set)
    ///    - Bypasses the normal KV matching logic
    ///
    /// 3. **If neither are set (default behavior)**:
    ///    - Finds the best worker based on KV cache overlap
    ///    - Updates router states to track the request
    ///    - Routes to the selected worker
    ///
    /// The router state updates include tracking active sequences and managing
    /// prefill/completion lifecycle for proper KV cache management.
556
557
    async fn generate(
        &self,
558
        request: SingleIn<PreprocessedRequest>,
559
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
        // Extract context ID for request tracking
        let context_id = request.context().id().to_string();

        // Check if this is a query_instance_id request first
        let query_instance_id = request.has_annotation("query_instance_id");

        let (instance_id, dp_rank, overlap_amount) = if let Some(id) = request.backend_instance_id {
            // If instance_id is set, use it and compute actual overlap
            let dp_rank = request.dp_rank.unwrap_or(0);
            if query_instance_id {
                tracing::debug!(
                    "backend_instance_id is set, routing to instance {id} with dp_rank {dp_rank} and ignoring query_instance_id annotation"
                );
            }

            // Compute actual overlap blocks by querying the indexer
            let block_hashes =
                compute_block_hash_for_seq(&request.token_ids, self.chooser.block_size());
            let overlap_scores = self.chooser.indexer.find_matches(block_hashes).await?;
            let worker = WorkerWithDpRank::new(id, dp_rank);
            let overlap_blocks = overlap_scores.scores.get(&worker).copied().unwrap_or(0);

            self.chooser
                .add_request(
                    context_id.clone(),
                    &request.token_ids,
                    overlap_blocks,
                    worker,
                )
                .await;
            (id, dp_rank, overlap_blocks)
        } else {
            // Otherwise, find the best match
            let (best_worker, overlap_amount) = self
                .chooser
                .find_best_match(
                    Some(&context_id),
                    &request.token_ids,
                    request.router_config_override.as_ref(),
                    !query_instance_id, // Don't update states if query_instance_id
                )
                .await?;
            (best_worker.worker_id, best_worker.dp_rank, overlap_amount)
        };

        // if request has the annotation "query_instance_id",
        // then the request will not be routed to the worker,
        // and instead the worker_instance_id will be returned.
        let stream_context = request.context().clone();
        if query_instance_id {
            let instance_id_str = instance_id.to_string();
            let response = Annotated::from_annotation("worker_instance_id", &instance_id_str)?;

            // Return the tokens in nvext.token_data format
            let response_tokens = Annotated::from_annotation("token_data", &request.token_ids)?;
            tracing::trace!(
                "Tokens requested in the response through the query_instance_id annotation: {:?}",
                response_tokens
            );
            let stream = stream::iter(vec![response, response_tokens]);
            return Ok(ResponseStream::new(Box::pin(stream), stream_context));
        }
        let (mut backend_input, context) = request.into_parts();
        backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
        backend_input.dp_rank = Some(dp_rank);
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642

        // Check if worker_id is requested in extra_fields
        let should_populate_worker_id = backend_input
            .extra_fields
            .as_deref()
            .unwrap_or(&[])
            .iter()
            .any(|s| s == "worker_id");

        // Get prefill worker ID if available (stored by PrefillRouter)
        // In aggregated mode, prefill_worker_id is None, so we use decode_worker_id for both
        let decode_worker_id = instance_id;
        let prefill_worker_id = context
            .get::<u64>("prefill_worker_id")
            .ok()
            .map(|arc| *arc)
            .or(Some(decode_worker_id)); // Use decode_worker_id if no separate prefill worker

643
644
645
646
647
648
649
650
651
        let updated_request = context.map(|_| backend_input);

        let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
        let stream_context = response_stream.context();
        let chooser = self.chooser.clone();
        let context_for_monitoring = stream_context.clone();

        let wrapped_stream = Box::pin(async_stream::stream! {
            let mut prefill_marked = false;
652
            let mut first_item = true;
653
654
655
656
657
658
659
660

            loop {
                tokio::select! {
                    biased;

                    _ = context_for_monitoring.stopped() => {
                        tracing::debug!("Request {context_id} cancelled, ending stream");
                        break;
661
                    }
Yan Ru Pei's avatar
Yan Ru Pei committed
662

663
                    item = response_stream.next() => {
664
                        let Some(mut item) = item else {
665
666
                            break;
                        };
667

668
669
                        if !prefill_marked {
                            if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
670
                                tracing::warn!("Failed to mark prefill completed for request {context_id}: {e}");
671
                            }
672
                            prefill_marked = true;
673
                        }
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693

                        // Inject worker_id in first item's disaggregated_params if requested
                        if first_item && should_populate_worker_id {
                            if let Some(ref mut data) = item.data {
                                // Add worker_id to disaggregated_params
                                let worker_id_json = json!({
                                    "prefill_worker_id": prefill_worker_id,
                                    "decode_worker_id": decode_worker_id,
                                });

                                if let Some(ref mut params) = data.disaggregated_params {
                                    if let Some(obj) = params.as_object_mut() {
                                        obj.insert("worker_id".to_string(), worker_id_json);
                                    }
                                } else {
                                    data.disaggregated_params = Some(json!({"worker_id": worker_id_json}));
                                }
                            }
                            first_item = false;
                        }
694
695

                        yield item;
696
                    }
697
698
                }
            }
699

700
            if let Err(e) = chooser.free(&context_id).await {
701
                tracing::warn!("Failed to free request {context_id}: {e}");
702
            }
703
704
        });
        Ok(ResponseStream::new(wrapped_stream, stream_context))
705
706
    }
}
Yan Ru Pei's avatar
Yan Ru Pei committed
707
708
709
710
711
712
713

impl Drop for KvRouter {
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}