scheduler.rs 11.5 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
// SPDX-License-Identifier: Apache-2.0
3

4
pub use dynamo_kv_router::scheduling::policy::RouterSchedulingPolicy;
5
6
7
8
9
pub use dynamo_kv_router::scheduling::{
    KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse,
};
pub use dynamo_kv_router::selector::DefaultWorkerSelector;

10
use super::WorkerSelector;
11
use super::metrics::ROUTER_QUEUE_METRICS;
12
use super::queue::SchedulerQueue;
13
14
15
use super::sequence::{
    ActiveSequencesMulti, SequenceError, SequenceRequest, create_multi_worker_sequences,
};
16
use crate::discovery::RuntimeConfigWatch;
17
use crate::local_model::runtime_config::ModelRuntimeConfig;
18
use anyhow::Result;
19
20
21
22
use dynamo_kv_router::{
    config::{KvRouterConfig, RouterConfigOverride},
    protocols::{OverlapScores, WorkerId},
};
23
use dynamo_runtime::component::Component;
Yan Ru Pei's avatar
Yan Ru Pei committed
24
use dynamo_runtime::traits::DistributedRuntimeProvider;
25
use std::collections::{HashMap, HashSet};
26
27
use std::sync::Arc;
use std::time::Duration;
28
29
#[cfg(feature = "bench")]
use std::time::Instant;
30

31
use dynamo_tokens::SequenceHash;
32
33
34

pub struct KvScheduler {
    request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
35
    slots: Arc<ActiveSequencesMulti>,
36
    queue: Arc<SchedulerQueue>,
37
38
39
40
}

impl KvScheduler {
    pub async fn start(
41
        component: Component,
42
        block_size: u32,
43
        workers_with_configs: RuntimeConfigWatch,
44
        selector: Option<Box<WorkerSelector>>,
45
        kv_router_config: &KvRouterConfig,
46
        worker_type: &'static str,
47
    ) -> Result<Self, KvSchedulerError> {
48
        let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::new(None, worker_type)));
49

50
        // Get initial workers from watch receiver.
51
52
53
        // When skip_initial_worker_wait is false, the caller ensures at least one
        // worker is present (via wait_for). When true the map may be empty;
        // workers will be lazily registered via allowed_worker_ids per-request.
54
55
        let initial_workers: HashMap<WorkerId, ModelRuntimeConfig> =
            workers_with_configs.borrow().clone();
56

57
        let router_id = component.drt().discovery().instance_id();
58
59
60
61
62
63
64
65
66
67
        let slots = create_multi_worker_sequences(
            component.clone(),
            block_size as usize,
            initial_workers,
            kv_router_config.router_replica_sync,
            router_id,
            worker_type,
        )
        .await
        .map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?;
68

69
        // Spawn background task to sync slots when the watch value changes.
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        //
        // In EPP mode (skip_initial_worker_wait=true) we skip the monitoring task:
        // the per-request allowed_worker_ids is the source of truth, workers are
        // lazily registered via register_external_workers() from the C bindings,
        // and update_workers() would impose discovery-based lifecycle (add/remove)
        // on the slot tracker, conflicting with EPP ownership.
        if kv_router_config.skip_initial_worker_wait {
            tracing::info!("skipping discovery-based worker monitoring");
        } else {
            let slots_monitor = slots.clone();
            let mut monitor_rx = workers_with_configs.clone();
            let monitor_cancel_token = component.drt().child_token();
            tokio::spawn(async move {
                tracing::trace!("KvScheduler workers monitoring task started");
                let mut last_workers: HashMap<WorkerId, ModelRuntimeConfig> = HashMap::new();

                loop {
                    tokio::select! {
                        _ = monitor_cancel_token.cancelled() => {
                            tracing::trace!("KvScheduler workers monitoring task shutting down");
90
91
                            break;
                        }
92
93
94
95
96
97
                        result = monitor_rx.changed() => {
                            if result.is_err() {
                                tracing::warn!("KvScheduler: config watch sender dropped, shutting down");
                                break;
                            }
                        }
98
                    }
99

100
101
102
103
104
105
106
107
108
109
110
111
                    let current_workers = monitor_rx.borrow_and_update().clone();

                    if current_workers != last_workers {
                        let dp_range: HashMap<u64, (u32, u32)> = current_workers
                            .iter()
                            .map(|(&id, c)| {
                                (id, (c.data_parallel_start_rank, c.data_parallel_size))
                            })
                            .collect();
                        slots_monitor.update_workers(&dp_range);
                        last_workers = current_workers;
                    }
Yan Ru Pei's avatar
Yan Ru Pei committed
112
                }
113
114
            });
        }
