kv_router.rs 24.8 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
use std::sync::Arc;
6
use std::time::Duration;
7

8
use anyhow::Result;
9
use derive_builder::Builder;
10
use dynamo_runtime::{
11
    component::{Client, Endpoint},
12
    discovery::{DiscoveryQuery, watch_and_extract_field},
13
    pipeline::{
14
15
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
        SingleIn, async_trait,
16
17
    },
    protocols::annotated::Annotated,
18
    traits::DistributedRuntimeProvider,
19
20
};
use futures::stream::{self, StreamExt};
21
use serde::{Deserialize, Serialize};
22
use serde_json::json;
23

24
pub mod approx;
25
pub mod indexer;
26
pub mod prefill_router;
27
28
pub mod protocols;
pub mod publisher;
29
pub mod recorder;
30
31
pub mod scheduler;
pub mod scoring;
32
pub mod sequence;
33
pub mod subscriber;
34

35
36
pub use prefill_router::PrefillRouter;

37
38
use crate::{
    kv_router::{
39
        approx::ApproxKvIndexer,
40
        approx::PruneConfig,
41
        indexer::{
42
43
            KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
            compute_block_hash_for_seq, compute_seq_hash_for_block,
44
        },
Yan Ru Pei's avatar
Yan Ru Pei committed
45
46
47
        protocols::{
            LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult, WorkerWithDpRank,
        },
48
        scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
49
        sequence::SequenceError,
50
        subscriber::start_kv_router_background,
51
    },
52
    local_model::runtime_config::ModelRuntimeConfig,
53
    model_card::ModelDeploymentCard,
54
    preprocessor::PreprocessedRequest,
55
    protocols::common::llm_backend::LLMEngineOutput,
56
57
};

58
59
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
60
61
62
63
64

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
65
pub const KV_EVENT_SUBJECT: &str = "kv_events";
66
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
67
68
69
70
71
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
72

73
74
75
76
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

77
78
79
80
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
81
        workers: &HashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>,
82
        request: &SchedulingRequest,
83
        block_size: u32,
84
85
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
86

87
88
89
90
91
92
93
94
95
96
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize)]
pub struct RouterConfigOverride {
    #[builder(default)]
    pub overlap_score_weight: Option<f64>,

    #[builder(default)]
    pub router_temperature: Option<f64>,
}

97
/// KV Router configuration parameters
98
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
99
100
101
pub struct KvRouterConfig {
    pub overlap_score_weight: f64,

102
    pub router_temperature: f64,
103

104
105
    pub use_kv_events: bool,

106
107
    pub router_replica_sync: bool,

108
109
110
    /// Whether to track active blocks in the router (default: true)
    pub router_track_active_blocks: bool,

111
112
113
    /// Threshold for triggering snapshots. If None, no snapshots will be performed.
    pub router_snapshot_threshold: Option<u32>,

114
    /// Whether to reset the router state on startup (default: false)
115
    pub router_reset_states: bool,
116
117
118
119
120
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
121
            overlap_score_weight: 1.0,
122
            router_temperature: 0.0,
123
            use_kv_events: true,
124
            router_replica_sync: false,
125
            router_track_active_blocks: true,
126
            router_snapshot_threshold: Some(1000000),
127
            router_reset_states: false,
128
129
130
131
132
133
134
        }
    }
}

impl KvRouterConfig {
    /// Create a new KvRouterConfig with optional weight values.
    /// If a weight is None, the default value will be used.
135
    #[allow(clippy::too_many_arguments)]
136
137
    pub fn new(
        overlap_score_weight: Option<f64>,
138
        temperature: Option<f64>,
139
        use_kv_events: Option<bool>,
140
        replica_sync: Option<bool>,
141
        track_active_blocks: Option<bool>,
142
143
        router_snapshot_threshold: Option<Option<u32>>,
        router_reset_states: Option<bool>,
144
145
146
147
    ) -> Self {
        let default = Self::default();
        Self {
            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
148
            router_temperature: temperature.unwrap_or(default.router_temperature),
149
            use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
150
            router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
151
152
            router_track_active_blocks: track_active_blocks
                .unwrap_or(default.router_track_active_blocks),
153
154
155
            router_snapshot_threshold: router_snapshot_threshold
                .unwrap_or(default.router_snapshot_threshold),
            router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
156
157
158
159
        }
    }
}

