kv_router.rs 21 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use std::sync::Arc;
5
use std::time::Instant;
6

7
use anyhow::Result;
8
use dynamo_kv_router::{
9
    PrefillLoadEstimator,
10
    config::{KvRouterConfig, RouterConfigOverride, min_initial_workers_from_env},
11
    indexer::KvRouterError,
12
13
    protocols::KV_EVENT_SUBJECT,
    protocols::{
14
15
16
        BlockExtraInfo, BlockHashOptions, DpRank, LocalBlockHash, PrefillLoadHint, RouterEvent,
        RouterRequest, RouterResponse, TokensWithHashes, WorkerId, WorkerWithDpRank,
        compute_block_hash_for_seq,
17
18
    },
};
19
use dynamo_runtime::{
20
    component::{Client, Endpoint},
21
    discovery::DiscoveryQuery,
22
    pipeline::{
23
24
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, ResponseStream, SingleIn,
        async_trait,
25
    },
26
    protocols::EndpointId,
27
    protocols::annotated::Annotated,
28
    traits::DistributedRuntimeProvider,
29
};
30
use futures::stream;
31
use tracing::Instrument;
32
use validator::Validate;
33

34
pub mod indexer;
35
pub mod metrics;
36
pub mod prefill_router;
37
pub mod publisher;
38
pub mod push_router;
39
pub mod scheduler;
40
pub mod sequence;
41

42
pub use indexer::{Indexer, ServedIndexerHandle, ServedIndexerMode, ensure_served_indexer_service};
43
pub use prefill_router::PrefillRouter;
44
pub use push_router::{DirectRoutingRouter, KvPushRouter};
45

46
use crate::{
47
    discovery::RuntimeConfigWatch,
48
    kv_router::{
49
        scheduler::{DefaultWorkerSelector, KvScheduler, PotentialLoad},
50
        sequence::{SequenceError, SequenceRequest},
51
    },
52
    local_model::runtime_config::ModelRuntimeConfig,
53
54
};

55
56
use std::collections::HashSet;

57
58
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
59
60
61
62
63
64
65
66
67
68

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
69

70
71
72
73
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

74
75
76
// for worker-local kvindexer query
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer

77
78
79
80
81
82
/// Generates a dp_rank-specific endpoint name for the worker KV indexer query service.
/// Each dp_rank has its own LocalKvIndexer and query endpoint to ensure per-dp_rank monotonicity.
pub fn worker_kv_indexer_query_endpoint(dp_rank: DpRank) -> String {
    format!("worker_kv_indexer_query_dp{dp_rank}")
}

83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
fn log_routing_input_hashes(
    request_id: Option<&str>,
    block_size: u32,
    tokens: &[u32],
    local_hashes: &[LocalBlockHash],
) {
    if !tracing::enabled!(tracing::Level::DEBUG) {
        return;
    }

    let local_hash_ids: Vec<u64> = local_hashes.iter().map(|hash| hash.0).collect();

    tracing::debug!(
        request_id = request_id.unwrap_or(""),
        isl_tokens = tokens.len(),
        block_size,
        num_blocks = local_hashes.len(),
        local_hashes = ?local_hash_ids,
        "[ROUTING_INPUT] request local hashes"
    );
}

105
// for router discovery registration
106
pub const KV_ROUTER_ENDPOINT: &str = "router-discovery";
107
108

/// Creates an EndpointId for the KV router in the given namespace.
109
pub fn router_endpoint_id(namespace: String, component: String) -> EndpointId {
110
111
    EndpointId {
        namespace,
112
        component,
113
114
115
116
117
        name: KV_ROUTER_ENDPOINT.to_string(),
    }
}

/// Creates a DiscoveryQuery for the KV router in the given namespace.
118
pub fn router_discovery_query(namespace: String, component: String) -> DiscoveryQuery {
119
120
    DiscoveryQuery::Endpoint {
        namespace,
121
        component,
122
123
124
125
        endpoint: KV_ROUTER_ENDPOINT.to_string(),
    }
}