Yan Ru Pei's avatar
Yan Ru Pei committed
115
116
117
118

        let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
        let scheduler_cancel_token = component.drt().primary_token();

119
120
121
122
123
124
125
        let policy =
            RouterSchedulingPolicy::new(kv_router_config.router_queue_policy, block_size as usize);
        tracing::info!(
            "Router queue policy: {}",
            kv_router_config.router_queue_policy
        );

126
127
128
        let queue = Arc::new(SchedulerQueue::new(
            slots.clone(),
            workers_with_configs.clone(),
129
            kv_router_config.router_queue_threshold,
130
131
            block_size,
            selector,
132
            policy,
133
134
135
        ));
        let queue_clone = queue.clone();

136
        // Background task: receive requests and periodically recheck pending
Yan Ru Pei's avatar
Yan Ru Pei committed
137
138
        tokio::spawn(async move {
            let mut request_rx = request_rx;
139
            let mut recheck_interval = tokio::time::interval(Duration::from_secs(60));
Yan Ru Pei's avatar
Yan Ru Pei committed
140
141
142
            tracing::trace!("scheduler background task started");

            loop {
143
144
145
146
147
148
149
150
151
                tokio::select! {
                    _ = scheduler_cancel_token.cancelled() => {
                        tracing::trace!("scheduler background task shutting down");
                        break;
                    }
                    request = request_rx.recv() => {
                        let Some(request) = request else {
                            tracing::warn!("scheduler shutdown");
                            break;
152
                        };
153
154
                        tracing::trace!("received request to be scheduled");
                        queue_clone.enqueue(request).await;
155
                        ROUTER_QUEUE_METRICS.set_pending(worker_type, queue_clone.pending_count());
156
157
158
                    }
                    _ = recheck_interval.tick() => {
                        queue_clone.update().await;
159
                        ROUTER_QUEUE_METRICS.set_pending(worker_type, queue_clone.pending_count());
160
161
                    }
                }
162
163
            }

164
            tracing::trace!("background endpoint subscriber shutting down");
165
166
        });

167
168
169
170
171
        Ok(KvScheduler {
            request_tx,
            slots,
            queue,
        })
172
173
    }

174
    #[expect(clippy::too_many_arguments)]
175
176
    pub async fn schedule(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
177
        maybe_request_id: Option<String>,
178
        isl_tokens: usize,
179
        token_seq: Option<Vec<SequenceHash>>,
180
        overlaps: OverlapScores,
181
        router_config_override: Option<&RouterConfigOverride>,
182
        update_states: bool,
183
        lora_name: Option<String>,
184
        priority_jump: f64,
185
        expected_output_tokens: Option<u32>,
186
187
        allowed_worker_ids: Option<HashSet<WorkerId>>,
    ) -> Result<SchedulingResponse, KvSchedulerError> {
188
189
190
        #[cfg(feature = "bench")]
        let start = Instant::now();

191
192
        let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
        let request = SchedulingRequest {
Yan Ru Pei's avatar
Yan Ru Pei committed
193
            maybe_request_id,
194
            token_seq,
195
            isl_tokens,
196
            overlaps,
197
198
            decode_blocks: HashMap::new(),
            prefill_tokens: HashMap::new(),
199
            router_config_override: router_config_override.cloned(),
200
            update_states,
201
            lora_name,
202
            priority_jump,
203
            expected_output_tokens,
204
            allowed_worker_ids,
205
            resp_tx: Some(resp_tx),
206
        };
207

208
209
210
211
        self.request_tx
            .send(request)
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
212
213
214
215

        #[cfg(feature = "bench")]
        let send_elapsed = start.elapsed();

216
        let response = resp_rx
217
            .await
218
            .map_err(|_| KvSchedulerError::SubscriberShutdown)??;
219

220
221
222
223
224
225
226
227
228
229
        #[cfg(feature = "bench")]
        let total_elapsed = start.elapsed();
        #[cfg(feature = "bench")]
        tracing::info!(
            isl_tokens,
            send_us = send_elapsed.as_micros() as u64,
            total_us = total_elapsed.as_micros() as u64,
            "scheduler.schedule completed"
        );

230
        Ok(response)
231
232
    }

