scheduler.rs 9.4 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
// SPDX-License-Identifier: Apache-2.0
3

4
5
6
7
8
pub use dynamo_kv_router::scheduling::{
    KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse,
};
pub use dynamo_kv_router::selector::DefaultWorkerSelector;

9
10
11
use super::KvRouterConfig;
use super::RouterConfigOverride;
use super::WorkerSelector;
12
use super::protocols::{OverlapScores, WorkerId};
13
use super::queue::SchedulerQueue;
14
15
16
use super::sequence::{
    ActiveSequencesMulti, SequenceError, SequenceRequest, create_multi_worker_sequences,
};
17
use crate::discovery::RuntimeConfigWatch;
18
use crate::local_model::runtime_config::ModelRuntimeConfig;
19
use anyhow::Result;
20
use dynamo_runtime::component::Component;
Yan Ru Pei's avatar
Yan Ru Pei committed
21
use dynamo_runtime::traits::DistributedRuntimeProvider;
22
use std::collections::{HashMap, HashSet};
23
24
use std::sync::Arc;
use std::time::Duration;
25
26
#[cfg(feature = "bench")]
use std::time::Instant;
27

28
use dynamo_tokens::SequenceHash;
29
30
31

pub struct KvScheduler {
    request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
32
    slots: Arc<ActiveSequencesMulti>,
33
    queue: Arc<SchedulerQueue>,
34
35
36
37
}

impl KvScheduler {
    pub async fn start(
38
        component: Component,
39
        block_size: u32,
40
        workers_with_configs: RuntimeConfigWatch,
41
        selector: Option<Box<WorkerSelector>>,
42
        kv_router_config: &KvRouterConfig,
43
        worker_type: &'static str,
44
    ) -> Result<Self, KvSchedulerError> {
45
        let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
46

47
48
49
50
        // Get initial workers from watch receiver.
        // Caller must ensure at least one worker is present (via wait_for).
        let initial_workers: HashMap<WorkerId, ModelRuntimeConfig> =
            workers_with_configs.borrow().clone();
51

52
        let router_id = component.drt().discovery().instance_id();
53
54
55
56
57
58
59
60
61
62
        let slots = create_multi_worker_sequences(
            component.clone(),
            block_size as usize,
            initial_workers,
            kv_router_config.router_replica_sync,
            router_id,
            worker_type,
        )
        .await
        .map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?;
63

64
        // Spawn background task to sync slots when the watch value changes.
Yan Ru Pei's avatar
Yan Ru Pei committed
65
        let slots_monitor = slots.clone();
66
        let mut monitor_rx = workers_with_configs.clone();
67
        let monitor_cancel_token = component.drt().child_token();
68
        tokio::spawn(async move {
69
            tracing::trace!("KvScheduler workers monitoring task started");
70
            let mut last_workers: HashMap<WorkerId, ModelRuntimeConfig> = HashMap::new();
71

72
            loop {
Yan Ru Pei's avatar
Yan Ru Pei committed
73
74
                tokio::select! {
                    _ = monitor_cancel_token.cancelled() => {
75
                        tracing::trace!("KvScheduler workers monitoring task shutting down");
76
77
                        break;
                    }
78
                    result = monitor_rx.changed() => {
79
80
81
82
83
                        if result.is_err() {
                            tracing::warn!("KvScheduler: config watch sender dropped, shutting down");
                            break;
                        }
                    }
84
85
                }

86
87
88
                let current_workers = monitor_rx.borrow_and_update().clone();

                if current_workers != last_workers {
89
90
91
92
                    let dp_sizes: HashMap<u64, u32> = current_workers
                        .iter()
                        .map(|(&id, c)| (id, c.data_parallel_size))
                        .collect();
93
                    slots_monitor.update_workers(&dp_sizes);
94
                    last_workers = current_workers;
Yan Ru Pei's avatar
Yan Ru Pei committed
95
96
97
98
99
100
101
                }
            }
        });

        let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
        let scheduler_cancel_token = component.drt().primary_token();

102
103
104
        let queue = Arc::new(SchedulerQueue::new(
            slots.clone(),
            workers_with_configs.clone(),
105
            kv_router_config.router_queue_threshold,
106
107
            block_size,
            selector,
108
109
110
        ));
        let queue_clone = queue.clone();

111
        // Background task: receive requests and periodically recheck pending
Yan Ru Pei's avatar
Yan Ru Pei committed
112
113
        tokio::spawn(async move {
            let mut request_rx = request_rx;
114
            let mut recheck_interval = tokio::time::interval(Duration::from_secs(60));
Yan Ru Pei's avatar
Yan Ru Pei committed
115
116
117
            tracing::trace!("scheduler background task started");

            loop {
118
119
120
121
122
123
124
125
126
                tokio::select! {
                    _ = scheduler_cancel_token.cancelled() => {
                        tracing::trace!("scheduler background task shutting down");
                        break;
                    }
                    request = request_rx.recv() => {
                        let Some(request) = request else {
                            tracing::warn!("scheduler shutdown");
                            break;
127
                        };
128
129
130
131
132
133
134
                        tracing::trace!("received request to be scheduled");
                        queue_clone.enqueue(request).await;
                    }
                    _ = recheck_interval.tick() => {
                        queue_clone.update().await;
                    }
                }
135
136
            }

137
            tracing::trace!("background endpoint subscriber shutting down");
138
139
        });

140
141
142
143
144
        Ok(KvScheduler {
            request_tx,
            slots,
            queue,
        })
145
146
    }

147
    #[allow(clippy::too_many_arguments)]
148
149
    pub async fn schedule(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
150
        maybe_request_id: Option<String>,
151
        isl_tokens: usize,
152
        token_seq: Option<Vec<SequenceHash>>,
153
        overlaps: OverlapScores,
154
        router_config_override: Option<&RouterConfigOverride>,
155
        update_states: bool,
156
        lora_name: Option<String>,
157
        priority_jump: f64,
158
159
        allowed_worker_ids: Option<HashSet<WorkerId>>,
    ) -> Result<SchedulingResponse, KvSchedulerError> {
160
161
162
        #[cfg(feature = "bench")]
        let start = Instant::now();

163
164
        let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
        let request = SchedulingRequest {
Yan Ru Pei's avatar
Yan Ru Pei committed
165
            maybe_request_id,
166
            token_seq,
167
            isl_tokens,
168
            overlaps,
169
170
            decode_blocks: HashMap::new(),
            prefill_tokens: HashMap::new(),
171
            router_config_override: router_config_override.cloned(),
172
            update_states,
173
            lora_name,
174
            priority_jump,
175
            allowed_worker_ids,
176
            resp_tx: Some(resp_tx),
177
        };
178

179
180
181
182
        self.request_tx
            .send(request)
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
183
184
185
186

        #[cfg(feature = "bench")]
        let send_elapsed = start.elapsed();

187
        let response = resp_rx
188
            .await
189
            .map_err(|_| KvSchedulerError::SubscriberShutdown)??;
190

191
192
193
194
195
196
197
198
199
200
        #[cfg(feature = "bench")]
        let total_elapsed = start.elapsed();
        #[cfg(feature = "bench")]
        tracing::info!(
            isl_tokens,
            send_us = send_elapsed.as_micros() as u64,
            total_us = total_elapsed.as_micros() as u64,
            "scheduler.schedule completed"
        );

201
        Ok(response)
202
203
    }

204
205
    pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
        self.slots.add_request(req).await
206
207
    }

