queue.rs 5.89 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::cmp::Ordering;
use std::collections::{BinaryHeap, VecDeque};
use std::sync::Arc;
use std::time::{Duration, Instant};

use tokio::sync::{Mutex, Notify};

use crate::discovery::RuntimeConfigWatch;

use super::protocols::WorkerWithDpRank;
use super::scheduler::SchedulingRequest;
use super::sequence::ActiveSequencesMultiWorker;

/// Large default for max_num_batched_tokens when not configured (effectively disables queueing for that worker)
const DEFAULT_MAX_BATCHED_TOKENS: u64 = 10_000_000;

/// Entry in the priority queue, ordered by effective arrival time (lower = higher priority).
/// Effective arrival = elapsed time since queue start minus `priority_jump`.
struct QueueEntry {
    effective_offset: Duration,
    request: SchedulingRequest,
}

impl Eq for QueueEntry {}

impl PartialEq for QueueEntry {
    fn eq(&self, other: &Self) -> bool {
        self.effective_offset == other.effective_offset
    }
}

impl Ord for QueueEntry {
    fn cmp(&self, other: &Self) -> Ordering {
        // BinaryHeap is a max-heap; reverse so lower effective_offset = higher priority
        other.effective_offset.cmp(&self.effective_offset)
    }
}

impl PartialOrd for QueueEntry {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        Some(self.cmp(other))
    }
}

/// Queue for managing scheduling requests with interior mutability.
/// Requests are held in `pending` when all workers are busy, and moved to `ready` when capacity frees up.
/// If queueing is disabled (threshold_frac is None), all requests go directly to `ready`.
/// Requests are ordered by effective arrival time: arrival_offset - priority_jump.
pub struct SchedulerQueue {
    pending: Mutex<BinaryHeap<QueueEntry>>,
    ready: Mutex<VecDeque<SchedulingRequest>>,
    slots: Arc<ActiveSequencesMultiWorker>,
    workers_with_configs: RuntimeConfigWatch,
    ready_notify: Arc<Notify>,
    /// Cached threshold fraction; None means queueing is disabled.
    threshold_frac: Option<f64>,
    /// Reference instant for computing arrival offsets.
    start_time: Instant,
}

impl SchedulerQueue {
    pub fn new(
        slots: Arc<ActiveSequencesMultiWorker>,
        workers_with_configs: RuntimeConfigWatch,
        ready_notify: Arc<Notify>,
        threshold_frac: Option<f64>,
    ) -> Self {
        if let Some(frac) = threshold_frac {
            tracing::info!("Router queue enabled with threshold fraction {frac}");
        }
        Self {
            pending: Mutex::new(BinaryHeap::new()),
            ready: Mutex::new(VecDeque::new()),
            slots,
            workers_with_configs,
            ready_notify,
            threshold_frac,
            start_time: Instant::now(),
        }
    }

    /// Build a QueueEntry for a request, computing its effective arrival offset.
    fn make_entry(&self, request: SchedulingRequest) -> QueueEntry {
        let arrival_offset = self.start_time.elapsed();
        let jump = Duration::from_secs_f64(request.priority_jump.max(0.0));
        let effective_offset = arrival_offset.saturating_sub(jump);
        QueueEntry {
            effective_offset,
            request,
        }
    }

    /// Enqueue a new request.
    /// If queueing is disabled (threshold not set), fast-track to ready.
    /// Otherwise, check busy condition and place in ready or pending.
    pub async fn enqueue(&self, request: SchedulingRequest) {
        let Some(threshold) = self.threshold_frac else {
            self.ready.lock().await.push_back(request);
            return;
        };

        if self.all_workers_busy(threshold).await {
            tracing::debug!("all workers busy, queueing request");
            let entry = self.make_entry(request);
            self.pending.lock().await.push(entry);
        } else {
            self.ready.lock().await.push_back(request);
        }
    }

    /// Try to dequeue the highest-priority request from the ready queue.
    pub async fn try_dequeue(&self) -> Option<SchedulingRequest> {
        self.ready.lock().await.pop_front()
    }

    /// Called on prefill_complete/free. Re-checks pending requests and moves eligible to ready.
    /// Notifies scheduler loop if any requests were moved.
    pub async fn update(&self) {
        let Some(threshold) = self.threshold_frac else {
            return;
        };

        let mut moved = false;
        loop {
            if self.pending.lock().await.is_empty() {
                break;
            }
            if self.all_workers_busy(threshold).await {
                break;
            }
            let entry = self.pending.lock().await.pop();
            if let Some(entry) = entry {
                tracing::debug!("moving request from pending to ready");
                self.ready.lock().await.push_back(entry.request);
                moved = true;
            } else {
                break;
            }
        }
        if moved {
            self.ready_notify.notify_one();
        }
    }

    /// Check if all workers are busy based on threshold.
    /// Returns true only if ALL workers exceed the threshold (no worker has capacity).
    async fn all_workers_busy(&self, threshold: f64) -> bool {
        let active_tokens = self.slots.active_tokens().await;
        let configs = self.workers_with_configs.borrow();

        for (&worker_id, config) in configs.iter() {
            let dp_size = config.data_parallel_size;
            let max_batched = config
                .max_num_batched_tokens
                .unwrap_or(DEFAULT_MAX_BATCHED_TOKENS);

            for dp_rank in 0..dp_size {
                let worker = WorkerWithDpRank::new(worker_id, dp_rank);
                let tokens = active_tokens.get(&worker).copied().unwrap_or(0);
                if (tokens as f64) <= threshold * (max_batched as f64) {
                    return false;
                }
            }
        }
        true
    }
}