kv_router.rs 24.6 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
use std::sync::Arc;
6
use std::time::Duration;
7

8
use anyhow::Result;
9
use derive_builder::Builder;
10
use dynamo_runtime::{
11
    component::{Client, Endpoint},
12
    discovery::{DiscoveryQuery, watch_and_extract_field},
13
    pipeline::{
14
15
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
        SingleIn, async_trait,
16
17
    },
    protocols::annotated::Annotated,
18
    traits::DistributedRuntimeProvider,
19
20
};
use futures::stream::{self, StreamExt};
21
use serde::{Deserialize, Serialize};
22
use serde_json::json;
23

24
pub mod approx;
25
pub mod indexer;
26
pub mod prefill_router;
27
28
pub mod protocols;
pub mod publisher;
29
pub mod recorder;
30
31
pub mod scheduler;
pub mod scoring;
32
pub mod sequence;
33
pub mod subscriber;
34

35
36
pub use prefill_router::PrefillRouter;

37
38
use crate::{
    kv_router::{
39
        approx::ApproxKvIndexer,
40
        approx::PruneConfig,
41
        indexer::{
42
43
            KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
            compute_block_hash_for_seq, compute_seq_hash_for_block,
44
        },
Yan Ru Pei's avatar
Yan Ru Pei committed
45
46
47
        protocols::{
            LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult, WorkerWithDpRank,
        },
48
        scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
49
        subscriber::start_kv_router_background,
50
    },
51
    local_model::runtime_config::ModelRuntimeConfig,
52
    model_card::ModelDeploymentCard,
53
    preprocessor::PreprocessedRequest,
54
    protocols::common::llm_backend::LLMEngineOutput,
55
56
};

57
58
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
59
60
61
62
63

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
64
pub const KV_EVENT_SUBJECT: &str = "kv_events";
65
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
66
67
68
69
70
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
71

72
73
74
75
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

76
77
78
79
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
80
        workers: &HashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>,
81
        request: &SchedulingRequest,
82
        block_size: u32,
83
84
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
85

86
87
88
89
90
91
92
93
94
95
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize)]
pub struct RouterConfigOverride {
    #[builder(default)]
    pub overlap_score_weight: Option<f64>,

    #[builder(default)]
    pub router_temperature: Option<f64>,
}

96
/// KV Router configuration parameters
97
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
98
99
100
pub struct KvRouterConfig {
    pub overlap_score_weight: f64,

101
    pub router_temperature: f64,
102

103
104
    pub use_kv_events: bool,

105
106
    pub router_replica_sync: bool,

107
108
109
    /// Whether to track active blocks in the router (default: true)
    pub router_track_active_blocks: bool,

110
111
112
    /// Threshold for triggering snapshots. If None, no snapshots will be performed.
    pub router_snapshot_threshold: Option<u32>,

113
    /// Whether to reset the router state on startup (default: false)
114
    pub router_reset_states: bool,
115
116
117
118
119
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
120
            overlap_score_weight: 1.0,
121
            router_temperature: 0.0,
122
            use_kv_events: true,
123
            router_replica_sync: false,
124
            router_track_active_blocks: true,
125
            router_snapshot_threshold: Some(1000000),
126
            router_reset_states: false,
127
128
129
130
131
132
133
        }
    }
}

impl KvRouterConfig {
    /// Create a new KvRouterConfig with optional weight values.
    /// If a weight is None, the default value will be used.
134
    #[allow(clippy::too_many_arguments)]
135
136
    pub fn new(
        overlap_score_weight: Option<f64>,
137
        temperature: Option<f64>,
138
        use_kv_events: Option<bool>,
139
        replica_sync: Option<bool>,
140
        track_active_blocks: Option<bool>,
141
142
        router_snapshot_threshold: Option<Option<u32>>,
        router_reset_states: Option<bool>,
143
144
145
146
    ) -> Self {
        let default = Self::default();
        Self {
            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
147
            router_temperature: temperature.unwrap_or(default.router_temperature),
148
            use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
149
            router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
150
151
            router_track_active_blocks: track_active_blocks
                .unwrap_or(default.router_track_active_blocks),
152
153
154
            router_snapshot_threshold: router_snapshot_threshold
                .unwrap_or(default.router_snapshot_threshold),
            router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
155
156
157
158
        }
    }
}