160
161
162
// TODO: is there a way (macro) to auto-derive the KvIndexerInterface trait for this
// since both variants implement it
pub enum Indexer {
163
164
    /// Updates itself based on KV events emitted by backend workers.
    /// Has the ability to persist and snapshot states.
165
    KvIndexer(KvIndexer),
166
167
168

    /// Predicts the cached blocks based on requests on a TTL basis.
    /// Currently does not persist or snapshot states (WIP to enable that).
169
    ApproxKvIndexer(ApproxKvIndexer),
170
171
172
173

    /// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
    /// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
    None,
174
175
176
177
178
179
180
181
182
183
}

impl Indexer {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
            Indexer::ApproxKvIndexer(indexer) => indexer.find_matches(sequence).await,
184
185
186
            Indexer::None => Ok(OverlapScores {
                scores: HashMap::new(),
                frequencies: Vec::new(),
187
                tree_sizes: HashMap::new(),
188
            }),
189
190
        }
    }
191
192
193
194
195

    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.dump_events().await,
            Indexer::ApproxKvIndexer(indexer) => indexer.dump_events().await,
196
197
198
199
200
            Indexer::None => {
                panic!(
                    "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
                );
            }
201
202
        }
    }
203
204
}

205
206
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
207
pub struct KvRouter {
208
209
210
    indexer: Indexer,

    // How about a Box<dyn KvIndexerInterface>
211
    scheduler: KvScheduler,
212

213
    block_size: u32,
214
215

    kv_router_config: KvRouterConfig,
Yan Ru Pei's avatar
Yan Ru Pei committed
216
217

    cancellation_token: tokio_util::sync::CancellationToken,
218
219

    client: Client,
220
221
222
223
}

impl KvRouter {
    pub async fn new(
224
225
        endpoint: Endpoint,
        client: Client,
226
        block_size: u32,
227
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
228
        kv_router_config: Option<KvRouterConfig>,
229
        consumer_uuid: String,
230
    ) -> Result<Self> {
231
        let kv_router_config = kv_router_config.unwrap_or_default();
232
        let component = endpoint.component();
233
        let cancellation_token = component.drt().primary_token();
234

235
        let instance_ids_rx = client.instance_avail_watcher();
236

237
238
        // Watch for runtime config updates via discovery interface
        let discovery = component.drt().discovery();
239
        let endpoint_id = endpoint.id();
240
        let discovery_key = DiscoveryQuery::EndpointModels {
241
242
243
            namespace: endpoint_id.namespace.clone(),
            component: endpoint_id.component.clone(),
            endpoint: endpoint_id.name.clone(),
244
245
246
247
248
249
250
251
        };
        let discovery_stream = discovery
            .list_and_watch(discovery_key, Some(cancellation_token.clone()))
            .await?;
        let runtime_configs_rx =
            watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
                card.runtime_config
            });
252

253
254
255
256
        let indexer = if kv_router_config.overlap_score_weight == 0.0 {
            // When overlap_score_weight is zero, we don't need to track prefixes
            Indexer::None
        } else if kv_router_config.use_kv_events {
257
            let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);
258
259
260
261
262
            Indexer::KvIndexer(KvIndexer::new(
                cancellation_token.clone(),
                block_size,
                kv_indexer_metrics,
            ))
263
264
265
266
267
268
        } else {
            // hard code 120 seconds for now
            Indexer::ApproxKvIndexer(ApproxKvIndexer::new(
                cancellation_token.clone(),
                block_size,
                Duration::from_secs(120),
269
                Some(PruneConfig {
270
                    max_tree_size: 2usize.pow(20), // 2 ** 20 = 1048576
271
272
                    prune_target_ratio: 0.8,
                }),
273
274
            ))
        };
275

276
        let scheduler = KvScheduler::start(
277
            component.clone(),
278
            block_size,
279
            instance_ids_rx,
280
            runtime_configs_rx,
281
            selector,
282
            kv_router_config.router_replica_sync,
283
            consumer_uuid.clone(),
284
285
        )
        .await?;
