kv_router.rs 20.6 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use std::sync::Arc;
5
use std::time::Instant;
6

7
use anyhow::Result;
8
use dynamo_kv_router::{
9
    PrefillLoadEstimator,
10
    config::{KvRouterConfig, RouterConfigOverride, min_initial_workers_from_env},
11
    indexer::KvRouterError,
12
13
    protocols::KV_EVENT_SUBJECT,
    protocols::{
14
15
16
        BlockExtraInfo, BlockHashOptions, DpRank, LocalBlockHash, PrefillLoadHint, RouterEvent,
        RouterRequest, RouterResponse, TokensWithHashes, WorkerId, WorkerWithDpRank,
        compute_block_hash_for_seq,
17
18
    },
};
19
use dynamo_runtime::{
20
    component::{Client, Endpoint},
21
    discovery::DiscoveryQuery,
22
    pipeline::{
23
24
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, ResponseStream, SingleIn,
        async_trait,
25
    },
26
    protocols::EndpointId,
27
    protocols::annotated::Annotated,
28
    traits::DistributedRuntimeProvider,
29
};
30
use futures::stream;
31
use tracing::Instrument;
32
use validator::Validate;
33

34
pub mod indexer;
35
pub mod metrics;
36
pub mod prefill_router;
37
pub mod publisher;
38
pub mod push_router;
39
pub mod scheduler;
40
pub mod sequence;
41

42
pub use indexer::{Indexer, ServedIndexerHandle, ServedIndexerMode, ensure_served_indexer_service};
43
pub use prefill_router::PrefillRouter;
44
pub use push_router::{DirectRoutingRouter, KvPushRouter};
45

46
use crate::{
47
    discovery::RuntimeConfigWatch,
48
    kv_router::{
49
        scheduler::{DefaultWorkerSelector, KvScheduler, PotentialLoad},
50
        sequence::{SequenceError, SequenceRequest},
51
    },
52
    local_model::runtime_config::ModelRuntimeConfig,
53
54
};

55
56
use std::collections::HashSet;

57
58
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
59
60
61
62
63
64
65
66
67
68

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
69

70
71
72
73
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

74
75
76
// for worker-local kvindexer query
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer

77
78
79
80
81
82
/// Generates a dp_rank-specific endpoint name for the worker KV indexer query service.
/// Each dp_rank has its own LocalKvIndexer and query endpoint to ensure per-dp_rank monotonicity.
pub fn worker_kv_indexer_query_endpoint(dp_rank: DpRank) -> String {
    format!("worker_kv_indexer_query_dp{dp_rank}")
}

83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
fn log_routing_input_hashes(
    request_id: Option<&str>,
    block_size: u32,
    tokens: &[u32],
    local_hashes: &[LocalBlockHash],
) {
    if !tracing::enabled!(tracing::Level::DEBUG) {
        return;
    }

    let local_hash_ids: Vec<u64> = local_hashes.iter().map(|hash| hash.0).collect();

    tracing::debug!(
        request_id = request_id.unwrap_or(""),
        isl_tokens = tokens.len(),
        block_size,
        num_blocks = local_hashes.len(),
        local_hashes = ?local_hash_ids,
        "[ROUTING_INPUT] request local hashes"
    );
}

105
// for router discovery registration
106
pub const KV_ROUTER_ENDPOINT: &str = "router-discovery";
107
108

/// Creates an EndpointId for the KV router in the given namespace.
109
pub fn router_endpoint_id(namespace: String, component: String) -> EndpointId {
110
111
    EndpointId {
        namespace,
112
        component,
113
114
115
116
117
        name: KV_ROUTER_ENDPOINT.to_string(),
    }
}

/// Creates a DiscoveryQuery for the KV router in the given namespace.
118
pub fn router_discovery_query(namespace: String, component: String) -> DiscoveryQuery {
119
120
    DiscoveryQuery::Endpoint {
        namespace,
121
        component,
122
123
124
125
        endpoint: KV_ROUTER_ENDPOINT.to_string(),
    }
}

