kv_router.rs 19.9 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use std::sync::Arc;
5
use std::time::Instant;
6

7
use anyhow::Result;
8
use dynamo_kv_router::{
9
    PrefillLoadEstimator,
10
    config::{KvRouterConfig, RouterConfigOverride, min_initial_workers_from_env},
11
    indexer::KvRouterError,
12
13
    protocols::KV_EVENT_SUBJECT,
    protocols::{
14
15
        BlockExtraInfo, BlockHashOptions, DpRank, PrefillLoadHint, RouterEvent, RouterRequest,
        RouterResponse, TokensWithHashes, WorkerId, WorkerWithDpRank, compute_block_hash_for_seq,
16
17
    },
};
18
use dynamo_runtime::{
19
    component::{Client, Endpoint},
20
    discovery::DiscoveryQuery,
21
    pipeline::{
22
23
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, ResponseStream, SingleIn,
        async_trait,
24
    },
25
    protocols::EndpointId,
26
    protocols::annotated::Annotated,
27
    traits::DistributedRuntimeProvider,
28
};
29
use futures::stream;
30
use tracing::Instrument;
31
use validator::Validate;
32

33
pub mod indexer;
34
pub mod metrics;
35
pub mod prefill_router;
36
pub mod publisher;
37
pub mod push_router;
38
pub mod scheduler;
39
pub mod sequence;
40

41
pub use indexer::{Indexer, ServedIndexerHandle, ServedIndexerMode, ensure_served_indexer_service};
42
pub use prefill_router::PrefillRouter;
43
pub use push_router::{DirectRoutingRouter, KvPushRouter};
44

45
use crate::{
46
    discovery::RuntimeConfigWatch,
47
    kv_router::{
48
        scheduler::{DefaultWorkerSelector, KvScheduler, PotentialLoad},
49
        sequence::{SequenceError, SequenceRequest},
50
    },
51
    local_model::runtime_config::ModelRuntimeConfig,
52
53
};

54
55
use std::collections::HashSet;

56
57
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
58
59
60
61
62
63
64
65
66
67

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
68

69
70
71
72
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

73
74
75
// for worker-local kvindexer query
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer

76
77
78
79
80
81
/// Generates a dp_rank-specific endpoint name for the worker KV indexer query service.
/// Each dp_rank has its own LocalKvIndexer and query endpoint to ensure per-dp_rank monotonicity.
pub fn worker_kv_indexer_query_endpoint(dp_rank: DpRank) -> String {
    format!("worker_kv_indexer_query_dp{dp_rank}")
}

82
// for router discovery registration
83
pub const KV_ROUTER_ENDPOINT: &str = "router-discovery";
84
85

/// Creates an EndpointId for the KV router in the given namespace.
86
pub fn router_endpoint_id(namespace: String, component: String) -> EndpointId {
87
88
    EndpointId {
        namespace,
89
        component,
90
91
92
93
94
        name: KV_ROUTER_ENDPOINT.to_string(),
    }
}

/// Creates a DiscoveryQuery for the KV router in the given namespace.
95
pub fn router_discovery_query(namespace: String, component: String) -> DiscoveryQuery {
96
97
    DiscoveryQuery::Endpoint {
        namespace,
98
        component,
99
100
101
102
        endpoint: KV_ROUTER_ENDPOINT.to_string(),
    }
}

103
104
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
105
106
107
108
pub struct KvRouter<Sel = DefaultWorkerSelector>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>,
{
109
    indexer: Indexer,
110
    scheduler: KvScheduler<Sel>,
111
    block_size: u32,
112
    kv_router_config: KvRouterConfig,
113
    prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
Yan Ru Pei's avatar
Yan Ru Pei committed
114
    cancellation_token: tokio_util::sync::CancellationToken,
115
    client: Client,
116
    is_eagle: bool,
117
    _served_indexer_handle: Option<ServedIndexerHandle>,
118
119
}

