kv_router.rs 11.3 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use anyhow::Result;
7
use dynamo_runtime::{
8
    component::{Component, InstanceSource},
9
    pipeline::{
10
11
        async_trait, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter,
        ResponseStream, SingleIn,
12
13
14
15
16
    },
    prelude::*,
    protocols::annotated::Annotated,
};
use futures::stream::{self, StreamExt};
17

18
pub mod approx;
19
pub mod indexer;
20
pub mod metrics_aggregator;
21
22
pub mod protocols;
pub mod publisher;
23
pub mod recorder;
24
25
pub mod scheduler;
pub mod scoring;
26
pub mod sequence;
27

28
29
30
use crate::{
    kv_router::{
        indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
31
        metrics_aggregator::EndpointCollector,
32
33
34
35
        protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
        scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
        scoring::ProcessedEndpoints,
    },
36
    preprocessor::PreprocessedRequest,
37
    protocols::common::llm_backend::LLMEngineOutput,
Ryan Olson's avatar
Ryan Olson committed
38
    tokens::TokenBlockSequence,
39
40
};

41
use dynamo_runtime::traits::events::EventSubscriber;
42
43
44

// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
45
pub const KV_EVENT_SUBJECT: &str = "kv_events";
46
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
47
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";
48

49
50
51
52
53
54
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
        workers: &ProcessedEndpoints,
        request: &SchedulingRequest,
55
        block_size: u32,
56
57
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
58

59
60
61
62
63
/// KV Router configuration parameters
#[derive(Debug, Clone)]
pub struct KvRouterConfig {
    pub overlap_score_weight: f64,

64
    pub router_temperature: f64,
65

66
67
    // note: this is not actually used for now
    pub max_num_batched_tokens: u32,
68
69
70
71
72
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
73
            overlap_score_weight: 1.0,
74
75
            router_temperature: 0.5,
            max_num_batched_tokens: 8192,
76
77
78
79
80
81
82
83
84
        }
    }
}

impl KvRouterConfig {
    /// Create a new KvRouterConfig with optional weight values.
    /// If a weight is None, the default value will be used.
    pub fn new(
        overlap_score_weight: Option<f64>,
85
86
        temperature: Option<f64>,
        max_num_batched_tokens: Option<u32>,
87
88
89
90
    ) -> Self {
        let default = Self::default();
        Self {
            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
91
92
93
            router_temperature: temperature.unwrap_or(default.router_temperature),
            max_num_batched_tokens: max_num_batched_tokens
                .unwrap_or(default.max_num_batched_tokens),
94
95
96
97
        }
    }
}

98
99
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
100
pub struct KvRouter {
101
    indexer: Option<KvIndexer>,
102
    scheduler: KvScheduler,
103
    block_size: u32,
104
105
106
107
}

impl KvRouter {
    pub async fn new(
108
        component: Component,
109
        block_size: u32,
110
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
111
        use_kv_events: bool,
112
    ) -> Result<Self> {
113
114
115
116
117
        let cancellation_token = component
            .drt()
            .primary_lease()
            .expect("Cannot KV route static workers")
            .primary_token();
118
        let metrics_aggregator =
119
120
121
122
123
            EndpointCollector::new(component.clone(), cancellation_token.clone()).await;

        let maybe_indexer =
            use_kv_events.then(|| KvIndexer::new(cancellation_token.clone(), block_size));

124
125
126
127
128
129
130
        let scheduler = KvScheduler::start(
            component.namespace().clone(),
            block_size,
            metrics_aggregator.endpoints_watcher(),
            selector,
        )
        .await?;
131

132
133
        // [gluo TODO] try subscribe_with_type::<RouterEvent>,
        // error checking below will be different.
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        if let Some(ref indexer) = maybe_indexer {
            let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?;
            let kv_events_tx = indexer.event_sender();

            tokio::spawn(async move {
                while let Some(event) = kv_events_rx.next().await {
                    let event: RouterEvent = match serde_json::from_slice(&event.payload) {
                        Ok(event) => event,
                        Err(e) => {
                            tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
                            // Choosing warn and continue to process other events from other workers
                            // A bad event likely signals a problem with a worker, but potentially other workers are still healthy
                            continue;
                        }
                    };
                    if let Err(e) = kv_events_tx.send(event).await {
                        tracing::debug!(
                            "failed to send kv event to indexer; shutting down: {:?}",
                            e
                        );
Alec's avatar
Alec committed
154
                    }
155
                }
156
157
            });
        }
158

159
        tracing::info!("KV Routing initialized");
160
        Ok(Self {
161
            indexer: maybe_indexer,
162
            scheduler,
163
            block_size,
164
        })
165
166
    }

167
    /// Give these tokens, find the worker with the best match in it's KV cache.
168
    /// Returned overlap amount is in number of blocks.
169
170
171
172
173
174
    /// Now also takes context_id for request tracking
    async fn find_best_match(
        &self,
        context_id: &str,
        tokens: &[u32],
    ) -> anyhow::Result<(i64, u32)> {
175
        let isl_tokens = tokens.len();
176
177
        let block_size = self.block_size;

Ryan Olson's avatar
Ryan Olson committed
178
        let (complete_blocks, _partial_block) =
179
            TokenBlockSequence::split_tokens(tokens, block_size, 1337_u64);
Ryan Olson's avatar
Ryan Olson committed
180
181
182
183
184

        let local_block_hashes = complete_blocks
            .into_iter()
            .map(|block| LocalBlockHash(block.block_hash()))
            .collect();
185
186
187
188
189
190
        let overlap_scores = match &self.indexer {
            Some(indexer) => indexer.find_matches(local_block_hashes).await?,
            None => Default::default(), // Returns empty/default instance
        };

        let best_worker_id = self
191
            .scheduler
192
193
194
195
196
197
198
            .schedule(
                context_id.to_string(),
                isl_tokens,
                block_size,
                tokens,
                overlap_scores.clone(),
            )
199
            .await?;
200
201
202
203
204
205
206
207
208

        let overlap_amount = overlap_scores
            .scores
            .get(&best_worker_id)
            .copied()
            .unwrap_or(0);
        Ok((best_worker_id, overlap_amount))
    }

209
210
211
    /// Push tokens to a specific request's sequence
    pub async fn push(&self, request_id: &String, tokens: &[u32]) {
        self.scheduler.push(request_id, tokens).await
212
213
214
215
216
    }