159
160
161
// TODO: is there a way (macro) to auto-derive the KvIndexerInterface trait for this
// since both variants implement it
pub enum Indexer {
162
163
    /// Updates itself based on KV events emitted by backend workers.
    /// Has the ability to persist and snapshot states.
164
    KvIndexer(KvIndexer),
165
166
167

    /// Predicts the cached blocks based on requests on a TTL basis.
    /// Currently does not persist or snapshot states (WIP to enable that).
168
    ApproxKvIndexer(ApproxKvIndexer),
169
170
171
172

    /// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
    /// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
    None,
173
174
175
176
177
178
179
180
181
182
}

impl Indexer {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
            Indexer::ApproxKvIndexer(indexer) => indexer.find_matches(sequence).await,
183
184
185
            Indexer::None => Ok(OverlapScores {
                scores: HashMap::new(),
                frequencies: Vec::new(),
186
                tree_sizes: HashMap::new(),
187
            }),
188
189
        }
    }
190
191
192
193
194

    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.dump_events().await,
            Indexer::ApproxKvIndexer(indexer) => indexer.dump_events().await,
195
196
197
198
199
            Indexer::None => {
                panic!(
                    "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
                );
            }
200
201
        }
    }
202
203
}

204
205
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
206
pub struct KvRouter {
207
208
209
    indexer: Indexer,

    // How about a Box<dyn KvIndexerInterface>
210
    scheduler: KvScheduler,
211

212
    block_size: u32,
213
214

    kv_router_config: KvRouterConfig,
Yan Ru Pei's avatar
Yan Ru Pei committed
215
216

    cancellation_token: tokio_util::sync::CancellationToken,
217
218

    client: Client,
219
220
221
222
}

impl KvRouter {
    pub async fn new(
223
224
        endpoint: Endpoint,
        client: Client,
225
        block_size: u32,
226
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
227
        kv_router_config: Option<KvRouterConfig>,
228
        consumer_uuid: String,
229
    ) -> Result<Self> {
230
        let kv_router_config = kv_router_config.unwrap_or_default();
231
        let component = endpoint.component();
232
        let cancellation_token = component.drt().primary_token();
233

234
        let instance_ids_rx = client.instance_avail_watcher();
235

236
237
        // Watch for runtime config updates via discovery interface
        let discovery = component.drt().discovery();
238
        let endpoint_id = endpoint.id();
239
        let discovery_key = DiscoveryQuery::EndpointModels {
240
241
242
            namespace: endpoint_id.namespace.clone(),
            component: endpoint_id.component.clone(),
            endpoint: endpoint_id.name.clone(),
243
244
245
246
247
248
249
250
        };
        let discovery_stream = discovery
            .list_and_watch(discovery_key, Some(cancellation_token.clone()))
            .await?;
        let runtime_configs_rx =
            watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
                card.runtime_config
            });
251

252
253
254
255
        let indexer = if kv_router_config.overlap_score_weight == 0.0 {
            // When overlap_score_weight is zero, we don't need to track prefixes
            Indexer::None
        } else if kv_router_config.use_kv_events {
256
            let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);
257
258
259
260
261
            Indexer::KvIndexer(KvIndexer::new(
                cancellation_token.clone(),
                block_size,
                kv_indexer_metrics,
            ))
262
263
264
265
266
267
        } else {
            // hard code 120 seconds for now
            Indexer::ApproxKvIndexer(ApproxKvIndexer::new(
                cancellation_token.clone(),
                block_size,
                Duration::from_secs(120),
268
                Some(PruneConfig {
269
                    max_tree_size: 2usize.pow(20), // 2 ** 20 = 1048576
270
271
                    prune_target_ratio: 0.8,
                }),
272
273
            ))
        };
274

275
        let scheduler = KvScheduler::start(
276
            component.clone(),
277
            block_size,
278
            instance_ids_rx,
279
            runtime_configs_rx,
280
            selector,
281
            kv_router_config.router_replica_sync,
282
            consumer_uuid.clone(),
283
284
        )
        .await?;
