scheduler.rs 8.33 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
// SPDX-License-Identifier: Apache-2.0
3

4
use dynamo_kv_router::protocols::SharedCacheHits;
5
pub use dynamo_kv_router::scheduling::policy::RouterSchedulingPolicy;
6
pub use dynamo_kv_router::scheduling::{
7
    KvSchedulerError, LocalScheduler, PotentialLoad, SchedulingRequest, SchedulingResponse,
8
    TierOverlapBlocks,
9
10
};
pub use dynamo_kv_router::selector::DefaultWorkerSelector;
11
use dynamo_kv_router::selector::WorkerSelector as WorkerSelectorTrait;
12

13
use super::metrics::ROUTER_QUEUE_METRICS;
14
use super::sequence::{
15
    RuntimeSequencePublisher, SequenceError, SequenceRequest, create_multi_worker_sequences,
16
};
17
use crate::discovery::RuntimeConfigWatch;
18
use crate::local_model::runtime_config::ModelRuntimeConfig;
19
use anyhow::Result;
20
use dynamo_kv_router::{
21
    PrefillLoadEstimator,
22
    config::{KvRouterConfig, RouterConfigOverride},
23
    protocols::{WorkerId, WorkerWithDpRank},
24
};
25
use dynamo_runtime::component::Component;
Yan Ru Pei's avatar
Yan Ru Pei committed
26
use dynamo_runtime::traits::DistributedRuntimeProvider;
27
use dynamo_tokens::SequenceHash;
28
use std::collections::{HashMap, HashSet};
29
30
use std::sync::Arc;
use std::time::Duration;
31

32
33
34
35
36
37
38
pub struct KvScheduler<Sel = DefaultWorkerSelector>
where
    Sel: WorkerSelectorTrait<ModelRuntimeConfig>,
{
    inner: Arc<
        LocalScheduler<RuntimeSequencePublisher, ModelRuntimeConfig, RouterSchedulingPolicy, Sel>,
    >,
39
40
}

