"docs/vscode:/vscode.git/clone" did not exist on "0b8b7ffb7337005f18a8a1611ee0d41213642a16"
kv_router.rs 8.95 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use anyhow::Result;
7
use dynamo_runtime::{
8
    component::{Component, InstanceSource},
9
    pipeline::{
10
11
        async_trait, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter,
        ResponseStream, SingleIn,
12
13
14
15
16
    },
    prelude::*,
    protocols::annotated::Annotated,
};
use futures::stream::{self, StreamExt};
17
18

pub mod indexer;
19
pub mod metrics_aggregator;
20
21
pub mod protocols;
pub mod publisher;
22
pub mod recorder;
23
24
pub mod scheduler;
pub mod scoring;
25

26
27
28
29
30
31
32
33
use crate::{
    kv_router::{
        indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
        metrics_aggregator::KvMetricsAggregator,
        protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
        scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
        scoring::ProcessedEndpoints,
    },
34
    preprocessor::PreprocessedRequest,
35
    protocols::common::llm_backend::LLMEngineOutput,
Ryan Olson's avatar
Ryan Olson committed
36
    tokens::TokenBlockSequence,
37
38
};

39
use dynamo_runtime::traits::events::EventSubscriber;
40
41
42

// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
43
pub const KV_EVENT_SUBJECT: &str = "kv_events";
44
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
45
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";
46

47
48
49
50
51
52
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
        workers: &ProcessedEndpoints,
        request: &SchedulingRequest,
53
        block_size: u32,
54
55
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
56

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
/// KV Router configuration parameters
#[derive(Debug, Clone)]
pub struct KvRouterConfig {
    /// Weight for overlap score in worker selection.
    /// Higher values prioritize KV cache reuse. Default: 2.0
    pub overlap_score_weight: f64,

    /// Weight for GPU cache usage in worker selection.
    /// Higher values avoid workers with nearly full KV caches. Default: 1.0
    pub gpu_cache_usage_weight: f64,

    /// Weight for waiting requests in worker selection.
    /// Higher values avoid workers with queued requests. Default: 1.0
    pub waiting_requests_weight: f64,
}

impl Default for KvRouterConfig {
    fn default() -> Self {
        Self {
76
            overlap_score_weight: 1.0,
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
            gpu_cache_usage_weight: 1.0,
            waiting_requests_weight: 1.0,
        }
    }
}

impl KvRouterConfig {
    /// Create a new KvRouterConfig with optional weight values.
    /// If a weight is None, the default value will be used.
    pub fn new(
        overlap_score_weight: Option<f64>,
        gpu_cache_usage_weight: Option<f64>,
        waiting_requests_weight: Option<f64>,
    ) -> Self {
        let default = Self::default();
        Self {
            overlap_score_weight: overlap_score_weight.unwrap_or(default.overlap_score_weight),
            gpu_cache_usage_weight: gpu_cache_usage_weight
                .unwrap_or(default.gpu_cache_usage_weight),
            waiting_requests_weight: waiting_requests_weight
                .unwrap_or(default.waiting_requests_weight),
        }
    }
}

102
103
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
104
pub struct KvRouter {
105
    indexer: KvIndexer,
106
    scheduler: KvScheduler,
107
    block_size: u32,
108
109
110
111
}

