kv_router.rs 17.6 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use std::time::Instant;
5

6
use anyhow::Result;
7
use dynamo_kv_router::{
8
    config::{KvRouterConfig, RouterConfigOverride, min_initial_workers_from_env},
9
    indexer::KvRouterError,
10
11
    protocols::KV_EVENT_SUBJECT,
    protocols::{
12
13
        BlockExtraInfo, BlockHashOptions, DpRank, RouterEvent, RouterRequest, RouterResponse,
        TokensWithHashes, WorkerId, WorkerWithDpRank, compute_block_hash_for_seq,
14
15
    },
};
16
use dynamo_runtime::{
17
    component::{Client, Endpoint},
18
    discovery::DiscoveryQuery,
19
    pipeline::{
20
21
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, ResponseStream, SingleIn,
        async_trait,
22
    },
23
    protocols::EndpointId,
24
    protocols::annotated::Annotated,
25
    traits::DistributedRuntimeProvider,
26
};
27
use futures::stream;
28
use tracing::Instrument;
29
use validator::Validate;
30

31
pub mod indexer;
32
mod jetstream;
33
pub mod metrics;
34
pub mod prefill_router;
35
pub mod publisher;
36
pub mod push_router;
37
pub mod scheduler;
38
pub mod sequence;
39
pub mod subscriber;
40
pub mod worker_query;
41

42
pub use indexer::Indexer;
43
pub use prefill_router::PrefillRouter;
44
pub use push_router::{DirectRoutingRouter, KvPushRouter};
45

46
use crate::{
47
    discovery::RuntimeConfigWatch,
48
    kv_router::{
49
        scheduler::{DefaultWorkerSelector, KvScheduler, PotentialLoad},
50
        sequence::{SequenceError, SequenceRequest},
51
    },
52
    local_model::runtime_config::ModelRuntimeConfig,
53
54
};

55
56
use std::collections::HashSet;

57
58
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
59
60
61
62
63
64
65
66
67
68

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
69

70
71
72
73
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

74
75
76
// for worker-local kvindexer query
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer

77
78
79
80
81
82
/// Generates a dp_rank-specific endpoint name for the worker KV indexer query service.
/// Each dp_rank has its own LocalKvIndexer and query endpoint to ensure per-dp_rank monotonicity.
pub fn worker_kv_indexer_query_endpoint(dp_rank: DpRank) -> String {
    format!("worker_kv_indexer_query_dp{dp_rank}")
}

83
// for router discovery registration
84
pub const KV_ROUTER_ENDPOINT: &str = "router-discovery";
85
86

/// Creates an EndpointId for the KV router in the given namespace.
87
pub fn router_endpoint_id(namespace: String, component: String) -> EndpointId {
88
89
    EndpointId {
        namespace,
90
        component,
91
92
93
94
95
        name: KV_ROUTER_ENDPOINT.to_string(),
    }
}

/// Creates a DiscoveryQuery for the KV router in the given namespace.
96
pub fn router_discovery_query(namespace: String, component: String) -> DiscoveryQuery {
97
98
    DiscoveryQuery::Endpoint {
        namespace,
99
        component,
100
101
102
103
        endpoint: KV_ROUTER_ENDPOINT.to_string(),
    }
}

104
105
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
106
107
108
109
pub struct KvRouter<Sel = DefaultWorkerSelector>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>,
{
110
    indexer: Indexer,
111
    scheduler: KvScheduler<Sel>,
112
    block_size: u32,
113
    kv_router_config: KvRouterConfig,
Yan Ru Pei's avatar
Yan Ru Pei committed
114
    cancellation_token: tokio_util::sync::CancellationToken,
115
    client: Client,
116
    is_eagle: bool,
117
118
}

