kv_router.rs 18.8 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
use std::sync::Arc;
6
use std::time::Duration;
7

8
use anyhow::Result;
9
use derive_builder::Builder;
10
use dynamo_runtime::{
11
    component::{Component, InstanceSource},
12
    pipeline::{
13
14
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
        SingleIn, async_trait,
15
16
17
    },
    prelude::*,
    protocols::annotated::Annotated,
18
    utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction},
19
20
};
use futures::stream::{self, StreamExt};
21
use serde::{Deserialize, Serialize};
22

23
pub mod approx;
24
pub mod indexer;
25
pub mod metrics_aggregator;
26
pub mod prefill_counter;
27
28
pub mod protocols;
pub mod publisher;
29
pub mod recorder;
30
31
pub mod scheduler;
pub mod scoring;
32
pub mod sequence;
33
pub mod subscriber;
34

35
use crate::{
36
    discovery::{MODEL_ROOT_PATH, ModelEntry},
37
    kv_router::{
38
39
        approx::ApproxKvIndexer,
        indexer::{
40
41
            KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
            compute_block_hash_for_seq, compute_seq_hash_for_block,
42
        },
43
        protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
44
        scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
45
        scoring::ProcessedEndpoints,
46
        subscriber::start_kv_router_background,
47
    },
48
    local_model::runtime_config::ModelRuntimeConfig,
49
    preprocessor::PreprocessedRequest,
50
    protocols::common::llm_backend::LLMEngineOutput,
51
52
};

53
54
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
55
56
57
58
59

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
60
pub const KV_EVENT_SUBJECT: &str = "kv_events";
61
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
62
63
64
65
66
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
67

68
69
70
71
72
73
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";
pub const ROUTER_SNAPSHOT_LOCK: &str = "router-snapshot-lock";
pub const ROUTER_CLEANUP_LOCK: &str = "router-cleanup-lock";

74
75
76
77
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
78
        workers: &HashMap<i64, Option<ModelRuntimeConfig>>,
79
        request: &SchedulingRequest,
80
        block_size: u32,
81
82
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
83

84
85
86
87
88
89
90
91
92
93
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize)]
pub struct RouterConfigOverride {
    #[builder(default)]
    pub overlap_score_weight: Option<f64>,

    #[builder(default)]
    pub router_temperature: Option<f64>,
}

94
/// KV Router configuration parameters
95
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
96
97
98
pub struct KvRouterConfig {
    pub overlap_score_weight: f64,

99
    pub router_temperature: f64,
100

101
102
    pub use_kv_events: bool,

103
104
    pub router_replica_sync: bool,

105
106
    // TODO: this is not actually used for now
    // Would need this (along with total kv blocks) to trigger AllWorkersBusy error for e.g. rate-limiting
107
    pub max_num_batched_tokens: u32,
108
109
110
111

    /// Threshold for triggering snapshots. If None, no snapshots will be performed.
    pub router_snapshot_threshold: Option<u32>,

112
    /// Whether to reset the router state on startup (default: false)
113
    pub router_reset_states: bool,
114
115
116
117
118
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
119
            overlap_score_weight: 1.0,
120
            router_temperature: 0.0,
121
            use_kv_events: true,
122
            router_replica_sync: false,
123
            max_num_batched_tokens: 8192,
124
            router_snapshot_threshold: Some(10000),
125
            router_reset_states: false,
126
127
128
129
130
131
132
133
134
        }
    }
}

impl KvRouterConfig {
    /// Create a new KvRouterConfig with optional weight values.
    /// If a weight is None, the default value will be used.
    pub fn new(
        overlap_score_weight: Option<f64>,
135
        temperature: Option<f64>,
136
        use_kv_events: Option<bool>,
137
        replica_sync: Option<bool>,
138
        max_num_batched_tokens: Option<u32>,
139
140
        router_snapshot_threshold: Option<Option<u32>>,
        router_reset_states: Option<bool>,
141
142
143
144
    ) -> Self {
        let default = Self::default();
        Self {
            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
145
            router_temperature: temperature.unwrap_or(default.router_temperature),
146
            use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
147
            router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
148
149
            max_num_batched_tokens: max_num_batched_tokens
                .unwrap_or(default.max_num_batched_tokens),
150
151
152
            router_snapshot_threshold: router_snapshot_threshold
                .unwrap_or(default.router_snapshot_threshold),
            router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
153
154
155
156
        }
    }
}

