kv_router.rs 19.3 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
Yan Ru Pei's avatar
Yan Ru Pei committed
5
use std::sync::Arc;
6
use std::time::{Duration, Instant};
7

8
use anyhow::Result;
Yan Ru Pei's avatar
Yan Ru Pei committed
9
use dynamo_kv_router::{ConcurrentRadixTree, ThreadPoolIndexer};
10
use dynamo_runtime::{
11
    component::{Client, Endpoint},
12
    discovery::DiscoveryQuery,
13
    pipeline::{
14
15
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, ResponseStream, SingleIn,
        async_trait,
16
    },
17
    protocols::EndpointId,
18
    protocols::annotated::Annotated,
19
    traits::DistributedRuntimeProvider,
20
};
21
use futures::stream;
Yan Ru Pei's avatar
Yan Ru Pei committed
22
use tokio::sync::oneshot;
23
use validator::Validate;
24

25
26
27
28
29
// Re-export from dynamo-kv-router crate
pub use dynamo_kv_router::approx;
pub use dynamo_kv_router::indexer;
pub use dynamo_kv_router::protocols;

30
pub mod config;
31
pub mod metrics;
32
pub mod prefill_router;
33
pub mod publisher;
34
pub mod push_router;
35
pub mod recorder;
36
pub mod scheduler;
37
pub mod sequence;
38
pub mod subscriber;
39
pub mod worker_query;
40

41
pub use config::{KvRouterConfig, RouterConfigOverride};
42
pub use prefill_router::PrefillRouter;
43
pub use push_router::KvPushRouter;
44

45
use crate::{
46
    discovery::RuntimeConfigWatch,
47
    kv_router::{
48
        approx::PruneConfig,
Yan Ru Pei's avatar
Yan Ru Pei committed
49
        indexer::{GetWorkersRequest, KvIndexer, KvIndexerInterface, KvRouterError},
Yan Ru Pei's avatar
Yan Ru Pei committed
50
        protocols::{
51
            DpRank, LocalBlockHash, OverlapScores, RouterEvent, RouterRequest, RouterResponse,
Yan Ru Pei's avatar
Yan Ru Pei committed
52
53
            TokensWithHashes, WorkerId, WorkerSelectionResult, WorkerWithDpRank,
            compute_block_hash_for_seq,
Yan Ru Pei's avatar
Yan Ru Pei committed
54
        },
55
        scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
56
        sequence::SequenceError,
57
    },
58
    local_model::runtime_config::ModelRuntimeConfig,
59
60
};

61
62
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
63
64
65
66
67

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
68
pub const KV_EVENT_SUBJECT: &str = "kv-events";
69
70
71
72
73
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
74

75
76
77
78
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

79
80
81
// for worker-local kvindexer query
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer

82
83
84
85
86
87
/// Generates a dp_rank-specific endpoint name for the worker KV indexer query service.
/// Each dp_rank has its own LocalKvIndexer and query endpoint to ensure per-dp_rank monotonicity.
pub fn worker_kv_indexer_query_endpoint(dp_rank: DpRank) -> String {
    format!("worker_kv_indexer_query_dp{dp_rank}")
}

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
// for router discovery registration
pub const KV_ROUTER_COMPONENT: &str = "kv-router";
pub const KV_ROUTER_ENDPOINT: &str = "generate";

/// Creates an EndpointId for the KV router in the given namespace.
pub fn router_endpoint_id(namespace: String) -> EndpointId {
    EndpointId {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        name: KV_ROUTER_ENDPOINT.to_string(),
    }
}

/// Creates a DiscoveryQuery for the KV router in the given namespace.
pub fn router_discovery_query(namespace: String) -> DiscoveryQuery {
    DiscoveryQuery::Endpoint {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        endpoint: KV_ROUTER_ENDPOINT.to_string(),
    }
}

110
111
112
113
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
114
        workers: &HashMap<protocols::WorkerId, ModelRuntimeConfig>,
115
        request: &SchedulingRequest,
116
        block_size: u32,
117
118
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
119

Yan Ru Pei's avatar
Yan Ru Pei committed
120
#[derive(Clone)]
121
pub enum Indexer {
Yan Ru Pei's avatar
Yan Ru Pei committed
122
    /// Single-threaded radix tree with channel-based event processing.
123
    /// Supports TTL-based expiration and size-based pruning.
124
    /// Has the ability to persist and snapshot states.
125
    KvIndexer(KvIndexer),
126

Yan Ru Pei's avatar
Yan Ru Pei committed
127
128
129
130
131
    /// Concurrent radix tree with a thread pool for event processing.
    /// Uses sticky worker routing for per-worker event serialization.
    /// Does not support TTL/pruning.
    Concurrent(Arc<ThreadPoolIndexer<ConcurrentRadixTree>>),

132
133
134
    /// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
    /// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
    None,
135
136
137
}

