scheduler.rs 11.5 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
// SPDX-License-Identifier: Apache-2.0
3

4
pub use dynamo_kv_router::scheduling::policy::RouterSchedulingPolicy;
5
6
7
8
9
pub use dynamo_kv_router::scheduling::{
    KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse,
};
pub use dynamo_kv_router::selector::DefaultWorkerSelector;

10
11
12
use super::KvRouterConfig;
use super::RouterConfigOverride;
use super::WorkerSelector;
13
use super::metrics::ROUTER_QUEUE_METRICS;
14
use super::protocols::{OverlapScores, WorkerId};
15
use super::queue::SchedulerQueue;
16
17
18
use super::sequence::{
    ActiveSequencesMulti, SequenceError, SequenceRequest, create_multi_worker_sequences,
};
19
use crate::discovery::RuntimeConfigWatch;
20
use crate::local_model::runtime_config::ModelRuntimeConfig;
21
use anyhow::Result;
22
use dynamo_runtime::component::Component;
Yan Ru Pei's avatar
Yan Ru Pei committed
23
use dynamo_runtime::traits::DistributedRuntimeProvider;
24
use std::collections::{HashMap, HashSet};
25
26
use std::sync::Arc;
use std::time::Duration;
27
28
#[cfg(feature = "bench")]
use std::time::Instant;
29

30
use dynamo_tokens::SequenceHash;
31
32
33

pub struct KvScheduler {
    request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
34
    slots: Arc<ActiveSequencesMulti>,
35
    queue: Arc<SchedulerQueue>,
36
37
38
39
}

impl KvScheduler {
    pub async fn start(
40
        component: Component,
41
        block_size: u32,
42
        workers_with_configs: RuntimeConfigWatch,
43
        selector: Option<Box<WorkerSelector>>,
44
        kv_router_config: &KvRouterConfig,
45
        worker_type: &'static str,
46
    ) -> Result<Self, KvSchedulerError> {
47
        let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::new(None, worker_type)));
48

49
        // Get initial workers from watch receiver.
50
51
52
        // When skip_initial_worker_wait is false, the caller ensures at least one
        // worker is present (via wait_for). When true the map may be empty;
        // workers will be lazily registered via allowed_worker_ids per-request.
53
54
        let initial_workers: HashMap<WorkerId, ModelRuntimeConfig> =
            workers_with_configs.borrow().clone();
55

56
        let router_id = component.drt().discovery().instance_id();
57
58
59
60
61
62
63
64
65
66
        let slots = create_multi_worker_sequences(
            component.clone(),
            block_size as usize,
            initial_workers,
            kv_router_config.router_replica_sync,
            router_id,
            worker_type,
        )
        .await
        .map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?;
67

68
        // Spawn background task to sync slots when the watch value changes.
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        //
        // In EPP mode (skip_initial_worker_wait=true) we skip the monitoring task:
        // the per-request allowed_worker_ids is the source of truth, workers are
        // lazily registered via register_external_workers() from the C bindings,
        // and update_workers() would impose discovery-based lifecycle (add/remove)
        // on the slot tracker, conflicting with EPP ownership.
        if kv_router_config.skip_initial_worker_wait {
            tracing::info!("skipping discovery-based worker monitoring");
        } else {
            let slots_monitor = slots.clone();
            let mut monitor_rx = workers_with_configs.clone();
            let monitor_cancel_token = component.drt().child_token();
            tokio::spawn(async move {
                tracing::trace!("KvScheduler workers monitoring task started");
                let mut last_workers: HashMap<WorkerId, ModelRuntimeConfig> = HashMap::new();

                loop {
                    tokio::select! {
                        _ = monitor_cancel_token.cancelled() => {
                            tracing::trace!("KvScheduler workers monitoring task shutting down");
89
90
                            break;
                        }
91
92
93
94
95
96
                        result = monitor_rx.changed() => {
                            if result.is_err() {
                                tracing::warn!("KvScheduler: config watch sender dropped, shutting down");
                                break;
                            }
                        }
97
                    }
98

99
100
101
102
103
104
105
106
107
108
109
110
                    let current_workers = monitor_rx.borrow_and_update().clone();

                    if current_workers != last_workers {
                        let dp_range: HashMap<u64, (u32, u32)> = current_workers
                            .iter()
                            .map(|(&id, c)| {
                                (id, (c.data_parallel_start_rank, c.data_parallel_size))
                            })
                            .collect();
                        slots_monitor.update_workers(&dp_range);
                        last_workers = current_workers;
                    }
Yan Ru Pei's avatar
Yan Ru Pei committed
111
                }
112
113
            });
        }