126
127
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
128
129
130
131
pub struct KvRouter<Sel = DefaultWorkerSelector>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>,
{
132
    indexer: Indexer,
133
    scheduler: KvScheduler<Sel>,
134
    block_size: u32,
135
    kv_router_config: KvRouterConfig,
136
    prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
Yan Ru Pei's avatar
Yan Ru Pei committed
137
    cancellation_token: tokio_util::sync::CancellationToken,
138
    client: Client,
139
    is_eagle: bool,
140
    _served_indexer_handle: Option<ServedIndexerHandle>,
141
142
}

143
144
145
146
impl<Sel> KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
{
147
    #[allow(clippy::too_many_arguments)]
148
    pub async fn new(
149
150
        endpoint: Endpoint,
        client: Client,
151
        workers_with_configs: RuntimeConfigWatch,
152
        block_size: u32,
153
        selector: Sel,
154
        kv_router_config: Option<KvRouterConfig>,
155
        prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
156
        worker_type: &'static str,
157
        model_name: Option<String>,
158
        is_eagle: bool,
159
    ) -> Result<Self> {
160
        let kv_router_config = kv_router_config.unwrap_or_default();
161
        kv_router_config.validate()?;
162
        let component = endpoint.component();
163
        let cancellation_token = component.drt().primary_token();
164
        let min_initial_workers = min_initial_workers_from_env()?;
165

166
167
168
169
170
171
172
        let indexer = Indexer::new(
            component,
            &kv_router_config,
            block_size,
            model_name.as_deref(),
        )
        .await?;
173

174
175
176
177
        if min_initial_workers > 0 && !kv_router_config.skip_initial_worker_wait {
            let mut startup_watch = workers_with_configs.clone();
            let _ = startup_watch
                .wait_for(|m| m.len() >= min_initial_workers)
178
179
                .await
                .map_err(|_| {
180
181
                    anyhow::anyhow!(
                        "runtime config watch closed before {} workers appeared",
182
                        min_initial_workers
183
                    )
184
185
                })?;
        }
186

187
        let scheduler = KvScheduler::start(
188
            component.clone(),
189
            block_size,
190
            workers_with_configs.clone(),
191
            selector,
192
            &kv_router_config,
193
            prefill_load_estimator.clone(),
194
            worker_type,
195
196
        )
        .await?;
197

198
199
        // Start KV event subscription if needed — skip when using a remote indexer.
        if kv_router_config.use_remote_indexer {
200
201
            tracing::info!("Skipping KV event subscription (using remote indexer)");
        } else if kv_router_config.should_subscribe_to_kv_events() {
202
            indexer::start_subscriber(component.clone(), &kv_router_config, indexer.clone())
203
                .await?;
204
        } else {
205
            tracing::info!(
206
207
208
                "Skipping KV event subscription (use_kv_events={}, overlap_score_weight={})",
                kv_router_config.use_kv_events,
                kv_router_config.overlap_score_weight,
209
            );
210
        }
211

212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
        let served_indexer_handle = if kv_router_config.serve_indexer {
            let model_name = model_name.clone().ok_or_else(|| {
                anyhow::anyhow!("model_name is required when serve_indexer is configured")
            })?;
            Some(
                ensure_served_indexer_service(
                    component.clone(),
                    ServedIndexerMode::from_use_kv_events(kv_router_config.use_kv_events),
                    model_name,
                    indexer.clone(),
                )
                .await?,
            )
        } else {
            None
        };

229
        tracing::info!("KV Routing initialized");
230
        Ok(Self {
231
            indexer,
232
            scheduler,
233
            block_size,
234
            kv_router_config,
235
            prefill_load_estimator,
Yan Ru Pei's avatar
Yan Ru Pei committed
236
            cancellation_token,
237
            client,
238
            is_eagle,
239
            _served_indexer_handle: served_indexer_handle,
240
        })
241
242
    }

243
244
245
246
247
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

248
249
250
251
252
253
254
255
    pub fn indexer(&self) -> &Indexer {
        &self.indexer
    }

