kv_router.rs 22.8 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
use std::sync::Arc;
6
use std::time::Duration;
7

8
use anyhow::Result;
9
use derive_builder::Builder;
10
use dynamo_runtime::{
11
    component::{Client, Endpoint},
12
    discovery::{DiscoveryQuery, watch_and_extract_field},
13
    pipeline::{
14
15
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
        SingleIn, async_trait,
16
17
    },
    protocols::annotated::Annotated,
18
    traits::DistributedRuntimeProvider,
19
20
};
use futures::stream::{self, StreamExt};
21
use serde::{Deserialize, Serialize};
22

23
pub mod approx;
24
pub mod indexer;
25
pub mod prefill_router;
26
27
pub mod protocols;
pub mod publisher;
28
pub mod recorder;
29
30
pub mod scheduler;
pub mod scoring;
31
pub mod sequence;
32
pub mod subscriber;
33

34
35
pub use prefill_router::PrefillRouter;

36
37
use crate::{
    kv_router::{
38
        approx::ApproxKvIndexer,
39
        approx::PruneConfig,
40
        indexer::{
41
42
            KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
            compute_block_hash_for_seq, compute_seq_hash_for_block,
43
        },
Yan Ru Pei's avatar
Yan Ru Pei committed
44
45
46
        protocols::{
            LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult, WorkerWithDpRank,
        },
47
        scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
48
        subscriber::start_kv_router_background,
49
    },
50
    local_model::runtime_config::ModelRuntimeConfig,
51
    model_card::ModelDeploymentCard,
52
    preprocessor::PreprocessedRequest,
53
    protocols::common::llm_backend::LLMEngineOutput,
54
55
};

56
57
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
58
59
60
61
62

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
63
pub const KV_EVENT_SUBJECT: &str = "kv_events";
64
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
65
66
67
68
69
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
70

71
72
73
74
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

75
76
77
78
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
79
        workers: &HashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>,
80
        request: &SchedulingRequest,
81
        block_size: u32,
82
83
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
84

85
86
87
88
89
90
91
92
93
94
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize)]
pub struct RouterConfigOverride {
    #[builder(default)]
    pub overlap_score_weight: Option<f64>,

    #[builder(default)]
    pub router_temperature: Option<f64>,
}

95
/// KV Router configuration parameters
96
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
97
98
99
pub struct KvRouterConfig {
    pub overlap_score_weight: f64,

100
    pub router_temperature: f64,
101

102
103
    pub use_kv_events: bool,

104
105
    pub router_replica_sync: bool,

106
107
108
    /// Whether to track active blocks in the router (default: true)
    pub router_track_active_blocks: bool,

109
110
111
    /// Threshold for triggering snapshots. If None, no snapshots will be performed.
    pub router_snapshot_threshold: Option<u32>,

112
    /// Whether to reset the router state on startup (default: false)
113
    pub router_reset_states: bool,
114
115
116
117
118
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
119
            overlap_score_weight: 1.0,
120
            router_temperature: 0.0,
121
            use_kv_events: true,
122
            router_replica_sync: false,
123
            router_track_active_blocks: true,
124
            router_snapshot_threshold: Some(1000000),
125
            router_reset_states: false,
126
127
128
129
130
131
132
        }
    }
}

impl KvRouterConfig {
    /// Create a new KvRouterConfig with optional weight values.
    /// If a weight is None, the default value will be used.
133
    #[allow(clippy::too_many_arguments)]
134
135
    pub fn new(
        overlap_score_weight: Option<f64>,
136
        temperature: Option<f64>,
137
        use_kv_events: Option<bool>,
138
        replica_sync: Option<bool>,
139
        track_active_blocks: Option<bool>,
140
141
        router_snapshot_threshold: Option<Option<u32>>,
        router_reset_states: Option<bool>,
142
143
144
145
    ) -> Self {
        let default = Self::default();
        Self {
            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
146
            router_temperature: temperature.unwrap_or(default.router_temperature),
147
            use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
148
            router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
149
150
            router_track_active_blocks: track_active_blocks
                .unwrap_or(default.router_track_active_blocks),
151
152
153
            router_snapshot_threshold: router_snapshot_threshold
                .unwrap_or(default.router_snapshot_threshold),
            router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
154
155
156
157
        }
    }
}

158
159
160
// TODO: is there a way (macro) to auto-derive the KvIndexerInterface trait for this
// since both variants implement it
pub enum Indexer {
161
162
    /// Updates itself based on KV events emitted by backend workers.
    /// Has the ability to persist and snapshot states.
163
    KvIndexer(KvIndexer),
164
165
166

    /// Predicts the cached blocks based on requests on a TTL basis.
    /// Currently does not persist or snapshot states (WIP to enable that).
167
    ApproxKvIndexer(ApproxKvIndexer),
168
169
170
171

