kv_router.rs 23.8 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
use std::sync::Arc;
6
use std::time::Duration;
7

8
use anyhow::Result;
9
use derive_builder::Builder;
10
use dynamo_runtime::{
11
    component::{Component, InstanceSource},
12
    pipeline::{
13
14
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
        SingleIn, async_trait,
15
16
17
    },
    prelude::*,
    protocols::annotated::Annotated,
18
    utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction},
19
20
};
use futures::stream::{self, StreamExt};
21
use serde::{Deserialize, Serialize};
22

23
pub mod approx;
24
pub mod indexer;
25
pub mod prefill_router;
26
27
pub mod protocols;
pub mod publisher;
28
pub mod recorder;
29
30
pub mod scheduler;
pub mod scoring;
31
pub mod sequence;
32
pub mod subscriber;
33

34
35
pub use prefill_router::PrefillRouter;

36
37
use crate::{
    kv_router::{
38
39
        approx::ApproxKvIndexer,
        indexer::{
40
41
            KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
            compute_block_hash_for_seq, compute_seq_hash_for_block,
42
        },
Yan Ru Pei's avatar
Yan Ru Pei committed
43
44
45
        protocols::{
            LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult, WorkerWithDpRank,
        },
46
        scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
47
        subscriber::start_kv_router_background,
48
    },
49
    local_model::runtime_config::ModelRuntimeConfig,
50
    model_card::{self, ModelDeploymentCard},
51
    preprocessor::PreprocessedRequest,
52
    protocols::common::llm_backend::LLMEngineOutput,
53
54
};

55
56
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
57
58
59
60
61

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
62
pub const KV_EVENT_SUBJECT: &str = "kv_events";
63
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
64
65
66
67
68
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
69

70
71
72
73
74
75
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";
pub const ROUTER_SNAPSHOT_LOCK: &str = "router-snapshot-lock";
pub const ROUTER_CLEANUP_LOCK: &str = "router-cleanup-lock";

76
77
78
79
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
80
        workers: &HashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>,
81
        request: &SchedulingRequest,
82
        block_size: u32,
83
84
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
85

86
87
88
89
90
91
92
93
94
95
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize)]
pub struct RouterConfigOverride {
    #[builder(default)]
    pub overlap_score_weight: Option<f64>,

    #[builder(default)]
    pub router_temperature: Option<f64>,
}

96
/// KV Router configuration parameters
97
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
98
99
100
pub struct KvRouterConfig {
    pub overlap_score_weight: f64,

101
    pub router_temperature: f64,
102

103
104
    pub use_kv_events: bool,

105
106
    pub router_replica_sync: bool,

107
108
109
    /// Whether to track active blocks in the router (default: true)
    pub router_track_active_blocks: bool,

110
111
112
    /// Threshold for triggering snapshots. If None, no snapshots will be performed.
    pub router_snapshot_threshold: Option<u32>,

113
    /// Whether to reset the router state on startup (default: false)
114
    pub router_reset_states: bool,
115
116
117
118
119
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
120
            overlap_score_weight: 1.0,
121
            router_temperature: 0.0,
122
            use_kv_events: true,
123
            router_replica_sync: false,
124
            router_track_active_blocks: true,
125
            router_snapshot_threshold: Some(1000000),
126
            router_reset_states: false,
127
128
129
130
131
132
133
        }
    }
}

impl KvRouterConfig {
    /// Create a new KvRouterConfig with optional weight values.
    /// If a weight is None, the default value will be used.
134
    #[allow(clippy::too_many_arguments)]
135
136
    pub fn new(
        overlap_score_weight: Option<f64>,
137
        temperature: Option<f64>,
138
        use_kv_events: Option<bool>,
139
        replica_sync: Option<bool>,
140
        track_active_blocks: Option<bool>,
141
142
        router_snapshot_threshold: Option<Option<u32>>,
        router_reset_states: Option<bool>,
143
144
145
146
    ) -> Self {
        let default = Self::default();
        Self {
            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
147
            router_temperature: temperature.unwrap_or(default.router_temperature),
148
            use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
149
            router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
150
151
            router_track_active_blocks: track_active_blocks
                .unwrap_or(default.router_track_active_blocks),
152
153
154
            router_snapshot_threshold: router_snapshot_threshold
                .unwrap_or(default.router_snapshot_threshold),
            router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
155
156
157
158
        }
    }
}

