kv_router.rs 13.2 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use std::sync::Arc;
5
use std::time::Duration;
6

7
use anyhow::Result;
8
use dynamo_runtime::{
9
    component::{Component, InstanceSource},
10
    pipeline::{
11
12
        async_trait, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter,
        ResponseStream, SingleIn,
13
14
15
16
17
    },
    prelude::*,
    protocols::annotated::Annotated,
};
use futures::stream::{self, StreamExt};
18
use tokio::sync::Mutex;
19

20
pub mod approx;
21
pub mod indexer;
22
pub mod metrics_aggregator;
23
24
pub mod protocols;
pub mod publisher;
25
pub mod recorder;
26
27
pub mod scheduler;
pub mod scoring;
28
pub mod sequence;
29

30
31
use crate::{
    kv_router::{
32
33
34
35
36
        approx::ApproxKvIndexer,
        indexer::{
            compute_block_hash_for_seq, KvIndexer, KvIndexerInterface, KvRouterError,
            OverlapScores, RouterEvent,
        },
37
        metrics_aggregator::EndpointCollector,
38
39
40
41
        protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
        scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
        scoring::ProcessedEndpoints,
    },
42
    preprocessor::PreprocessedRequest,
43
    protocols::common::llm_backend::LLMEngineOutput,
44
45
};

46
use dynamo_runtime::traits::events::EventSubscriber;
47
48
49

// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
50
pub const KV_EVENT_SUBJECT: &str = "kv_events";
51
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
52
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";
53

54
55
56
57
58
59
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
        workers: &ProcessedEndpoints,
        request: &SchedulingRequest,
60
        block_size: u32,
61
62
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
63

64
/// KV Router configuration parameters
65
#[derive(Debug, Clone, Copy)]
66
67
68
pub struct KvRouterConfig {
    pub overlap_score_weight: f64,

69
    pub router_temperature: f64,
70

71
72
    pub use_kv_events: bool,

73
74
    // note: this is not actually used for now
    pub max_num_batched_tokens: u32,
75
76
77
78
79
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
80
            overlap_score_weight: 1.0,
81
            router_temperature: 0.0,
82
            use_kv_events: true,
83
            max_num_batched_tokens: 8192,
84
85
86
87
88
89
90
91
92
        }
    }
}

impl KvRouterConfig {
    /// Create a new KvRouterConfig with optional weight values.
    /// If a weight is None, the default value will be used.
    pub fn new(
        overlap_score_weight: Option<f64>,
93
        temperature: Option<f64>,
94
        use_kv_events: Option<bool>,
95
        max_num_batched_tokens: Option<u32>,
96
97
98
99
    ) -> Self {
        let default = Self::default();
        Self {
            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
100
            router_temperature: temperature.unwrap_or(default.router_temperature),
101
            use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
102
103
            max_num_batched_tokens: max_num_batched_tokens
                .unwrap_or(default.max_num_batched_tokens),
104
105
106
107
        }
    }
}

108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
// TODO: is there a way (macro) to auto-derive the KvIndexerInterface trait for this
// since both variants implement it
pub enum Indexer {
    KvIndexer(KvIndexer),
    ApproxKvIndexer(ApproxKvIndexer),
}

impl Indexer {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
            Indexer::ApproxKvIndexer(indexer) => indexer.find_matches(sequence).await,
        }
    }
}

127
128
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
129
pub struct KvRouter {
130
131
132
    indexer: Indexer,

    // How about a Box<dyn KvIndexerInterface>
133
    scheduler: KvScheduler,
134

135
    block_size: u32,
136
137
138
139

    // To ensure blocking reads / writes
    // TODO: benchmark tradeoffs
    find_best_match_mutex: Mutex<()>,
140
141
142
143
}

impl KvRouter {
    pub async fn new(
144
        component: Component,
145
        block_size: u32,
146
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
147
        use_kv_events: bool,
148
    ) -> Result<Self> {
149
150
151
152
153
        let cancellation_token = component
            .drt()
            .primary_lease()
            .expect("Cannot KV route static workers")
            .primary_token();
154
        let metrics_aggregator =
155
156
            EndpointCollector::new(component.clone(), cancellation_token.clone()).await;

157
158
159
160
161
162
163
164
165
166
        let indexer = if use_kv_events {
            Indexer::KvIndexer(KvIndexer::new(cancellation_token.clone(), block_size))
        } else {
            // hard code 120 seconds for now
            Indexer::ApproxKvIndexer(ApproxKvIndexer::new(
                cancellation_token.clone(),
                block_size,
                Duration::from_secs(120),
            ))
        };
167

168
169
170
171
172
173
174
        let scheduler = KvScheduler::start(
            component.namespace().clone(),
            block_size,
            metrics_aggregator.endpoints_watcher(),
            selector,
        )
        .await?;
175

176
177
        // [gluo TODO] try subscribe_with_type::<RouterEvent>,
        // error checking below will be different.
178
        if let Indexer::KvIndexer(ref kv_indexer) = indexer {
179
            let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?;
180
            let kv_events_tx = kv_indexer.event_sender();
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197

            tokio::spawn(async move {
                while let Some(event) = kv_events_rx.next().await {
                    let event: RouterEvent = match serde_json::from_slice(&event.payload) {
                        Ok(event) => event,
                        Err(e) => {
                            tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
                            // Choosing warn and continue to process other events from other workers
                            // A bad event likely signals a problem with a worker, but potentially other workers are still healthy
                            continue;
                        }
                    };
                    if let Err(e) = kv_events_tx.send(event).await {
                        tracing::debug!(
                            "failed to send kv event to indexer; shutting down: {:?}",
                            e
                        );
Alec's avatar
Alec committed
198
                    }
199
                }
200
201
            });
        }
202

203
        tracing::info!("KV Routing initialized");
204
        Ok(Self {
205
            indexer,
206
            scheduler,
207
            block_size,
208
            find_best_match_mutex: Mutex::new(()), // Add this
209
        })
210
211
    }

212
    /// Give these tokens, find the worker with the best match in it's KV cache.
213
    /// Returned overlap amount is in number of blocks.
214
215
216
217
218
219
    /// Now also takes context_id for request tracking
    async fn find_best_match(
        &self,
        context_id: &str,
        tokens: &[u32],
    ) -> anyhow::Result<(i64, u32)> {
220
221
222
223
        // Acquire mutex to serialize access
        // TODO: may as well make all the subroutines synchronous if benchmarking favors this
        let _guard = self.find_best_match_mutex.lock().await;

224
        let isl_tokens = tokens.len();
225
226
        let block_size = self.block_size;

227
228
        let local_block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let overlap_scores = self.indexer.find_matches(local_block_hashes).await?;
229
230

        let best_worker_id = self
231
            .scheduler
232
233
234
235
236
237
238
            .schedule(
                context_id.to_string(),
                isl_tokens,
                block_size,
                tokens,
                overlap_scores.clone(),
            )
239
            .await?;
240

241
242
243
244
245
246
247
        if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
            indexer
                .process_routing_decision_for_request(tokens, best_worker_id)
                .await
                .unwrap();
        };

248
249
250
251
252
253
254
255
        let overlap_amount = overlap_scores
            .scores
            .get(&best_worker_id)
            .copied()
            .unwrap_or(0);
        Ok((best_worker_id, overlap_amount))
    }