119
120
121
122
impl<Sel> KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
{
123
    #[allow(clippy::too_many_arguments)]
124
    pub async fn new(
125
126
        endpoint: Endpoint,
        client: Client,
127
        workers_with_configs: RuntimeConfigWatch,
128
        block_size: u32,
129
        selector: Sel,
130
        kv_router_config: Option<KvRouterConfig>,
131
        worker_type: &'static str,
132
        model_name: Option<String>,
133
        is_eagle: bool,
134
    ) -> Result<Self> {
135
        let kv_router_config = kv_router_config.unwrap_or_default();
136
        kv_router_config.validate()?;
137
        let component = endpoint.component();
138
        let cancellation_token = component.drt().primary_token();
139
        let min_initial_workers = min_initial_workers_from_env()?;
140

141
        let indexer = Indexer::new(component, &kv_router_config, block_size, model_name).await?;
142

143
144
145
146
        if min_initial_workers > 0 && !kv_router_config.skip_initial_worker_wait {
            let mut startup_watch = workers_with_configs.clone();
            let _ = startup_watch
                .wait_for(|m| m.len() >= min_initial_workers)
147
148
                .await
                .map_err(|_| {
149
150
                    anyhow::anyhow!(
                        "runtime config watch closed before {} workers appeared",
151
                        min_initial_workers
152
                    )
153
154
                })?;
        }
155

156
        let scheduler = KvScheduler::start(
157
            component.clone(),
158
            block_size,
159
            workers_with_configs.clone(),
160
            selector,
161
            &kv_router_config,
162
            worker_type,
163
164
        )
        .await?;
165

166
167
168
169
170
        // Start KV event subscription if needed — skip when using a remote indexer
        // (the standalone indexer handles its own event subscription).
        if kv_router_config.remote_indexer_component.is_some() {
            tracing::info!("Skipping KV event subscription (using remote indexer)");
        } else if kv_router_config.should_subscribe_to_kv_events() {
171
172
            subscriber::start_subscriber(component.clone(), &kv_router_config, indexer.clone())
                .await?;
173
        } else {
174
            tracing::info!(
175
176
177
                "Skipping KV event subscription (use_kv_events={}, overlap_score_weight={})",
                kv_router_config.use_kv_events,
                kv_router_config.overlap_score_weight,
178
            );
179
        }
180

181
        tracing::info!("KV Routing initialized");
182
        Ok(Self {
183
            indexer,
184
            scheduler,
185
            block_size,
186
            kv_router_config,
Yan Ru Pei's avatar
Yan Ru Pei committed
187
            cancellation_token,
188
            client,
189
            is_eagle,
190
        })
191
192
    }

193
194
195
196
197
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

198
199
200
201
202
203
204
205
    pub fn indexer(&self) -> &Indexer {
        &self.indexer
    }