126
127
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
128
129
130
131
pub struct KvRouter<Sel = DefaultWorkerSelector>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>,
{
132
    indexer: Indexer,
133
    scheduler: KvScheduler<Sel>,
134
    workers_with_configs: RuntimeConfigWatch,
135
    block_size: u32,
136
    kv_router_config: KvRouterConfig,
137
    prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
Yan Ru Pei's avatar
Yan Ru Pei committed
138
    cancellation_token: tokio_util::sync::CancellationToken,
139
    client: Client,
140
    is_eagle: bool,
141
    _served_indexer_handle: Option<ServedIndexerHandle>,
142
143
}

144
145
146
147
impl<Sel> KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
{
148
    #[allow(clippy::too_many_arguments)]
149
    pub async fn new(
150
151
        endpoint: Endpoint,
        client: Client,
152
        workers_with_configs: RuntimeConfigWatch,
153
        block_size: u32,
154
        selector: Sel,
155
        kv_router_config: Option<KvRouterConfig>,
156
        prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
157
        worker_type: &'static str,
158
        model_name: Option<String>,
159
        is_eagle: bool,
160
    ) -> Result<Self> {
161
        let kv_router_config = kv_router_config.unwrap_or_default();
162
        kv_router_config.validate()?;
163
        let component = endpoint.component();
164
        let cancellation_token = component.drt().primary_token();
165
        let min_initial_workers = min_initial_workers_from_env()?;
166

167
168
169
170
171
172
173
        let indexer = Indexer::new(
            component,
            &kv_router_config,
            block_size,
            model_name.as_deref(),
        )
        .await?;
174

175
176
177
178
        if min_initial_workers > 0 && !kv_router_config.skip_initial_worker_wait {
            let mut startup_watch = workers_with_configs.clone();
            let _ = startup_watch
                .wait_for(|m| m.len() >= min_initial_workers)
179
180
                .await
                .map_err(|_| {
181
182
                    anyhow::anyhow!(
                        "runtime config watch closed before {} workers appeared",
183
                        min_initial_workers
184
                    )
185
186
                })?;
        }
187

188
        let scheduler = KvScheduler::start(
189
            component.clone(),
190
            block_size,
191
            workers_with_configs.clone(),
192
            selector,
193
            &kv_router_config,
194
            prefill_load_estimator.clone(),
195
            worker_type,
196
197
        )
        .await?;
198

199
200
        // Start KV event subscription if needed — skip when using a remote indexer.
        if kv_router_config.use_remote_indexer {
201
202
            tracing::info!("Skipping KV event subscription (using remote indexer)");
        } else if kv_router_config.should_subscribe_to_kv_events() {
203
            indexer::start_subscriber(component.clone(), &kv_router_config, indexer.clone())
204
                .await?;
205
        } else {
206
            tracing::info!(
207
208
209
                "Skipping KV event subscription (use_kv_events={}, overlap_score_weight={})",
                kv_router_config.use_kv_events,
                kv_router_config.overlap_score_weight,
210
            );
211
        }
212

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        let served_indexer_handle = if kv_router_config.serve_indexer {
            let model_name = model_name.clone().ok_or_else(|| {
                anyhow::anyhow!("model_name is required when serve_indexer is configured")
            })?;
            Some(
                ensure_served_indexer_service(
                    component.clone(),
                    ServedIndexerMode::from_use_kv_events(kv_router_config.use_kv_events),
                    model_name,
                    indexer.clone(),
                )
                .await?,
            )
        } else {
            None
        };

230
        tracing::info!("KV Routing initialized");
231
        Ok(Self {
232
            indexer,
233
            scheduler,
234
            workers_with_configs,
235
            block_size,
236
            kv_router_config,
237
            prefill_load_estimator,
Yan Ru Pei's avatar
Yan Ru Pei committed
238
            cancellation_token,
239
            client,
240
            is_eagle,
241
            _served_indexer_handle: served_indexer_handle,
242
        })
243
244
    }

245
246
247
248
249
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

250
251
252
253
254
255
256
257
    pub fn indexer(&self) -> &Indexer {
        &self.indexer
    }