120
121
122
123
impl<Sel> KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
{
124
    #[allow(clippy::too_many_arguments)]
125
    pub async fn new(
126
127
        endpoint: Endpoint,
        client: Client,
128
        workers_with_configs: RuntimeConfigWatch,
129
        block_size: u32,
130
        selector: Sel,
131
        kv_router_config: Option<KvRouterConfig>,
132
        prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
133
        worker_type: &'static str,
134
        model_name: Option<String>,
135
        is_eagle: bool,
136
    ) -> Result<Self> {
137
        let kv_router_config = kv_router_config.unwrap_or_default();
138
        kv_router_config.validate()?;
139
        let component = endpoint.component();
140
        let cancellation_token = component.drt().primary_token();
141
        let min_initial_workers = min_initial_workers_from_env()?;
142

143
144
145
146
147
148
149
        let indexer = Indexer::new(
            component,
            &kv_router_config,
            block_size,
            model_name.as_deref(),
        )
        .await?;
150

151
152
153
154
        if min_initial_workers > 0 && !kv_router_config.skip_initial_worker_wait {
            let mut startup_watch = workers_with_configs.clone();
            let _ = startup_watch
                .wait_for(|m| m.len() >= min_initial_workers)
155
156
                .await
                .map_err(|_| {
157
158
                    anyhow::anyhow!(
                        "runtime config watch closed before {} workers appeared",
159
                        min_initial_workers
160
                    )
161
162
                })?;
        }
163

164
        let scheduler = KvScheduler::start(
165
            component.clone(),
166
            block_size,
167
            workers_with_configs.clone(),
168
            selector,
169
            &kv_router_config,
170
            prefill_load_estimator.clone(),
171
            worker_type,
172
173
        )
        .await?;
174

175
176
        // Start KV event subscription if needed — skip when using a remote indexer.
        if kv_router_config.use_remote_indexer {
177
178
            tracing::info!("Skipping KV event subscription (using remote indexer)");
        } else if kv_router_config.should_subscribe_to_kv_events() {
179
            indexer::start_subscriber(component.clone(), &kv_router_config, indexer.clone())
180
                .await?;
181
        } else {
182
            tracing::info!(
183
184
185
                "Skipping KV event subscription (use_kv_events={}, overlap_score_weight={})",
                kv_router_config.use_kv_events,
                kv_router_config.overlap_score_weight,
186
            );
187
        }
188

189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        let served_indexer_handle = if kv_router_config.serve_indexer {
            let model_name = model_name.clone().ok_or_else(|| {
                anyhow::anyhow!("model_name is required when serve_indexer is configured")
            })?;
            Some(
                ensure_served_indexer_service(
                    component.clone(),
                    ServedIndexerMode::from_use_kv_events(kv_router_config.use_kv_events),
                    model_name,
                    indexer.clone(),
                )
                .await?,
            )
        } else {
            None
        };

206
        tracing::info!("KV Routing initialized");
207
        Ok(Self {
208
            indexer,
209
            scheduler,
210
            block_size,
211
            kv_router_config,
212
            prefill_load_estimator,
Yan Ru Pei's avatar
Yan Ru Pei committed
213
            cancellation_token,
214
            client,
215
            is_eagle,
216
            _served_indexer_handle: served_indexer_handle,
217
        })
218
219
    }

220
221
222
223
224
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

225
226
227
228
229
230
231
232
    pub fn indexer(&self) -> &Indexer {
        &self.indexer
    }

