kv_router.rs 14.2 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use std::sync::Arc;
5
use std::time::Duration;
6

7
use anyhow::Result;
8
use dynamo_runtime::{
9
    component::{Component, InstanceSource},
10
    pipeline::{
11
12
        async_trait, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter,
        ResponseStream, SingleIn,
13
14
15
16
17
    },
    prelude::*,
    protocols::annotated::Annotated,
};
use futures::stream::{self, StreamExt};
18
use tokio::sync::Mutex;
19

20
pub mod approx;
21
pub mod indexer;
22
pub mod metrics_aggregator;
23
24
pub mod protocols;
pub mod publisher;
25
pub mod recorder;
26
27
pub mod scheduler;
pub mod scoring;
28
pub mod sequence;
29

30
31
use crate::{
    kv_router::{
32
33
34
35
36
        approx::ApproxKvIndexer,
        indexer::{
            compute_block_hash_for_seq, KvIndexer, KvIndexerInterface, KvRouterError,
            OverlapScores, RouterEvent,
        },
37
        metrics_aggregator::EndpointCollector,
38
39
40
41
        protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
        scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
        scoring::ProcessedEndpoints,
    },
42
    preprocessor::PreprocessedRequest,
43
    protocols::common::llm_backend::LLMEngineOutput,
44
45
};

46
use dynamo_runtime::traits::events::EventSubscriber;
47
48
49

// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
50
pub const KV_EVENT_SUBJECT: &str = "kv_events";
51
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
52
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";
53

54
55
56
57
58
59
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
        workers: &ProcessedEndpoints,
        request: &SchedulingRequest,
60
        block_size: u32,
61
62
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
63

64
/// KV Router configuration parameters
65
#[derive(Debug, Clone, Copy)]
66
67
68
pub struct KvRouterConfig {
    pub overlap_score_weight: f64,

69
    pub router_temperature: f64,
70

71
72
    pub use_kv_events: bool,

73
74
    // note: this is not actually used for now
    pub max_num_batched_tokens: u32,
75
76
77
78
79
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
80
            overlap_score_weight: 1.0,
81
            router_temperature: 0.0,
82
            use_kv_events: true,
83
            max_num_batched_tokens: 8192,
84
85
86
87
88
89
90
91
92
        }
    }
}

impl KvRouterConfig {
    /// Create a new KvRouterConfig with optional weight values.
    /// If a weight is None, the default value will be used.
    pub fn new(
        overlap_score_weight: Option<f64>,
93
        temperature: Option<f64>,
94
        use_kv_events: Option<bool>,
95
        max_num_batched_tokens: Option<u32>,
96
97
98
99
    ) -> Self {
        let default = Self::default();
        Self {
            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
100
            router_temperature: temperature.unwrap_or(default.router_temperature),
101
            use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
102
103
            max_num_batched_tokens: max_num_batched_tokens
                .unwrap_or(default.max_num_batched_tokens),
104
105
106
107
        }
    }
}

108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
// TODO: is there a way (macro) to auto-derive the KvIndexerInterface trait for this
// since both variants implement it
pub enum Indexer {
    KvIndexer(KvIndexer),
    ApproxKvIndexer(ApproxKvIndexer),
}

impl Indexer {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
            Indexer::ApproxKvIndexer(indexer) => indexer.find_matches(sequence).await,
        }
    }
}

127
128
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
129
pub struct KvRouter {
130
131
132
    indexer: Indexer,

    // How about a Box<dyn KvIndexerInterface>
133
    scheduler: KvScheduler,
134

135
    block_size: u32,
136
137
138
139

    // To ensure blocking reads / writes
    // TODO: benchmark tradeoffs
    find_best_match_mutex: Mutex<()>,
140
141
142
143
}

