scheduler.rs 9.53 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut;
use std::cmp::min;

use crate::kv_router::indexer::OverlapScores;
pub use crate::kv_router::protocols::{ForwardPassMetrics, KV_BLOCK_SIZE};
use crate::kv_router::scoring::ProcessedEndpoints;

Ryan Olson's avatar
Ryan Olson committed
24
#[allow(dead_code)]
25
26
27
28
29
30
31
32
33
34
35
36
#[derive(Debug, thiserror::Error)]
pub enum KvSchedulerError {
    #[error("no endpoints aviailable to route work")]
    NoEndpoints,

    #[error("all workers busy")]
    AllWorkersBusy,

    #[error("endpoint subscriber shutdown")]
    SubscriberShutdown,
}

Alec's avatar
Alec committed
37
38
39
40
41
42
43
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FlexibleEndpoint {
    pub name: String,
    pub subject: String,
    pub data: Option<ForwardPassMetrics>,
}

44
45
46
47
48
49
50
51
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Endpoint {
    pub name: String,
    pub subject: String,
    pub data: ForwardPassMetrics,
}

impl Endpoint {
GuanLuo's avatar
GuanLuo committed
52
53
    pub fn worker_id(&self) -> i64 {
        i64::from_str_radix(
54
            self.subject
GuanLuo's avatar
GuanLuo committed
55
                .split("-")
56
57
58
59
                .last()
                .expect("invalid subject")
                .to_string()
                .as_str(),
GuanLuo's avatar
GuanLuo committed
60
            16,
61
        )
GuanLuo's avatar
GuanLuo committed
62
        .expect("invalid worker id")
63
64
65
66
67
68
69
70
71
    }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Service {
    pub name: String,
    pub id: String,
    pub version: String,
    pub started: String,
Alec's avatar
Alec committed
72
    pub endpoints: Vec<FlexibleEndpoint>,
73
74
75
76
77
}

pub struct SchedulingRequest {
    isl_tokens: usize,
    overlap: OverlapScores,
GuanLuo's avatar
GuanLuo committed
78
    resp_tx: tokio::sync::oneshot::Sender<i64>,
79
80
81
}

impl SchedulingRequest {
GuanLuo's avatar
GuanLuo committed
82
    pub fn respond(self, worker_id: i64) {
83
        if self.resp_tx.send(worker_id).is_err() {
84
            tracing::trace!("failed to send response to requestor");
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        }
    }
}

pub struct KvScheduler {
    request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
}

impl KvScheduler {
    pub async fn start(
        endpoints_rx: tokio::sync::mpsc::Receiver<ProcessedEndpoints>,
    ) -> Result<Self, KvSchedulerError> {
        let mut endpoints_rx = endpoints_rx;

99
        tracing::trace!("awaiting the start of the background endpoint subscriber");
100
101
102
103
104
105
106
107
108
        let mut endpoints = match endpoints_rx.recv().await {
            Some(endpoints) => endpoints,
            None => {
                return Err(KvSchedulerError::SubscriberShutdown);
            }
        };

        // Channel to accept new scheduling requests
        let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(16);
109
        tracing::debug!("scheduler starting");
110
111
112
113
        // Background task to handle scheduling requests
        tokio::spawn(async move {
            let mut request: SchedulingRequest;
            let mut request_rx = request_rx;
114
            tracing::debug!("scheduler background task started");
115
116
117
118
119
120
121
122

            'outer: loop {
                request = tokio::select! {
                    biased;

                    new_request = request_rx.recv() => {
                        match new_request {
                            Some(new_request) => {
123
                                tracing::trace!("received request to be scheduled");
124
125
126
                                new_request
                            },
                            None => {
127
                                tracing::trace!("scheduler shutdown");
128
129
130
131
132
133
134
135
                                break 'outer;
                            }
                        }
                    }

                    new_endpoints = endpoints_rx.recv() => {
                        match new_endpoints {
                            Some(new_endpoints) => {
136
                                tracing::trace!("updated endpoints");
137
138
139
140
                                endpoints = new_endpoints;
                                continue 'outer;
                            }
                            None => {
141
                                tracing::trace!("endpoint subscriber shutdown");
142
143
144
145
146
                                break 'outer;
                            }
                        }
                    }
                };
147
                tracing::debug!("selected");
148
149
150
151
152
153
154
                loop {
                    match select_worker(endpoints.borrow_mut(), &request) {
                        Ok(worker_id) => {
                            request.respond(worker_id);
                            continue 'outer;
                        }
                        Err(KvSchedulerError::AllWorkersBusy) => {
155
                            tracing::trace!("all workers busy; waiting for more capacity");
156
157
158
                            endpoints = match endpoints_rx.recv().await {
                                Some(endpoints) => endpoints,
                                None => {
159
                                    tracing::trace!("endpoint subscriber shutdown");
160
161
162
163
164
                                    break 'outer;
                                }
                            };
                        }
                        Err(e) => {
165
                            tracing::error!("error scheduling request: {:?}", e);
166
167
168
169
170
171
                            break 'outer;
                        }
                    }
                }
            }

172
            tracing::trace!("background endpoint subscriber shutting down");
173
174
175
176
177
        });