41
42
43
44
impl<Sel> KvScheduler<Sel>
where
    Sel: WorkerSelectorTrait<ModelRuntimeConfig> + Send + Sync + 'static,
{
45
    pub async fn start(
46
        component: Component,
47
        block_size: u32,
48
        workers_with_configs: RuntimeConfigWatch,
49
        selector: Sel,
50
        kv_router_config: &KvRouterConfig,
51
        prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
52
        worker_type: &'static str,
53
    ) -> Result<Self, KvSchedulerError> {
54
55
        let initial_workers: HashMap<WorkerId, ModelRuntimeConfig> =
            workers_with_configs.borrow().clone();
56

57
        let router_id = component.drt().discovery().instance_id();
58
59
60
61
62
63
64
65
66
67
        let slots = create_multi_worker_sequences(
            component.clone(),
            block_size as usize,
            initial_workers,
            kv_router_config.router_replica_sync,
            router_id,
            worker_type,
        )
        .await
        .map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?;
68

69
70
        let watch_worker_configs = !kv_router_config.skip_initial_worker_wait;
        if !watch_worker_configs {
71
72
            tracing::info!("skipping discovery-based worker monitoring");
        }
Yan Ru Pei's avatar
Yan Ru Pei committed
73

74
        let policy = RouterSchedulingPolicy::new(kv_router_config.router_queue_policy);
75
76
77
78
79
        tracing::info!(
            "Router queue policy: {}",
            kv_router_config.router_queue_policy
        );

80
81
        let inner = Arc::new(LocalScheduler::new(
            slots,
82
            workers_with_configs.clone(),
83
            kv_router_config.router_queue_threshold,
84
85
            block_size,
            selector,
86
            policy,
87
88
            prefill_load_estimator,
            kv_router_config.router_queue_recheck_interval(),
89
            kv_router_config.router_track_prefill_tokens,
90
91
92
            component.drt().child_token(),
            worker_type,
            watch_worker_configs,
93
94
        ));

95
96
        let metrics_scheduler = Arc::clone(&inner);
        let metrics_cancel_token = component.drt().child_token();
97
        let mut queue_updates = inner.subscribe_queue_updates();
Yan Ru Pei's avatar
Yan Ru Pei committed
98
        tokio::spawn(async move {
99
            let mut recheck_interval = tokio::time::interval(Duration::from_secs(60));
100
            ROUTER_QUEUE_METRICS.set_pending(worker_type, metrics_scheduler.pending_count());
101
102
            ROUTER_QUEUE_METRICS
                .set_pending_isl_tokens(worker_type, metrics_scheduler.pending_isl_tokens());
Yan Ru Pei's avatar
Yan Ru Pei committed
103
104

            loop {
105
                tokio::select! {
106
                    _ = metrics_cancel_token.cancelled() => break,
107
108
109
110
111
112
113
                    result = queue_updates.changed() => {
                        if result.is_err() {
                            break;
                        }
                        ROUTER_QUEUE_METRICS
                            .set_pending(worker_type, metrics_scheduler.pending_count());
                    }
114
                    _ = recheck_interval.tick() => {
115
116
117
118
119
                        ROUTER_QUEUE_METRICS.set_pending(worker_type, metrics_scheduler.pending_count());
                        ROUTER_QUEUE_METRICS.set_pending_isl_tokens(
                            worker_type,
                            metrics_scheduler.pending_isl_tokens(),
                        );
120
121
                    }
                }
122
123
124
            }
        });

125
        Ok(Self { inner })
126
127
    }

128
    #[expect(clippy::too_many_arguments)]
129
130
    pub async fn schedule(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
131
        maybe_request_id: Option<String>,
132
        isl_tokens: usize,
133
        token_seq: Option<Vec<SequenceHash>>,
134
135
136
137
        tier_overlap_blocks: TierOverlapBlocks,
        effective_overlap_blocks: HashMap<dynamo_kv_router::protocols::WorkerWithDpRank, f64>,
        effective_cached_tokens: HashMap<dynamo_kv_router::protocols::WorkerWithDpRank, usize>,
        tree_sizes: HashMap<dynamo_kv_router::protocols::WorkerWithDpRank, usize>,
138
        router_config_override: Option<&RouterConfigOverride>,
139
        update_states: bool,
140
        lora_name: Option<String>,
141
        priority_jump: f64,
142
        expected_output_tokens: Option<u32>,
143
        pinned_worker: Option<WorkerWithDpRank>,
144
        allowed_worker_ids: Option<HashSet<WorkerId>>,
145
        shared_cache_hits: Option<SharedCacheHits>,
146
    ) -> Result<SchedulingResponse, KvSchedulerError> {
147
148
149
150
151
152
        let response = self
            .inner
            .schedule(
                maybe_request_id,
                isl_tokens,
                token_seq,
153
154
155
156
                tier_overlap_blocks,
                effective_overlap_blocks,
                effective_cached_tokens,
                tree_sizes,
157
158
159
160
161
                router_config_override,
                update_states,
                lora_name,
                priority_jump,
                expected_output_tokens,
162
                pinned_worker,
163
                allowed_worker_ids,
164
                shared_cache_hits,
165
166
167
            )
            .await;
        ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.pending_count());
168
        ROUTER_QUEUE_METRICS.set_pending_isl_tokens(self.worker_type(), self.pending_isl_tokens());
169
        response
170
171
    }

172
    pub fn register_workers(&self, worker_ids: &HashSet<WorkerId>) {
173
        self.inner.register_workers(worker_ids);
174
175
    }

176
    pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
177
        self.inner.add_request(req).await
178
179
    }

180
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
181
182
        self.inner.mark_prefill_completed(request_id).await?;
        ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.pending_count());
183
        ROUTER_QUEUE_METRICS.set_pending_isl_tokens(self.worker_type(), self.pending_isl_tokens());
184
        Ok(())
185
186
    }

187
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
188
189
        self.inner.free(request_id).await?;
        ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.pending_count());
190
        ROUTER_QUEUE_METRICS.set_pending_isl_tokens(self.worker_type(), self.pending_isl_tokens());
191
        Ok(())
192
    }
193

194
    pub fn pending_count(&self) -> usize {
195
        self.inner.pending_count()
196
197
    }

198
199
200
201
    pub fn pending_isl_tokens(&self) -> usize {
        self.inner.pending_isl_tokens()
    }

202
    pub fn worker_type(&self) -> &'static str {
203
        self.inner.worker_type()
204
205
    }

206
    pub fn add_output_block(
207
208
209
210
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
211
        self.inner.add_output_block(request_id, decay_fraction)
212
213
    }

214
    pub fn get_potential_loads(
215
        &self,
216
        token_seq: Option<Vec<SequenceHash>>,
217
        isl_tokens: usize,
218
        effective_cached_tokens: HashMap<dynamo_kv_router::protocols::WorkerWithDpRank, usize>,
219
        track_prefill_tokens: bool,
220
    ) -> Vec<PotentialLoad> {
221
222
223
224
225
226
        self.inner.get_potential_loads(
            token_seq,
            isl_tokens,
            effective_cached_tokens,
            track_prefill_tokens,
        )
227
    }
228
229

    pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
230
        self.inner.get_active_lora_counts()
231
    }
232
}