    /// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
    /// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
    None,
172
173
174
175
176
177
178
179
180
181
}

impl Indexer {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
            Indexer::ApproxKvIndexer(indexer) => indexer.find_matches(sequence).await,
182
183
184
            Indexer::None => Ok(OverlapScores {
                scores: HashMap::new(),
                frequencies: Vec::new(),
185
                tree_sizes: HashMap::new(),
186
            }),
187
188
        }
    }
189
190
191
192
193

    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.dump_events().await,
            Indexer::ApproxKvIndexer(indexer) => indexer.dump_events().await,
194
195
196
197
198
            Indexer::None => {
                panic!(
                    "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
                );
            }
199
200
        }
    }
201
202
}

203
204
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
205
pub struct KvRouter {
206
207
208
    indexer: Indexer,

    // How about a Box<dyn KvIndexerInterface>
209
    scheduler: KvScheduler,
210

211
    block_size: u32,
212
213

    kv_router_config: KvRouterConfig,
Yan Ru Pei's avatar
Yan Ru Pei committed
214
215

    cancellation_token: tokio_util::sync::CancellationToken,
216
217

    client: Client,
218
219
220
221
}

impl KvRouter {
    pub async fn new(
222
223
        endpoint: Endpoint,
        client: Client,
224
        block_size: u32,
225
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
226
        kv_router_config: Option<KvRouterConfig>,
227
        consumer_uuid: String,
228
    ) -> Result<Self> {
229
        let kv_router_config = kv_router_config.unwrap_or_default();
230
        let component = endpoint.component();
231
        let cancellation_token = component.drt().primary_token();
232

233
        let instance_ids_rx = client.instance_avail_watcher();
234

235
236
        // Watch for runtime config updates via discovery interface
        let discovery = component.drt().discovery();
237
        let endpoint_id = endpoint.id();
238
        let discovery_key = DiscoveryQuery::EndpointModels {
239
240
241
            namespace: endpoint_id.namespace.clone(),
            component: endpoint_id.component.clone(),
            endpoint: endpoint_id.name.clone(),
242
243
244
245
246
247
248
249
        };
        let discovery_stream = discovery
            .list_and_watch(discovery_key, Some(cancellation_token.clone()))
            .await?;
        let runtime_configs_rx =
            watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
                card.runtime_config
            });
250

251
252
253
254
        let indexer = if kv_router_config.overlap_score_weight == 0.0 {
            // When overlap_score_weight is zero, we don't need to track prefixes
            Indexer::None
        } else if kv_router_config.use_kv_events {
255
            let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);
256
257
258
259
260
            Indexer::KvIndexer(KvIndexer::new(
                cancellation_token.clone(),
                block_size,
                kv_indexer_metrics,
            ))
261
262
263
264
265
266
        } else {
            // hard code 120 seconds for now
            Indexer::ApproxKvIndexer(ApproxKvIndexer::new(
                cancellation_token.clone(),
                block_size,
                Duration::from_secs(120),
267
268
269
270
                Some(PruneConfig {
                    max_tree_size: 2usize.pow(14), // 2** 14 = 16384
                    prune_target_ratio: 0.8,
                }),
271
272
            ))
        };
273

274
        let scheduler = KvScheduler::start(
275
            component.clone(),
276
            block_size,
277
            instance_ids_rx,
278
            runtime_configs_rx,
279
            selector,
280
            kv_router_config.router_replica_sync,
281
            consumer_uuid.clone(),
282
283
        )
        .await?;
284

285
        // Start unified background process if using KvIndexer
286
        if let Indexer::KvIndexer(ref kv_indexer) = indexer {
287
288
289
290
            start_kv_router_background(
                component.clone(),
                consumer_uuid,
                kv_indexer.event_sender(),
291
                kv_indexer.remove_worker_sender(),
292
293
294
                kv_router_config
                    .router_snapshot_threshold
                    .map(|_| kv_indexer.get_workers_sender()),
295
296
297
298
299
300
301
302
                kv_router_config
                    .router_snapshot_threshold
                    .map(|_| kv_indexer.snapshot_event_sender()),
                cancellation_token.clone(),
                kv_router_config.router_snapshot_threshold,
                kv_router_config.router_reset_states,
            )
            .await?;
303
        }
304

305
        tracing::info!("KV Routing initialized");
306
        Ok(Self {
307
            indexer,
308
            scheduler,
309
            block_size,
310
            kv_router_config,
Yan Ru Pei's avatar
Yan Ru Pei committed
311
            cancellation_token,
312
            client,
313
        })
314
315
    }

316
317
318
319
320
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