        Ok(KvScheduler { request_tx })
    }

178
    #[allow(dead_code)]
179
180
181
182
    pub async fn schedule(
        &self,
        overlap: OverlapScores,
        isl_tokens: usize,
GuanLuo's avatar
GuanLuo committed
183
    ) -> Result<i64, KvSchedulerError> {
184
185
186
187
188
189
        let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
        let request = SchedulingRequest {
            isl_tokens,
            overlap,
            resp_tx,
        };
190
        tracing::debug!("before sending request");
191
192
193
194
        self.request_tx
            .send(request)
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
195
        tracing::debug!("after sending request");
196
197
198
199

        let res = resp_rx
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
200
        tracing::debug!("after receiving response");
201
202
203
204
205
206
207
        Ok(res)
    }
}

pub fn select_worker(
    workers: &mut ProcessedEndpoints,
    request: &SchedulingRequest,
GuanLuo's avatar
GuanLuo committed
208
) -> Result<i64, KvSchedulerError> {
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    // balance mode prioritizes balancing load across workers
    let balance_threshold: f64 = 0.1;
    let balance_mode = workers.load_std > balance_threshold * workers.load_avg;

    // Determine alpha based on mode
    let alpha = if balance_mode { 0.7 } else { 0.3 };
    let gamma = 0.1; // example tuning param

    // Compute each worker's score
    let mut best_index = None;
    let mut best_cost = f64::INFINITY;

    if workers.endpoints.is_empty() {
        return Err(KvSchedulerError::NoEndpoints);
    }

    for (i, w) in workers.endpoints.iter().enumerate() {
        // Exclude workers that are at capacity
        if w.data.request_active_slots >= w.data.request_total_slots
            || w.data.kv_active_blocks >= w.data.kv_total_blocks
        {
            continue;
        }

        let kv_load_ratio = w.data.kv_active_blocks as f64 / w.data.kv_total_blocks as f64;
        let load_deviation = kv_load_ratio - workers.load_avg;

GuanLuo's avatar
GuanLuo committed
236
        // [FIXME] multiple endpoints of the same worker cause out of bound error
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
        let worker_id = workers.worker_ids[i];
        let overlap_score = request.overlap.scores.get(&worker_id).map_or(0, |x| *x);
        let overlap_score = overlap_score as usize * KV_BLOCK_SIZE;

        let new_tokens = request.isl_tokens.saturating_sub(overlap_score);
        let normalized_new_tokens = new_tokens as f64 / request.isl_tokens as f64;

        let request_load_ratio =
            w.data.request_active_slots as f64 / w.data.request_total_slots as f64;

        // cost = alpha * load_deviation + (1 - alpha)*normalized_new_tokens + gamma * request_load_ratio
        let cost = alpha * load_deviation
            + (1.0 - alpha) * normalized_new_tokens
            + gamma * request_load_ratio;

252
        tracing::debug!("worker: {}; load_deviation: {}; normalized new blocks: {}; request_load_ratio: {} cost: {}",
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
                worker_id,
                load_deviation,
                normalized_new_tokens,
                request_load_ratio,
                cost
            );

        if cost < best_cost {
            best_cost = cost;
            best_index = Some(i);
        }
    }

    if let Some(best_index) = best_index {
        let total_blocks = min(request.isl_tokens / KV_BLOCK_SIZE, 1);

        workers.endpoints[best_index].data.request_active_slots += 1;
        workers.endpoints[best_index].data.kv_active_blocks += total_blocks as u64;
    }

    match best_index {
        Some(i) => {
275
            tracing::info!(
276
                "selected worker: {}; cost: {}",
GuanLuo's avatar
GuanLuo committed
277
                workers.endpoints[i].worker_id(),
278
279
                best_cost
            );
GuanLuo's avatar
GuanLuo committed
280
            Ok(workers.endpoints[i].worker_id())
281
282
        }
        None => {
283
            tracing::debug!("all workers busy");
284
285
286
287
            Err(KvSchedulerError::AllWorkersBusy)
        }
    }
}