scheduler.rs 10.1 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
// SPDX-License-Identifier: Apache-2.0
3

4
5
6
7
8
pub use dynamo_kv_router::scheduling::{
    KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse,
};
pub use dynamo_kv_router::selector::DefaultWorkerSelector;

9
10
11
use super::KvRouterConfig;
use super::RouterConfigOverride;
use super::WorkerSelector;
12
use super::metrics::ROUTER_QUEUE_METRICS;
13
use super::protocols::{OverlapScores, WorkerId};
14
use super::queue::SchedulerQueue;
15
16
17
use super::sequence::{
    ActiveSequencesMulti, SequenceError, SequenceRequest, create_multi_worker_sequences,
};
18
use crate::discovery::RuntimeConfigWatch;
19
use crate::local_model::runtime_config::ModelRuntimeConfig;
20
use anyhow::Result;
21
use dynamo_runtime::component::Component;
Yan Ru Pei's avatar
Yan Ru Pei committed
22
use dynamo_runtime::traits::DistributedRuntimeProvider;
23
use std::collections::{HashMap, HashSet};
24
25
use std::sync::Arc;
use std::time::Duration;
26
27
#[cfg(feature = "bench")]
use std::time::Instant;
28

29
use dynamo_tokens::SequenceHash;
30
31
32

pub struct KvScheduler {
    request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
33
    slots: Arc<ActiveSequencesMulti>,
34
    queue: Arc<SchedulerQueue>,
35
36
37
38
}

impl KvScheduler {
    pub async fn start(
39
        component: Component,
40
        block_size: u32,
41
        workers_with_configs: RuntimeConfigWatch,
42
        selector: Option<Box<WorkerSelector>>,
43
        kv_router_config: &KvRouterConfig,
44
        worker_type: &'static str,
45
    ) -> Result<Self, KvSchedulerError> {
46
        let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::new(None, worker_type)));
47

48
49
50
51
        // Get initial workers from watch receiver.
        // Caller must ensure at least one worker is present (via wait_for).
        let initial_workers: HashMap<WorkerId, ModelRuntimeConfig> =
            workers_with_configs.borrow().clone();
52

53
        let router_id = component.drt().discovery().instance_id();
54
55
56
57
58
59
60
61
62
63
        let slots = create_multi_worker_sequences(
            component.clone(),
            block_size as usize,
            initial_workers,
            kv_router_config.router_replica_sync,
            router_id,
            worker_type,
        )
        .await
        .map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?;
64

65
        // Spawn background task to sync slots when the watch value changes.
Yan Ru Pei's avatar
Yan Ru Pei committed
66
        let slots_monitor = slots.clone();
67
        let mut monitor_rx = workers_with_configs.clone();
68
        let monitor_cancel_token = component.drt().child_token();
69
        tokio::spawn(async move {
70
            tracing::trace!("KvScheduler workers monitoring task started");
71
            let mut last_workers: HashMap<WorkerId, ModelRuntimeConfig> = HashMap::new();
72

73
            loop {
Yan Ru Pei's avatar
Yan Ru Pei committed
74
75
                tokio::select! {
                    _ = monitor_cancel_token.cancelled() => {
76
                        tracing::trace!("KvScheduler workers monitoring task shutting down");
77
78
                        break;
                    }
79
                    result = monitor_rx.changed() => {
80
81
82
83
84
                        if result.is_err() {
                            tracing::warn!("KvScheduler: config watch sender dropped, shutting down");
                            break;
                        }
                    }
85
86
                }

87
88
89
                let current_workers = monitor_rx.borrow_and_update().clone();

                if current_workers != last_workers {
90
                    let dp_range: HashMap<u64, (u32, u32)> = current_workers
91
                        .iter()
92
                        .map(|(&id, c)| (id, (c.data_parallel_start_rank, c.data_parallel_size)))
93
                        .collect();
94
                    slots_monitor.update_workers(&dp_range);
95
                    last_workers = current_workers;
Yan Ru Pei's avatar
Yan Ru Pei committed
96
97
98
99
100
101
102
                }
            }
        });

        let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
        let scheduler_cancel_token = component.drt().primary_token();

103
104
105
        let queue = Arc::new(SchedulerQueue::new(
            slots.clone(),
            workers_with_configs.clone(),
106
            kv_router_config.router_queue_threshold,
107
108
            block_size,
            selector,
109
110
111
        ));
        let queue_clone = queue.clone();

112
        // Background task: receive requests and periodically recheck pending
Yan Ru Pei's avatar
Yan Ru Pei committed
113
114
        tokio::spawn(async move {
            let mut request_rx = request_rx;
115
            let mut recheck_interval = tokio::time::interval(Duration::from_secs(60));
Yan Ru Pei's avatar
Yan Ru Pei committed
116
117
118
            tracing::trace!("scheduler background task started");

            loop {
119
120
121
122
123
124
125
126
127
                tokio::select! {
                    _ = scheduler_cancel_token.cancelled() => {
                        tracing::trace!("scheduler background task shutting down");
                        break;
                    }
                    request = request_rx.recv() => {
                        let Some(request) = request else {
                            tracing::warn!("scheduler shutdown");
                            break;
128
                        };
129
130
                        tracing::trace!("received request to be scheduled");
                        queue_clone.enqueue(request).await;
131
                        ROUTER_QUEUE_METRICS.set_pending(worker_type, queue_clone.pending_count());
132
133
134
                    }
                    _ = recheck_interval.tick() => {
                        queue_clone.update().await;
135
                        ROUTER_QUEUE_METRICS.set_pending(worker_type, queue_clone.pending_count());
136
137
                    }
                }
138
139
            }

140
            tracing::trace!("background endpoint subscriber shutting down");
141
142
        });