208
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
209
        self.slots
210
            .mark_prefill_completed(&request_id.to_string())
211
212
213
            .await?;
        self.queue.update().await;
        Ok(())
214
215
    }

216
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
217
218
219
        self.slots.free(&request_id.to_string()).await?;
        self.queue.update().await;
        Ok(())
220
    }
221

222
223
224
225
226
227
    /// Get the worker type for this scheduler ("prefill" or "decode").
    /// Used for Prometheus metric labeling.
    pub fn worker_type(&self) -> &'static str {
        self.slots.worker_type()
    }

228
    pub fn add_output_block(
229
230
231
232
233
234
235
236
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
        self.slots
            .add_output_block(&request_id.to_string(), decay_fraction)
    }

237
    pub fn get_potential_loads(
238
        &self,
239
        token_seq: Option<Vec<SequenceHash>>,
240
241
242
        isl_tokens: usize,
        overlaps: OverlapScores,
    ) -> Vec<PotentialLoad> {
243
244
245
        let (decode_blocks, prefill_tokens) =
            self.slots
                .potential_blocks_and_tokens(token_seq.as_deref(), isl_tokens, overlaps);
246

Yan Ru Pei's avatar
Yan Ru Pei committed
247
        // Get all unique WorkerWithDpRank from both hashmaps
248
        let mut workers: HashSet<dynamo_kv_router::protocols::WorkerWithDpRank> = HashSet::new();
Yan Ru Pei's avatar
Yan Ru Pei committed
249
250
        workers.extend(decode_blocks.keys().copied());
        workers.extend(prefill_tokens.keys().copied());
251
252
253

        // Create PotentialLoad for each worker
        let mut loads = Vec::new();
Yan Ru Pei's avatar
Yan Ru Pei committed
254
        for worker in workers {
255
            loads.push(PotentialLoad {
Yan Ru Pei's avatar
Yan Ru Pei committed
256
257
                worker_id: worker.worker_id,
                dp_rank: worker.dp_rank,
258
                potential_prefill_tokens: prefill_tokens
Yan Ru Pei's avatar
Yan Ru Pei committed
259
                    .get(&worker)
260
261
                    .copied()
                    .unwrap_or(isl_tokens),
Yan Ru Pei's avatar
Yan Ru Pei committed
262
                potential_decode_blocks: decode_blocks.get(&worker).copied().unwrap_or(0),
263
264
265
266
267
            });
        }

        loads
    }
268
269
270
271
272

    /// Get active request counts grouped by LORA name
    pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
        self.slots.get_active_lora_counts()
    }
273
}