impl Indexer {
138
139
140
141
142
143
144
    pub fn new(
        component: &dynamo_runtime::component::Component,
        kv_router_config: &KvRouterConfig,
        block_size: u32,
        cancellation_token: tokio_util::sync::CancellationToken,
    ) -> Self {
        if kv_router_config.overlap_score_weight == 0.0 {
Yan Ru Pei's avatar
Yan Ru Pei committed
145
146
147
148
149
150
151
            return Indexer::None;
        }

        if kv_router_config.router_event_threads > 1 {
            return Indexer::Concurrent(Arc::new(ThreadPoolIndexer::new(
                ConcurrentRadixTree::new(),
                kv_router_config.router_event_threads as usize,
152
                block_size,
Yan Ru Pei's avatar
Yan Ru Pei committed
153
            )));
154
        }
Yan Ru Pei's avatar
Yan Ru Pei committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175

        let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);

        // If use_kv_events is false, enable TTL and pruning for approximate behavior
        let prune_config = if !kv_router_config.use_kv_events {
            Some(PruneConfig {
                ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
                max_tree_size: kv_router_config.router_max_tree_size,
                prune_target_ratio: kv_router_config.router_prune_target_ratio,
            })
        } else {
            None
        };

        Indexer::KvIndexer(KvIndexer::new_with_frequency(
            cancellation_token,
            None, // expiration_duration for frequency tracking
            block_size,
            kv_indexer_metrics,
            prune_config,
        ))
176
177
178
    }

    pub(crate) async fn find_matches(
179
180
181
182
183
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
Yan Ru Pei's avatar
Yan Ru Pei committed
184
            Indexer::Concurrent(tpi) => tpi.find_matches(sequence).await,
185
186
187
            Indexer::None => Ok(OverlapScores {
                scores: HashMap::new(),
                frequencies: Vec::new(),
188
                tree_sizes: HashMap::new(),
189
            }),
190
191
        }
    }
192

193
    pub(crate) async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
194
195
        match self {
            Indexer::KvIndexer(indexer) => indexer.dump_events().await,
Yan Ru Pei's avatar
Yan Ru Pei committed
196
            Indexer::Concurrent(tpi) => tpi.dump_events().await,
197
198
199
200
201
            Indexer::None => {
                panic!(
                    "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
                );
            }
202
203
        }
    }
204

205
    pub(crate) async fn process_routing_decision_for_request(
206
        &self,
207
        tokens_with_hashes: &mut TokensWithHashes,
208
209
210
211
212
        worker: WorkerWithDpRank,
    ) -> Result<(), KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => {
                indexer
213
                    .process_routing_decision_for_request(tokens_with_hashes, worker)
214
215
                    .await
            }
Yan Ru Pei's avatar
Yan Ru Pei committed
216
217
218
219
            Indexer::Concurrent(tpi) => {
                tpi.process_routing_decision_for_request(tokens_with_hashes, worker)
                    .await
            }
220
221
222
            Indexer::None => Ok(()),
        }
    }
Yan Ru Pei's avatar
Yan Ru Pei committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264

    pub(crate) async fn apply_event(&self, event: RouterEvent) {
        match self {
            Indexer::KvIndexer(indexer) => {
                if let Err(e) = indexer.event_sender().send(event).await {
                    tracing::warn!("Failed to send event to indexer: {e}");
                }
            }
            Indexer::Concurrent(tpi) => tpi.apply_event(event).await,
            Indexer::None => {}
        }
    }

    pub(crate) async fn remove_worker(&self, worker_id: WorkerId) {
        match self {
            Indexer::KvIndexer(indexer) => {
                if let Err(e) = indexer.remove_worker_sender().send(worker_id).await {
                    tracing::warn!("Failed to send worker removal for {worker_id}: {e}");
                }
            }
            Indexer::Concurrent(tpi) => {
                KvIndexerInterface::remove_worker(tpi.as_ref(), worker_id).await;
            }
            Indexer::None => {}
        }
    }

    pub(crate) async fn get_workers(&self) -> Vec<WorkerId> {
        match self {
            Indexer::KvIndexer(indexer) => {
                let (resp_tx, resp_rx) = oneshot::channel();
                let req = GetWorkersRequest { resp: resp_tx };
                if let Err(e) = indexer.get_workers_sender().send(req).await {
                    tracing::warn!("Failed to send get_workers request: {e}");
                    return Vec::new();
                }
                resp_rx.await.unwrap_or_default()
            }
            Indexer::Concurrent(tpi) => tpi.backend().get_workers(),
            Indexer::None => Vec::new(),
        }
    }