157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
// TODO: is there a way (macro) to auto-derive the KvIndexerInterface trait for this
// since both variants implement it
pub enum Indexer {
    KvIndexer(KvIndexer),
    ApproxKvIndexer(ApproxKvIndexer),
}

impl Indexer {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
            Indexer::ApproxKvIndexer(indexer) => indexer.find_matches(sequence).await,
        }
    }
174
175
176
177
178
179
180

    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.dump_events().await,
            Indexer::ApproxKvIndexer(indexer) => indexer.dump_events().await,
        }
    }
181
182
}

183
184
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
185
pub struct KvRouter {
186
187
188
    indexer: Indexer,

    // How about a Box<dyn KvIndexerInterface>
189
    scheduler: KvScheduler,
190

191
    block_size: u32,
192
193
194
195
}

impl KvRouter {
    pub async fn new(
196
        component: Component,
197
        block_size: u32,
198
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
199
        kv_router_config: Option<KvRouterConfig>,
200
        consumer_uuid: String,
201
    ) -> Result<Self> {
202
203
        let kv_router_config = kv_router_config.unwrap_or_default();

204
205
206
207
208
        let cancellation_token = component
            .drt()
            .primary_lease()
            .expect("Cannot KV route static workers")
            .primary_token();
209
210
211
212
213
214
215
216
217
218

        let generate_endpoint = component.endpoint("generate");
        let client = generate_endpoint.client().await?;

        let instances_rx = match client.instance_source.as_ref() {
            InstanceSource::Dynamic(rx) => rx.clone(),
            InstanceSource::Static => {
                panic!("Expected dynamic instance source for KV routing");
            }
        };
219

220
        // Create runtime config watcher using the generic etcd watcher
221
222
223
224
225
        // TODO: Migrate to discovery_client() once it exposes kv_get_and_watch_prefix functionality
        let etcd_client = component
            .drt()
            .etcd_client()
            .expect("Cannot KV route without etcd client");
226
227
228
229
230
231
232
233
234
235

        let runtime_configs_watcher = watch_prefix_with_extraction(
            etcd_client,
            MODEL_ROOT_PATH,
            key_extractors::lease_id,
            |model_entry: ModelEntry| model_entry.runtime_config,
            cancellation_token.clone(),
        )
        .await?;
        let runtime_configs_rx = runtime_configs_watcher.receiver();
236

237
        let indexer = if kv_router_config.use_kv_events {
238
239
240
241
242
243
244
245
246
            Indexer::KvIndexer(KvIndexer::new(cancellation_token.clone(), block_size))
        } else {
            // hard code 120 seconds for now
            Indexer::ApproxKvIndexer(ApproxKvIndexer::new(
                cancellation_token.clone(),
                block_size,
                Duration::from_secs(120),
            ))
        };
247

248
        let scheduler = KvScheduler::start(
249
            component.clone(),
250
            block_size,
251
            instances_rx,
252
            runtime_configs_rx,
253
            selector,
254
            kv_router_config.router_replica_sync,
255
256
        )
        .await?;
257

258
        // Start unified background process if using KvIndexer
259
        if let Indexer::KvIndexer(ref kv_indexer) = indexer {
260
261
262
263
264
265
266
267
268
269
270
271
            start_kv_router_background(
                component.clone(),
                consumer_uuid,
                kv_indexer.event_sender(),
                kv_router_config
                    .router_snapshot_threshold
                    .map(|_| kv_indexer.snapshot_event_sender()),
                cancellation_token.clone(),
                kv_router_config.router_snapshot_threshold,
                kv_router_config.router_reset_states,
            )
            .await?;
272
        }
273

274
        tracing::info!("KV Routing initialized");
275
        Ok(Self {
276
            indexer,
277
            scheduler,
278
            block_size,
279
        })
280
281
    }

282
    /// Give these tokens, find the worker with the best match in it's KV cache.
283
    /// Returned overlap amount is in number of blocks.
284
285
286
287
288
    /// Now also takes context_id for request tracking
    async fn find_best_match(
        &self,
        context_id: &str,
        tokens: &[u32],
289
        router_config_override: Option<&RouterConfigOverride>,
290
        update_states: bool,
291
    ) -> anyhow::Result<(i64, u32)> {
292
        let isl_tokens = tokens.len();
293

294
295
296
297
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let seq_hashes = compute_seq_hash_for_block(&block_hashes);

        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
298
299

        let best_worker_id = self
300
            .scheduler
301
302
303
            .schedule(
                context_id.to_string(),
                isl_tokens,
304
                seq_hashes.clone(),
305
                overlap_scores.clone(),
306
                router_config_override,
307
                update_states,
308
            )
309
            .await?;
310

311
312
        if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
            indexer
313
                .process_routing_decision(best_worker_id, block_hashes, seq_hashes)
314
315
316
317
                .await
                .unwrap();
        };

318
319
320
321
322
323
324
325
        let overlap_amount = overlap_scores
            .scores
            .get(&best_worker_id)
            .copied()
            .unwrap_or(0);
        Ok((best_worker_id, overlap_amount))
    }

