scheduler.rs 10.4 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
// SPDX-License-Identifier: Apache-2.0
3

4
pub use dynamo_kv_router::scheduling::policy::RouterSchedulingPolicy;
5
6
7
8
9
pub use dynamo_kv_router::scheduling::{
    KvSchedulerError, PotentialLoad, SchedulingRequest, SchedulingResponse,
};
pub use dynamo_kv_router::selector::DefaultWorkerSelector;

10
11
12
use super::KvRouterConfig;
use super::RouterConfigOverride;
use super::WorkerSelector;
13
use super::metrics::ROUTER_QUEUE_METRICS;
14
use super::protocols::{OverlapScores, WorkerId};
15
use super::queue::SchedulerQueue;
16
17
18
use super::sequence::{
    ActiveSequencesMulti, SequenceError, SequenceRequest, create_multi_worker_sequences,
};
19
use crate::discovery::RuntimeConfigWatch;
20
use crate::local_model::runtime_config::ModelRuntimeConfig;
21
use anyhow::Result;
22
use dynamo_runtime::component::Component;
Yan Ru Pei's avatar
Yan Ru Pei committed
23
use dynamo_runtime::traits::DistributedRuntimeProvider;
24
use std::collections::{HashMap, HashSet};
25
26
use std::sync::Arc;
use std::time::Duration;
27
28
#[cfg(feature = "bench")]
use std::time::Instant;
29

30
use dynamo_tokens::SequenceHash;
31
32
33

pub struct KvScheduler {
    request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
34
    slots: Arc<ActiveSequencesMulti>,
35
    queue: Arc<SchedulerQueue>,
36
37
38
39
}

impl KvScheduler {
    pub async fn start(
40
        component: Component,
41
        block_size: u32,
42
        workers_with_configs: RuntimeConfigWatch,
43
        selector: Option<Box<WorkerSelector>>,
44
        kv_router_config: &KvRouterConfig,
45
        worker_type: &'static str,
46
    ) -> Result<Self, KvSchedulerError> {
47
        let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::new(None, worker_type)));
48

49
50
51
52
        // Get initial workers from watch receiver.
        // Caller must ensure at least one worker is present (via wait_for).
        let initial_workers: HashMap<WorkerId, ModelRuntimeConfig> =
            workers_with_configs.borrow().clone();
53

54
        let router_id = component.drt().discovery().instance_id();
55
56
57
58
59
60
61
62
63
64
        let slots = create_multi_worker_sequences(
            component.clone(),
            block_size as usize,
            initial_workers,
            kv_router_config.router_replica_sync,
            router_id,
            worker_type,
        )
        .await
        .map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?;
65

66
        // Spawn background task to sync slots when the watch value changes.
Yan Ru Pei's avatar
Yan Ru Pei committed
67
        let slots_monitor = slots.clone();
68
        let mut monitor_rx = workers_with_configs.clone();
69
        let monitor_cancel_token = component.drt().child_token();
70
        tokio::spawn(async move {
71
            tracing::trace!("KvScheduler workers monitoring task started");
72
            let mut last_workers: HashMap<WorkerId, ModelRuntimeConfig> = HashMap::new();
73

74
            loop {
Yan Ru Pei's avatar
Yan Ru Pei committed
75
76
                tokio::select! {
                    _ = monitor_cancel_token.cancelled() => {
77
                        tracing::trace!("KvScheduler workers monitoring task shutting down");
78
79
                        break;
                    }
80
                    result = monitor_rx.changed() => {
81
82
83
84
85
                        if result.is_err() {
                            tracing::warn!("KvScheduler: config watch sender dropped, shutting down");
                            break;
                        }
                    }
86
87
                }

88
89
90
                let current_workers = monitor_rx.borrow_and_update().clone();

                if current_workers != last_workers {
91
                    let dp_range: HashMap<u64, (u32, u32)> = current_workers
92
                        .iter()
93
                        .map(|(&id, c)| (id, (c.data_parallel_start_rank, c.data_parallel_size)))
94
                        .collect();
95
                    slots_monitor.update_workers(&dp_range);
96
                    last_workers = current_workers;
Yan Ru Pei's avatar
Yan Ru Pei committed
97
98
99
100
101
102
103
                }
            }
        });

        let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
        let scheduler_cancel_token = component.drt().primary_token();

104
105
106
107
108
109
110
        let policy =
            RouterSchedulingPolicy::new(kv_router_config.router_queue_policy, block_size as usize);
        tracing::info!(
            "Router queue policy: {}",
            kv_router_config.router_queue_policy
        );

111
112
113
        let queue = Arc::new(SchedulerQueue::new(
            slots.clone(),
            workers_with_configs.clone(),
114
            kv_router_config.router_queue_threshold,
115
116
            block_size,
            selector,
117
            policy,
118
119
120
        ));
        let queue_clone = queue.clone();

121
        // Background task: receive requests and periodically recheck pending
Yan Ru Pei's avatar
Yan Ru Pei committed
122
123
        tokio::spawn(async move {
            let mut request_rx = request_rx;
124
            let mut recheck_interval = tokio::time::interval(Duration::from_secs(60));
Yan Ru Pei's avatar
Yan Ru Pei committed
125
126
127
            tracing::trace!("scheduler background task started");

            loop {
128
129
130
131
132
133
134
135
136
                tokio::select! {
                    _ = scheduler_cancel_token.cancelled() => {
                        tracing::trace!("scheduler background task shutting down");
                        break;
                    }
                    request = request_rx.recv() => {
                        let Some(request) = request else {
                            tracing::warn!("scheduler shutdown");
                            break;
137
                        };
138
139
                        tracing::trace!("received request to be scheduled");
                        queue_clone.enqueue(request).await;
140
                        ROUTER_QUEUE_METRICS.set_pending(worker_type, queue_clone.pending_count());
141
142
143
                    }
                    _ = recheck_interval.tick() => {
                        queue_clone.update().await;
144
                        ROUTER_QUEUE_METRICS.set_pending(worker_type, queue_clone.pending_count());
145
146
                    }
                }
147
148
            }

149
            tracing::trace!("background endpoint subscriber shutting down");
150
151
        });

