policy.rs 13 KB
Newer Older
1
2
3
4
5
6
7
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::time::Duration;

use super::config::RouterQueuePolicy;
use super::types::SchedulingRequest;
8
use ordered_float::OrderedFloat;
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
/// Pluggable scheduling policy that determines queue ordering.
/// Monomorphized for zero-cost inlining on the hot comparison path.
///
/// Higher key = higher priority (natural max-heap ordering).
pub trait SchedulingPolicy: Send + Sync + 'static {
    /// Priority key stored in each queue entry.
    type Key: Ord + Eq + Clone + Send + 'static;

    /// Compute priority key at enqueue time.
    fn enqueue_key(&self, arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key;

    /// Recompute priority key during update(). Default: return old key unchanged.
    fn rekey(&self, _now: Duration, old_key: &Self::Key, _req: &SchedulingRequest) -> Self::Key {
        old_key.clone()
    }

    /// When true, queue rebuilds heap via rekey() on each update() call.
    /// When false (default), rekey path is compiled out entirely.
    const DYNAMIC: bool = false;
}

/// FCFS with priority bumps: key = priority_jump - arrival_offset.
/// Earlier arrival or higher priority_jump produces a higher key, scheduled first.
///
/// Optimizes for tail TTFT — no request waits longer than necessary,
/// since ordering is purely by (adjusted) arrival time.
pub struct FcfsPolicy;

impl SchedulingPolicy for FcfsPolicy {
    type Key = OrderedFloat<f64>;

    fn enqueue_key(&self, arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key {
        OrderedFloat(request.priority_jump.max(0.0) - arrival_offset.as_secs_f64())
    }
}

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
/// LCFS with priority bumps: key = priority_jump + arrival_offset.
/// Later arrival or higher priority_jump produces a higher key, scheduled first.
///
/// This intentionally favors newer arrivals under saturation and is mainly useful
/// for policy comparison experiments.
pub struct LcfsPolicy;

impl SchedulingPolicy for LcfsPolicy {
    type Key = OrderedFloat<f64>;

    fn enqueue_key(&self, arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key {
        OrderedFloat(request.priority_jump.max(0.0) + arrival_offset.as_secs_f64())
    }
}

60
61
/// Weighted Shortest Processing Time (Smith's rule):
/// key = (1 + priority_jump) / new_tokens, where new_tokens estimates the
62
63
64
/// actual prefill cost by subtracting the effective KV cache overlap from ISL.
/// Unpinned requests use the best available overlap. Pinned requests use only
/// the overlap for their exact target worker so queue ordering matches routing.
65
66
67
68
///
/// Optimizes for average TTFT — minimizes total weighted completion time
/// (Smith 1956). Short or high-priority requests are scheduled before
/// long low-priority ones, reducing mean latency across the batch.
69
pub struct WsptPolicy;
70
71
72
73
74
75

impl SchedulingPolicy for WsptPolicy {
    type Key = OrderedFloat<f64>;

    fn enqueue_key(&self, _arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key {
        let weight = 1.0 + request.priority_jump.max(0.0);
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        let allowed_ids = request.allowed_worker_ids.as_ref();
        let cached_tokens = request.pinned_worker.map_or_else(
            || {
                request
                    .effective_cached_tokens
                    .iter()
                    .filter(|(worker, _)| {
                        allowed_ids.is_none_or(|ids| ids.contains(&worker.worker_id))
                    })
                    .map(|(_, tokens)| *tokens)
                    .max()
                    .unwrap_or(0)
            },
            |worker| {
                request
                    .effective_cached_tokens
                    .get(&worker)
                    .copied()
                    .unwrap_or(0)
            },
        );
97
98
99
100
101
102
103
104
105
106
        let new_tokens = request.isl_tokens.saturating_sub(cached_tokens).max(1);
        OrderedFloat(weight / new_tokens as f64)
    }
}

/// Runtime-dispatched scheduling policy selected via configuration.
/// Delegates to the concrete policy variant; the branch is fully predictable
/// since the variant is fixed at queue construction time.
pub enum RouterSchedulingPolicy {
    Fcfs(FcfsPolicy),
107
    Lcfs(LcfsPolicy),
108
109
110
111
    Wspt(WsptPolicy),
}

impl RouterSchedulingPolicy {
112
    pub fn new(kind: RouterQueuePolicy) -> Self {
113
114
        match kind {
            RouterQueuePolicy::Fcfs => Self::Fcfs(FcfsPolicy),
115
            RouterQueuePolicy::Lcfs => Self::Lcfs(LcfsPolicy),
116
            RouterQueuePolicy::Wspt => Self::Wspt(WsptPolicy),
117
118
119
120
121
122
123
124
125
126
        }
    }
}

impl SchedulingPolicy for RouterSchedulingPolicy {
    type Key = OrderedFloat<f64>;

