scheduler.rs 6.54 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
// SPDX-License-Identifier: Apache-2.0
3

4
pub use dynamo_kv_router::scheduling::policy::RouterSchedulingPolicy;
5
pub use dynamo_kv_router::scheduling::{
6
    KvSchedulerError, LocalScheduler, PotentialLoad, SchedulingRequest, SchedulingResponse,
7
8
};
pub use dynamo_kv_router::selector::DefaultWorkerSelector;
9
use dynamo_kv_router::selector::WorkerSelector as WorkerSelectorTrait;
10

11
use super::metrics::ROUTER_QUEUE_METRICS;
12
use super::sequence::{
13
    RuntimeSequencePublisher, SequenceError, SequenceRequest, create_multi_worker_sequences,
14
};
15
use crate::discovery::RuntimeConfigWatch;
16
use crate::local_model::runtime_config::ModelRuntimeConfig;
17
use anyhow::Result;
18
use dynamo_kv_router::{
19
    PrefillLoadEstimator,
20
21
22
    config::{KvRouterConfig, RouterConfigOverride},
    protocols::{OverlapScores, WorkerId},
};
23
use dynamo_runtime::component::Component;
Yan Ru Pei's avatar
Yan Ru Pei committed
24
use dynamo_runtime::traits::DistributedRuntimeProvider;
25
use dynamo_tokens::SequenceHash;
26
use std::collections::{HashMap, HashSet};
27
28
use std::sync::Arc;
use std::time::Duration;
29

30
31
32
33
34
35
36
pub struct KvScheduler<Sel = DefaultWorkerSelector>
where
    Sel: WorkerSelectorTrait<ModelRuntimeConfig>,
{
    inner: Arc<
        LocalScheduler<RuntimeSequencePublisher, ModelRuntimeConfig, RouterSchedulingPolicy, Sel>,
    >,
37
38
}

39
40
41
42
impl<Sel> KvScheduler<Sel>
where
    Sel: WorkerSelectorTrait<ModelRuntimeConfig> + Send + Sync + 'static,
{
43
    pub async fn start(
44
        component: Component,
45
        block_size: u32,
46
        workers_with_configs: RuntimeConfigWatch,
47
        selector: Sel,
48
        kv_router_config: &KvRouterConfig,
49
        prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
50
        worker_type: &'static str,
51
    ) -> Result<Self, KvSchedulerError> {
52
53
        let initial_workers: HashMap<WorkerId, ModelRuntimeConfig> =
            workers_with_configs.borrow().clone();
54

55
        let router_id = component.drt().discovery().instance_id();
56
57
58
59
60
61
62
63
64
65
        let slots = create_multi_worker_sequences(
            component.clone(),
            block_size as usize,
            initial_workers,
            kv_router_config.router_replica_sync,
            router_id,
            worker_type,
        )
        .await
        .map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?;
66

67
68
        let watch_worker_configs = !kv_router_config.skip_initial_worker_wait;
        if !watch_worker_configs {
69
70
            tracing::info!("skipping discovery-based worker monitoring");
        }
Yan Ru Pei's avatar
Yan Ru Pei committed
71

72
73
74
75
76
77
78
        let policy =
            RouterSchedulingPolicy::new(kv_router_config.router_queue_policy, block_size as usize);
        tracing::info!(
            "Router queue policy: {}",
            kv_router_config.router_queue_policy
        );

79
80
        let inner = Arc::new(LocalScheduler::new(
            slots,
81
            workers_with_configs.clone(),
82
            kv_router_config.router_queue_threshold,
83
84
            block_size,
            selector,
85
            policy,
86
87
            prefill_load_estimator,
            kv_router_config.router_queue_recheck_interval(),
88
            kv_router_config.router_track_prefill_tokens,
89
90
91
            component.drt().child_token(),
            worker_type,
            watch_worker_configs,
92
93
        ));

94
95
        let metrics_scheduler = Arc::clone(&inner);
        let metrics_cancel_token = component.drt().child_token();
Yan Ru Pei's avatar
Yan Ru Pei committed
96
        tokio::spawn(async move {
97
            let mut recheck_interval = tokio::time::interval(Duration::from_secs(60));
98
            ROUTER_QUEUE_METRICS.set_pending(worker_type, metrics_scheduler.pending_count());
Yan Ru Pei's avatar
Yan Ru Pei committed
99
100

            loop {
101
                tokio::select! {
102
                    _ = metrics_cancel_token.cancelled() => break,
103
                    _ = recheck_interval.tick() => {
104
105
                        ROUTER_QUEUE_METRICS
                            .set_pending(worker_type, metrics_scheduler.pending_count());
106
107
                    }
                }
108
109
110
            }
        });

111
        Ok(Self { inner })
112
113
    }

114
    #[expect(clippy::too_many_arguments)]
115
116
    pub async fn schedule(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
117
        maybe_request_id: Option<String>,
118
        isl_tokens: usize,
119
        token_seq: Option<Vec<SequenceHash>>,
120
        overlaps: OverlapScores,
121
        router_config_override: Option<&RouterConfigOverride>,
122
        update_states: bool,
123
        lora_name: Option<String>,
124
        priority_jump: f64,
125
        expected_output_tokens: Option<u32>,
126
127
        allowed_worker_ids: Option<HashSet<WorkerId>>,
    ) -> Result<SchedulingResponse, KvSchedulerError> {
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
        let response = self
            .inner
            .schedule(
                maybe_request_id,
                isl_tokens,
                token_seq,
                overlaps,
                router_config_override,
                update_states,
                lora_name,
                priority_jump,
                expected_output_tokens,
                allowed_worker_ids,
            )
            .await;
        ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.pending_count());
        response
145
146
    }

147
    pub fn register_workers(&self, worker_ids: &HashSet<WorkerId>) {
148
        self.inner.register_workers(worker_ids);
149
150
    }

151
    pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
152
        self.inner.add_request(req).await
153
154
    }

155
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
156
157
        self.inner.mark_prefill_completed(request_id).await?;
        ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.pending_count());
158
        Ok(())
159
160
    }

161
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
162
163
        self.inner.free(request_id).await?;
        ROUTER_QUEUE_METRICS.set_pending(self.worker_type(), self.pending_count());
164
        Ok(())
165
    }
166

167
    pub fn pending_count(&self) -> usize {
168
        self.inner.pending_count()
169
170
    }

171
    pub fn worker_type(&self) -> &'static str {
172
        self.inner.worker_type()
173
174
    }

175
    pub fn add_output_block(
176
177
178
179
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
180
        self.inner.add_output_block(request_id, decay_fraction)
181
182
    }

183
    pub fn get_potential_loads(
184
        &self,
185
        token_seq: Option<Vec<SequenceHash>>,
186
187
        isl_tokens: usize,
        overlaps: OverlapScores,
188
        track_prefill_tokens: bool,
189
    ) -> Vec<PotentialLoad> {
190
        self.inner
191
            .get_potential_loads(token_seq, isl_tokens, overlaps, track_prefill_tokens)
192
    }
193
194

    pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
195
        self.inner.get_active_lora_counts()
196
    }
197
}