"examples/multimodal_v1/connect/__init__.py" did not exist on "e0a51940d105175d1105114b814933c7fc5dbd48"
kv_router.rs 17.2 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
use std::time::{Duration, Instant};
6

7
use anyhow::Result;
8
use dynamo_runtime::{
9
    component::{Client, Endpoint},
10
    discovery::DiscoveryQuery,
11
    pipeline::{
12
13
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, ResponseStream, SingleIn,
        async_trait,
14
    },
15
    protocols::EndpointId,
16
    protocols::annotated::Annotated,
17
    traits::DistributedRuntimeProvider,
18
};
19
use futures::stream;
20
use validator::Validate;
21

22
23
24
25
26
// Re-export from dynamo-kv-router crate
pub use dynamo_kv_router::approx;
pub use dynamo_kv_router::indexer;
pub use dynamo_kv_router::protocols;

27
pub mod config;
28
pub mod metrics;
29
pub mod prefill_router;
30
pub mod publisher;
31
pub mod push_router;
32
pub mod recorder;
33
pub mod scheduler;
34
pub mod sequence;
35
pub mod subscriber;
36
pub mod worker_query;
37

38
pub use config::{KvRouterConfig, RouterConfigOverride};
39
pub use prefill_router::PrefillRouter;
40
pub use push_router::KvPushRouter;
41

42
use crate::{
43
    discovery::RuntimeConfigWatch,
44
    kv_router::{
45
        approx::PruneConfig,
46
        indexer::{KvIndexer, KvIndexerInterface, KvRouterError},
Yan Ru Pei's avatar
Yan Ru Pei committed
47
        protocols::{
48
            DpRank, LocalBlockHash, OverlapScores, RouterEvent, RouterRequest, RouterResponse,
49
            TokensWithHashes, WorkerSelectionResult, WorkerWithDpRank, compute_block_hash_for_seq,
Yan Ru Pei's avatar
Yan Ru Pei committed
50
        },
51
        scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
52
        sequence::SequenceError,
53
    },
54
    local_model::runtime_config::ModelRuntimeConfig,
55
56
};

57
58
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
59
60
61
62
63

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
64
pub const KV_EVENT_SUBJECT: &str = "kv-events";
65
66
67
68
69
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
70

71
72
73
74
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";

75
76
77
// for worker-local kvindexer query
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer

78
79
80
81
82
83
/// Generates a dp_rank-specific endpoint name for the worker KV indexer query service.
/// Each dp_rank has its own LocalKvIndexer and query endpoint to ensure per-dp_rank monotonicity.
pub fn worker_kv_indexer_query_endpoint(dp_rank: DpRank) -> String {
    format!("worker_kv_indexer_query_dp{dp_rank}")
}

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
// for router discovery registration
pub const KV_ROUTER_COMPONENT: &str = "kv-router";
pub const KV_ROUTER_ENDPOINT: &str = "generate";

/// Creates an EndpointId for the KV router in the given namespace.
pub fn router_endpoint_id(namespace: String) -> EndpointId {
    EndpointId {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        name: KV_ROUTER_ENDPOINT.to_string(),
    }
}

/// Creates a DiscoveryQuery for the KV router in the given namespace.
pub fn router_discovery_query(namespace: String) -> DiscoveryQuery {
    DiscoveryQuery::Endpoint {
        namespace,
        component: KV_ROUTER_COMPONENT.to_string(),
        endpoint: KV_ROUTER_ENDPOINT.to_string(),
    }
}

106
107
108
109
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
110
        workers: &HashMap<protocols::WorkerId, ModelRuntimeConfig>,
111
        request: &SchedulingRequest,
112
        block_size: u32,
113
114
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
115

116
pub enum Indexer {
117
118
    /// Updates itself based on KV events emitted by backend workers or routing decisions.
    /// Supports TTL-based expiration and size-based pruning.
119
    /// Has the ability to persist and snapshot states.
120
    KvIndexer(KvIndexer),
121
122
123
124

    /// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
    /// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
    None,
125
126
127
}