159
160
161
// TODO: is there a way (macro) to auto-derive the KvIndexerInterface trait for this
// since both variants implement it
pub enum Indexer {
162
163
    /// Updates itself based on KV events emitted by backend workers.
    /// Has the ability to persist and snapshot states.
164
    KvIndexer(KvIndexer),
165
166
167

    /// Predicts the cached blocks based on requests on a TTL basis.
    /// Currently does not persist or snapshot states (WIP to enable that).
168
    ApproxKvIndexer(ApproxKvIndexer),
169
170
171
172

    /// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
    /// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
    None,
173
174
175
176
177
178
179
180
181
182
}

impl Indexer {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
            Indexer::ApproxKvIndexer(indexer) => indexer.find_matches(sequence).await,
183
184
185
186
            Indexer::None => Ok(OverlapScores {
                scores: HashMap::new(),
                frequencies: Vec::new(),
            }),
187
188
        }
    }
189
190
191
192
193

    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.dump_events().await,
            Indexer::ApproxKvIndexer(indexer) => indexer.dump_events().await,
194
195
196
197
198
            Indexer::None => {
                panic!(
                    "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
                );
            }
199
200
        }
    }
201
202
}

203
204
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
205
pub struct KvRouter {
206
207
208
    indexer: Indexer,

    // How about a Box<dyn KvIndexerInterface>
209
    scheduler: KvScheduler,
210

211
    block_size: u32,
212
213

    kv_router_config: KvRouterConfig,
Yan Ru Pei's avatar
Yan Ru Pei committed
214
215

    cancellation_token: tokio_util::sync::CancellationToken,
216
217
218
219
}