285

286
        // Start unified background process if using KvIndexer
287
        if let Indexer::KvIndexer(ref kv_indexer) = indexer {
288
289
290
291
            start_kv_router_background(
                component.clone(),
                consumer_uuid,
                kv_indexer.event_sender(),
292
                kv_indexer.remove_worker_sender(),
293
294
295
                kv_router_config
                    .router_snapshot_threshold
                    .map(|_| kv_indexer.get_workers_sender()),
296
297
298
299
300
301
302
303
                kv_router_config
                    .router_snapshot_threshold
                    .map(|_| kv_indexer.snapshot_event_sender()),
                cancellation_token.clone(),
                kv_router_config.router_snapshot_threshold,
                kv_router_config.router_reset_states,
            )
            .await?;
304
        }
305

306
        tracing::info!("KV Routing initialized");
307
        Ok(Self {
308
            indexer,
309
            scheduler,
310
            block_size,
311
            kv_router_config,
Yan Ru Pei's avatar
Yan Ru Pei committed
312
            cancellation_token,
313
            client,
314
        })
315
316
    }

317
318
319
320
321
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

322
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
323
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
Yan Ru Pei's avatar
Yan Ru Pei committed
324
325
    /// Now also takes optional context_id for request tracking
    pub async fn find_best_match(
326
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
327
        context_id: Option<&str>,
328
        tokens: &[u32],
329
        router_config_override: Option<&RouterConfigOverride>,
330
        update_states: bool,
Yan Ru Pei's avatar
Yan Ru Pei committed
331
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
Yan Ru Pei's avatar
Yan Ru Pei committed
332
333
334
335
336
        // Validate that context_id is provided when update_states is true
        if update_states && context_id.is_none() {
            panic!("context_id must be provided if update_states is true");
        }

337
        let isl_tokens = tokens.len();
338

339
340
341
342
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let seq_hashes = compute_seq_hash_for_block(&block_hashes);

        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
343

344
345
346
347
348
349
350
351
352
353
354
355
356
        // Determine who needs seq_hashes
        let approx_indexer_needs_it = matches!(self.indexer, Indexer::ApproxKvIndexer(_));
        let scheduler_needs_it = self.kv_router_config.router_track_active_blocks;

        // Optimize cloning: only clone if both need it, otherwise move
        let (maybe_seq_hashes_1, maybe_seq_hashes_2) =
            match (approx_indexer_needs_it, scheduler_needs_it) {
                (true, true) => (Some(seq_hashes.clone()), Some(seq_hashes)),
                (true, false) => (Some(seq_hashes), None),
                (false, true) => (None, Some(seq_hashes)),
                (false, false) => (None, None),
            };

Yan Ru Pei's avatar
Yan Ru Pei committed
357
        let best_worker = self
358
            .scheduler
359
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
360
                context_id.map(|s| s.to_string()),
361
                isl_tokens,
362
                maybe_seq_hashes_2,
363
                overlap_scores.clone(),
364
                router_config_override,
365
                update_states,
366
            )
367
            .await?;
368

369
370
        if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
            indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
371
                .process_routing_decision(best_worker, block_hashes, maybe_seq_hashes_1.unwrap())
372
373
374
375
                .await
                .unwrap();
        };

376
377
        let overlap_amount = overlap_scores
            .scores
Yan Ru Pei's avatar
Yan Ru Pei committed
378
            .get(&best_worker)
379
380
            .copied()
            .unwrap_or(0);
Yan Ru Pei's avatar
Yan Ru Pei committed
381
        Ok((best_worker, overlap_amount))
382
383
    }

384
385
386
387
388
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
        overlap_blocks: u32,
Yan Ru Pei's avatar
Yan Ru Pei committed
389
        worker: WorkerWithDpRank,
390
391
    ) {
        let isl_tokens = tokens.len();
392
393
394
395
396

        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
            compute_seq_hash_for_block(&block_hashes)
        });