    fn enqueue_key(&self, arrival_offset: Duration, request: &SchedulingRequest) -> Self::Key {
        match self {
            Self::Fcfs(p) => p.enqueue_key(arrival_offset, request),
127
            Self::Lcfs(p) => p.enqueue_key(arrival_offset, request),
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
            Self::Wspt(p) => p.enqueue_key(arrival_offset, request),
        }
    }
}

#[cfg(test)]
mod tests {
    use rustc_hash::FxHashMap;

    use super::*;
    use crate::protocols::{OverlapScores, WorkerWithDpRank};

    fn request_with(
        isl_tokens: usize,
        priority_jump: f64,
        overlaps: OverlapScores,
    ) -> SchedulingRequest {
145
146
147
148
149
150
151
152
153
154
        let effective_overlap_blocks = overlaps
            .scores
            .iter()
            .map(|(worker, overlap)| (*worker, *overlap as f64))
            .collect();
        let effective_cached_tokens = overlaps
            .scores
            .iter()
            .map(|(worker, overlap)| (*worker, *overlap as usize * 16))
            .collect();
155
156
157
158
        SchedulingRequest {
            maybe_request_id: None,
            token_seq: None,
            isl_tokens,
159
160
161
162
            tier_overlap_blocks: Default::default(),
            effective_overlap_blocks,
            effective_cached_tokens,
            tree_sizes: std::collections::HashMap::new(),
163
164
            decode_blocks: FxHashMap::default(),
            prefill_tokens: FxHashMap::default(),
165
            track_prefill_tokens: true,
166
167
168
169
170
            router_config_override: None,
            update_states: false,
            lora_name: None,
            priority_jump,
            expected_output_tokens: None,
171
            pinned_worker: None,
172
            allowed_worker_ids: None,
173
            shared_cache_hits: None,
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
            resp_tx: None,
        }
    }

    fn overlaps_from(scores: &[(u64, u32)]) -> OverlapScores {
        let mut map = FxHashMap::default();
        for &(worker_id, score) in scores {
            map.insert(WorkerWithDpRank::new(worker_id, 0), score);
        }
        OverlapScores {
            scores: map,
            frequencies: vec![],
            tree_sizes: FxHashMap::default(),
        }
    }

    // ---- FCFS policy tests ----

    #[test]
    fn fcfs_earlier_arrival_scheduled_first() {
        let policy = FcfsPolicy;
        let req = request_with(512, 0.0, OverlapScores::default());
        let early = policy.enqueue_key(Duration::from_secs(1), &req);
        let late = policy.enqueue_key(Duration::from_secs(10), &req);
        assert!(early > late, "earlier arrival should have higher key");
    }

    #[test]
    fn fcfs_priority_jump_promotes() {
        let policy = FcfsPolicy;
        // Both arrive at the same wall-clock offset (10s), but one has priority.
        let normal = request_with(512, 0.0, OverlapScores::default());
        let boosted = request_with(512, 100.0, OverlapScores::default());
        let t = Duration::from_secs(10);
        let key_normal = policy.enqueue_key(t, &normal);
        let key_boosted = policy.enqueue_key(t, &boosted);
        assert!(
            key_boosted > key_normal,
            "priority_jump should produce a higher key"
        );
    }

    #[test]
    fn fcfs_priority_jump_beats_earlier_arrival() {
        let policy = FcfsPolicy;
        // Request A arrived at t=0 with no priority.
        // Request B arrived at t=5 with priority_jump=50s.
        // B should be scheduled first despite arriving later.
        let a = request_with(512, 0.0, OverlapScores::default());
        let b = request_with(512, 50.0, OverlapScores::default());
        let key_a = policy.enqueue_key(Duration::from_secs(0), &a);
        let key_b = policy.enqueue_key(Duration::from_secs(5), &b);
        assert!(key_b > key_a);
    }

229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    #[test]
    fn lcfs_later_arrival_scheduled_first() {
        let policy = LcfsPolicy;
        let req = request_with(512, 0.0, OverlapScores::default());
        let early = policy.enqueue_key(Duration::from_secs(1), &req);
        let late = policy.enqueue_key(Duration::from_secs(10), &req);
        assert!(late > early, "later arrival should have higher key");
    }

    #[test]
    fn lcfs_priority_jump_promotes() {
        let policy = LcfsPolicy;
        let normal = request_with(512, 0.0, OverlapScores::default());
        let boosted = request_with(512, 100.0, OverlapScores::default());
        let t = Duration::from_secs(10);
        let key_normal = policy.enqueue_key(t, &normal);
        let key_boosted = policy.enqueue_key(t, &boosted);
        assert!(
            key_boosted > key_normal,
            "priority_jump should produce a higher key"
        );
    }