impl KvRouter {
    pub async fn new(
220
        component: Component,
221
        block_size: u32,
222
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
223
        kv_router_config: Option<KvRouterConfig>,
224
        consumer_uuid: String,
225
    ) -> Result<Self> {
226
        let kv_router_config = kv_router_config.unwrap_or_default();
227
        let cancellation_token = component.drt().primary_token();
228
229
230
231
232
233
234
235
236
        let generate_endpoint = component.endpoint("generate");
        let client = generate_endpoint.client().await?;

        let instances_rx = match client.instance_source.as_ref() {
            InstanceSource::Dynamic(rx) => rx.clone(),
            InstanceSource::Static => {
                panic!("Expected dynamic instance source for KV routing");
            }
        };
237

238
        // Create runtime config watcher using the generic etcd watcher
239
240
241
242
243
        // TODO: Migrate to discovery_client() once it exposes kv_get_and_watch_prefix functionality
        let etcd_client = component
            .drt()
            .etcd_client()
            .expect("Cannot KV route without etcd client");
244
245
246

        let runtime_configs_watcher = watch_prefix_with_extraction(
            etcd_client,
247
            &format!("{}/{}", model_card::ROOT_PATH, component.path()),
248
            key_extractors::lease_id,
249
            |card: ModelDeploymentCard| Some(card.runtime_config),
250
251
252
253
            cancellation_token.clone(),
        )
        .await?;
        let runtime_configs_rx = runtime_configs_watcher.receiver();
254

255
256
257
258
        let indexer = if kv_router_config.overlap_score_weight == 0.0 {
            // When overlap_score_weight is zero, we don't need to track prefixes
            Indexer::None
        } else if kv_router_config.use_kv_events {
259
260
261
262
263
264
            let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(&component);
            Indexer::KvIndexer(KvIndexer::new(
                cancellation_token.clone(),
                block_size,
                kv_indexer_metrics,
            ))
265
266
267
268
269
270
271
272
        } else {
            // hard code 120 seconds for now
            Indexer::ApproxKvIndexer(ApproxKvIndexer::new(
                cancellation_token.clone(),
                block_size,
                Duration::from_secs(120),
            ))
        };
273

274
        let scheduler = KvScheduler::start(
275
            component.clone(),
276
            block_size,
277
            instances_rx,
278
            runtime_configs_rx,
279
            selector,
280
            kv_router_config.router_replica_sync,
281
            consumer_uuid.clone(),
282
283
        )
        .await?;
284

285
        // Start unified background process if using KvIndexer
286
        if let Indexer::KvIndexer(ref kv_indexer) = indexer {
287
288
289
290
            start_kv_router_background(
                component.clone(),
                consumer_uuid,
                kv_indexer.event_sender(),
291
                kv_indexer.remove_worker_sender(),
292
293
294
                kv_router_config
                    .router_snapshot_threshold
                    .map(|_| kv_indexer.get_workers_sender()),
295
296
297
298
299
300
301
302
                kv_router_config
                    .router_snapshot_threshold
                    .map(|_| kv_indexer.snapshot_event_sender()),
                cancellation_token.clone(),
                kv_router_config.router_snapshot_threshold,
                kv_router_config.router_reset_states,
            )
            .await?;
303
        }
304

305
        tracing::info!("KV Routing initialized");
306
        Ok(Self {
307
            indexer,
308
            scheduler,
309
            block_size,
310
            kv_router_config,
Yan Ru Pei's avatar
Yan Ru Pei committed
311
            cancellation_token,
312
        })
313
314
    }

315
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
316
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
Yan Ru Pei's avatar
Yan Ru Pei committed
317
318
    /// Now also takes optional context_id for request tracking
    pub async fn find_best_match(
319
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
320
        context_id: Option<&str>,
321
        tokens: &[u32],
322
        router_config_override: Option<&RouterConfigOverride>,
323
        update_states: bool,
Yan Ru Pei's avatar
Yan Ru Pei committed
324
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
Yan Ru Pei's avatar
Yan Ru Pei committed
325
326
327
328
329
        // Validate that context_id is provided when update_states is true
        if update_states && context_id.is_none() {
            panic!("context_id must be provided if update_states is true");
        }

330
        let isl_tokens = tokens.len();
331

332
333
334
335
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let seq_hashes = compute_seq_hash_for_block(&block_hashes);

        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
336

337
338
339
340
341
342
343
344
345
346
347
348
349
        // Determine who needs seq_hashes
        let approx_indexer_needs_it = matches!(self.indexer, Indexer::ApproxKvIndexer(_));
        let scheduler_needs_it = self.kv_router_config.router_track_active_blocks;

        // Optimize cloning: only clone if both need it, otherwise move
        let (maybe_seq_hashes_1, maybe_seq_hashes_2) =
            match (approx_indexer_needs_it, scheduler_needs_it) {
                (true, true) => (Some(seq_hashes.clone()), Some(seq_hashes)),
                (true, false) => (Some(seq_hashes), None),
                (false, true) => (None, Some(seq_hashes)),
                (false, false) => (None, None),
            };

Yan Ru Pei's avatar
Yan Ru Pei committed
350
        let best_worker = self
351
            .scheduler
352
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
353
                context_id.map(|s| s.to_string()),
354
                isl_tokens,
355
                maybe_seq_hashes_2,
356
                overlap_scores.clone(),
357
                router_config_override,
358
                update_states,
359
            )
360
            .await?;
361

362
363
        if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
            indexer
Yan Ru Pei's avatar
Yan Ru Pei committed
364
                .process_routing_decision(best_worker, block_hashes, maybe_seq_hashes_1.unwrap())
365
366
367
368
                .await
                .unwrap();
        };

369
370
        let overlap_amount = overlap_scores
            .scores
Yan Ru Pei's avatar
Yan Ru Pei committed
371
            .get(&best_worker)
372
373
            .copied()
            .unwrap_or(0);
Yan Ru Pei's avatar
Yan Ru Pei committed
374
        Ok((best_worker, overlap_amount))
375
376
    }

377
378
379
380
381
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
        overlap_blocks: u32,
