scheduler.rs 10.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Neelay Shah's avatar
Neelay Shah committed
16
17
use dynamo_runtime::component::Namespace;
use dynamo_runtime::traits::events::EventPublisher;
18
19
20
21
22
use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut;
use std::cmp::min;

use crate::kv_router::indexer::OverlapScores;
23
pub use crate::kv_router::protocols::ForwardPassMetrics;
24
use crate::kv_router::scoring::ProcessedEndpoints;
25
use crate::kv_router::KV_HIT_RATE_SUBJECT;
26

27
28
29
30
31
32
33
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KVHitRateEvent {
    pub worker_id: i64,
    pub isl_blocks: usize,
    pub overlap_blocks: usize,
}

34
35
36
37
38
39
40
41
42
43
44
45
#[derive(Debug, thiserror::Error)]
pub enum KvSchedulerError {
    #[error("no endpoints aviailable to route work")]
    NoEndpoints,

    #[error("all workers busy")]
    AllWorkersBusy,

    #[error("endpoint subscriber shutdown")]
    SubscriberShutdown,
}

46
47
/// [gluo FIXME] exactly the same as EndpointInfo except that 'data'
/// is cleaned (not optional)
48
49
50
51
52
53
54
55
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Endpoint {
    pub name: String,
    pub subject: String,
    pub data: ForwardPassMetrics,
}

impl Endpoint {
GuanLuo's avatar
GuanLuo committed
56
57
    pub fn worker_id(&self) -> i64 {
        i64::from_str_radix(
58
            self.subject
GuanLuo's avatar
GuanLuo committed
59
                .split("-")
60
61
62
63
                .last()
                .expect("invalid subject")
                .to_string()
                .as_str(),
GuanLuo's avatar
GuanLuo committed
64
            16,
65
        )
GuanLuo's avatar
GuanLuo committed
66
        .expect("invalid worker id")
67
68
69
70
71
72
    }
}

pub struct SchedulingRequest {
    isl_tokens: usize,
    overlap: OverlapScores,
GuanLuo's avatar
GuanLuo committed
73
    resp_tx: tokio::sync::oneshot::Sender<i64>,
74
75
76
}

impl SchedulingRequest {
GuanLuo's avatar
GuanLuo committed
77
    pub fn respond(self, worker_id: i64) {
78
        if self.resp_tx.send(worker_id).is_err() {
79
            tracing::trace!("failed to send response to requestor");
80
81
82
83
84
85
86
87
88
89
90
        }
    }
}

pub struct KvScheduler {
    request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
}

impl KvScheduler {
    pub async fn start(
        endpoints_rx: tokio::sync::mpsc::Receiver<ProcessedEndpoints>,
91
        ns: Namespace,
92
        kv_block_size: usize,
93
94
95
    ) -> Result<Self, KvSchedulerError> {
        let mut endpoints_rx = endpoints_rx;

96
        tracing::trace!("awaiting the start of the background endpoint subscriber");
97
98
99
100
101
102
103
        let mut endpoints = match endpoints_rx.recv().await {
            Some(endpoints) => endpoints,
            None => {
                return Err(KvSchedulerError::SubscriberShutdown);
            }
        };

104
105
106
107
108
109
110
        // Channel to asynchronously publish metric events on
        let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<KVHitRateEvent>();

        // Publisher task
        tokio::spawn(async move {
            let mut event_rx = event_rx;
            while let Some(event) = event_rx.recv().await {
111
                if let Err(e) = ns.publish(KV_HIT_RATE_SUBJECT, &event).await {
112
113
114
115
116
                    tracing::warn!("Failed to publish KV hit rate event: {:?}", e);
                }
            }
        });

117
118
        // Channel to accept new scheduling requests
        let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(16);
119
        tracing::debug!("scheduler starting");
120
121
122
123
        // Background task to handle scheduling requests
        tokio::spawn(async move {
            let mut request: SchedulingRequest;
            let mut request_rx = request_rx;
124
            tracing::debug!("scheduler background task started");
125
126
127
128
129
130
131
132

            'outer: loop {
                request = tokio::select! {
                    biased;

                    new_request = request_rx.recv() => {
                        match new_request {
                            Some(new_request) => {
133
                                tracing::trace!("received request to be scheduled");
134
135
136
                                new_request
                            },
                            None => {
137
                                tracing::trace!("scheduler shutdown");
138
139
140
141
142
143
144
145
                                break 'outer;
                            }
                        }
                    }

                    new_endpoints = endpoints_rx.recv() => {
                        match new_endpoints {
                            Some(new_endpoints) => {
146
                                tracing::trace!("updated endpoints");
147
148
149
150
                                endpoints = new_endpoints;
                                continue 'outer;
                            }
                            None => {
151
                                tracing::trace!("endpoint subscriber shutdown");
152
153
154
155
156
                                break 'outer;
                            }
                        }
                    }
                };
157
                tracing::debug!("selected");
158
                loop {
159
160
                    match select_worker(endpoints.borrow_mut(), &request, &event_tx, kv_block_size)
                    {
161
162
163
164
165
                        Ok(worker_id) => {
                            request.respond(worker_id);
                            continue 'outer;
                        }
                        Err(KvSchedulerError::AllWorkersBusy) => {
166
                            tracing::trace!("all workers busy; waiting for more capacity");
167
168
169
                            endpoints = match endpoints_rx.recv().await {
                                Some(endpoints) => endpoints,
                                None => {
170
                                    tracing::trace!("endpoint subscriber shutdown");
171
172
173
174
175
                                    break 'outer;
                                }
                            };
                        }
                        Err(e) => {
176
                            tracing::error!("error scheduling request: {:?}", e);
177
178
179
180
181
182
                            break 'outer;
                        }
                    }
                }
            }

183
            tracing::trace!("background endpoint subscriber shutting down");
184
185
186
187
188
189
190
191
192
        });