143
144
145
146
147
        Ok(KvScheduler {
            request_tx,
            slots,
            queue,
        })
148
149
    }

150
    #[expect(clippy::too_many_arguments)]
151
152
    pub async fn schedule(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
153
        maybe_request_id: Option<String>,
154
        isl_tokens: usize,
155
        token_seq: Option<Vec<SequenceHash>>,
156
        overlaps: OverlapScores,
157
        router_config_override: Option<&RouterConfigOverride>,
158
        update_states: bool,
159
        lora_name: Option<String>,
160
        priority_jump: f64,
161
        expected_output_tokens: Option<u32>,
162
163
        allowed_worker_ids: Option<HashSet<WorkerId>>,
    ) -> Result<SchedulingResponse, KvSchedulerError> {
164
165
166
        #[cfg(feature = "bench")]
        let start = Instant::now();

167
168
        let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
        let request = SchedulingRequest {
Yan Ru Pei's avatar
Yan Ru Pei committed
169
            maybe_request_id,
170
            token_seq,
171
            isl_tokens,
172
            overlaps,
173
174
            decode_blocks: HashMap::new(),
            prefill_tokens: HashMap::new(),
175
            router_config_override: router_config_override.cloned(),
176
            update_states,
177
            lora_name,
178
            priority_jump,
179
            expected_output_tokens,
180
            allowed_worker_ids,
181
            resp_tx: Some(resp_tx),
182
        };
183

184
185
186
187
        self.request_tx
            .send(request)
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
188
189
190
191

        #[cfg(feature = "bench")]
        let send_elapsed = start.elapsed();

192
        let response = resp_rx
193
            .await
194
            .map_err(|_| KvSchedulerError::SubscriberShutdown)??;
195

196
197
198
199
200
201
202
203
204
205
        #[cfg(feature = "bench")]
        let total_elapsed = start.elapsed();
        #[cfg(feature = "bench")]
        tracing::info!(
            isl_tokens,
            send_us = send_elapsed.as_micros() as u64,
            total_us = total_elapsed.as_micros() as u64,
            "scheduler.schedule completed"
        );

206
        Ok(response)
207
208
    }

209
210
    pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
        self.slots.add_request(req).await
211
212
    }

213
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
214
        self.slots
215
            .mark_prefill_completed(&request_id.to_string())
216
217
            .await?;
        self.queue.update().await;
218
        ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.queue.pending_count());
219
        Ok(())
220
221
    }

222
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
223
224
        self.slots.free(&request_id.to_string()).await?;
        self.queue.update().await;
225
        ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.queue.pending_count());
226
        Ok(())
227
    }
228

229
230
231
232
233
    /// Number of requests currently parked in the scheduler queue.
    pub fn pending_count(&self) -> usize {
        self.queue.pending_count()
    }

234
235
236
237
238
239
    /// Get the worker type for this scheduler ("prefill" or "decode").
    /// Used for Prometheus metric labeling.
    pub fn worker_type(&self) -> &'static str {
        self.slots.worker_type()
    }

240
    pub fn add_output_block(
241
242
243
244
245
246
247
248
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
        self.slots
            .add_output_block(&request_id.to_string(), decay_fraction)
    }

249
    pub fn get_potential_loads(
250
        &self,
251
        token_seq: Option<Vec<SequenceHash>>,
252
253
254
        isl_tokens: usize,
        overlaps: OverlapScores,
    ) -> Vec<PotentialLoad> {
255
256
257
        let (decode_blocks, prefill_tokens) =
            self.slots
                .potential_blocks_and_tokens(token_seq.as_deref(), isl_tokens, overlaps);
258

Yan Ru Pei's avatar
Yan Ru Pei committed
259
        // Get all unique WorkerWithDpRank from both hashmaps
260
        let mut workers: HashSet<dynamo_kv_router::protocols::WorkerWithDpRank> = HashSet::new();
Yan Ru Pei's avatar
Yan Ru Pei committed
261
262
        workers.extend(decode_blocks.keys().copied());
        workers.extend(prefill_tokens.keys().copied());
263
264
265

        // Create PotentialLoad for each worker
        let mut loads = Vec::new();
Yan Ru Pei's avatar
Yan Ru Pei committed
266
        for worker in workers {
267
            loads.push(PotentialLoad {
Yan Ru Pei's avatar
Yan Ru Pei committed
268
269
                worker_id: worker.worker_id,
                dp_rank: worker.dp_rank,
270
                potential_prefill_tokens: prefill_tokens
Yan Ru Pei's avatar
Yan Ru Pei committed
271
                    .get(&worker)
272
273
                    .copied()
                    .unwrap_or(isl_tokens),
Yan Ru Pei's avatar
Yan Ru Pei committed
274
                potential_decode_blocks: decode_blocks.get(&worker).copied().unwrap_or(0),
275
276
277
278
279
            });
        }

        loads
    }
280
281
282
283
284

    /// Get active request counts grouped by LORA name
    pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
        self.slots.get_active_lora_counts()
    }
285
}