286

287
        // Start unified background process if using KvIndexer
288
        if let Indexer::KvIndexer(ref kv_indexer) = indexer {
289
290
291
292
            start_kv_router_background(
                component.clone(),
                consumer_uuid,
                kv_indexer.event_sender(),
293
                kv_indexer.remove_worker_sender(),
294
295
296
                kv_router_config
                    .router_snapshot_threshold
                    .map(|_| kv_indexer.get_workers_sender()),
297
298
299
300
301
302
303
304
                kv_router_config
                    .router_snapshot_threshold
                    .map(|_| kv_indexer.snapshot_event_sender()),
                cancellation_token.clone(),
                kv_router_config.router_snapshot_threshold,
                kv_router_config.router_reset_states,
            )
            .await?;
305
        }
306

307
        tracing::info!("KV Routing initialized");
308
        Ok(Self {
309
            indexer,
310
            scheduler,
311
            block_size,
312
            kv_router_config,
Yan Ru Pei's avatar
Yan Ru Pei committed
313
            cancellation_token,
314
            client,
315
        })
316
317
    }

318
319
320
321
322
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

323
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
324
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
Yan Ru Pei's avatar
Yan Ru Pei committed
325
326
    /// Now also takes optional context_id for request tracking
    pub async fn find_best_match(
327
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
328
        context_id: Option<&str>,
329
        tokens: &[u32],
330
        router_config_override: Option<&RouterConfigOverride>,
331
        update_states: bool,
Yan Ru Pei's avatar
Yan Ru Pei committed
332
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
Yan Ru Pei's avatar
Yan Ru Pei committed
333
334
335
336
337
        // Validate that context_id is provided when update_states is true
        if update_states && context_id.is_none() {
            panic!("context_id must be provided if update_states is true");
        }

338
        let isl_tokens = tokens.len();
339

340
341
342
343
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let seq_hashes = compute_seq_hash_for_block(&block_hashes);

        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
344

345
346
347
348
349
350
351
352
353
354
355
356
357
        // Determine who needs seq_hashes
        let approx_indexer_needs_it = matches!(self.indexer, Indexer::ApproxKvIndexer(_));
        let scheduler_needs_it = self.kv_router_config.router_track_active_blocks;

        // Optimize cloning: only clone if both need it, otherwise move
        let (maybe_seq_hashes_1, maybe_seq_hashes_2) =
            match (approx_indexer_needs_it, scheduler_needs_it) {
                (true, true) => (Some(seq_hashes.clone()), Some(seq_hashes)),
                (true, false) => (Some(seq_hashes), None),
                (false, true) => (None, Some(seq_hashes)),
                (false, false) => (None, None),
            };

Yan Ru Pei's avatar
Yan Ru Pei committed
358
        let best_worker = self
359
            .scheduler
360
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
361
                context_id.map(|s| s.to_string()),
362
                isl_tokens,
363
                maybe_seq_hashes_2,
364
                overlap_scores.clone(),
365
                router_config_override,
366
                update_states,
367
            )
368
            .await?;
369

370
371
        if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
            indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
372
                .process_routing_decision(best_worker, block_hashes, maybe_seq_hashes_1.unwrap())
373
374
375
376
                .await
                .unwrap();
        };

377
378
        let overlap_amount = overlap_scores
            .scores
Yan Ru Pei's avatar
Yan Ru Pei committed
379
            .get(&best_worker)
380
381
            .copied()
            .unwrap_or(0);
Yan Ru Pei's avatar
Yan Ru Pei committed
382
        Ok((best_worker, overlap_amount))
383
384
    }

385
386
387
388
389
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
        overlap_blocks: u32,
Yan Ru Pei's avatar
Yan Ru Pei committed
390
        worker: WorkerWithDpRank,
391
392
    ) {
        let isl_tokens = tokens.len();
393
394
395
396
397

        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
            compute_seq_hash_for_block(&block_hashes)
        });
398

399
400
        if let Err(e) = self
            .scheduler