    pub fn kv_router_config(&self) -> &KvRouterConfig {
        &self.kv_router_config
    }

258
259
260
261
    pub fn is_eagle(&self) -> bool {
        self.is_eagle
    }

262
263
    pub async fn record_routing_decision(
        &self,
264
        mut tokens_with_hashes: TokensWithHashes,
265
266
267
268
269
270
271
        worker: WorkerWithDpRank,
    ) -> Result<(), KvRouterError> {
        self.indexer
            .process_routing_decision_for_request(&mut tokens_with_hashes, worker)
            .await
    }

272
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
273
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
274
275
276
    /// Now also takes optional context_id for request tracking.
    ///
    /// When `allowed_worker_ids` is Some, only workers in that set are considered for selection.
277
    #[allow(clippy::too_many_arguments)]
Yan Ru Pei's avatar
Yan Ru Pei committed
278
    pub async fn find_best_match(
279
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
280
        context_id: Option<&str>,
281
        tokens: &[u32],
282
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
283
        router_config_override: Option<&RouterConfigOverride>,
284
        update_states: bool,
285
        lora_name: Option<String>,
286
        priority_jump: f64,
287
        expected_output_tokens: Option<u32>,
288
        allowed_worker_ids: Option<HashSet<WorkerId>>,
Yan Ru Pei's avatar
Yan Ru Pei committed
289
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
290
291
        let start = Instant::now();

Yan Ru Pei's avatar
Yan Ru Pei committed
292
        if update_states && context_id.is_none() {
293
            anyhow::bail!("context_id must be provided when update_states is true");
Yan Ru Pei's avatar
Yan Ru Pei committed
294
295
        }

296
        let isl_tokens = tokens.len();
297
298
299
300
301
302
303
304
        let hash_options = BlockHashOptions {
            block_mm_infos,
            lora_name: lora_name.as_deref(),
            is_eagle: Some(self.is_eagle),
        };

        let block_hashes = tracing::info_span!("kv_router.compute_block_hashes")
            .in_scope(|| compute_block_hash_for_seq(tokens, self.block_size, hash_options));
305
        log_routing_input_hashes(context_id, self.block_size, tokens, &block_hashes);
306
307
308
309
        let hash_elapsed = start.elapsed();
        // Compute seq_hashes only if scheduler needs it for active blocks tracking
        let maybe_seq_hashes = tracing::info_span!("kv_router.compute_seq_hashes").in_scope(|| {
            self.kv_router_config.compute_seq_hashes_for_tracking(
310
311
                tokens,
                self.block_size,
312
313
314
                router_config_override,
                hash_options,
                Some(&block_hashes),
315
316
            )
        });
317
        let seq_hash_elapsed = start.elapsed();
318

319
        let overlap_scores = self
320
321
322
323
            .indexer
            .find_matches(block_hashes)
            .instrument(tracing::info_span!("kv_router.find_matches"))
            .await?;
324
        let find_matches_elapsed = start.elapsed();
325

326
        let response = self
327
            .scheduler
328
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
329
                context_id.map(|s| s.to_string()),
330
                isl_tokens,
331
                maybe_seq_hashes,
332
                overlap_scores,
333
                router_config_override,
334
                update_states,
335
                lora_name,
336
                priority_jump,
337
                expected_output_tokens,
338
                allowed_worker_ids,
339
            )
340
            .instrument(tracing::info_span!("kv_router.schedule"))
341
            .await?;
342
343
        let total_elapsed = start.elapsed();

344
345
346
347
        if let Some(m) = metrics::RoutingOverheadMetrics::get() {
            m.observe(
                hash_elapsed,
                seq_hash_elapsed,
348
                find_matches_elapsed,
349
350
351
                total_elapsed,
            );
        }
352

353
        #[cfg(feature = "bench")]
354
355
356
        tracing::info!(
            isl_tokens,
            hash_us = hash_elapsed.as_micros() as u64,
357
358
359
            seq_hash_us = (seq_hash_elapsed - hash_elapsed).as_micros() as u64,
            find_matches_us = (find_matches_elapsed - seq_hash_elapsed).as_micros() as u64,
            schedule_us = (total_elapsed - find_matches_elapsed).as_micros() as u64,
360
361
362
            total_us = total_elapsed.as_micros() as u64,
            "find_best_match completed"
        );
363

364
        Ok((response.best_worker, response.overlap_blocks))
365
366
    }