    #[test]
    fn router_scheduling_policy_matches_fcfs_and_lcfs_ordering() {
        let req = request_with(512, 0.0, OverlapScores::default());
        let early = Duration::from_secs(1);
        let late = Duration::from_secs(10);

258
        let fcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Fcfs);
259
260
        assert!(fcfs.enqueue_key(early, &req) > fcfs.enqueue_key(late, &req));

261
        let lcfs = RouterSchedulingPolicy::new(RouterQueuePolicy::Lcfs);
262
263
264
        assert!(lcfs.enqueue_key(late, &req) > lcfs.enqueue_key(early, &req));
    }

265
266
267
268
    // ---- WSPT policy tests ----

    #[test]
    fn wspt_shorter_request_scheduled_first() {
269
        let policy = WsptPolicy;
270
271
272
273
274
275
276
277
278
279
280
        let short = request_with(100, 0.0, OverlapScores::default());
        let long = request_with(1000, 0.0, OverlapScores::default());
        let t = Duration::ZERO;
        assert!(
            policy.enqueue_key(t, &short) > policy.enqueue_key(t, &long),
            "shorter request should have higher key"
        );
    }

    #[test]
    fn wspt_overlap_reduces_effective_cost() {
281
        let policy = WsptPolicy;
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        // Both 1024 ISL tokens, but one has 60 blocks cached (960 tokens).
        let no_cache = request_with(1024, 0.0, OverlapScores::default());
        let cached = request_with(1024, 0.0, overlaps_from(&[(0, 60)]));
        let t = Duration::ZERO;
        let key_no_cache = policy.enqueue_key(t, &no_cache);
        let key_cached = policy.enqueue_key(t, &cached);
        assert!(
            key_cached > key_no_cache,
            "request with overlap should have higher key (fewer new tokens)"
        );
    }

    #[test]
    fn wspt_priority_promotes() {
296
        let policy = WsptPolicy;
297
298
299
300
301
302
303
304
305
306
307
        let normal = request_with(512, 0.0, OverlapScores::default());
        let boosted = request_with(512, 5.0, OverlapScores::default());
        let t = Duration::ZERO;
        assert!(
            policy.enqueue_key(t, &boosted) > policy.enqueue_key(t, &normal),
            "priority_jump should increase key"
        );
    }

    #[test]
    fn wspt_uses_max_overlap() {
308
        let policy = WsptPolicy;
309
310
311
312
313
314
315
316
317
318
319
320
        // 4 workers with overlaps [10, 20, 50, 60]. max = 60.
        // new_tokens = 1024 - 60*16 = 64
        let req = request_with(
            1024,
            0.0,
            overlaps_from(&[(0, 10), (1, 20), (2, 50), (3, 60)]),
        );
        let key = policy.enqueue_key(Duration::ZERO, &req);
        let expected = OrderedFloat(1.0 / 64.0);
        assert_eq!(key, expected);
    }

321
322
    #[test]
    fn wspt_uses_pinned_worker_overlap_when_present() {
323
        let policy = WsptPolicy;
324
325
326
327
328
329
330
331
332
333
        let mut req = request_with(1024, 0.0, overlaps_from(&[(0, 60), (1, 1)]));
        req.pinned_worker = Some(WorkerWithDpRank::new(1, 0));

        let key = policy.enqueue_key(Duration::ZERO, &req);
        let expected = OrderedFloat(1.0 / 1008.0);
        assert_eq!(key, expected);
    }

    #[test]
    fn wspt_missing_pinned_overlap_uses_zero() {
334
        let policy = WsptPolicy;
335
336
337
338
339
340
341
342
        let mut req = request_with(1024, 0.0, overlaps_from(&[(0, 60)]));
        req.pinned_worker = Some(WorkerWithDpRank::new(1, 0));

        let key = policy.enqueue_key(Duration::ZERO, &req);
        let expected = OrderedFloat(1.0 / 1024.0);
        assert_eq!(key, expected);
    }

343
344
    #[test]
    fn wspt_no_overlap_falls_back_to_isl() {
345
        let policy = WsptPolicy;
346
347
348
349
350
351
352
353
        let req = request_with(512, 0.0, OverlapScores::default());
        let key = policy.enqueue_key(Duration::ZERO, &req);
        let expected = OrderedFloat(1.0 / 512.0);
        assert_eq!(key, expected);
    }

    #[test]
    fn wspt_full_overlap_clamps_to_one() {
354
        let policy = WsptPolicy;
355
356
357
358
359
360
361
        // 512 tokens, 64 blocks cached = 1024 cached tokens > ISL → saturating_sub → 0 → max(1)
        let req = request_with(512, 0.0, overlaps_from(&[(0, 64)]));
        let key = policy.enqueue_key(Duration::ZERO, &req);
        let expected = OrderedFloat(1.0 / 1.0);
        assert_eq!(key, expected);
    }
}