397
398
399
400

        self.scheduler
            .add_request(
                request_id,
401
                maybe_seq_hashes,
402
403
                isl_tokens,
                overlap_blocks,
Yan Ru Pei's avatar
Yan Ru Pei committed
404
                worker,
405
406
407
408
            )
            .await;
    }

409
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<()> {
410
        self.scheduler.mark_prefill_completed(request_id).await
411
412
    }

413
    pub async fn free(&self, request_id: &str) -> Result<()> {
414
        self.scheduler.free(request_id).await
415
    }
416

417
    pub fn block_size(&self) -> u32 {
418
419
        self.block_size
    }
420

421
422
423
424
425
426
    /// Get potential prefill and decode loads for all workers
    pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
        let isl_tokens = tokens.len();
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;

427
428
429
430
431
        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
            compute_seq_hash_for_block(&block_hashes)
        });

432
433
        Ok(self
            .scheduler
434
            .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
435
436
437
            .await)
    }

438
439
440
441
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
442
443
}

Michael Feil's avatar
Michael Feil committed
444
445
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
446
447
448
449
450
451
452
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
453
454
455
456
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
            RouterRequest::New { tokens } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
457
                let (best_worker, overlap_blocks) = self
Yan Ru Pei's avatar
Yan Ru Pei committed
458
                    .find_best_match(Some(&context_id), &tokens, None, true)
Michael Feil's avatar
Michael Feil committed
459
460
461
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
462
463
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
464
465
466
                    overlap_blocks,
                }
            }
467
468
469
470
471
472
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
            RouterRequest::MarkFree => RouterResponse::FreeMarked {
                success: self.free(&context_id).await.is_ok(),
            },
Michael Feil's avatar
Michael Feil committed
473
        };
474
475
476
477
478
479

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
480
481

pub struct KvPushRouter {
482
    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
483
    pub chooser: Arc<KvRouter>,
484
485
486
487
}