impl Indexer {
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    pub fn new(
        component: &dynamo_runtime::component::Component,
        kv_router_config: &KvRouterConfig,
        block_size: u32,
        cancellation_token: tokio_util::sync::CancellationToken,
    ) -> Self {
        if kv_router_config.overlap_score_weight == 0.0 {
            // When overlap_score_weight is zero, we don't need to track prefixes
            Indexer::None
        } else {
            let kv_indexer_metrics = indexer::KvIndexerMetrics::from_component(component);

            // If use_kv_events is false, enable TTL and pruning for approximate behavior
            let prune_config = if !kv_router_config.use_kv_events {
                Some(PruneConfig {
                    ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
                    max_tree_size: kv_router_config.router_max_tree_size,
                    prune_target_ratio: kv_router_config.router_prune_target_ratio,
                })
            } else {
                None
            };

            Indexer::KvIndexer(KvIndexer::new_with_frequency(
                cancellation_token,
                None, // expiration_duration for frequency tracking
                block_size,
                kv_indexer_metrics,
                prune_config,
            ))
        }
    }

    pub(crate) async fn find_matches(
162
163
164
165
166
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
167
168
169
            Indexer::None => Ok(OverlapScores {
                scores: HashMap::new(),
                frequencies: Vec::new(),
170
                tree_sizes: HashMap::new(),
171
            }),
172
173
        }
    }
174

175
    pub(crate) async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
176
177
        match self {
            Indexer::KvIndexer(indexer) => indexer.dump_events().await,
178
179
180
181
182
            Indexer::None => {
                panic!(
                    "Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
                );
            }
183
184
        }
    }
185

186
    pub(crate) async fn process_routing_decision_for_request(
187
        &self,
188
        tokens_with_hashes: &mut TokensWithHashes,
189
190
191
192
193
        worker: WorkerWithDpRank,
    ) -> Result<(), KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => {
                indexer
194
                    .process_routing_decision_for_request(tokens_with_hashes, worker)
195
196
197
198
199
                    .await
            }
            Indexer::None => Ok(()),
        }
    }
200
201
}

202
203
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
204
pub struct KvRouter {
205
    indexer: Indexer,
206
    scheduler: KvScheduler,
207
    block_size: u32,
208
    kv_router_config: KvRouterConfig,
Yan Ru Pei's avatar
Yan Ru Pei committed
209
    cancellation_token: tokio_util::sync::CancellationToken,
210
    client: Client,
211
212
213
}

impl KvRouter {
214
    #[allow(clippy::too_many_arguments)]
215
    pub async fn new(
216
217
        endpoint: Endpoint,
        client: Client,
218
        mut workers_with_configs: RuntimeConfigWatch,
219
        block_size: u32,
220
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
221
        kv_router_config: Option<KvRouterConfig>,
222
        router_id: u64,
223
        worker_type: &'static str,
224
    ) -> Result<Self> {
225
        let kv_router_config = kv_router_config.unwrap_or_default();
226
        kv_router_config.validate()?;
227
        let component = endpoint.component();
228
        let cancellation_token = component.drt().primary_token();
229

230
231
232
233
234
235
        let indexer = Indexer::new(
            component,
            &kv_router_config,
            block_size,
            cancellation_token.clone(),
        );
236

237
        // Wait for at least one worker with a known runtime config before starting scheduler
238
239
240
241
242
243
        let _ = workers_with_configs
            .wait_for(|m| !m.is_empty())
            .await
            .map_err(|_| {
                anyhow::anyhow!("runtime config watch closed before any workers appeared")
            })?;
244

245
        let scheduler = KvScheduler::start(
246
            component.clone(),
247
            block_size,
248
            workers_with_configs.clone(),
249
            selector,
250
            kv_router_config.router_replica_sync,
251
            router_id,
252
            worker_type,
253
254
        )
        .await?;
255

256
257
258
259
260
261
262
263
        // Start KV event subscription if needed (use_kv_events=true and overlap_score_weight>0)
        if kv_router_config.should_subscribe_to_kv_events() {
            // Guaranteed to be KvIndexer since overlap_score_weight > 0.0
            let Indexer::KvIndexer(kv_indexer) = &indexer else {
                unreachable!(
                    "should_subscribe_to_kv_events implies overlap_score_weight > 0 implies KvIndexer"
                )
            };
264

265
266
267
268
269
270
271
272
273
            subscriber::start_subscriber(
                component.clone(),
                &kv_router_config,
                router_id,
                kv_indexer,
                cancellation_token.clone(),
            )
            .await?;
        } else {
274
            tracing::info!(
275
276
277
                "Skipping KV event subscription (use_kv_events={}, overlap_score_weight={})",
                kv_router_config.use_kv_events,
                kv_router_config.overlap_score_weight,
278
            );
279
        }
280

281
        tracing::info!("KV Routing initialized");
282
        Ok(Self {
283
            indexer,
284
            scheduler,
285
            block_size,
286
            kv_router_config,
Yan Ru Pei's avatar
Yan Ru Pei committed
287
            cancellation_token,
288
            client,
289
        })
290
291
    }

292
293
294
295
296
    /// Get a reference to the client used by this KvRouter
    pub fn client(&self) -> &Client {
        &self.client
    }

297
298
299
300
301
302
303
304
    pub fn indexer(&self) -> &Indexer {
        &self.indexer
    }