Yan Ru Pei's avatar
Yan Ru Pei committed
114
115
116
117

        let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
        let scheduler_cancel_token = component.drt().primary_token();

118
119
120
121
122
123
124
        let policy =
            RouterSchedulingPolicy::new(kv_router_config.router_queue_policy, block_size as usize);
        tracing::info!(
            "Router queue policy: {}",
            kv_router_config.router_queue_policy
        );

125
126
127
        let queue = Arc::new(SchedulerQueue::new(
            slots.clone(),
            workers_with_configs.clone(),
128
            kv_router_config.router_queue_threshold,
129
130
            block_size,
            selector,
131
            policy,
132
133
134
        ));
        let queue_clone = queue.clone();

135
        // Background task: receive requests and periodically recheck pending
Yan Ru Pei's avatar
Yan Ru Pei committed
136
137
        tokio::spawn(async move {
            let mut request_rx = request_rx;
138
            let mut recheck_interval = tokio::time::interval(Duration::from_secs(60));
Yan Ru Pei's avatar
Yan Ru Pei committed
139
140
141
            tracing::trace!("scheduler background task started");

            loop {
142
143
144
145
146
147
148
149
150
                tokio::select! {
                    _ = scheduler_cancel_token.cancelled() => {
                        tracing::trace!("scheduler background task shutting down");
                        break;
                    }
                    request = request_rx.recv() => {
                        let Some(request) = request else {
                            tracing::warn!("scheduler shutdown");
                            break;
151
                        };
152
153
                        tracing::trace!("received request to be scheduled");
                        queue_clone.enqueue(request).await;
154
                        ROUTER_QUEUE_METRICS.set_pending(worker_type, queue_clone.pending_count());
155
156
157
                    }
                    _ = recheck_interval.tick() => {
                        queue_clone.update().await;
158
                        ROUTER_QUEUE_METRICS.set_pending(worker_type, queue_clone.pending_count());
159
160
                    }
                }
161
162
            }

163
            tracing::trace!("background endpoint subscriber shutting down");
164
165
        });

166
167
168
169
170
        Ok(KvScheduler {
            request_tx,
            slots,
            queue,
        })
171
172
    }

173
    #[expect(clippy::too_many_arguments)]
174
175
    pub async fn schedule(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
176
        maybe_request_id: Option<String>,
177
        isl_tokens: usize,
178
        token_seq: Option<Vec<SequenceHash>>,
179
        overlaps: OverlapScores,
180
        router_config_override: Option<&RouterConfigOverride>,
181
        update_states: bool,
182
        lora_name: Option<String>,
183
        priority_jump: f64,
184
        expected_output_tokens: Option<u32>,
185
186
        allowed_worker_ids: Option<HashSet<WorkerId>>,
    ) -> Result<SchedulingResponse, KvSchedulerError> {
187
188
189
        #[cfg(feature = "bench")]
        let start = Instant::now();

190
191
        let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
        let request = SchedulingRequest {
Yan Ru Pei's avatar
Yan Ru Pei committed
192
            maybe_request_id,
193
            token_seq,
194
            isl_tokens,
195
            overlaps,
196
197
            decode_blocks: HashMap::new(),
            prefill_tokens: HashMap::new(),
198
            router_config_override: router_config_override.cloned(),
199
            update_states,
200
            lora_name,
201
            priority_jump,
202
            expected_output_tokens,
203
            allowed_worker_ids,
204
            resp_tx: Some(resp_tx),
205
        };
206

207
208
209
210
        self.request_tx
            .send(request)
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
211
212
213
214

        #[cfg(feature = "bench")]
        let send_elapsed = start.elapsed();

215
        let response = resp_rx
216
            .await
217
            .map_err(|_| KvSchedulerError::SubscriberShutdown)??;
218

219
220
221
222
223
224
225
226
227
228
        #[cfg(feature = "bench")]
        let total_elapsed = start.elapsed();
        #[cfg(feature = "bench")]
        tracing::info!(
            isl_tokens,
            send_us = send_elapsed.as_micros() as u64,
            total_us = total_elapsed.as_micros() as u64,
            "scheduler.schedule completed"
        );

229
        Ok(response)
230
231
    }