256
257
258
    /// Push tokens to a specific request's sequence
    pub async fn push(&self, request_id: &String, tokens: &[u32]) {
        self.scheduler.push(request_id, tokens).await
259
260
261
262
263
    }

    /// Free all blocks associated with a request
    pub async fn free(&self, request_id: &String) {
        self.scheduler.free(request_id).await
264
    }
265
266

    /// Get the block size this router was configured with
267
    pub fn block_size(&self) -> u32 {
268
269
        self.block_size
    }
270
271
}

272
// NOTE: this would not be usable for now, should deprecate
273
274
275
276
277
278
279
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
280
        let (worker_id, _) = self.find_best_match(ctx.id(), &request.tokens).await?;
281
282
283
284
285
286
287

        let response = RouterResponse { worker_id };
        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
288
289

pub struct KvPushRouter {
290
    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
291
292
293
294
295
    chooser: Arc<KvRouter>,
}

impl KvPushRouter {
    pub fn new(
296
        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
297
298
299
300
301
302
303
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
}

#[async_trait]
304
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
305
306
307
308
    for KvPushRouter
{
    async fn generate(
        &self,
309
        request: SingleIn<PreprocessedRequest>,
310
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
311
        match self.inner.client.instance_source.as_ref() {
312
313
            InstanceSource::Static => self.inner.r#static(request).await,
            InstanceSource::Dynamic(_) => {
314
315
316
317
318
319
320
                // Extract context ID for request tracking
                let context_id = request.context().id().to_string();

                let (instance_id, overlap_amount) = self
                    .chooser
                    .find_best_match(&context_id, &request.token_ids)
                    .await?;
321
322
                // Update the request with the estimated prefix hit blocks
                let (mut backend_input, context) = request.into_parts();
323
                let isl = backend_input.token_ids.len();
324
325
                backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
                let updated_request = context.map(|_| backend_input);
326
327
328
329
330
331
332
333

                // Get the response stream from the worker
                let mut response_stream = self.inner.direct(updated_request, instance_id).await?;

                // Wrap the stream to track tokens
                let stream_context = response_stream.context();
                let chooser = self.chooser.clone();
                let request_id = context_id.clone();
334
                let block_size = chooser.block_size() as usize;
335
336

                let wrapped_stream = Box::pin(async_stream::stream! {
337
338
339
                    let mut accumulated_tokens = Vec::new();
                    let mut total_output_length = 0usize;
                    let mut last_block_index = (isl.saturating_sub(1)) / block_size;
340
                    let mut first_push_done = false;
341

342
343
                    while let Some(item) = response_stream.next().await {
                        // Track tokens if they exist in the response
344
345
346
347
348
349
350
351
352
353
354
355
356
                        let Some(ref output) = item.data else {
                            yield item;
                            continue;
                        };
                        if output.token_ids.is_empty() {
                            yield item;
                            continue;
                        }

                        // Add tokens to accumulator
                        accumulated_tokens.extend_from_slice(&output.token_ids);
                        total_output_length += output.token_ids.len();

357
358
                        // Always push for the first generated token (to mark prefill done)
                        // or when we've moved to a new block
359
                        let current_block_index = (isl + total_output_length).saturating_sub(1) / block_size;
360
361
362
363
                        let should_push = (!first_push_done && total_output_length >= 1) ||
                                      (first_push_done && current_block_index > last_block_index);

                        if should_push {
364
365
366
                            chooser.push(&request_id, &accumulated_tokens).await;
                            accumulated_tokens.clear();
                            last_block_index = current_block_index;
367
368
369
                            if !first_push_done {
                                first_push_done = true;
                            }
370
                        }
371

372
373
                        yield item;
                    }
374

375
376
377
378
                    chooser.free(&request_id).await;
                });

                Ok(ResponseStream::new(wrapped_stream, stream_context))
379
380
381
382
            }
        }
    }
}