367
368
369
370
371
    /// Register externally-provided workers in the slot tracker.
    pub fn register_workers(&self, worker_ids: &HashSet<WorkerId>) {
        self.scheduler.register_workers(worker_ids);
    }

372
    #[allow(clippy::too_many_arguments)]
373
374
375
376
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
377
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
378
        overlap_blocks: u32,
379
        expected_output_tokens: Option<u32>,
Yan Ru Pei's avatar
Yan Ru Pei committed
380
        worker: WorkerWithDpRank,
381
        lora_name: Option<String>,
382
        router_config_override: Option<&RouterConfigOverride>,
383
384
    ) {
        let isl_tokens = tokens.len();
385
386
387
388
389
        let hash_options = BlockHashOptions {
            block_mm_infos,
            lora_name: lora_name.as_deref(),
            is_eagle: Some(self.is_eagle),
        };
390

391
392
393
394
        let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
            tokens,
            self.block_size,
            router_config_override,
395
396
            hash_options,
            None,
397
        );
398
399
400
        let track_prefill_tokens = self
            .kv_router_config
            .track_prefill_tokens(router_config_override);
401
402
        let prefill_load_hint =
            self.prefill_load_hint_for(isl_tokens, overlap_blocks, track_prefill_tokens);
403

404
405
        if let Err(e) = self
            .scheduler
406
407
408
409
410
            .add_request(SequenceRequest {
                request_id: request_id.clone(),
                token_sequence: maybe_seq_hashes,
                isl: isl_tokens,
                overlap: overlap_blocks,
411
                track_prefill_tokens,
412
                expected_output_tokens,
413
                prefill_load_hint,
Yan Ru Pei's avatar
Yan Ru Pei committed
414
                worker,
415
                lora_name,
416
            })
417
418
419
420
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
421
422
    }

423
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
424
        self.scheduler.mark_prefill_completed(request_id).await
425
426
    }

427
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
428
        self.scheduler.free(request_id).await
429
    }
430

431
432
433
434
435
    /// Number of requests currently parked in the scheduler queue.
    pub fn pending_count(&self) -> usize {
        self.scheduler.pending_count()
    }

436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
    fn prefill_load_hint_for(
        &self,
        isl_tokens: usize,
        overlap_blocks: u32,
        track_prefill_tokens: bool,
    ) -> Option<PrefillLoadHint> {
        if !track_prefill_tokens {
            return None;
        }

        let prefix = (overlap_blocks as usize) * (self.block_size as usize);
        let effective_isl = isl_tokens.saturating_sub(prefix);
        if effective_isl == 0 {
            return None;
        }

        let Some(estimator) = &self.prefill_load_estimator else {
            return None;
        };

        match estimator.predict_prefill_duration(1, effective_isl, prefix) {
            Ok(expected_prefill_duration) => Some(PrefillLoadHint {
                initial_effective_prefill_tokens: effective_isl,
                expected_prefill_duration: Some(expected_prefill_duration),
            }),
            Err(error) => {
                tracing::warn!(
                    effective_isl,
                    prefix,
                    "failed to predict prefill duration for direct add_request path: {error}"
                );
                None
            }
        }
    }

472
473
474
475
476
477
    /// Get the worker type for this router ("prefill" or "decode").
    /// Used for Prometheus metric labeling.
    pub fn worker_type(&self) -> &'static str {
        self.scheduler.worker_type()
    }