152
153
154
155
156
        Ok(KvScheduler {
            request_tx,
            slots,
            queue,
        })
157
158
    }

159
    #[expect(clippy::too_many_arguments)]
160
161
    pub async fn schedule(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
162
        maybe_request_id: Option<String>,
163
        isl_tokens: usize,
164
        token_seq: Option<Vec<SequenceHash>>,
165
        overlaps: OverlapScores,
166
        router_config_override: Option<&RouterConfigOverride>,
167
        update_states: bool,
168
        lora_name: Option<String>,
169
        priority_jump: f64,
170
        expected_output_tokens: Option<u32>,
171
172
        allowed_worker_ids: Option<HashSet<WorkerId>>,
    ) -> Result<SchedulingResponse, KvSchedulerError> {
173
174
175
        #[cfg(feature = "bench")]
        let start = Instant::now();

176
177
        let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
        let request = SchedulingRequest {
Yan Ru Pei's avatar
Yan Ru Pei committed
178
            maybe_request_id,
179
            token_seq,
180
            isl_tokens,
181
            overlaps,
182
183
            decode_blocks: HashMap::new(),
            prefill_tokens: HashMap::new(),
184
            router_config_override: router_config_override.cloned(),
185
            update_states,
186
            lora_name,
187
            priority_jump,
188
            expected_output_tokens,
189
            allowed_worker_ids,
190
            resp_tx: Some(resp_tx),
191
        };
192

193
194
195
196
        self.request_tx
            .send(request)
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
197
198
199
200

        #[cfg(feature = "bench")]
        let send_elapsed = start.elapsed();

201
        let response = resp_rx
202
            .await
203
            .map_err(|_| KvSchedulerError::SubscriberShutdown)??;
204

205
206
207
208
209
210
211
212
213
214
        #[cfg(feature = "bench")]
        let total_elapsed = start.elapsed();
        #[cfg(feature = "bench")]
        tracing::info!(
            isl_tokens,
            send_us = send_elapsed.as_micros() as u64,
            total_us = total_elapsed.as_micros() as u64,
            "scheduler.schedule completed"
        );

215
        Ok(response)
216
217
    }

218
219
    pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
        self.slots.add_request(req).await
220
221
    }

222
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
223
        self.slots
224
            .mark_prefill_completed(&request_id.to_string())
225
226
            .await?;
        self.queue.update().await;
227
        ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.queue.pending_count());
228
        Ok(())
229
230
    }

231
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
232
233
        self.slots.free(&request_id.to_string()).await?;
        self.queue.update().await;
234
        ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.queue.pending_count());
235
        Ok(())
236
    }
237

238
239
240
241
242
    /// Number of requests currently parked in the scheduler queue.
    pub fn pending_count(&self) -> usize {
        self.queue.pending_count()
    }

243
244
245
246
247
248
    /// Get the worker type for this scheduler ("prefill" or "decode").
    /// Used for Prometheus metric labeling.
    pub fn worker_type(&self) -> &'static str {
        self.slots.worker_type()
    }

249
    pub fn add_output_block(
250
251
252
253
254
255
256
257
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
        self.slots
            .add_output_block(&request_id.to_string(), decay_fraction)
    }

258
    pub fn get_potential_loads(
259
        &self,
260
        token_seq: Option<Vec<SequenceHash>>,
261
262
263
        isl_tokens: usize,
        overlaps: OverlapScores,
    ) -> Vec<PotentialLoad> {
264
265
266
        let (decode_blocks, prefill_tokens) =
            self.slots
                .potential_blocks_and_tokens(token_seq.as_deref(), isl_tokens, overlaps);
267

Yan Ru Pei's avatar
Yan Ru Pei committed
268
        // Get all unique WorkerWithDpRank from both hashmaps
269
        let mut workers: HashSet<dynamo_kv_router::protocols::WorkerWithDpRank> = HashSet::new();
Yan Ru Pei's avatar
Yan Ru Pei committed
270
271
        workers.extend(decode_blocks.keys().copied());
        workers.extend(prefill_tokens.keys().copied());
272
273
274

        // Create PotentialLoad for each worker
        let mut loads = Vec::new();
Yan Ru Pei's avatar
Yan Ru Pei committed
275
        for worker in workers {
276
            loads.push(PotentialLoad {
Yan Ru Pei's avatar
Yan Ru Pei committed
277
278
                worker_id: worker.worker_id,
                dp_rank: worker.dp_rank,
279
                potential_prefill_tokens: prefill_tokens
Yan Ru Pei's avatar
Yan Ru Pei committed
280
                    .get(&worker)
281
282
                    .copied()
                    .unwrap_or(isl_tokens),
Yan Ru Pei's avatar
Yan Ru Pei committed
283
                potential_decode_blocks: decode_blocks.get(&worker).copied().unwrap_or(0),
284
285
286
287
288
            });
        }

        loads
    }
289
290
291
292
293

    /// Get active request counts grouped by LORA name
    pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
        self.slots.get_active_lora_counts()
    }
294
}