        Ok(KvScheduler { request_tx })
    }

    pub async fn schedule(
        &self,
        overlap: OverlapScores,
        isl_tokens: usize,
GuanLuo's avatar
GuanLuo committed
193
    ) -> Result<i64, KvSchedulerError> {
194
195
196
197
198
199
        let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
        let request = SchedulingRequest {
            isl_tokens,
            overlap,
            resp_tx,
        };
200
        tracing::debug!("before sending request");
201
202
203
204
        self.request_tx
            .send(request)
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
205
        tracing::debug!("after sending request");
206
207
208
209

        let res = resp_rx
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
210
        tracing::debug!("after receiving response");
211
212
213
214
215
216
217
        Ok(res)
    }
}

pub fn select_worker(
    workers: &mut ProcessedEndpoints,
    request: &SchedulingRequest,
218
    event_tx: &tokio::sync::mpsc::UnboundedSender<KVHitRateEvent>,
219
    kv_block_size: usize,
GuanLuo's avatar
GuanLuo committed
220
) -> Result<i64, KvSchedulerError> {
221
222
223
224
225
226
227
228
229
230
231
    // balance mode prioritizes balancing load across workers
    let balance_threshold: f64 = 0.1;
    let balance_mode = workers.load_std > balance_threshold * workers.load_avg;

    // Determine alpha based on mode
    let alpha = if balance_mode { 0.7 } else { 0.3 };
    let gamma = 0.1; // example tuning param

    // Compute each worker's score
    let mut best_index = None;
    let mut best_cost = f64::INFINITY;
232
    // [FIXME] REMOVE ONLY FOR TESTING
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
    if workers.endpoints.is_empty() {
        return Err(KvSchedulerError::NoEndpoints);
    }

    for (i, w) in workers.endpoints.iter().enumerate() {
        // Exclude workers that are at capacity
        if w.data.request_active_slots >= w.data.request_total_slots
            || w.data.kv_active_blocks >= w.data.kv_total_blocks
        {
            continue;
        }

        let kv_load_ratio = w.data.kv_active_blocks as f64 / w.data.kv_total_blocks as f64;
        let load_deviation = kv_load_ratio - workers.load_avg;

GuanLuo's avatar
GuanLuo committed
248
        // [FIXME] multiple endpoints of the same worker cause out of bound error
249
250
        let worker_id = workers.worker_ids[i];
        let overlap_score = request.overlap.scores.get(&worker_id).map_or(0, |x| *x);
251
        let overlap_score = overlap_score as usize * kv_block_size;
252
253
254
255
256
257
258
259
260
261
262
263

        let new_tokens = request.isl_tokens.saturating_sub(overlap_score);
        let normalized_new_tokens = new_tokens as f64 / request.isl_tokens as f64;

        let request_load_ratio =
            w.data.request_active_slots as f64 / w.data.request_total_slots as f64;

        // cost = alpha * load_deviation + (1 - alpha)*normalized_new_tokens + gamma * request_load_ratio
        let cost = alpha * load_deviation
            + (1.0 - alpha) * normalized_new_tokens
            + gamma * request_load_ratio;

264
        tracing::debug!("worker: {}; load_deviation: {}; normalized new blocks: {}; request_load_ratio: {} cost: {}",
265
266
267
268
269
270
271
272
273
274
275
276
277
278
                worker_id,
                load_deviation,
                normalized_new_tokens,
                request_load_ratio,
                cost
            );

        if cost < best_cost {
            best_cost = cost;
            best_index = Some(i);
        }
    }

    if let Some(best_index) = best_index {
279
        let total_blocks = min(request.isl_tokens / kv_block_size, 1);
280
281
282

        workers.endpoints[best_index].data.request_active_slots += 1;
        workers.endpoints[best_index].data.kv_active_blocks += total_blocks as u64;
283
284
285

        // Optimization - pass this to a channel for emitting events, async task, etc. to avoid blocking the scheduler
        let best_worker_id = workers.endpoints[best_index].worker_id();
286
        let isl_blocks = request.isl_tokens / kv_block_size;
287
288
289
290
291
292
293
294
295
296
297
298
299
        let overlap_blocks = request
            .overlap
            .scores
            .get(&best_worker_id)
            .copied()
            .unwrap_or(0);
        if let Err(e) = event_tx.send(KVHitRateEvent {
            worker_id: best_worker_id,
            isl_blocks,
            overlap_blocks: overlap_blocks as usize,
        }) {
            tracing::warn!("Failed to send KV hit rate event: {:?}", e);
        }
300
301
302
303
    }

    match best_index {
        Some(i) => {
304
            tracing::info!(
305
                "selected worker: {}; cost: {}",
GuanLuo's avatar
GuanLuo committed
306
                workers.endpoints[i].worker_id(),
307
308
                best_cost
            );
GuanLuo's avatar
GuanLuo committed
309
            Ok(workers.endpoints[i].worker_id())
310
311
        }
        None => {
312
            tracing::debug!("all workers busy");
313
314
315
316
            Err(KvSchedulerError::AllWorkersBusy)
        }
    }
}