265
266
}

267
268
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
269
pub struct KvRouter {
270
    indexer: Indexer,
271
    scheduler: KvScheduler,
272
    block_size: u32,
273
    kv_router_config: KvRouterConfig,
Yan Ru Pei's avatar
Yan Ru Pei committed
274
    cancellation_token: tokio_util::sync::CancellationToken,
275
    client: Client,
276
277
278
}

impl KvRouter {
279
    #[allow(clippy::too_many_arguments)]
280
    pub async fn new(
281
282
        endpoint: Endpoint,
        client: Client,
283
        mut workers_with_configs: RuntimeConfigWatch,
284
        block_size: u32,
285
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
286
        kv_router_config: Option<KvRouterConfig>,
287
        router_id: u64,
288
        worker_type: &'static str,
289
    ) -> Result<Self> {
290
        let kv_router_config = kv_router_config.unwrap_or_default();
291
        kv_router_config.validate()?;
292
        let component = endpoint.component();
293
        let cancellation_token = component.drt().primary_token();
294

295
296
297
298
299
300
        let indexer = Indexer::new(
            component,
            &kv_router_config,
            block_size,
            cancellation_token.clone(),
        );
301

302
        // Wait for at least one worker with a known runtime config before starting scheduler
303
304
305
306
307
308
        let _ = workers_with_configs
            .wait_for(|m| !m.is_empty())
            .await
            .map_err(|_| {
                anyhow::anyhow!("runtime config watch closed before any workers appeared")
            })?;
309

310
        let scheduler = KvScheduler::start(
311
            component.clone(),
312
            block_size,
313
            workers_with_configs.clone(),
314
            selector,
315
            kv_router_config.router_replica_sync,
316
            router_id,
317
            worker_type,
318
319
        )
        .await?;
320

321
322
323
324
325
326
        // Start KV event subscription if needed (use_kv_events=true and overlap_score_weight>0)
        if kv_router_config.should_subscribe_to_kv_events() {
            subscriber::start_subscriber(
                component.clone(),
                &kv_router_config,
                router_id,
Yan Ru Pei's avatar
Yan Ru Pei committed
327
                indexer.clone(),
328
329
330
331
                cancellation_token.clone(),
            )
            .await?;
        } else {
332
            tracing::info!(
333
334
335
                "Skipping KV event subscription (use_kv_events={}, overlap_score_weight={})",
                kv_router_config.use_kv_events,
                kv_router_config.overlap_score_weight,
336
            );
337
        }
338

339
        tracing::info!("KV Routing initialized");
340
        Ok(Self {
341
            indexer,
342
            scheduler,
343
            block_size,
344
            kv_router_config,
Yan Ru Pei's avatar
Yan Ru Pei committed
345
            cancellation_token,
346
            client,
347
        })
348
349
    }

350
351
352
353
354
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

355
356
357
358
359
360
361
362
    pub fn indexer(&self) -> &Indexer {
        &self.indexer
    }

    pub fn kv_router_config(&self) -> &KvRouterConfig {
        &self.kv_router_config
    }

363
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
364
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
Yan Ru Pei's avatar
Yan Ru Pei committed
365
    /// Now also takes optional context_id for request tracking
366
    #[allow(clippy::too_many_arguments)]
Yan Ru Pei's avatar
Yan Ru Pei committed
367
    pub async fn find_best_match(
368
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
369
        context_id: Option<&str>,
370
        tokens: &[u32],
371
        router_config_override: Option<&RouterConfigOverride>,
372
        update_states: bool,
373
        lora_name: Option<String>,
Yan Ru Pei's avatar
Yan Ru Pei committed
374
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
375
376
        let start = Instant::now();

Yan Ru Pei's avatar
Yan Ru Pei committed
377
        if update_states && context_id.is_none() {
378
            anyhow::bail!("context_id must be provided when update_states is true");
Yan Ru Pei's avatar
Yan Ru Pei committed
379
380
        }

381
        let isl_tokens = tokens.len();
382

383
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
384
        let hash_elapsed = start.elapsed();
385

386
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;
387
        let find_matches_elapsed = start.elapsed();
388

389
390
391
        // Compute seq_hashes only if scheduler needs it for active blocks tracking
        let maybe_seq_hashes = self
            .kv_router_config
392
            .compute_seq_hashes_for_tracking(tokens, self.block_size);
393
        let seq_hash_elapsed = start.elapsed();
394

Yan Ru Pei's avatar
Yan Ru Pei committed
395
        let best_worker = self
396
            .scheduler
397
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
398
                context_id.map(|s| s.to_string()),
399
                isl_tokens,
400
                maybe_seq_hashes,
401
                overlap_scores.clone(),
402
                router_config_override,
403
                update_states,
404
                lora_name,
405
            )
406
            .await?;
407
408
409
410
411
412
413
414
        let total_elapsed = start.elapsed();