Yan Ru Pei's avatar
Yan Ru Pei committed
382
        worker: WorkerWithDpRank,
383
384
    ) {
        let isl_tokens = tokens.len();
385
386
387
388
389

        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
            compute_seq_hash_for_block(&block_hashes)
        });
390
391
392
393

        self.scheduler
            .add_request(
                request_id,
394
                maybe_seq_hashes,
395
396
                isl_tokens,
                overlap_blocks,
Yan Ru Pei's avatar
Yan Ru Pei committed
397
                worker,
398
399
400
401
            )
            .await;
    }

402
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<()> {
403
        self.scheduler.mark_prefill_completed(request_id).await
404
405
    }

406
    pub async fn free(&self, request_id: &str) -> Result<()> {
407
        self.scheduler.free(request_id).await
408
    }
409

410
    pub fn block_size(&self) -> u32 {
411
412
        self.block_size
    }
413

414
415
416
417
418
419
    /// Get potential prefill and decode loads for all workers
    pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
        let isl_tokens = tokens.len();
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;

420
421
422
423
424
        let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
            let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
            compute_seq_hash_for_block(&block_hashes)
        });

425
426
        Ok(self
            .scheduler
427
            .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
428
429
430
            .await)
    }

431
432
433
434
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
435
436
}

Michael Feil's avatar
Michael Feil committed
437
438
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
439
440
441
442
443
444
445
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
446
447
448
449
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
            RouterRequest::New { tokens } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
450
                let (best_worker, overlap_blocks) = self
Yan Ru Pei's avatar
Yan Ru Pei committed
451
                    .find_best_match(Some(&context_id), &tokens, None, true)
Michael Feil's avatar
Michael Feil committed
452
453
454
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
455
456
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
457
458
459
                    overlap_blocks,
                }
            }
460
461
462
463
464
465
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
            RouterRequest::MarkFree => RouterResponse::FreeMarked {
                success: self.free(&context_id).await.is_ok(),
            },
Michael Feil's avatar
Michael Feil committed
466
        };
467
468
469
470
471
472

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
473
474

pub struct KvPushRouter {
475
    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
476
    pub chooser: Arc<KvRouter>,
477
478
479
480
}