233
234
235
236
237
    /// Register externally-provided workers in the slot tracker.
    pub fn register_workers(&self, worker_ids: &HashSet<WorkerId>) {
        self.queue.register_workers(worker_ids);
    }

238
239
    pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
        self.slots.add_request(req).await
240
241
    }

242
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
243
        self.slots
244
            .mark_prefill_completed(&request_id.to_string())
245
246
            .await?;
        self.queue.update().await;
247
        ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.queue.pending_count());
248
        Ok(())
249
250
    }

251
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
252
253
        self.slots.free(&request_id.to_string()).await?;
        self.queue.update().await;
254
        ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.queue.pending_count());
255
        Ok(())
256
    }
257

258
259
260
261
262
    /// Number of requests currently parked in the scheduler queue.
    pub fn pending_count(&self) -> usize {
        self.queue.pending_count()
    }

263
264
265
266
267
268
    /// Get the worker type for this scheduler ("prefill" or "decode").
    /// Used for Prometheus metric labeling.
    pub fn worker_type(&self) -> &'static str {
        self.slots.worker_type()
    }

269
    pub fn add_output_block(
270
271
272
273
274
275
276
277
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
        self.slots
            .add_output_block(&request_id.to_string(), decay_fraction)
    }

278
    pub fn get_potential_loads(
279
        &self,
280
        token_seq: Option<Vec<SequenceHash>>,
281
282
283
        isl_tokens: usize,
        overlaps: OverlapScores,
    ) -> Vec<PotentialLoad> {
284
285
286
        let (decode_blocks, prefill_tokens) =
            self.slots
                .potential_blocks_and_tokens(token_seq.as_deref(), isl_tokens, overlaps);
287

Yan Ru Pei's avatar
Yan Ru Pei committed
288
        // Get all unique WorkerWithDpRank from both hashmaps
289
        let mut workers: HashSet<dynamo_kv_router::protocols::WorkerWithDpRank> = HashSet::new();
Yan Ru Pei's avatar
Yan Ru Pei committed
290
291
        workers.extend(decode_blocks.keys().copied());
        workers.extend(prefill_tokens.keys().copied());
292
293
294

        // Create PotentialLoad for each worker
        let mut loads = Vec::new();
Yan Ru Pei's avatar
Yan Ru Pei committed
295
        for worker in workers {
296
            loads.push(PotentialLoad {
Yan Ru Pei's avatar
Yan Ru Pei committed
297
298
                worker_id: worker.worker_id,
                dp_rank: worker.dp_rank,
299
                potential_prefill_tokens: prefill_tokens
Yan Ru Pei's avatar
Yan Ru Pei committed
300
                    .get(&worker)
301
302
                    .copied()
                    .unwrap_or(isl_tokens),
Yan Ru Pei's avatar
Yan Ru Pei committed
303
                potential_decode_blocks: decode_blocks.get(&worker).copied().unwrap_or(0),
304
305
306
307
308
            });
        }

        loads
    }
309
310
311
312
313

    /// Get active request counts grouped by LORA name
    pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
        self.slots.get_active_lora_counts()
    }
314
}