478
479
480
481
482
483
484
    /// Return the worker's unique global DP rank when it owns exactly one rank.
    pub fn unique_dp_rank_for_worker(&self, worker_id: WorkerId) -> Option<u32> {
        let configs = self.workers_with_configs.borrow();
        let config = configs.get(&worker_id)?;
        (config.data_parallel_size == 1).then_some(config.data_parallel_start_rank)
    }

485
    pub fn add_output_block(
486
487
488
489
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
490
        self.scheduler.add_output_block(request_id, decay_fraction)
491
492
    }

493
    pub fn block_size(&self) -> u32 {
494
495
        self.block_size
    }
496

497
498
499
500
501
    /// Compute the overlap blocks for a given token sequence and worker.
    /// This queries the indexer to find how many blocks are already cached.
    pub async fn get_overlap_blocks(
        &self,
        tokens: &[u32],
502
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
503
        worker: WorkerWithDpRank,
504
        lora_name: Option<&str>,
505
    ) -> Result<u32, KvRouterError> {
506
507
508
509
510
511
512
513
514
        let block_hashes = compute_block_hash_for_seq(
            tokens,
            self.block_size,
            BlockHashOptions {
                block_mm_infos,
                lora_name,
                is_eagle: Some(self.is_eagle),
            },
        );
515
        log_routing_input_hashes(None, self.block_size, tokens, &block_hashes);
516
517
518
519
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;
        Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0))
    }

520
    /// Get potential prefill and decode loads for all workers
521
522
523
524
    pub async fn get_potential_loads(
        &self,
        tokens: &[u32],
        router_config_override: Option<&RouterConfigOverride>,
525
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
526
        lora_name: Option<&str>,
527
    ) -> Result<Vec<PotentialLoad>> {
528
        let isl_tokens = tokens.len();
529
530
531
532
533
534
        let hash_options = BlockHashOptions {
            block_mm_infos,
            lora_name,
            is_eagle: Some(self.is_eagle),
        };
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, hash_options);
535

536
537
538
539
        let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
            tokens,
            self.block_size,
            router_config_override,
540
541
            hash_options,
            Some(&block_hashes),
542
        );
543
544
545
        let track_prefill_tokens = self
            .kv_router_config
            .track_prefill_tokens(router_config_override);
546
547
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;

548
549
550
551
552
553
        Ok(self.scheduler.get_potential_loads(
            maybe_seq_hashes,
            isl_tokens,
            overlap_scores,
            track_prefill_tokens,
        ))
554
555
    }

556
557
558
559
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
560
561
}

Michael Feil's avatar
Michael Feil committed
562
563
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
564
#[async_trait]
565
566
567
568
569
impl<Sel> AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error>
    for KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
{
570
571
572
573
574
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
575
576
577
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
578
579
580
581
            RouterRequest::New {
                tokens,
                block_mm_infos,
            } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
582
                let (best_worker, overlap_blocks) = self
583
584
585
586
587
588
589
590
                    .find_best_match(
                        Some(&context_id),
                        &tokens,
                        block_mm_infos.as_deref(),
                        None,
                        true,
                        None,
                        0.0,
591
                        None,
592
                        None,
593
                    )
Michael Feil's avatar
Michael Feil committed
594
595
596
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
597
598
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
599
600
601
                    overlap_blocks,
                }
            }
602
603
604
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
605
606
607
608
609
610
611
612
613
            RouterRequest::MarkFree { request_id } => {
                let request_id = match request_id.as_deref() {
                    Some(request_id) if !request_id.trim().is_empty() => request_id,
                    _ => &context_id,
                };
                RouterResponse::FreeMarked {
                    success: self.free(request_id).await.is_ok(),
                }
            }
Michael Feil's avatar
Michael Feil committed
614
        };
615
616
617
618
619
620

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
621

622
623
624
625
impl<Sel> Drop for KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>,
{
Yan Ru Pei's avatar
Yan Ru Pei committed
626
627
628
629
630
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}