    pub fn kv_router_config(&self) -> &KvRouterConfig {
        &self.kv_router_config
    }

305
    /// Give these tokens, find the worker with the best match in it's KV cache.
Yan Ru Pei's avatar
Yan Ru Pei committed
306
    /// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
Yan Ru Pei's avatar
Yan Ru Pei committed
307
    /// Now also takes optional context_id for request tracking
308
    #[allow(clippy::too_many_arguments)]
Yan Ru Pei's avatar
Yan Ru Pei committed
309
    pub async fn find_best_match(
310
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
311
        context_id: Option<&str>,
312
        tokens: &[u32],
313
        router_config_override: Option<&RouterConfigOverride>,
314
        update_states: bool,
315
        lora_name: Option<String>,
Yan Ru Pei's avatar
Yan Ru Pei committed
316
    ) -> anyhow::Result<(WorkerWithDpRank, u32)> {
317
318
        let start = Instant::now();

Yan Ru Pei's avatar
Yan Ru Pei committed
319
        if update_states && context_id.is_none() {
320
            anyhow::bail!("context_id must be provided when update_states is true");
Yan Ru Pei's avatar
Yan Ru Pei committed
321
322
        }

323
        let isl_tokens = tokens.len();
324

325
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
326
        let hash_elapsed = start.elapsed();
327

328
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;
329
        let find_matches_elapsed = start.elapsed();
330

331
332
333
        // Compute seq_hashes only if scheduler needs it for active blocks tracking
        let maybe_seq_hashes = self
            .kv_router_config
334
            .compute_seq_hashes_for_tracking(tokens, self.block_size);
335
        let seq_hash_elapsed = start.elapsed();
336

Yan Ru Pei's avatar
Yan Ru Pei committed
337
        let best_worker = self
338
            .scheduler
339
            .schedule(
Yan Ru Pei's avatar
Yan Ru Pei committed
340
                context_id.map(|s| s.to_string()),
341
                isl_tokens,
342
                maybe_seq_hashes,
343
                overlap_scores.clone(),
344
                router_config_override,
345
                update_states,
346
                lora_name,
347
            )
348
            .await?;
349
350
351
352
353
354
355
356
        let total_elapsed = start.elapsed();

        metrics::ROUTING_OVERHEAD_METRICS.observe(
            hash_elapsed,
            find_matches_elapsed,
            seq_hash_elapsed,
            total_elapsed,
        );
357

358
        #[cfg(feature = "bench")]
359
360
361
362
363
364
365
366
367
        tracing::info!(
            isl_tokens,
            hash_us = hash_elapsed.as_micros() as u64,
            find_matches_us = (find_matches_elapsed - hash_elapsed).as_micros() as u64,
            seq_hash_us = (seq_hash_elapsed - find_matches_elapsed).as_micros() as u64,
            schedule_us = (total_elapsed - seq_hash_elapsed).as_micros() as u64,
            total_us = total_elapsed.as_micros() as u64,
            "find_best_match completed"
        );
368

369
370
        // Note: Routing decision recording (for approximate mode) is now handled
        // by KvPushRouter::generate after select_worker returns.
371

372
373
        let overlap_amount = overlap_scores
            .scores
Yan Ru Pei's avatar
Yan Ru Pei committed
374
            .get(&best_worker)
375
376
            .copied()
            .unwrap_or(0);
Yan Ru Pei's avatar
Yan Ru Pei committed
377
        Ok((best_worker, overlap_amount))
378
379
    }

380
    #[allow(clippy::too_many_arguments)]
381
382
383
384
385
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
        overlap_blocks: u32,
386
        expected_output_tokens: Option<u32>,
Yan Ru Pei's avatar
Yan Ru Pei committed
387
        worker: WorkerWithDpRank,
388
        lora_name: Option<String>,
389
390
    ) {
        let isl_tokens = tokens.len();
391

392
393
394
        let maybe_seq_hashes = self
            .kv_router_config
            .compute_seq_hashes_for_tracking(tokens, self.block_size);
395

396
397
        if let Err(e) = self
            .scheduler
398
            .add_request(
399
                request_id.clone(),
400
                maybe_seq_hashes,
401
402
                isl_tokens,
                overlap_blocks,
403
                expected_output_tokens,
Yan Ru Pei's avatar
Yan Ru Pei committed
404
                worker,
405
                lora_name,
406
            )
407
408
409
410
            .await
        {
            tracing::warn!("Failed to add request {request_id}: {e}");
        }
411
412
    }