    pub fn kv_router_config(&self) -> &KvRouterConfig {
        &self.kv_router_config
    }

206
207
208
209
    pub fn is_eagle(&self) -> bool {
        self.is_eagle
    }

210
211
    pub async fn record_routing_decision(
        &self,
212
        mut tokens_with_hashes: TokensWithHashes,
213
214
215
216
217
218
219
        worker: WorkerWithDpRank,
    ) -> Result<(), KvRouterError> {
        self.indexer
            .process_routing_decision_for_request(&mut tokens_with_hashes, worker)
            .await
    }

220
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
221
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
222
223
224
    /// Now also takes optional context_id for request tracking.
    ///
    /// When `allowed_worker_ids` is Some, only workers in that set are considered for selection.
225
    #[allow(clippy::too_many_arguments)]
Yan Ru Pei's avatar
Yan Ru Pei committed
226
    pub async fn find_best_match(
227
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
228
        context_id: Option<&str>,
229
        tokens: &[u32],
230
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
231
        router_config_override: Option<&RouterConfigOverride>,
232
        update_states: bool,
233
        lora_name: Option<String>,
234
        priority_jump: f64,
235
        expected_output_tokens: Option<u32>,
236
        allowed_worker_ids: Option<HashSet<WorkerId>>,
Yan Ru Pei's avatar
Yan Ru Pei committed
237
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
238
239
        let start = Instant::now();

Yan Ru Pei's avatar
Yan Ru Pei committed
240
        if update_states && context_id.is_none() {
241
            anyhow::bail!("context_id must be provided when update_states is true");
Yan Ru Pei's avatar
Yan Ru Pei committed
242
243
        }

244
        let isl_tokens = tokens.len();
245
246
247
248
249
250
251
252
253
254
255
256
        let hash_options = BlockHashOptions {
            block_mm_infos,
            lora_name: lora_name.as_deref(),
            is_eagle: Some(self.is_eagle),
        };

        let block_hashes = tracing::info_span!("kv_router.compute_block_hashes")
            .in_scope(|| compute_block_hash_for_seq(tokens, self.block_size, hash_options));
        let hash_elapsed = start.elapsed();
        // Compute seq_hashes only if scheduler needs it for active blocks tracking
        let maybe_seq_hashes = tracing::info_span!("kv_router.compute_seq_hashes").in_scope(|| {
            self.kv_router_config.compute_seq_hashes_for_tracking(
257
258
                tokens,
                self.block_size,
259
260
261
                router_config_override,
                hash_options,
                Some(&block_hashes),
262
263
            )
        });
264
        let seq_hash_elapsed = start.elapsed();
265

266
        let overlap_scores = self
267
268
269
270
            .indexer
            .find_matches(block_hashes)
            .instrument(tracing::info_span!("kv_router.find_matches"))
            .await?;
271
        let find_matches_elapsed = start.elapsed();
272

273
        let response = self
274
            .scheduler
275
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
276
                context_id.map(|s| s.to_string()),
277
                isl_tokens,
278
                maybe_seq_hashes,
279
                overlap_scores,
280
                router_config_override,
281
                update_states,
282
                lora_name,
283
                priority_jump,
284
                expected_output_tokens,
285
                allowed_worker_ids,
286
            )
287
            .instrument(tracing::info_span!("kv_router.schedule"))
288
            .await?;
289
290
        let total_elapsed = start.elapsed();

291
292
293
294
        if let Some(m) = metrics::RoutingOverheadMetrics::get() {
            m.observe(
                hash_elapsed,
                seq_hash_elapsed,
295
                find_matches_elapsed,
296
297
298
                total_elapsed,
            );
        }
299

300
        #[cfg(feature = "bench")]
301
302
303
        tracing::info!(
            isl_tokens,
            hash_us = hash_elapsed.as_micros() as u64,
304
305
306
            seq_hash_us = (seq_hash_elapsed - hash_elapsed).as_micros() as u64,
            find_matches_us = (find_matches_elapsed - seq_hash_elapsed).as_micros() as u64,
            schedule_us = (total_elapsed - find_matches_elapsed).as_micros() as u64,
307
308
309
            total_us = total_elapsed.as_micros() as u64,
            "find_best_match completed"
        );
310

311
        Ok((response.best_worker, response.overlap_blocks))
312
313
    }

314
315
316
317
318
    /// Register externally-provided workers in the slot tracker.
    pub fn register_workers(&self, worker_ids: &HashSet<WorkerId>) {
        self.scheduler.register_workers(worker_ids);
    }

319
    #[allow(clippy::too_many_arguments)]
320
321
322
323
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
324
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
325
        overlap_blocks: u32,
326
        expected_output_tokens: Option<u32>,
Yan Ru Pei's avatar
Yan Ru Pei committed
327
        worker: WorkerWithDpRank,
328
        lora_name: Option<String>,
329
        router_config_override: Option<&RouterConfigOverride>,
330
331
    ) {
        let isl_tokens = tokens.len();
332
333
334
335
336
        let hash_options = BlockHashOptions {
            block_mm_infos,
            lora_name: lora_name.as_deref(),
            is_eagle: Some(self.is_eagle),
        };
337

338
339
340
341
        let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
            tokens,
            self.block_size,
            router_config_override,
342
343
            hash_options,
            None,
344
        );
345
346
347
        let track_prefill_tokens = self
            .kv_router_config
            .track_prefill_tokens(router_config_override);
348

349
350
        if let Err(e) = self
            .scheduler
351
352
353
354
355
            .add_request(SequenceRequest {
                request_id: request_id.clone(),
                token_sequence: maybe_seq_hashes,
                isl: isl_tokens,
                overlap: overlap_blocks,
356
                track_prefill_tokens,
357
                expected_output_tokens,
Yan Ru Pei's avatar
Yan Ru Pei committed
358
                worker,
359
                lora_name,
360
            })
361
362
363
364
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
365
366
    }

367
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
368
        self.scheduler.mark_prefill_completed(request_id).await