401
            .add_request(
402
                request_id.clone(),
403
                maybe_seq_hashes,
404
405
                isl_tokens,
                overlap_blocks,
Yan Ru Pei's avatar
Yan Ru Pei committed
406
                worker,
407
            )
408
409
410
411
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
412
413
    }

414
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
415
        self.scheduler.mark_prefill_completed(request_id).await
416
417
    }

418
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
419
        self.scheduler.free(request_id).await
420
    }
421

422
    pub fn block_size(&self) -> u32 {
423
424
        self.block_size
    }
425

426
427
428
429
430
431
    /// Get potential prefill and decode loads for all workers
    pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
        let isl_tokens = tokens.len();
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;

432
433
434
435
436
        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
            compute_seq_hash_for_block(&block_hashes)
        });

437
438
        Ok(self
            .scheduler
439
            .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
440
441
442
            .await)
    }

443
444
445
446
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
447
448
}

Michael Feil's avatar
Michael Feil committed
449
450
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
451
452
453
454
455
456
457
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
458
459
460
461
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
            RouterRequest::New { tokens } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
462
                let (best_worker, overlap_blocks) = self
Yan Ru Pei's avatar
Yan Ru Pei committed
463
                    .find_best_match(Some(&context_id), &tokens, None, true)
Michael Feil's avatar
Michael Feil committed
464
465
466
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
467
468
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
469
470
471
                    overlap_blocks,
                }
            }
472
473
474
475
476
477
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
            RouterRequest::MarkFree => RouterResponse::FreeMarked {
                success: self.free(&context_id).await.is_ok(),
            },
Michael Feil's avatar
Michael Feil committed
478
        };
479
480
481
482
483
484

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
485
486

pub struct KvPushRouter {
487
    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
488
    pub chooser: Arc<KvRouter>,
489
490
491
492
}