        metrics::ROUTING_OVERHEAD_METRICS.observe(
            hash_elapsed,
            find_matches_elapsed,
            seq_hash_elapsed,
            total_elapsed,
        );
415

416
        #[cfg(feature = "bench")]
417
418
419
420
421
422
423
424
425
        tracing::info!(
            isl_tokens,
            hash_us = hash_elapsed.as_micros() as u64,
            find_matches_us = (find_matches_elapsed - hash_elapsed).as_micros() as u64,
            seq_hash_us = (seq_hash_elapsed - find_matches_elapsed).as_micros() as u64,
            schedule_us = (total_elapsed - seq_hash_elapsed).as_micros() as u64,
            total_us = total_elapsed.as_micros() as u64,
            "find_best_match completed"
        );
426

427
428
        // Note: Routing decision recording (for approximate mode) is now handled
        // by KvPushRouter::generate after select_worker returns.
429

430
431
        let overlap_amount = overlap_scores
            .scores
Yan Ru Pei's avatar
Yan Ru Pei committed
432
            .get(&best_worker)
433
434
            .copied()
            .unwrap_or(0);
Yan Ru Pei's avatar
Yan Ru Pei committed
435
        Ok((best_worker, overlap_amount))
436
437
    }

438
    #[allow(clippy::too_many_arguments)]
439
440
441
442
443
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
        overlap_blocks: u32,
444
        expected_output_tokens: Option<u32>,
Yan Ru Pei's avatar
Yan Ru Pei committed
445
        worker: WorkerWithDpRank,
446
        lora_name: Option<String>,
447
448
    ) {
        let isl_tokens = tokens.len();
449

450
451
452
        let maybe_seq_hashes = self
            .kv_router_config
            .compute_seq_hashes_for_tracking(tokens, self.block_size);
453

454
455
        if let Err(e) = self
            .scheduler
456
            .add_request(
457
                request_id.clone(),
458
                maybe_seq_hashes,
459
460
                isl_tokens,
                overlap_blocks,
461
                expected_output_tokens,
Yan Ru Pei's avatar
Yan Ru Pei committed
462
                worker,
463
                lora_name,
464
            )
465
466
467
468
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
469
470
    }

471
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
472
        self.scheduler.mark_prefill_completed(request_id).await
473
474
    }

475
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
476
        self.scheduler.free(request_id).await
477
    }
478

479
480
481
482
483
484
    /// Get the worker type for this router ("prefill" or "decode").
    /// Used for Prometheus metric labeling.
    pub fn worker_type(&self) -> &'static str {
        self.scheduler.worker_type()
    }

485
486
487
488
489
490
491
492
493
494
    pub async fn add_output_block(
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
        self.scheduler
            .add_output_block(request_id, decay_fraction)
            .await
    }

495
    pub fn block_size(&self) -> u32 {
496
497
        self.block_size
    }
498

499
500
501
502
503
504
505
506
507
508
509
510
    /// Compute the overlap blocks for a given token sequence and worker.
    /// This queries the indexer to find how many blocks are already cached.
    pub async fn get_overlap_blocks(
        &self,
        tokens: &[u32],
        worker: WorkerWithDpRank,
    ) -> Result<u32, KvRouterError> {
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;
        Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0))
    }

511
512
513
    /// Get potential prefill and decode loads for all workers
    pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
        let isl_tokens = tokens.len();
514
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
515
        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
516

517
518
        let maybe_seq_hashes = self
            .kv_router_config
519
            .compute_seq_hashes_for_tracking(tokens, self.block_size);
520

521
522
        Ok(self
            .scheduler
523
            .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
524
525
526
            .await)
    }

527
528
529
530
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
531
532
}

Michael Feil's avatar
Michael Feil committed
533
534
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
535
536
537
538
539
540
541
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
542
543
544
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
545
            RouterRequest::New { tokens } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
546
                let (best_worker, overlap_blocks) = self
547
                    .find_best_match(Some(&context_id), &tokens, None, true, None)
Michael Feil's avatar
Michael Feil committed
548
549
550
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
551
552
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
553
554
555
                    overlap_blocks,
                }
            }
556
557
558
559
560
561
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
            RouterRequest::MarkFree => RouterResponse::FreeMarked {
                success: self.free(&context_id).await.is_ok(),
            },
Michael Feil's avatar
Michael Feil committed
562
        };
563
564
565
566
567
568

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
569

Yan Ru Pei's avatar
Yan Ru Pei committed
570
571
572
573
574
575
impl Drop for KvRouter {
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}