369
370
    }

371
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
372
        self.scheduler.free(request_id).await
373
    }
374

375
376
377
378
379
    /// Number of requests currently parked in the scheduler queue.
    pub fn pending_count(&self) -> usize {
        self.scheduler.pending_count()
    }

380
381
382
383
384
385
    /// Get the worker type for this router ("prefill" or "decode").
    /// Used for Prometheus metric labeling.
    pub fn worker_type(&self) -> &'static str {
        self.scheduler.worker_type()
    }

386
    pub fn add_output_block(
387
388
389
390
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
391
        self.scheduler.add_output_block(request_id, decay_fraction)
392
393
    }

394
    pub fn block_size(&self) -> u32 {
395
396
        self.block_size
    }
397

398
399
400
401
402
    /// Compute the overlap blocks for a given token sequence and worker.
    /// This queries the indexer to find how many blocks are already cached.
    pub async fn get_overlap_blocks(
        &self,
        tokens: &[u32],
403
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
404
        worker: WorkerWithDpRank,
405
        lora_name: Option<&str>,
406
    ) -> Result<u32, KvRouterError> {
407
408
409
410
411
412
413
414
415
        let block_hashes = compute_block_hash_for_seq(
            tokens,
            self.block_size,
            BlockHashOptions {
                block_mm_infos,
                lora_name,
                is_eagle: Some(self.is_eagle),
            },
        );
416
417
418
419
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;
        Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0))
    }

420
    /// Get potential prefill and decode loads for all workers
421
422
423
424
    pub async fn get_potential_loads(
        &self,
        tokens: &[u32],
        router_config_override: Option<&RouterConfigOverride>,
425
        block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
426
        lora_name: Option<&str>,
427
    ) -> Result<Vec<PotentialLoad>> {
428
        let isl_tokens = tokens.len();
429
430
431
432
433
434
        let hash_options = BlockHashOptions {
            block_mm_infos,
            lora_name,
            is_eagle: Some(self.is_eagle),
        };
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, hash_options);
435

436
437
438
439
        let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking(
            tokens,
            self.block_size,
            router_config_override,
440
441
            hash_options,
            Some(&block_hashes),
442
        );
443
444
445
        let track_prefill_tokens = self
            .kv_router_config
            .track_prefill_tokens(router_config_override);
446
447
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;

448
449
450
451
452
453
        Ok(self.scheduler.get_potential_loads(
            maybe_seq_hashes,
            isl_tokens,
            overlap_scores,
            track_prefill_tokens,
        ))
454
455
    }

456
457
458
459
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
460
461
}

Michael Feil's avatar
Michael Feil committed
462
463
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
464
#[async_trait]
465
466
467
468
469
impl<Sel> AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error>
    for KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig> + Send + Sync + 'static,
{
470
471
472
473
474
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
475
476
477
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
478
479
480
481
            RouterRequest::New {
                tokens,
                block_mm_infos,
            } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
482
                let (best_worker, overlap_blocks) = self
483
484
485
486
487
488
489
490
                    .find_best_match(
                        Some(&context_id),
                        &tokens,
                        block_mm_infos.as_deref(),
                        None,
                        true,
                        None,
                        0.0,
491
                        None,
492
                        None,
493
                    )
Michael Feil's avatar
Michael Feil committed
494
495
496
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
497
498
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
499
500
501
                    overlap_blocks,
                }
            }
502
503
504
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
505
506
507
508
509
510
511
512
513
            RouterRequest::MarkFree { request_id } => {
                let request_id = match request_id.as_deref() {
                    Some(request_id) if !request_id.trim().is_empty() => request_id,
                    _ => &context_id,
                };
                RouterResponse::FreeMarked {
                    success: self.free(request_id).await.is_ok(),
                }
            }
Michael Feil's avatar
Michael Feil committed
514
        };
515
516
517
518
519
520

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
521

522
523
524
525
impl<Sel> Drop for KvRouter<Sel>
where
    Sel: dynamo_kv_router::selector::WorkerSelector<ModelRuntimeConfig>,
{
Yan Ru Pei's avatar
Yan Ru Pei committed
526
527
528
529
530
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}