impl KvPushRouter {
    pub fn new(
493
        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
494
495
496
497
498
499
500
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
}

#[async_trait]
501
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
502
503
    for KvPushRouter
{
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
    /// Generate method that handles KV-aware routing with three distinct behaviors:
    ///
    /// 1. **If `query_instance_id` annotation is set**:
    ///    - Returns the best matching worker ID without routing the request
    ///    - Does NOT update any router local states
    ///    - Response includes worker_instance_id and token_data annotations
    ///
    /// 2. **If `backend_instance_id` is set in the request**:
    ///    - Routes directly to the specified backend instance
    ///    - DOES update router states to track this request (unless query_instance_id is also set)
    ///    - Bypasses the normal KV matching logic
    ///
    /// 3. **If neither are set (default behavior)**:
    ///    - Finds the best worker based on KV cache overlap
    ///    - Updates router states to track the request
    ///    - Routes to the selected worker
    ///
    /// The router state updates include tracking active sequences and managing
    /// prefill/completion lifecycle for proper KV cache management.
523
524
    async fn generate(
        &self,
525
        request: SingleIn<PreprocessedRequest>,
526
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
        // Extract context ID for request tracking
        let context_id = request.context().id().to_string();

        // Check if this is a query_instance_id request first
        let query_instance_id = request.has_annotation("query_instance_id");

        let (instance_id, dp_rank, overlap_amount) = if let Some(id) = request.backend_instance_id {
            // If instance_id is set, use it and compute actual overlap
            let dp_rank = request.dp_rank.unwrap_or(0);
            if query_instance_id {
                tracing::debug!(
                    "backend_instance_id is set, routing to instance {id} with dp_rank {dp_rank} and ignoring query_instance_id annotation"
                );
            }

            // Compute actual overlap blocks by querying the indexer
            let block_hashes =
                compute_block_hash_for_seq(&request.token_ids, self.chooser.block_size());
            let overlap_scores = self.chooser.indexer.find_matches(block_hashes).await?;
            let worker = WorkerWithDpRank::new(id, dp_rank);
            let overlap_blocks = overlap_scores.scores.get(&worker).copied().unwrap_or(0);

            self.chooser
                .add_request(
                    context_id.clone(),
                    &request.token_ids,
                    overlap_blocks,
                    worker,
                )
                .await;
            (id, dp_rank, overlap_blocks)
        } else {
            // Otherwise, find the best match
            let (best_worker, overlap_amount) = self
                .chooser
                .find_best_match(
                    Some(&context_id),
                    &request.token_ids,
                    request.router_config_override.as_ref(),
                    !query_instance_id, // Don't update states if query_instance_id
                )
                .await?;
            (best_worker.worker_id, best_worker.dp_rank, overlap_amount)
        };

        // if request has the annotation "query_instance_id",
        // then the request will not be routed to the worker,
        // and instead the worker_instance_id will be returned.
        let stream_context = request.context().clone();
        if query_instance_id {
            let instance_id_str = instance_id.to_string();
            let response = Annotated::from_annotation("worker_instance_id", &instance_id_str)?;

            // Return the tokens in nvext.token_data format
            let response_tokens = Annotated::from_annotation("token_data", &request.token_ids)?;
            tracing::trace!(
                "Tokens requested in the response through the query_instance_id annotation: {:?}",
                response_tokens
            );
            let stream = stream::iter(vec![response, response_tokens]);
            return Ok(ResponseStream::new(Box::pin(stream), stream_context));
        }
        let (mut backend_input, context) = request.into_parts();
        backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
        backend_input.dp_rank = Some(dp_rank);
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609

        // Check if worker_id is requested in extra_fields
        let should_populate_worker_id = backend_input
            .extra_fields
            .as_deref()
            .unwrap_or(&[])
            .iter()
            .any(|s| s == "worker_id");

        // Get prefill worker ID if available (stored by PrefillRouter)
        // In aggregated mode, prefill_worker_id is None, so we use decode_worker_id for both
        let decode_worker_id = instance_id;
        let prefill_worker_id = context
            .get::<u64>("prefill_worker_id")
            .ok()
            .map(|arc| *arc)
            .or(Some(decode_worker_id)); // Use decode_worker_id if no separate prefill worker

610
611
612
613
614
615
616
617
618
        let updated_request = context.map(|_| backend_input);

        let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
        let stream_context = response_stream.context();
        let chooser = self.chooser.clone();
        let context_for_monitoring = stream_context.clone();

        let wrapped_stream = Box::pin(async_stream::stream! {
            let mut prefill_marked = false;
619
            let mut first_item = true;
620
621
622
623
624
625
626
627

            loop {
                tokio::select! {
                    biased;

                    _ = context_for_monitoring.stopped() => {
                        tracing::debug!("Request {context_id} cancelled, ending stream");
                        break;
628
                    }
Yan Ru Pei's avatar
Yan Ru Pei committed
629

630
                    item = response_stream.next() => {
631
                        let Some(mut item) = item else {
632
633
                            break;
                        };
634

635
636
                        if !prefill_marked {
                            if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
637
                                tracing::warn!("Failed to mark prefill completed for request {context_id}: {e}");
638
                            }
639
                            prefill_marked = true;
640
                        }
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662

                        yield item.clone();

                        // Inject worker_id in first item's disaggregated_params if requested
                        if first_item && should_populate_worker_id {
                            if let Some(ref mut data) = item.data {
                                // Add worker_id to disaggregated_params
                                let worker_id_json = json!({
                                    "prefill_worker_id": prefill_worker_id,
                                    "decode_worker_id": decode_worker_id,
                                });

                                if let Some(ref mut params) = data.disaggregated_params {
                                    if let Some(obj) = params.as_object_mut() {
                                        obj.insert("worker_id".to_string(), worker_id_json);
                                    }
                                } else {
                                    data.disaggregated_params = Some(json!({"worker_id": worker_id_json}));
                                }
                            }
                            first_item = false;
                        }
663
                    }
664
665
                }
            }
666

667
            if let Err(e) = chooser.free(&context_id).await {
668
                tracing::warn!("Failed to free request {context_id}: {e}");
669
            }
670
671
        });
        Ok(ResponseStream::new(wrapped_stream, stream_context))
672
673
    }
}
Yan Ru Pei's avatar
Yan Ru Pei committed
674
675
676
677
678
679
680

impl Drop for KvRouter {
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}