    pub fn kv_router_config(&self) -> &KvRouterConfig {
        &self.kv_router_config
    }

256
257
258
259
    pub fn is_eagle(&self) -> bool {
        self.is_eagle
    }

260
261
    pub async fn record_routing_decision(
        &self,
262
        mut tokens_with_hashes: TokensWithHashes,
263
264
265
266
267
268
269
        worker: WorkerWithDpRank,
    ) -> Result<(), KvRouterError> {
        self.indexer
            .process_routing_decision_for_request(&mut tokens_with_hashes, worker)
            .await
    }

270
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
271
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
272
273
274
    /// Now also takes optional context_id for request tracking.
    ///
    /// When `allowed_worker_ids` is Some, only workers in that set are considered for selection.
275
    #[allow(clippy::too_many_arguments)]
Yan Ru Pei's avatar
Yan Ru Pei committed
276
    pub async fn find_best_match(
277
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
278
        context_id: Option<&str>,
279
        tokens: &[u32],
280
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
281
        router_config_override: Option<&RouterConfigOverride>,
282
        update_states: bool,
283
        lora_name: Option<String>,
284
        priority_jump: f64,
285
        expected_output_tokens: Option<u32>,
286
        allowed_worker_ids: Option<HashSet<WorkerId>>,
Yan Ru Pei's avatar
Yan Ru Pei committed
287
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
288
289
        let start = Instant::now();

Yan Ru Pei's avatar
Yan Ru Pei committed
290
        if update_states && context_id.is_none() {
291
            anyhow::bail!("context_id must be provided when update_states is true");
Yan Ru Pei's avatar
Yan Ru Pei committed
292
293
        }

294
        let isl_tokens = tokens.len();
295
296
297
298
299
300
301
302
        let hash_options = BlockHashOptions {
            block_mm_infos,
            lora_name: lora_name.as_deref(),
            is_eagle: Some(self.is_eagle),
        };

        let block_hashes = tracing::info_span!("kv_router.compute_block_hashes")
            .in_scope(|| compute_block_hash_for_seq(tokens, self.block_size, hash_options));
303
        log_routing_input_hashes(context_id, self.block_size, tokens, &block_hashes);
304
305
306
307
        let hash_elapsed = start.elapsed();
        // Compute seq_hashes only if scheduler needs it for active blocks tracking
        let maybe_seq_hashes = tracing::info_span!("kv_router.compute_seq_hashes").in_scope(|| {
            self.kv_router_config.compute_seq_hashes_for_tracking(
308
309
                tokens,
                self.block_size,
310
311
312
                router_config_override,
                hash_options,
                Some(&block_hashes),
313
314
            )
        });
315
        let seq_hash_elapsed = start.elapsed();
316

317
        let overlap_scores = self
318
319
320
321
            .indexer
            .find_matches(block_hashes)
            .instrument(tracing::info_span!("kv_router.find_matches"))
            .await?;
322
        let find_matches_elapsed = start.elapsed();
323

324
        let response = self
325
            .scheduler
326
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
327
                context_id.map(|s| s.to_string()),
328
                isl_tokens,
329
                maybe_seq_hashes,
330
                overlap_scores,
331
                router_config_override,
332
                update_states,
333
                lora_name,
334
                priority_jump,
335
                expected_output_tokens,
336
                allowed_worker_ids,
337
            )
338
            .instrument(tracing::info_span!("kv_router.schedule"))
339
            .await?;
340
341
        let total_elapsed = start.elapsed();

342
343
344
345
        if let Some(m) = metrics::RoutingOverheadMetrics::get() {
            m.observe(
                hash_elapsed,
                seq_hash_elapsed,
346
                find_matches_elapsed,
347
348
349
                total_elapsed,
            );
        }
350

351
        #[cfg(feature = "bench")]
352
353
354
        tracing::info!(
            isl_tokens,
            hash_us = hash_elapsed.as_micros() as u64,
355
356
357
            seq_hash_us = (seq_hash_elapsed - hash_elapsed).as_micros() as u64,
            find_matches_us = (find_matches_elapsed - seq_hash_elapsed).as_micros() as u64,
            schedule_us = (total_elapsed - find_matches_elapsed).as_micros() as u64,
358
359
360
            total_us = total_elapsed.as_micros() as u64,
            "find_best_match completed"
        );
361

362
        Ok((response.best_worker, response.overlap_blocks))
363
364
    }