impl KvPushRouter {
    pub fn new(
481
        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
482
483
484
485
486
487
488
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
}

#[async_trait]
489
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
490
491
    for KvPushRouter
{
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
    /// Generate method that handles KV-aware routing with three distinct behaviors:
    ///
    /// 1. **If `query_instance_id` annotation is set**:
    ///    - Returns the best matching worker ID without routing the request
    ///    - Does NOT update any router local states
    ///    - Response includes worker_instance_id and token_data annotations
    ///
    /// 2. **If `backend_instance_id` is set in the request**:
    ///    - Routes directly to the specified backend instance
    ///    - DOES update router states to track this request (unless query_instance_id is also set)
    ///    - Bypasses the normal KV matching logic
    ///
    /// 3. **If neither are set (default behavior)**:
    ///    - Finds the best worker based on KV cache overlap
    ///    - Updates router states to track the request
    ///    - Routes to the selected worker
    ///
    /// The router state updates include tracking active sequences and managing
    /// prefill/completion lifecycle for proper KV cache management.
511
512
    async fn generate(
        &self,
513
        request: SingleIn<PreprocessedRequest>,
514
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
515
        match self.inner.client.instance_source.as_ref() {
516
517
            InstanceSource::Static => self.inner.r#static(request).await,
            InstanceSource::Dynamic(_) => {
518
519
                // Extract context ID for request tracking
                let context_id = request.context().id().to_string();
520
521
522
523

                // Check if this is a query_instance_id request first
                let query_instance_id = request.has_annotation("query_instance_id");

Yan Ru Pei's avatar
Yan Ru Pei committed
524
525
526
527
528
529
530
531
532
                let (instance_id, dp_rank, overlap_amount) = if let Some(id) =
                    request.backend_instance_id
                {
                    // If instance_id is set, use it and compute actual overlap
                    let dp_rank = request.dp_rank.unwrap_or(0);
                    if query_instance_id {
                        tracing::debug!(
                            "backend_instance_id is set, routing to instance {id} with dp_rank {dp_rank} and ignoring query_instance_id annotation"
                        );
533
                    }
Yan Ru Pei's avatar
Yan Ru Pei committed
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550

                    // Compute actual overlap blocks by querying the indexer
                    let block_hashes =
                        compute_block_hash_for_seq(&request.token_ids, self.chooser.block_size());
                    let overlap_scores = self.chooser.indexer.find_matches(block_hashes).await?;
                    let worker = WorkerWithDpRank::new(id, dp_rank);
                    let overlap_blocks = overlap_scores.scores.get(&worker).copied().unwrap_or(0);

                    self.chooser
                        .add_request(
                            context_id.clone(),
                            &request.token_ids,
                            overlap_blocks,
                            worker,
                        )
                        .await;
                    (id, dp_rank, overlap_blocks)
551
552
                } else {
                    // Otherwise, find the best match
Yan Ru Pei's avatar
Yan Ru Pei committed
553
554
                    let (best_worker, overlap_amount) = self
                        .chooser
555
                        .find_best_match(
Yan Ru Pei's avatar
Yan Ru Pei committed
556
                            Some(&context_id),
557
558
                            &request.token_ids,
                            request.router_config_override.as_ref(),
559
                            !query_instance_id, // Don't update states if query_instance_id
560
                        )
Yan Ru Pei's avatar
Yan Ru Pei committed
561
562
                        .await?;
                    (best_worker.worker_id, best_worker.dp_rank, overlap_amount)
563
564
                };

565
566
567
                // if request has the annotation "query_instance_id",
                // then the request will not be routed to the worker,
                // and instead the worker_instance_id will be returned.
568
569
570
571
572
                let stream_context = request.context().clone();
                if query_instance_id {
                    let instance_id_str = instance_id.to_string();
                    let response =
                        Annotated::from_annotation("worker_instance_id", &instance_id_str)?;
573
574
575
576
577
578
579
580
581

                    // Return the tokens in nvext.token_data format
                    let response_tokens =
                        Annotated::from_annotation("token_data", &request.token_ids)?;
                    tracing::trace!(
                        "Tokens requested in the response through the query_instance_id annotation: {:?}",
                        response_tokens
                    );
                    let stream = stream::iter(vec![response, response_tokens]);
582
583
                    return Ok(ResponseStream::new(Box::pin(stream), stream_context));
                }
584
585
                let (mut backend_input, context) = request.into_parts();
                backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
Yan Ru Pei's avatar
Yan Ru Pei committed
586
                backend_input.dp_rank = Some(dp_rank);
587
                let updated_request = context.map(|_| backend_input);
588

589
                let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
590
591
                let stream_context = response_stream.context();
                let chooser = self.chooser.clone();
592
                let context_for_monitoring = stream_context.clone();
593
594

                let wrapped_stream = Box::pin(async_stream::stream! {
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
                    let mut prefill_marked = false;

                    loop {
                        tokio::select! {
                            biased;

                            _ = context_for_monitoring.stopped() => {
                                tracing::debug!("Request {context_id} cancelled, ending stream");
                                break;
                            }

                            item = response_stream.next() => {
                                let Some(item) = item else {
                                    break;
                                };

                                if !prefill_marked {
                                    if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
                                        tracing::warn!("Failed to mark prefill completed for request {context_id}: {e:?}");
                                    }
                                    prefill_marked = true;
                                }
                                yield item;
                            }
619
                        }
620
621
                    }

622
623
624
                    if let Err(e) = chooser.free(&context_id).await {
                        tracing::warn!("Failed to free request {context_id}: {e:?}");
                    }
625
626
                });
                Ok(ResponseStream::new(wrapped_stream, stream_context))
627
628
629
630
            }
        }
    }
}
Yan Ru Pei's avatar
Yan Ru Pei committed
631
632
633
634
635
636
637

impl Drop for KvRouter {
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}