kv_router.rs 14.5 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
use std::sync::Arc;
6
use std::time::Duration;
7

8
use anyhow::Result;
9
use derive_builder::Builder;
10
use dynamo_runtime::{
11
    component::{Component, InstanceSource},
12
    pipeline::{
13
14
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
        SingleIn, async_trait,
15
16
17
18
19
    },
    prelude::*,
    protocols::annotated::Annotated,
};
use futures::stream::{self, StreamExt};
20
use serde::{Deserialize, Serialize};
21

22
pub mod approx;
23
pub mod indexer;
24
pub mod metrics_aggregator;
25
pub mod prefill_counter;
26
27
pub mod protocols;
pub mod publisher;
28
pub mod recorder;
29
30
pub mod scheduler;
pub mod scoring;
31
pub mod sequence;
32

33
use crate::{
34
    discovery::{MODEL_ROOT_PATH, ModelEntry},
35
    kv_router::{
36
37
        approx::ApproxKvIndexer,
        indexer::{
38
39
            KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
            compute_block_hash_for_seq, compute_seq_hash_for_block,
40
        },
41
42
43
44
        protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
        scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
        scoring::ProcessedEndpoints,
    },
45
    local_model::runtime_config::ModelRuntimeConfig,
46
    preprocessor::PreprocessedRequest,
47
    protocols::common::llm_backend::LLMEngineOutput,
48
49
};

50
use dynamo_runtime::traits::events::EventSubscriber;
51
52
53

// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
54
55
56
57
58

// for metric scraping (pull-based)
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";

// for metric publishing (push-based)
59
pub const KV_EVENT_SUBJECT: &str = "kv_events";
60
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
61
62
63
64
65
pub const KV_METRICS_SUBJECT: &str = "kv_metrics";

// for inter-router comms
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
66

67
68
69
70
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
71
        workers: &HashMap<i64, Option<ModelRuntimeConfig>>,
72
        request: &SchedulingRequest,
73
        block_size: u32,
74
75
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
76

77
78
79
80
81
82
83
84
85
86
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize)]
pub struct RouterConfigOverride {
    #[builder(default)]
    pub overlap_score_weight: Option<f64>,

    #[builder(default)]
    pub router_temperature: Option<f64>,
}

87
/// KV Router configuration parameters
88
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
89
90
91
pub struct KvRouterConfig {
    pub overlap_score_weight: f64,

92
    pub router_temperature: f64,
93

94
95
    pub use_kv_events: bool,

96
97
    pub router_replica_sync: bool,

98
99
    // TODO: this is not actually used for now
    // Would need this (along with total kv blocks) to trigger AllWorkersBusy error for e.g. rate-limiting
100
    pub max_num_batched_tokens: u32,
101
102
103
104
105
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
106
            overlap_score_weight: 1.0,
107
            router_temperature: 0.0,
108
            use_kv_events: true,
109
            router_replica_sync: false,
110
            max_num_batched_tokens: 8192,
111
112
113
114
115
116
117
118
119
        }
    }
}

impl KvRouterConfig {
    /// Create a new KvRouterConfig with optional weight values.
    /// If a weight is None, the default value will be used.
    pub fn new(
        overlap_score_weight: Option<f64>,
120
        temperature: Option<f64>,
121
        use_kv_events: Option<bool>,
122
        replica_sync: Option<bool>,
123
        max_num_batched_tokens: Option<u32>,
124
125
126
127
    ) -> Self {
        let default = Self::default();
        Self {
            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
128
            router_temperature: temperature.unwrap_or(default.router_temperature),
129
            use_kv_events: use_kv_events.unwrap_or(default.use_kv_events),
130
            router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
131
132
            max_num_batched_tokens: max_num_batched_tokens
                .unwrap_or(default.max_num_batched_tokens),
133
134
135
136
        }
    }
}

137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
// TODO: is there a way (macro) to auto-derive the KvIndexerInterface trait for this
// since both variants implement it
pub enum Indexer {
    KvIndexer(KvIndexer),
    ApproxKvIndexer(ApproxKvIndexer),
}

impl Indexer {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        match self {
            Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
            Indexer::ApproxKvIndexer(indexer) => indexer.find_matches(sequence).await,
        }
    }
}