    pub fn kv_router_config(&self) -> &KvRouterConfig {
        &self.kv_router_config
    }

233
234
235
236
    pub fn is_eagle(&self) -> bool {
        self.is_eagle
    }

237
238
    pub async fn record_routing_decision(
        &self,
239
        mut tokens_with_hashes: TokensWithHashes,
240
241
242
243
244
245
246
        worker: WorkerWithDpRank,
    ) -> Result<(), KvRouterError> {
        self.indexer
            .process_routing_decision_for_request(&mut tokens_with_hashes, worker)
            .await
    }

247
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
248
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
249
250
251
    /// Now also takes optional context_id for request tracking.
    ///
    /// When `allowed_worker_ids` is Some, only workers in that set are considered for selection.
252
    #[allow(clippy::too_many_arguments)]
Yan Ru Pei's avatar
Yan Ru Pei committed
253
    pub async fn find_best_match(
254
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
255
        context_id: Option<&str>,
256
        tokens: &[u32],
257
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
258
        router_config_override: Option<&RouterConfigOverride>,
259
        update_states: bool,
260
        lora_name: Option<String>,
261
        priority_jump: f64,
262
        expected_output_tokens: Option<u32>,
263
        allowed_worker_ids: Option<HashSet<WorkerId>>,
Yan Ru Pei's avatar
Yan Ru Pei committed
264
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
265
266
        let start = Instant::now();

Yan Ru Pei's avatar
Yan Ru Pei committed
267
        if update_states && context_id.is_none() {
268
            anyhow::bail!("context_id must be provided when update_states is true");
Yan Ru Pei's avatar
Yan Ru Pei committed
269
270
        }

271
        let isl_tokens = tokens.len();
272
273
274
275
276
277
278
279
280
281
282
283
        let hash_options = BlockHashOptions {
            block_mm_infos,
            lora_name: lora_name.as_deref(),
            is_eagle: Some(self.is_eagle),
        };

        let block_hashes = tracing::info_span!("kv_router.compute_block_hashes")
            .in_scope(|| compute_block_hash_for_seq(tokens, self.block_size, hash_options));
        let hash_elapsed = start.elapsed();
        // Compute seq_hashes only if scheduler needs it for active blocks tracking
        let maybe_seq_hashes = tracing::info_span!("kv_router.compute_seq_hashes").in_scope(|| {
            self.kv_router_config.compute_seq_hashes_for_tracking(
284
285
                tokens,
                self.block_size,
286
287
288
                router_config_override,
                hash_options,
                Some(&block_hashes),
289
290
            )
        });
291
        let seq_hash_elapsed = start.elapsed();
292

293
        let overlap_scores = self
294
295
296
297
            .indexer
            .find_matches(block_hashes)
            .instrument(tracing::info_span!("kv_router.find_matches"))
            .await?;
298
        let find_matches_elapsed = start.elapsed();
299

300
        let response = self
301
            .scheduler
302
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
303
                context_id.map(|s| s.to_string()),
304
                isl_tokens,
305
                maybe_seq_hashes,
306
                overlap_scores,
307
                router_config_override,
308
                update_states,
309
                lora_name,
310
                priority_jump,
311
                expected_output_tokens,
312
                allowed_worker_ids,
313
            )
314
            .instrument(tracing::info_span!("kv_router.schedule"))
315
            .await?;
316
317
        let total_elapsed = start.elapsed();

318
319
320
321
        if let Some(m) = metrics::RoutingOverheadMetrics::get() {
            m.observe(
                hash_elapsed,
                seq_hash_elapsed,
322
                find_matches_elapsed,
323
324
325
                total_elapsed,
            );
        }
326

327
        #[cfg(feature = "bench")]
328
329
330
        tracing::info!(
            isl_tokens,
            hash_us = hash_elapsed.as_micros() as u64,
331
332
333
            seq_hash_us = (seq_hash_elapsed - hash_elapsed).as_micros() as u64,
            find_matches_us = (find_matches_elapsed - seq_hash_elapsed).as_micros() as u64,
            schedule_us = (total_elapsed - find_matches_elapsed).as_micros() as u64,
334
335
336
            total_us = total_elapsed.as_micros() as u64,
            "find_best_match completed"
        );
337

338
        Ok((response.best_worker, response.overlap_blocks))
339
340
    }

341
342
343
344
345
    /// Register externally-provided workers in the slot tracker.
    pub fn register_workers(&self, worker_ids: &HashSet<WorkerId>) {
        self.scheduler.register_workers(worker_ids);
    }

346
    #[allow(clippy::too_many_arguments)]
347
348
349
350
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
351
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
352
        overlap_blocks: u32,
353
        expected_output_tokens: Option<u32>,
Yan Ru Pei's avatar
Yan Ru Pei committed
354
        worker: WorkerWithDpRank,
355
        lora_name: Option<String>,
356
        router_config_override: Option<&RouterConfigOverride>,
357
358
    ) {
        let isl_tokens = tokens.len();
359
360
361
362
363
        let hash_options = BlockHashOptions {
            block_mm_infos,
            lora_name: lora_name.as_deref(),
            is_eagle: Some(self.is_eagle),
        };
364

365
366
367
368
        let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
            tokens,
            self.block_size,
            router_config_override,
369
370
            hash_options,
            None,
371
        );
372
373
374
        let track_prefill_tokens = self
            .kv_router_config
            .track_prefill_tokens(router_config_override);
375
376
        let prefill_load_hint =
            self.prefill_load_hint_for(isl_tokens, overlap_blocks, track_prefill_tokens);
377

378
379
        if let Err(e) = self
            .scheduler
380
381
382
383
384
            .add_request(SequenceRequest {
                request_id: request_id.clone(),
                token_sequence: maybe_seq_hashes,
                isl: isl_tokens,
                overlap: overlap_blocks,
385
                track_prefill_tokens,
386
                expected_output_tokens,
387
                prefill_load_hint,
Yan Ru Pei's avatar
Yan Ru Pei committed
388
                worker,
389
                lora_name,
390
            })
391
392
393
394
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
395
396
    }

397
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
398
        self.scheduler.mark_prefill_completed(request_id).await
399
400
    }

401
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
402
        self.scheduler.free(request_id).await
403
    }
404

405
406
407
408
409
    /// Number of requests currently parked in the scheduler queue.
    pub fn pending_count(&self) -> usize {
        self.scheduler.pending_count()
    }