321
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
322
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
Yan Ru Pei's avatar
Yan Ru Pei committed
323
324
    /// Now also takes optional context_id for request tracking
    pub async fn find_best_match(
325
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
326
        context_id: Option<&str>,
327
        tokens: &[u32],
328
        router_config_override: Option<&RouterConfigOverride>,
329
        update_states: bool,
Yan Ru Pei's avatar
Yan Ru Pei committed
330
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
Yan Ru Pei's avatar
Yan Ru Pei committed
331
332
333
334
335
        // Validate that context_id is provided when update_states is true
        if update_states && context_id.is_none() {
            panic!("context_id must be provided if update_states is true");
        }

336
        let isl_tokens = tokens.len();
337

338
339
340
341
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let seq_hashes = compute_seq_hash_for_block(&block_hashes);

        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
342

343
344
345
346
347
348
349
350
351
352
353
354
355
        // Determine who needs seq_hashes
        let approx_indexer_needs_it = matches!(self.indexer, Indexer::ApproxKvIndexer(_));
        let scheduler_needs_it = self.kv_router_config.router_track_active_blocks;

        // Optimize cloning: only clone if both need it, otherwise move
        let (maybe_seq_hashes_1, maybe_seq_hashes_2) =
            match (approx_indexer_needs_it, scheduler_needs_it) {
                (true, true) => (Some(seq_hashes.clone()), Some(seq_hashes)),
                (true, false) => (Some(seq_hashes), None),
                (false, true) => (None, Some(seq_hashes)),
                (false, false) => (None, None),
            };

Yan Ru Pei's avatar
Yan Ru Pei committed
356
        let best_worker = self
357
            .scheduler
358
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
359
                context_id.map(|s| s.to_string()),
360
                isl_tokens,
361
                maybe_seq_hashes_2,
362
                overlap_scores.clone(),
363
                router_config_override,
364
                update_states,
365
            )
366
            .await?;
367

368
369
        if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
            indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
370
                .process_routing_decision(best_worker, block_hashes, maybe_seq_hashes_1.unwrap())
371
372
373
374
                .await
                .unwrap();
        };

375
376
        let overlap_amount = overlap_scores
            .scores
Yan Ru Pei's avatar
Yan Ru Pei committed
377
            .get(&best_worker)
378
379
            .copied()
            .unwrap_or(0);
Yan Ru Pei's avatar
Yan Ru Pei committed
380
        Ok((best_worker, overlap_amount))
381
382
    }

383
384
385
386
387
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
        overlap_blocks: u32,
Yan Ru Pei's avatar
Yan Ru Pei committed
388
        worker: WorkerWithDpRank,
389
390
    ) {
        let isl_tokens = tokens.len();
391
392
393
394
395

        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
            compute_seq_hash_for_block(&block_hashes)
        });
396
397
398
399

        self.scheduler
            .add_request(
                request_id,
400
                maybe_seq_hashes,
401
402
                isl_tokens,
                overlap_blocks,
Yan Ru Pei's avatar
Yan Ru Pei committed
403
                worker,
404
405
406
407
            )
            .await;
    }

408
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<()> {
409
        self.scheduler.mark_prefill_completed(request_id).await
410
411
    }

412
    pub async fn free(&self, request_id: &str) -> Result<()> {
413
        self.scheduler.free(request_id).await
414
    }
415

416
    pub fn block_size(&self) -> u32 {
417
418
        self.block_size
    }
419

420
421
422
423
424
425
    /// Get potential prefill and decode loads for all workers
    pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
        let isl_tokens = tokens.len();
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;

426
427
428
429
430
        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
            compute_seq_hash_for_block(&block_hashes)
        });

431
432
        Ok(self
            .scheduler
433
            .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
434
435
436
            .await)
    }

437
438
439
440
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
441
442
}

Michael Feil's avatar
Michael Feil committed
443
444
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
445
446
447
448
449
450
451
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
452
453
454
455
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
            RouterRequest::New { tokens } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
456
                let (best_worker, overlap_blocks) = self
Yan Ru Pei's avatar
Yan Ru Pei committed
457
                    .find_best_match(Some(&context_id), &tokens, None, true)
Michael Feil's avatar
Michael Feil committed
458
459
460
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
461
462
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
463
464
465
                    overlap_blocks,
                }
            }
466
467
468
469
470
471
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
            RouterRequest::MarkFree => RouterResponse::FreeMarked {
                success: self.free(&context_id).await.is_ok(),
            },
Michael Feil's avatar
Michael Feil committed
472
        };
473
474
475
476
477
478

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
479
480

pub struct KvPushRouter {
481
    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
482
    pub chooser: Arc<KvRouter>,
483
484
485
486
}