impl KvRouter {
    pub async fn new(
144
        component: Component,
145
        block_size: u32,
146
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
147
        use_kv_events: bool,
148
    ) -> Result<Self> {
149
150
151
152
153
        let cancellation_token = component
            .drt()
            .primary_lease()
            .expect("Cannot KV route static workers")
            .primary_token();
154
        let metrics_aggregator =
155
156
            EndpointCollector::new(component.clone(), cancellation_token.clone()).await;

157
158
159
160
161
162
163
164
165
166
        let indexer = if use_kv_events {
            Indexer::KvIndexer(KvIndexer::new(cancellation_token.clone(), block_size))
        } else {
            // hard code 120 seconds for now
            Indexer::ApproxKvIndexer(ApproxKvIndexer::new(
                cancellation_token.clone(),
                block_size,
                Duration::from_secs(120),
            ))
        };
167

168
169
170
171
172
173
174
        let scheduler = KvScheduler::start(
            component.namespace().clone(),
            block_size,
            metrics_aggregator.endpoints_watcher(),
            selector,
        )
        .await?;
175

176
177
        // [gluo TODO] try subscribe_with_type::<RouterEvent>,
        // error checking below will be different.
178
        if let Indexer::KvIndexer(ref kv_indexer) = indexer {
179
            let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?;
180
            let kv_events_tx = kv_indexer.event_sender();
181
182
183
184
185
186
187
188
189
190
191
192
193

            tokio::spawn(async move {
                while let Some(event) = kv_events_rx.next().await {
                    let event: RouterEvent = match serde_json::from_slice(&event.payload) {
                        Ok(event) => event,
                        Err(e) => {
                            tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
                            // Choosing warn and continue to process other events from other workers
                            // A bad event likely signals a problem with a worker, but potentially other workers are still healthy
                            continue;
                        }
                    };
                    if let Err(e) = kv_events_tx.send(event).await {
194
                        tracing::warn!(
195
196
197
                            "failed to send kv event to indexer; shutting down: {:?}",
                            e
                        );
Alec's avatar
Alec committed
198
                    }
199
                }
200
201
            });
        }
202

203
        tracing::info!("KV Routing initialized");
204
        Ok(Self {
205
            indexer,
206
            scheduler,
207
            block_size,
208
            find_best_match_mutex: Mutex::new(()), // Add this
209
        })
210
211
    }

212
    /// Give these tokens, find the worker with the best match in it's KV cache.
213
    /// Returned overlap amount is in number of blocks.
214
215
216
217
218
219
    /// Now also takes context_id for request tracking
    async fn find_best_match(
        &self,
        context_id: &str,
        tokens: &[u32],
    ) -> anyhow::Result<(i64, u32)> {
220
221
222
223
        // Acquire mutex to serialize access
        // TODO: may as well make all the subroutines synchronous if benchmarking favors this
        let _guard = self.find_best_match_mutex.lock().await;

224
        let isl_tokens = tokens.len();
225
226
        let block_size = self.block_size;

227
228
        let local_block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let overlap_scores = self.indexer.find_matches(local_block_hashes).await?;
229
230

        let best_worker_id = self
231
            .scheduler
232
233
234
235
236
237
238
            .schedule(
                context_id.to_string(),
                isl_tokens,
                block_size,
                tokens,
                overlap_scores.clone(),
            )
239
            .await?;
240

241
242
243
244
245
246
247
        if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
            indexer
                .process_routing_decision_for_request(tokens, best_worker_id)
                .await
                .unwrap();
        };

248
249
250
251
252
253
254
255
        let overlap_amount = overlap_scores
            .scores
            .get(&best_worker_id)
            .copied()
            .unwrap_or(0);
        Ok((best_worker_id, overlap_amount))
    }

256
257
258
    /// Push tokens to a specific request's sequence
    pub async fn push(&self, request_id: &String, tokens: &[u32]) {
        self.scheduler.push(request_id, tokens).await
259
260
261
262
263
    }

    /// Free all blocks associated with a request
    pub async fn free(&self, request_id: &String) {
        self.scheduler.free(request_id).await
264
    }
265
266

    /// Get the block size this router was configured with