impl KvPushRouter {
    pub fn new(
488
        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
489
490
491
492
493
494
495
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
}

#[async_trait]
496
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
497
498
    for KvPushRouter
{
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
    /// Generate method that handles KV-aware routing with three distinct behaviors:
    ///
    /// 1. **If `query_instance_id` annotation is set**:
    ///    - Returns the best matching worker ID without routing the request
    ///    - Does NOT update any router local states
    ///    - Response includes worker_instance_id and token_data annotations
    ///
    /// 2. **If `backend_instance_id` is set in the request**:
    ///    - Routes directly to the specified backend instance
    ///    - DOES update router states to track this request (unless query_instance_id is also set)
    ///    - Bypasses the normal KV matching logic
    ///
    /// 3. **If neither are set (default behavior)**:
    ///    - Finds the best worker based on KV cache overlap
    ///    - Updates router states to track the request
    ///    - Routes to the selected worker
    ///
    /// The router state updates include tracking active sequences and managing
    /// prefill/completion lifecycle for proper KV cache management.
518
519
    async fn generate(
        &self,
520
        request: SingleIn<PreprocessedRequest>,
521
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
        // Extract context ID for request tracking
        let context_id = request.context().id().to_string();

        // Check if this is a query_instance_id request first
        let query_instance_id = request.has_annotation("query_instance_id");

        let (instance_id, dp_rank, overlap_amount) = if let Some(id) = request.backend_instance_id {
            // If instance_id is set, use it and compute actual overlap
            let dp_rank = request.dp_rank.unwrap_or(0);
            if query_instance_id {
                tracing::debug!(
                    "backend_instance_id is set, routing to instance {id} with dp_rank {dp_rank} and ignoring query_instance_id annotation"
                );
            }

            // Compute actual overlap blocks by querying the indexer
            let block_hashes =
                compute_block_hash_for_seq(&request.token_ids, self.chooser.block_size());
            let overlap_scores = self.chooser.indexer.find_matches(block_hashes).await?;
            let worker = WorkerWithDpRank::new(id, dp_rank);
            let overlap_blocks = overlap_scores.scores.get(&worker).copied().unwrap_or(0);

            self.chooser
                .add_request(
                    context_id.clone(),
                    &request.token_ids,
                    overlap_blocks,
                    worker,
                )
                .await;
            (id, dp_rank, overlap_blocks)
        } else {
            // Otherwise, find the best match
            let (best_worker, overlap_amount) = self
                .chooser
                .find_best_match(
                    Some(&context_id),
                    &request.token_ids,
                    request.router_config_override.as_ref(),
                    !query_instance_id, // Don't update states if query_instance_id
                )
                .await?;
            (best_worker.worker_id, best_worker.dp_rank, overlap_amount)
        };

        // if request has the annotation "query_instance_id",
        // then the request will not be routed to the worker,
        // and instead the worker_instance_id will be returned.
        let stream_context = request.context().clone();
        if query_instance_id {
            let instance_id_str = instance_id.to_string();
            let response = Annotated::from_annotation("worker_instance_id", &instance_id_str)?;

            // Return the tokens in nvext.token_data format
            let response_tokens = Annotated::from_annotation("token_data", &request.token_ids)?;
            tracing::trace!(
                "Tokens requested in the response through the query_instance_id annotation: {:?}",
                response_tokens
            );
            let stream = stream::iter(vec![response, response_tokens]);
            return Ok(ResponseStream::new(Box::pin(stream), stream_context));
        }
        let (mut backend_input, context) = request.into_parts();
        backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
        backend_input.dp_rank = Some(dp_rank);
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604

        // Check if worker_id is requested in extra_fields
        let should_populate_worker_id = backend_input
            .extra_fields
            .as_deref()
            .unwrap_or(&[])
            .iter()
            .any(|s| s == "worker_id");

        // Get prefill worker ID if available (stored by PrefillRouter)
        // In aggregated mode, prefill_worker_id is None, so we use decode_worker_id for both
        let decode_worker_id = instance_id;
        let prefill_worker_id = context
            .get::<u64>("prefill_worker_id")
            .ok()
            .map(|arc| *arc)
            .or(Some(decode_worker_id)); // Use decode_worker_id if no separate prefill worker

605
606
607
608
609
610
611
612
613
        let updated_request = context.map(|_| backend_input);

        let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
        let stream_context = response_stream.context();
        let chooser = self.chooser.clone();
        let context_for_monitoring = stream_context.clone();

        let wrapped_stream = Box::pin(async_stream::stream! {
            let mut prefill_marked = false;
614
            let mut first_item = true;
615
616
617
618
619
620
621
622

            loop {
                tokio::select! {
                    biased;

                    _ = context_for_monitoring.stopped() => {
                        tracing::debug!("Request {context_id} cancelled, ending stream");
                        break;
623
                    }
Yan Ru Pei's avatar
Yan Ru Pei committed
624

625
                    item = response_stream.next() => {
626
                        let Some(mut item) = item else {
627
628
                            break;
                        };
629

630
631
632
                        if !prefill_marked {
                            if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
                                tracing::warn!("Failed to mark prefill completed for request {context_id}: {e:?}");
633
                            }
634
                            prefill_marked = true;
635
                        }
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657

                        yield item.clone();

                        // Inject worker_id in first item's disaggregated_params if requested
                        if first_item && should_populate_worker_id {
                            if let Some(ref mut data) = item.data {
                                // Add worker_id to disaggregated_params
                                let worker_id_json = json!({
                                    "prefill_worker_id": prefill_worker_id,
                                    "decode_worker_id": decode_worker_id,
                                });

                                if let Some(ref mut params) = data.disaggregated_params {
                                    if let Some(obj) = params.as_object_mut() {
                                        obj.insert("worker_id".to_string(), worker_id_json);
                                    }
                                } else {
                                    data.disaggregated_params = Some(json!({"worker_id": worker_id_json}));
                                }
                            }
                            first_item = false;
                        }
658
                    }
659
660
                }
            }
661

662
663
            if let Err(e) = chooser.free(&context_id).await {
                tracing::warn!("Failed to free request {context_id}: {e:?}");
664
            }
665
666
        });
        Ok(ResponseStream::new(wrapped_stream, stream_context))
667
668
    }
}
Yan Ru Pei's avatar
Yan Ru Pei committed
669
670
671
672
673
674
675

impl Drop for KvRouter {
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}