impl KvPushRouter {
    pub fn new(
487
        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
488
489
490
491
492
493
494
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
}

#[async_trait]
495
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
496
497
    for KvPushRouter
{
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
    /// Generate method that handles KV-aware routing with three distinct behaviors:
    ///
    /// 1. **If `query_instance_id` annotation is set**:
    ///    - Returns the best matching worker ID without routing the request
    ///    - Does NOT update any router local states
    ///    - Response includes worker_instance_id and token_data annotations
    ///
    /// 2. **If `backend_instance_id` is set in the request**:
    ///    - Routes directly to the specified backend instance
    ///    - DOES update router states to track this request (unless query_instance_id is also set)
    ///    - Bypasses the normal KV matching logic
    ///
    /// 3. **If neither are set (default behavior)**:
    ///    - Finds the best worker based on KV cache overlap
    ///    - Updates router states to track the request
    ///    - Routes to the selected worker
    ///
    /// The router state updates include tracking active sequences and managing
    /// prefill/completion lifecycle for proper KV cache management.
517
518
    async fn generate(
        &self,
519
        request: SingleIn<PreprocessedRequest>,
520
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
        // Extract context ID for request tracking
        let context_id = request.context().id().to_string();

        // Check if this is a query_instance_id request first
        let query_instance_id = request.has_annotation("query_instance_id");

        let (instance_id, dp_rank, overlap_amount) = if let Some(id) = request.backend_instance_id {
            // If instance_id is set, use it and compute actual overlap
            let dp_rank = request.dp_rank.unwrap_or(0);
            if query_instance_id {
                tracing::debug!(
                    "backend_instance_id is set, routing to instance {id} with dp_rank {dp_rank} and ignoring query_instance_id annotation"
                );
            }

            // Compute actual overlap blocks by querying the indexer
            let block_hashes =
                compute_block_hash_for_seq(&request.token_ids, self.chooser.block_size());
            let overlap_scores = self.chooser.indexer.find_matches(block_hashes).await?;
            let worker = WorkerWithDpRank::new(id, dp_rank);
            let overlap_blocks = overlap_scores.scores.get(&worker).copied().unwrap_or(0);

            self.chooser
                .add_request(
                    context_id.clone(),
                    &request.token_ids,
                    overlap_blocks,
                    worker,
                )
                .await;
            (id, dp_rank, overlap_blocks)
        } else {
            // Otherwise, find the best match
            let (best_worker, overlap_amount) = self
                .chooser
                .find_best_match(
                    Some(&context_id),
                    &request.token_ids,
                    request.router_config_override.as_ref(),
                    !query_instance_id, // Don't update states if query_instance_id
                )
                .await?;
            (best_worker.worker_id, best_worker.dp_rank, overlap_amount)
        };

        // if request has the annotation "query_instance_id",
        // then the request will not be routed to the worker,
        // and instead the worker_instance_id will be returned.
        let stream_context = request.context().clone();
        if query_instance_id {
            let instance_id_str = instance_id.to_string();
            let response = Annotated::from_annotation("worker_instance_id", &instance_id_str)?;

            // Return the tokens in nvext.token_data format
            let response_tokens = Annotated::from_annotation("token_data", &request.token_ids)?;
            tracing::trace!(
                "Tokens requested in the response through the query_instance_id annotation: {:?}",
                response_tokens
            );
            let stream = stream::iter(vec![response, response_tokens]);
            return Ok(ResponseStream::new(Box::pin(stream), stream_context));
        }
        let (mut backend_input, context) = request.into_parts();
        backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
        backend_input.dp_rank = Some(dp_rank);
        let updated_request = context.map(|_| backend_input);

        let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
        let stream_context = response_stream.context();
        let chooser = self.chooser.clone();
        let context_for_monitoring = stream_context.clone();

        let wrapped_stream = Box::pin(async_stream::stream! {
            let mut prefill_marked = false;

            loop {
                tokio::select! {
                    biased;

                    _ = context_for_monitoring.stopped() => {
                        tracing::debug!("Request {context_id} cancelled, ending stream");
                        break;
603
                    }
Yan Ru Pei's avatar
Yan Ru Pei committed
604

605
606
607
608
                    item = response_stream.next() => {
                        let Some(item) = item else {
                            break;
                        };
609

610
611
612
                        if !prefill_marked {
                            if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
                                tracing::warn!("Failed to mark prefill completed for request {context_id}: {e:?}");
613
                            }
614
                            prefill_marked = true;
615
                        }
616
                        yield item;
617
                    }
618
619
                }
            }
620

621
622
            if let Err(e) = chooser.free(&context_id).await {
                tracing::warn!("Failed to free request {context_id}: {e:?}");
623
            }
624
625
        });
        Ok(ResponseStream::new(wrapped_stream, stream_context))
626
627
    }
}
Yan Ru Pei's avatar
Yan Ru Pei committed
628
629
630
631
632
633
634

impl Drop for KvRouter {
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}