267
    pub fn block_size(&self) -> u32 {
268
269
        self.block_size
    }
270
271
}

272
// NOTE: this would not be usable for now, should deprecate
273
274
275
276
277
278
279
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
280
        let (worker_id, _) = self.find_best_match(ctx.id(), &request.tokens).await?;
281
282
283
284
285
286
287

        let response = RouterResponse { worker_id };
        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
288
289

pub struct KvPushRouter {
290
    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
291
292
293
294
295
    chooser: Arc<KvRouter>,
}

impl KvPushRouter {
    pub fn new(
296
        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
297
298
299
300
301
302
303
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
}

#[async_trait]
304
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
305
306
307
308
    for KvPushRouter
{
    async fn generate(
        &self,
309
        request: SingleIn<PreprocessedRequest>,
310
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
311
        match self.inner.client.instance_source.as_ref() {
312
313
            InstanceSource::Static => self.inner.r#static(request).await,
            InstanceSource::Dynamic(_) => {
314
315
316
317
318
319
                // Extract context ID for request tracking
                let context_id = request.context().id().to_string();
                let (instance_id, overlap_amount) = self
                    .chooser
                    .find_best_match(&context_id, &request.token_ids)
                    .await?;
320
321
322
                let query_instance_id = request.has_annotation("query_instance_id");
                // Extract context information before moving the request
                let stream_context = request.context().clone();
323
324
                // Update the request with the estimated prefix hit blocks
                let (mut backend_input, context) = request.into_parts();
325
                let isl = backend_input.token_ids.len();
326
327
                backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
                let updated_request = context.map(|_| backend_input);
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
                // if request has the annotation "query_instance_id", for example
                // curl -d '{... ,"nvext": { "annotations": ["query_instance_id"]}}'
                // request will not be routed to worker immediately
                if query_instance_id {
                    let instance_id_str = instance_id.to_string();
                    let response =
                        Annotated::from_annotation("worker_instance_id", &instance_id_str)?;
                    let stream = stream::iter(vec![response]);
                    Ok(ResponseStream::new(Box::pin(stream), stream_context))
                } else {
                    // Get the response stream from the worker
                    let mut response_stream =
                        self.inner.direct(updated_request, instance_id).await?;

                    // Wrap the stream to track tokens
                    let stream_context = response_stream.context();
                    let chooser = self.chooser.clone();
                    let request_id = context_id.clone();
                    let block_size = chooser.block_size() as usize;

                    let wrapped_stream = Box::pin(async_stream::stream! {
                        let mut accumulated_tokens = Vec::new();
                        let mut total_output_length = 0usize;
                        let mut last_block_index = (isl.saturating_sub(1)) / block_size;
                        let mut first_push_done = false;

                        while let Some(item) = response_stream.next().await {
                            // Track tokens if they exist in the response
                            let Some(ref output) = item.data else {
                                yield item;
                                continue;
                            };
                            if output.token_ids.is_empty() {
                                yield item;
                                continue;
                            }
364

365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
                            // Add tokens to accumulator
                            accumulated_tokens.extend_from_slice(&output.token_ids);
                            total_output_length += output.token_ids.len();

                            // Always push for the first generated token (to mark prefill done)
                            // or when we've moved to a new block
                            let current_block_index = (isl + total_output_length).saturating_sub(1) / block_size;
                            let should_push = (!first_push_done && total_output_length >= 1) ||
                                        (first_push_done && current_block_index > last_block_index);

                            if should_push {
                                chooser.push(&request_id, &accumulated_tokens).await;
                                accumulated_tokens.clear();
                                last_block_index = current_block_index;
                                if !first_push_done {
                                    first_push_done = true;
                                }
                            }
383
384
385
386

                            yield item;
                        }

387
388
389
390
                        chooser.free(&request_id).await;
                    });
                    Ok(ResponseStream::new(wrapped_stream, stream_context))
                }
391
392
393
394
            }
        }
    }
}