326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
    pub async fn add_request(
        &self,
        request_id: String,
        tokens: &[u32],
        overlap_blocks: u32,
        worker_id: i64,
    ) {
        let isl_tokens = tokens.len();
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let seq_hashes = compute_seq_hash_for_block(&block_hashes);

        self.scheduler
            .add_request(
                request_id,
                seq_hashes,
                isl_tokens,
                overlap_blocks,
                worker_id,
            )
            .await;
    }

348
    pub async fn mark_prefill_completed(&self, request_id: &str) {
349
        self.scheduler.mark_prefill_completed(request_id).await
350
351
    }

352
    pub async fn free(&self, request_id: &str) {
353
        self.scheduler.free(request_id).await
354
    }
355

356
    pub fn block_size(&self) -> u32 {
357
358
        self.block_size
    }
359

360
361
362
363
364
365
366
367
368
369
370
371
372
    /// Get potential prefill and decode loads for all workers
    pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
        let isl_tokens = tokens.len();
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let seq_hashes = compute_seq_hash_for_block(&block_hashes);
        let overlap_scores = self.indexer.find_matches(block_hashes).await?;

        Ok(self
            .scheduler
            .get_potential_loads(seq_hashes, isl_tokens, overlap_scores)
            .await)
    }

373
374
375
376
    /// Dump all events from the indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.indexer.dump_events().await
    }
377
378
}

379
// NOTE: this would not be usable for now, should deprecate
380
381
382
383
384
385
386
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
387
        let (worker_id, _) = self
388
            .find_best_match(ctx.id(), &request.tokens, None, true)
389
            .await?;
390
391
392
393
394
395
396

        let response = RouterResponse { worker_id };
        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
397
398

pub struct KvPushRouter {
399
    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
400
401
402
403
404
    chooser: Arc<KvRouter>,
}

impl KvPushRouter {
    pub fn new(
405
        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
406
407
408
409
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
410

411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
    /// Find the best matching worker for the given tokens without updating states
    pub async fn find_best_match(
        &self,
        context_id: &str,
        tokens: &[u32],
        router_config_override: Option<&RouterConfigOverride>,
    ) -> Result<(i64, u32)> {
        self.chooser
            .find_best_match(context_id, tokens, router_config_override, false)
            .await
    }

    /// Get potential prefill and decode loads for all workers
    pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
        self.chooser.get_potential_loads(tokens).await
    }

428
429
430
431
    /// Dump all events from the KV router's indexer
    pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
        self.chooser.dump_events().await
    }