    /// Free all blocks associated with a request
    pub async fn free(&self, request_id: &String) {
        self.scheduler.free(request_id).await
217
    }
218
219

    /// Get the block size this router was configured with
220
    pub fn block_size(&self) -> u32 {
221
222
        self.block_size
    }
223
224
}

225
// NOTE: this would not be usable for now, should deprecate
226
227
228
229
230
231
232
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
233
        let (worker_id, _) = self.find_best_match(ctx.id(), &request.tokens).await?;
234
235
236
237
238
239
240

        let response = RouterResponse { worker_id };
        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
241
242

pub struct KvPushRouter {
243
    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
244
245
246
247
248
    chooser: Arc<KvRouter>,
}

impl KvPushRouter {
    pub fn new(
249
        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
250
251
252
253
254
255
256
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
}

#[async_trait]
257
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
258
259
260
261
    for KvPushRouter
{
    async fn generate(
        &self,
262
        request: SingleIn<PreprocessedRequest>,
263
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
264
        match self.inner.client.instance_source.as_ref() {
265
266
            InstanceSource::Static => self.inner.r#static(request).await,
            InstanceSource::Dynamic(_) => {
267
268
269
270
271
272
273
                // Extract context ID for request tracking
                let context_id = request.context().id().to_string();

                let (instance_id, overlap_amount) = self
                    .chooser
                    .find_best_match(&context_id, &request.token_ids)
                    .await?;
274
275
                // Update the request with the estimated prefix hit blocks
                let (mut backend_input, context) = request.into_parts();
276
                let isl = backend_input.token_ids.len();
277
278
                backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
                let updated_request = context.map(|_| backend_input);
279
280
281
282
283
284
285
286

                // Get the response stream from the worker
                let mut response_stream = self.inner.direct(updated_request, instance_id).await?;

                // Wrap the stream to track tokens
                let stream_context = response_stream.context();
                let chooser = self.chooser.clone();
                let request_id = context_id.clone();
287
                let block_size = chooser.block_size() as usize;
288
289

                let wrapped_stream = Box::pin(async_stream::stream! {
290
291
292
293
                    let mut accumulated_tokens = Vec::new();
                    let mut total_output_length = 0usize;
                    let mut last_block_index = (isl.saturating_sub(1)) / block_size;

294
295
                    while let Some(item) = response_stream.next().await {
                        // Track tokens if they exist in the response
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
                        let Some(ref output) = item.data else {
                            yield item;
                            continue;
                        };
                        if output.token_ids.is_empty() {
                            yield item;
                            continue;
                        }

                        // Add tokens to accumulator
                        accumulated_tokens.extend_from_slice(&output.token_ids);
                        total_output_length += output.token_ids.len();

                        // Check if we've moved to a new block
                        let current_block_index = (isl + total_output_length).saturating_sub(1) / block_size;
                        if current_block_index > last_block_index {
                            chooser.push(&request_id, &accumulated_tokens).await;
                            accumulated_tokens.clear();
                            last_block_index = current_block_index;
315
                        }
316

317
318
                        yield item;
                    }
319

320
321
322
323
                    chooser.free(&request_id).await;
                });

                Ok(ResponseStream::new(wrapped_stream, stream_context))
324
325
326
327
            }
        }
    }
}