156
157
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
158
pub struct KvRouter {
159
160
161
    indexer: Indexer,

    // How about a Box<dyn KvIndexerInterface>
162
    scheduler: KvScheduler,
163

164
    block_size: u32,
165
166
167
168
}

impl KvRouter {
    pub async fn new(
169
        component: Component,
170
        block_size: u32,
171
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
172
        kv_router_config: Option<KvRouterConfig>,
173
    ) -> Result<Self> {
174
175
        let kv_router_config = kv_router_config.unwrap_or_default();

176
177
178
179
180
        let cancellation_token = component
            .drt()
            .primary_lease()
            .expect("Cannot KV route static workers")
            .primary_token();
181
182
183
184
185
186
187
188
189
190

        let generate_endpoint = component.endpoint("generate");
        let client = generate_endpoint.client().await?;

        let instances_rx = match client.instance_source.as_ref() {
            InstanceSource::Dynamic(rx) => rx.clone(),
            InstanceSource::Static => {
                panic!("Expected dynamic instance source for KV routing");
            }
        };
191

192
        // Create runtime config watcher using the generic etcd watcher
193
194
195
196
197
        // TODO: Migrate to discovery_client() once it exposes kv_get_and_watch_prefix functionality
        let etcd_client = component
            .drt()
            .etcd_client()
            .expect("Cannot KV route without etcd client");
198
199
200
201
202
203
204
205
206
207
208
209
210

        use dynamo_runtime::utils::typed_prefix_watcher::{
            key_extractors, watch_prefix_with_extraction,
        };
        let runtime_configs_watcher = watch_prefix_with_extraction(
            etcd_client,
            MODEL_ROOT_PATH,
            key_extractors::lease_id,
            |model_entry: ModelEntry| model_entry.runtime_config,
            cancellation_token.clone(),
        )
        .await?;
        let runtime_configs_rx = runtime_configs_watcher.receiver();
211

212
        let indexer = if kv_router_config.use_kv_events {
213
214
215
216
217
218
219
220
221
            Indexer::KvIndexer(KvIndexer::new(cancellation_token.clone(), block_size))
        } else {
            // hard code 120 seconds for now
            Indexer::ApproxKvIndexer(ApproxKvIndexer::new(
                cancellation_token.clone(),
                block_size,
                Duration::from_secs(120),
            ))
        };
222

223
        let scheduler = KvScheduler::start(
224
            component.clone(),
225
            block_size,
226
            instances_rx,
227
            runtime_configs_rx,
228
            selector,
229
            kv_router_config.router_replica_sync,
230
231
        )
        .await?;
232

233
234
        // [gluo TODO] try subscribe_with_type::<RouterEvent>,
        // error checking below will be different.
235
        if let Indexer::KvIndexer(ref kv_indexer) = indexer {
236
            let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?;
237
            let kv_events_tx = kv_indexer.event_sender();
238
239
240
241
242
243
244
245
246
247
248
249
250

            tokio::spawn(async move {
                while let Some(event) = kv_events_rx.next().await {
                    let event: RouterEvent = match serde_json::from_slice(&event.payload) {
                        Ok(event) => event,
                        Err(e) => {
                            tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
                            // Choosing warn and continue to process other events from other workers
                            // A bad event likely signals a problem with a worker, but potentially other workers are still healthy
                            continue;
                        }
                    };
                    if let Err(e) = kv_events_tx.send(event).await {
251
                        tracing::warn!(
252
253
254
                            "failed to send kv event to indexer; shutting down: {:?}",
                            e
                        );
Alec's avatar
Alec committed
255
                    }
256
                }
257
258
            });
        }
259

260
        tracing::info!("KV Routing initialized");
261
        Ok(Self {
262
            indexer,
263
            scheduler,
264
            block_size,
265
        })
266
267
    }

268
    /// Give these tokens, find the worker with the best match in it's KV cache.
269
    /// Returned overlap amount is in number of blocks.
270
271
272
273
274
    /// Now also takes context_id for request tracking
    async fn find_best_match(
        &self,
        context_id: &str,
        tokens: &[u32],
275
        router_config_override: Option<&RouterConfigOverride>,
276
    ) -> anyhow::Result<(i64, u32)> {
277
        let isl_tokens = tokens.len();
278

279
280
281
282
        let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
        let seq_hashes = compute_seq_hash_for_block(&block_hashes);

        let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
283
284

        let best_worker_id = self
285
            .scheduler
286
287
288
            .schedule(
                context_id.to_string(),
                isl_tokens,
289
                seq_hashes.clone(),
290
                overlap_scores.clone(),
291
                router_config_override,
292
            )
293
            .await?;
294

295
296
        if let Indexer::ApproxKvIndexer(ref indexer) = self.indexer {
            indexer
297
                .process_routing_decision(best_worker_id, block_hashes, seq_hashes)
298
299
300
301
                .await
                .unwrap();
        };

302
303
304
305
306
307
308
309
        let overlap_amount = overlap_scores
            .scores
            .get(&best_worker_id)
            .copied()
            .unwrap_or(0);
        Ok((best_worker_id, overlap_amount))
    }

310
    pub async fn mark_prefill_completed(&self, request_id: &str) {
311
        self.scheduler.mark_prefill_completed(request_id).await
312
313
    }

314
    pub async fn free(&self, request_id: &str) {
315
        self.scheduler.free(request_id).await
316
    }
317

318
    pub fn block_size(&self) -> u32 {
319
320
        self.block_size
    }
321
322
}