365
366
367
368
369
    /// Register externally-provided workers in the slot tracker.
    pub fn register_workers(&self, worker_ids: &HashSet<WorkerId>) {
        self.scheduler.register_workers(worker_ids);
    }

370
    #[allow(clippy::too_many_arguments)]
371
372
373
374
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
375
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
376
        overlap_blocks: u32,
377
        expected_output_tokens: Option<u32>,
Yan Ru Pei's avatar
Yan Ru Pei committed
378
        worker: WorkerWithDpRank,
379
        lora_name: Option<String>,
380
        router_config_override: Option<&RouterConfigOverride>,
381
382
    ) {
        let isl_tokens = tokens.len();
383
384
385
386
387
        let hash_options = BlockHashOptions {
            block_mm_infos,
            lora_name: lora_name.as_deref(),
            is_eagle: Some(self.is_eagle),
        };
388

389
390
391
392
        let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
            tokens,
            self.block_size,
            router_config_override,
393
394
            hash_options,
            None,
395
        );
396
397
398
        let track_prefill_tokens = self
            .kv_router_config
            .track_prefill_tokens(router_config_override);
399
400
        let prefill_load_hint =
            self.prefill_load_hint_for(isl_tokens, overlap_blocks, track_prefill_tokens);
401

402
403
        if let Err(e) = self
            .scheduler
404
405
406
407
408
            .add_request(SequenceRequest {
                request_id: request_id.clone(),
                token_sequence: maybe_seq_hashes,
                isl: isl_tokens,
                overlap: overlap_blocks,
409
                track_prefill_tokens,
410
                expected_output_tokens,
411
                prefill_load_hint,
Yan Ru Pei's avatar
Yan Ru Pei committed
412
                worker,
413
                lora_name,
414
            })
415
416
417
418
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
419
420
    }

421
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
422
        self.scheduler.mark_prefill_completed(request_id).await
423
424
    }

425
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
426
        self.scheduler.free(request_id).await
427
    }
428

429
430
431
432
433
    /// Number of requests currently parked in the scheduler queue.
    pub fn pending_count(&self) -> usize {
        self.scheduler.pending_count()
    }

434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
    fn prefill_load_hint_for(
        &self,
        isl_tokens: usize,
        overlap_blocks: u32,
        track_prefill_tokens: bool,
    ) -> Option<PrefillLoadHint> {
        if !track_prefill_tokens {
            return None;
        }

        let prefix = (overlap_blocks as usize) * (self.block_size as usize);
        let effective_isl = isl_tokens.saturating_sub(prefix);
        if effective_isl == 0 {
            return None;
        }

        let Some(estimator) = &self.prefill_load_estimator else {
            return None;
        };

        match estimator.predict_prefill_duration(1, effective_isl, prefix) {
            Ok(expected_prefill_duration) => Some(PrefillLoadHint {
                initial_effective_prefill_tokens: effective_isl,
                expected_prefill_duration: Some(expected_prefill_duration),
            }),
            Err(error) => {
                tracing::warn!(
                    effective_isl,
                    prefix,
                    "failed to predict prefill duration for direct add_request path: {error}"
                );
                None
            }
        }
    }