432
433
434
}

#[async_trait]
435
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
436
437
    for KvPushRouter
{
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
    /// Generate method that handles KV-aware routing with three distinct behaviors:
    ///
    /// 1. **If `query_instance_id` annotation is set**:
    ///    - Returns the best matching worker ID without routing the request
    ///    - Does NOT update any router local states
    ///    - Response includes worker_instance_id and token_data annotations
    ///
    /// 2. **If `backend_instance_id` is set in the request**:
    ///    - Routes directly to the specified backend instance
    ///    - DOES update router states to track this request (unless query_instance_id is also set)
    ///    - Bypasses the normal KV matching logic
    ///
    /// 3. **If neither are set (default behavior)**:
    ///    - Finds the best worker based on KV cache overlap
    ///    - Updates router states to track the request
    ///    - Routes to the selected worker
    ///
    /// The router state updates include tracking active sequences and managing
    /// prefill/completion lifecycle for proper KV cache management.
457
458
    async fn generate(
        &self,
459
        request: SingleIn<PreprocessedRequest>,
460
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
461
        match self.inner.client.instance_source.as_ref() {
462
463
            InstanceSource::Static => self.inner.r#static(request).await,
            InstanceSource::Dynamic(_) => {
464
465
                // Extract context ID for request tracking
                let context_id = request.context().id().to_string();
466
467
468
469

                // Check if this is a query_instance_id request first
                let query_instance_id = request.has_annotation("query_instance_id");

470
                let (instance_id, overlap_amount) = if let Some(id) = request.backend_instance_id {
471
472
473
474
475
476
                    // If instance_id is set, use it and manually add the request to track it
                    if !query_instance_id {
                        self.chooser
                            .add_request(context_id.clone(), &request.token_ids, 0, id)
                            .await;
                    }
477
478
479
480
                    (id, 0)
                } else {
                    // Otherwise, find the best match
                    self.chooser
481
482
483
484
                        .find_best_match(
                            &context_id,
                            &request.token_ids,
                            request.router_config_override.as_ref(),
485
                            !query_instance_id, // Don't update states if query_instance_id
486
                        )
487
488
489
                        .await?
                };

490
491
492
                // if request has the annotation "query_instance_id",
                // then the request will not be routed to the worker,
                // and instead the worker_instance_id will be returned.
493
494
495
496
497
                let stream_context = request.context().clone();
                if query_instance_id {
                    let instance_id_str = instance_id.to_string();
                    let response =
                        Annotated::from_annotation("worker_instance_id", &instance_id_str)?;
498
499
500
501
502
503
504
505
506

                    // Return the tokens in nvext.token_data format
                    let response_tokens =
                        Annotated::from_annotation("token_data", &request.token_ids)?;
                    tracing::trace!(
                        "Tokens requested in the response through the query_instance_id annotation: {:?}",
                        response_tokens
                    );
                    let stream = stream::iter(vec![response, response_tokens]);
507
508
                    return Ok(ResponseStream::new(Box::pin(stream), stream_context));
                }
509
510
511
                let (mut backend_input, context) = request.into_parts();
                backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
                let updated_request = context.map(|_| backend_input);
512

513
                let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
514
515
516
517
                let stream_context = response_stream.context();
                let chooser = self.chooser.clone();

                let wrapped_stream = Box::pin(async_stream::stream! {
518
519
520
521
                    if let Some(first_item) = response_stream.next().await {
                        chooser.mark_prefill_completed(&context_id).await;
                        yield first_item;
                    }
522
523
524
525
526

                    while let Some(item) = response_stream.next().await {
                        yield item;
                    }

527
                    chooser.free(&context_id).await;
528
529
                });
                Ok(ResponseStream::new(wrapped_stream, stream_context))
530
531
532
533
            }
        }
    }
}