impl KvRouter {
    pub async fn new(
112
        component: Component,
113
        block_size: u32,
114
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
115
    ) -> Result<Self> {
116
117
118
119
120
        let cancellation_token = component
            .drt()
            .primary_lease()
            .expect("Cannot KV route static workers")
            .primary_token();
121
        tracing::info!("KV Routing initialized");
122
123
124
125
126
127
128
129
130
131
        let metrics_aggregator =
            KvMetricsAggregator::new(component.clone(), cancellation_token.clone()).await;
        let indexer = KvIndexer::new(cancellation_token.clone(), block_size);
        let scheduler = KvScheduler::start(
            component.namespace().clone(),
            block_size,
            metrics_aggregator.endpoints_watcher(),
            selector,
        )
        .await?;
132

133
134
135
        // [gluo TODO] try subscribe_with_type::<RouterEvent>,
        // error checking below will be different.
        let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?;
136
137
138
139
        let kv_events_tx = indexer.event_sender();

        tokio::spawn(async move {
            while let Some(event) = kv_events_rx.next().await {
Alec's avatar
Alec committed
140
                let event: RouterEvent = match serde_json::from_slice(&event.payload) {
141
                    Ok(event) => event,
Alec's avatar
Alec committed
142
143
144
145
146
147
148
                    Err(e) => {
                        tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
                        // Choosing warn and continue to process other events from other workers
                        // A bad event likely signals a problem with a worker, but potentially other workers are still healthy
                        continue;
                    }
                };
149
                if let Err(e) = kv_events_tx.send(event).await {
150
                    tracing::debug!("failed to send kv event to indexer; shutting down: {:?}", e);
151
152
153
154
                }
            }
        });

155
        Ok(Self {
156
157
            scheduler,
            indexer,
158
            block_size,
159
        })
160
161
162
    }

    // [TODO] indexer needs to take 'lora_id' as parameter
GuanLuo's avatar
GuanLuo committed
163
    pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<i64> {
164
165
166
167
168
169
170
        // Extracting part of the code in KvRouter::generate() for only
        // the decision making part, routing is done by the caller
        let isl_tokens = token_ids.len();
        let overlap_scores = self
            .indexer
            .find_matches_for_request(token_ids.as_slice())
            .await?;
Alec's avatar
Alec committed
171
        tracing::debug!("KV router overlap_scores: {:?}", overlap_scores);
GuanLuo's avatar
GuanLuo committed
172
173
        let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
        Ok(worker_id)
174
    }
175

176
    /// Give these tokens, find the worker with the best match in it's KV cache.
177
178
    /// Returned overlap amount is in number of blocks.
    async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<(i64, u32)> {
179
        let isl_tokens = tokens.len();
180
181
        let block_size = self.block_size;

Ryan Olson's avatar
Ryan Olson committed
182
        let (complete_blocks, _partial_block) =
183
            TokenBlockSequence::split_tokens(tokens, block_size, 1337_u64);
Ryan Olson's avatar
Ryan Olson committed
184
185
186
187
188

        let local_block_hashes = complete_blocks
            .into_iter()
            .map(|block| LocalBlockHash(block.block_hash()))
            .collect();
189
        let overlap_scores = self.indexer.find_matches(local_block_hashes).await?;
190
191
192
193
194
195
        let worker_id = self
            .scheduler
            .schedule(overlap_scores.clone(), isl_tokens)
            .await?;
        let overlap_amount = overlap_scores.scores.get(&worker_id).copied().unwrap_or(0);
        Ok((worker_id, overlap_amount))
196
    }
197
198

    /// Get the block size this router was configured with
199
    pub fn block_size(&self) -> u32 {
200
201
        self.block_size
    }
202
203
204
205
206
207
208
209
210
}

#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
211
        let (worker_id, _) = self.find_best_match(&request.tokens).await?;
212
213
214
215
216
217
218

        let response = RouterResponse { worker_id };
        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
219
220

pub struct KvPushRouter {
221
    inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
222
223
224
225
226
    chooser: Arc<KvRouter>,
}

impl KvPushRouter {
    pub fn new(
227
        inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
228
229
230
231
232
233
234
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
}

#[async_trait]
235
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
236
237
238
239
    for KvPushRouter
{
    async fn generate(
        &self,
240
        request: SingleIn<PreprocessedRequest>,
241
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
242
        match self.inner.client.instance_source.as_ref() {
243
244
            InstanceSource::Static => self.inner.r#static(request).await,
            InstanceSource::Dynamic(_) => {
245
246
247
248
249
250
251
                let (instance_id, overlap_amount) =
                    self.chooser.find_best_match(&request.token_ids).await?;
                // Update the request with the estimated prefix hit blocks
                let (mut backend_input, context) = request.into_parts();
                backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
                let updated_request = context.map(|_| backend_input);
                self.inner.direct(updated_request, instance_id).await
252
253
254
255
            }
        }
    }
}