413
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
414
        self.scheduler.mark_prefill_completed(request_id).await
415
416
    }

417
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
418
        self.scheduler.free(request_id).await
419
    }
420

421
422
423
424
425
426
    /// Get the worker type for this router ("prefill" or "decode").
    /// Used for Prometheus metric labeling.
    pub fn worker_type(&self) -> &'static str {
        self.scheduler.worker_type()
    }

427
428
429
430
431
432
433
434
435
436
    pub async fn add_output_block(
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
        self.scheduler
            .add_output_block(request_id, decay_fraction)
            .await
    }

437
    pub fn block_size(&self) -> u32 {
438
439
        self.block_size
    }
440

441
442
443
444
445
446
447
448
449
450
451
452
    /// Compute the overlap blocks for a given token sequence and worker.
    /// This queries the indexer to find how many blocks are already cached.
    pub async fn get_overlap_blocks(
        &self,
        tokens: &[u32],
        worker: WorkerWithDpRank,
    ) -> Result<u32, KvRouterError> {
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;
        Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0))
    }

453
454
455
    /// Get potential prefill and decode loads for all workers
    pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
        let isl_tokens = tokens.len();
456
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
457
        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
458

459
460
        let maybe_seq_hashes = self
            .kv_router_config
461
            .compute_seq_hashes_for_tracking(tokens, self.block_size);
462

463
464
        Ok(self
            .scheduler
465
            .get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)
466
467
468
            .await)
    }

469
470
471
472
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
473
474
}

Michael Feil's avatar
Michael Feil committed
475
476
// NOTE: KVRouter works like a PushRouter,
// but without the reverse proxy functionality, but based on contract of 3 request types
477
478
479
480
481
482
483
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
Michael Feil's avatar
Michael Feil committed
484
485
486
        let context_id = ctx.context().id().to_string();
        // Handle different request types
        let response = match request {
487
            RouterRequest::New { tokens } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
488
                let (best_worker, overlap_blocks) = self
489
                    .find_best_match(Some(&context_id), &tokens, None, true, None)
Michael Feil's avatar
Michael Feil committed
490
491
492
                    .await?;

                RouterResponse::New {
Yan Ru Pei's avatar
Yan Ru Pei committed
493
494
                    worker_id: best_worker.worker_id,
                    dp_rank: best_worker.dp_rank,
Michael Feil's avatar
Michael Feil committed
495
496
497
                    overlap_blocks,
                }
            }
498
499
500
501
502
503
            RouterRequest::MarkPrefill => RouterResponse::PrefillMarked {
                success: self.mark_prefill_completed(&context_id).await.is_ok(),
            },
            RouterRequest::MarkFree => RouterResponse::FreeMarked {
                success: self.free(&context_id).await.is_ok(),
            },
Michael Feil's avatar
Michael Feil committed
504
        };
505
506
507
508
509
510

        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
511

Yan Ru Pei's avatar
Yan Ru Pei committed
512
513
514
515
516
517
impl Drop for KvRouter {
    fn drop(&mut self) {
        tracing::info!("Dropping KvRouter - cancelling background tasks");
        self.cancellation_token.cancel();
    }
}