232
233
234
235
236
    /// Register externally-provided workers in the slot tracker.
    pub fn register_workers(&self, worker_ids: &HashSet<WorkerId>) {
        self.queue.register_workers(worker_ids);
    }

237
238
    pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
        self.slots.add_request(req).await
239
240
    }

241
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
242
        self.slots
243
            .mark_prefill_completed(&request_id.to_string())
244
245
            .await?;
        self.queue.update().await;
246
        ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.queue.pending_count());
247
        Ok(())
248
249
    }

250
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
251
252
        self.slots.free(&request_id.to_string()).await?;
        self.queue.update().await;
253
        ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.queue.pending_count());
254
        Ok(())
255
    }
256

257
258
259
260
261
    /// Number of requests currently parked in the scheduler queue.
    pub fn pending_count(&self) -> usize {
        self.queue.pending_count()
    }

262
263
264
265
266
267
    /// Get the worker type for this scheduler ("prefill" or "decode").
    /// Used for Prometheus metric labeling.
    pub fn worker_type(&self) -> &'static str {
        self.slots.worker_type()
    }

268
    pub fn add_output_block(
269
270
271
272
273
274
275
276
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
        self.slots
            .add_output_block(&request_id.to_string(), decay_fraction)
    }

277
    pub fn get_potential_loads(
278
        &self,
279
        token_seq: Option<Vec<SequenceHash>>,
280
281
282
        isl_tokens: usize,
        overlaps: OverlapScores,
    ) -> Vec<PotentialLoad> {
283
284
285
        let (decode_blocks, prefill_tokens) =
            self.slots
                .potential_blocks_and_tokens(token_seq.as_deref(), isl_tokens, overlaps);
286

Yan Ru Pei's avatar
Yan Ru Pei committed
287
        // Get all unique WorkerWithDpRank from both hashmaps
288
        let mut workers: HashSet<dynamo_kv_router::protocols::WorkerWithDpRank> = HashSet::new();
Yan Ru Pei's avatar
Yan Ru Pei committed
289
290
        workers.extend(decode_blocks.keys().copied());
        workers.extend(prefill_tokens.keys().copied());
291
292
293

        // Create PotentialLoad for each worker
        let mut loads = Vec::new();
Yan Ru Pei's avatar
Yan Ru Pei committed
294
        for worker in workers {
295
            loads.push(PotentialLoad {
Yan Ru Pei's avatar
Yan Ru Pei committed
296
297
                worker_id: worker.worker_id,
                dp_rank: worker.dp_rank,
298
                potential_prefill_tokens: prefill_tokens
Yan Ru Pei's avatar
Yan Ru Pei committed
299
                    .get(&worker)
300
301
                    .copied()
                    .unwrap_or(isl_tokens),
Yan Ru Pei's avatar
Yan Ru Pei committed
302
                potential_decode_blocks: decode_blocks.get(&worker).copied().unwrap_or(0),
303
304
305
306
307
            });
        }

        loads
    }
308
309
310
311
312

    /// Get active request counts grouped by LORA name
    pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
        self.slots.get_active_lora_counts()
    }
313
}