323
// NOTE: this would not be usable for now, should deprecate
324
325
326
327
328
329
330
#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
331
332
333
        let (worker_id, _) = self
            .find_best_match(ctx.id(), &request.tokens, None)
            .await?;
334
335
336
337
338
339
340

        let response = RouterResponse { worker_id };
        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
341
342

pub struct KvPushRouter {
343
    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
344
345
346
347
348
    chooser: Arc<KvRouter>,
}

impl KvPushRouter {
    pub fn new(
349
        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
350
351
352
353
354
355
356
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
}

#[async_trait]
357
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
358
359
360
361
    for KvPushRouter
{
    async fn generate(
        &self,
362
        request: SingleIn<PreprocessedRequest>,
363
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
364
        match self.inner.client.instance_source.as_ref() {
365
366
            InstanceSource::Static => self.inner.r#static(request).await,
            InstanceSource::Dynamic(_) => {
367
368
                // Extract context ID for request tracking
                let context_id = request.context().id().to_string();
369
370
371
372
373
374
                let (instance_id, overlap_amount) = if let Some(id) = request.backend_instance_id {
                    // If instance_id is set, use it
                    (id, 0)
                } else {
                    // Otherwise, find the best match
                    self.chooser
375
376
377
378
379
                        .find_best_match(
                            &context_id,
                            &request.token_ids,
                            request.router_config_override.as_ref(),
                        )
380
381
382
                        .await?
                };

383
384
385
                let query_instance_id = request.has_annotation("query_instance_id");
                // Extract context information before moving the request
                let stream_context = request.context().clone();
386
387
388
389
                // Update the request with the estimated prefix hit blocks
                let (mut backend_input, context) = request.into_parts();
                backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
                let updated_request = context.map(|_| backend_input);
390

391
392
393
394
395
396
397
398
                // if request has the annotation "query_instance_id", for example
                // curl -d '{... ,"nvext": { "annotations": ["query_instance_id"]}}'
                // request will not be routed to worker immediately
                if query_instance_id {
                    let instance_id_str = instance_id.to_string();
                    let response =
                        Annotated::from_annotation("worker_instance_id", &instance_id_str)?;
                    let stream = stream::iter(vec![response]);
399
400
401
                    return Ok(ResponseStream::new(Box::pin(stream), stream_context));
                }

402
                let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
403
404
405
406
                let stream_context = response_stream.context();
                let chooser = self.chooser.clone();

                let wrapped_stream = Box::pin(async_stream::stream! {
407
408
409
410
                    if let Some(first_item) = response_stream.next().await {
                        chooser.mark_prefill_completed(&context_id).await;
                        yield first_item;
                    }
411
412
413
414
415

                    while let Some(item) = response_stream.next().await {
                        yield item;
                    }

416
                    chooser.free(&context_id).await;
417
418
                });
                Ok(ResponseStream::new(wrapped_stream, stream_context))
419
420
421
422
            }
        }
    }
}