410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
    fn prefill_load_hint_for(
        &self,
        isl_tokens: usize,
        overlap_blocks: u32,
        track_prefill_tokens: bool,
    ) -> Option<PrefillLoadHint> {
        if !track_prefill_tokens {
            return None;
        }

        let prefix = (overlap_blocks as usize) * (self.block_size as usize);
        let effective_isl = isl_tokens.saturating_sub(prefix);
        if effective_isl == 0 {
            return None;
        }

        let Some(estimator) = &self.prefill_load_estimator else {
            return None;
        };

        match estimator.predict_prefill_duration(1, effective_isl, prefix) {
            Ok(expected_prefill_duration) => Some(PrefillLoadHint {
                initial_effective_prefill_tokens: effective_isl,
                expected_prefill_duration: Some(expected_prefill_duration),
            }),
            Err(error) => {
                tracing::warn!(
                    effective_isl,
                    prefix,
                    "failed to predict prefill duration for direct add_request path: {error}"
                );
                None
            }
        }
    }

446
447
448
449
450
451
    /// Get the worker type for this router ("prefill" or "decode").
    /// Used for Prometheus metric labeling.
    pub fn worker_type(&self) -> &'static str {
        self.scheduler.worker_type()
    }

452
    pub fn add_output_block(
453
454
455
456
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
457
        self.scheduler.add_output_block(request_id, decay_fraction)
458
459
    }

460
    pub fn block_size(&self) -> u32 {
461
462
        self.block_size
    }
463

464
465
466
467
468
    /// Compute the overlap blocks for a given token sequence and worker.
    /// This queries the indexer to find how many blocks are already cached.
    pub async fn get_overlap_blocks(
        &self,
        tokens: &[u32],
469
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
470
        worker: WorkerWithDpRank,
471
        lora_name: Option<&str>,
472
    ) -> Result<u32, KvRouterError> {
473
474
475
476
477
478
479
480
481
        let block_hashes = compute_block_hash_for_seq(
            tokens,
            self.block_size,
            BlockHashOptions {
                block_mm_infos,
                lora_name,
                is_eagle: Some(self.is_eagle),
            },
        );
482
483
484
485
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;
        Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0))
    }

486
    /// Get potential prefill and decode loads for all workers
487
488
489
490
    pub async fn get_potential_loads(
        &self,
        tokens: &[u32],
        router_config_override: Option<&RouterConfigOverride>,
491
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
492
        lora_name: Option<&str>,
493
    ) -> Result<Vec<PotentialLoad>> {
494
        let isl_tokens = tokens.len();
495
496
497
498
499
500
        let hash_options = BlockHashOptions {
            block_mm_infos,
            lora_name,
            is_eagle: Some(self.is_eagle),
        };
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, hash_options);
501

502
503
504
505
        let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
            tokens,
            self.block_size,
            router_config_override,
506
507
            hash_options,
            Some(&block_hashes),
508
        );
509
510
511
        let track_prefill_tokens = self
            .kv_router_config
            .track_prefill_tokens(router_config_override);
512
513
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;

514
515
516
517
518
519
        Ok(self.scheduler.get_potential_loads(
            maybe_seq_hashes,
            isl_tokens,
            overlap_scores,
            track_prefill_tokens,
        ))
520
521
    }

522
523
524
525
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
526
527
}

Michael Feil's avatar
Michael Feil committed
528
529
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
530
#[async_trait]
531
532
533
534
535
impl<Sel> AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error>
    for KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
{
536
537
538
539
540
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
541
542
543
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
544
545
546
547
            RouterRequest::New {
                tokens,
                block_mm_infos,
            } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
548
                let (best_worker, overlap_blocks) = self
549
550
551
552
553
554
555
556
                    .find_best_match(
                        Some(&context_id),
                        &tokens,
                        block_mm_infos.as_deref(),
                        None,
                        true,
                        None,
                        0.0,
557
                        None,
558
                        None,
559
                    )
Michael Feil's avatar
Michael Feil committed
560
561
562
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
563
564
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
565
566
567
                    overlap_blocks,
                }
            }
568
569
570
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
571
572
573
574
575
576
577
578
579
            RouterRequest::MarkFree { request_id } => {
                let request_id = match request_id.as_deref() {
                    Some(request_id) if !request_id.trim().is_empty() => request_id,
                    _ => &context_id,
                };
                RouterResponse::FreeMarked {
                    success: self.free(request_id).await.is_ok(),
                }
            }
Michael Feil's avatar
Michael Feil committed
580
        };
581
582
583
584
585
586

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
587

588
589
590
591
impl<Sel> Drop for KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>,
{
Yan Ru Pei's avatar
Yan Ru Pei committed
592
593
594
595
596
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}