470
471
472
473
474
475
    /// Get the worker type for this router ("prefill" or "decode").
    /// Used for Prometheus metric labeling.
    pub fn worker_type(&self) -> &'static str {
        self.scheduler.worker_type()
    }

476
    pub fn add_output_block(
477
478
479
480
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
481
        self.scheduler.add_output_block(request_id, decay_fraction)
482
483
    }

484
    pub fn block_size(&self) -> u32 {
485
486
        self.block_size
    }
487

488
489
490
491
492
    /// Compute the overlap blocks for a given token sequence and worker.
    /// This queries the indexer to find how many blocks are already cached.
    pub async fn get_overlap_blocks(
        &self,
        tokens: &[u32],
493
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
494
        worker: WorkerWithDpRank,
495
        lora_name: Option<&str>,
496
    ) -> Result<u32, KvRouterError> {
497
498
499
500
501
502
503
504
505
        let block_hashes = compute_block_hash_for_seq(
            tokens,
            self.block_size,
            BlockHashOptions {
                block_mm_infos,
                lora_name,
                is_eagle: Some(self.is_eagle),
            },
        );
506
        log_routing_input_hashes(None, self.block_size, tokens, &block_hashes);
507
508
509
510
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;
        Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0))
    }

511
    /// Get potential prefill and decode loads for all workers
512
513
514
515
    pub async fn get_potential_loads(
        &self,
        tokens: &[u32],
        router_config_override: Option<&RouterConfigOverride>,
516
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
517
        lora_name: Option<&str>,
518
    ) -> Result<Vec<PotentialLoad>> {
519
        let isl_tokens = tokens.len();
520
521
522
523
524
525
        let hash_options = BlockHashOptions {
            block_mm_infos,
            lora_name,
            is_eagle: Some(self.is_eagle),
        };
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, hash_options);
526

527
528
529
530
        let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
            tokens,
            self.block_size,
            router_config_override,
531
532
            hash_options,
            Some(&block_hashes),
533
        );
534
535
536
        let track_prefill_tokens = self
            .kv_router_config
            .track_prefill_tokens(router_config_override);
537
538
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;

539
540
541
542
543
544
        Ok(self.scheduler.get_potential_loads(
            maybe_seq_hashes,
            isl_tokens,
            overlap_scores,
            track_prefill_tokens,
        ))
545
546
    }

547
548
549
550
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
551
552
}

Michael Feil's avatar
Michael Feil committed
553
554
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
555
#[async_trait]
556
557
558
559
560
impl<Sel> AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error>
    for KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
{
561
562
563
564
565
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
566
567
568
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
569
570
571
572
            RouterRequest::New {
                tokens,
                block_mm_infos,
            } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
573
                let (best_worker, overlap_blocks) = self
574
575
576
577
578
579
580
581
                    .find_best_match(
                        Some(&context_id),
                        &tokens,
                        block_mm_infos.as_deref(),
                        None,
                        true,
                        None,
                        0.0,
582
                        None,
583
                        None,
584
                    )
Michael Feil's avatar
Michael Feil committed
585
586
587
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
588
589
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
590
591
592
                    overlap_blocks,
                }
            }
593
594
595
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
596
597
598
599
600
601
602
603
604
            RouterRequest::MarkFree { request_id } => {
                let request_id = match request_id.as_deref() {
                    Some(request_id) if !request_id.trim().is_empty() => request_id,
                    _ => &context_id,
                };
                RouterResponse::FreeMarked {
                    success: self.free(request_id).await.is_ok(),
                }
            }
Michael Feil's avatar
Michael Feil committed
605
        };
606
607
608
609
610
611

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
612

613
614
615
616
impl<Sel> Drop for KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>,
{
Yan Ru Pei's avatar
Yan Ru Pei committed
617
618
619
620
621
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}