lib.rs 9.52 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod logging;
3
use std::collections::HashMap;
4
5
6
pub mod openai_api_types;
pub mod pd_router;
pub mod pd_types;
7
pub mod prometheus;
8
pub mod request_adapter;
9
pub mod router;
10
pub mod server;
11
pub mod service_discovery;
12
pub mod tree;
13
use crate::prometheus::PrometheusConfig;
14

15
#[pyclass(eq)]
16
#[derive(Clone, PartialEq, Debug)]
17
18
19
pub enum PolicyType {
    Random,
    RoundRobin,
20
    CacheAware,
21
    PowerOfTwo, // Moved from PD-specific, now shared
22
23
}

24
#[pyclass]
25
#[derive(Debug, Clone, PartialEq)]
26
27
28
29
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
30
    policy: PolicyType,
31
    worker_startup_timeout_secs: u64,
32
    worker_startup_check_interval: u64,
33
    cache_threshold: f32,
34
35
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
36
37
    eviction_interval_secs: u64,
    max_tree_size: usize,
38
    max_payload_size: usize,
39
    verbose: bool,
40
    log_dir: Option<String>,
41
42
43
44
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
45
46
47
48
    // PD service discovery fields
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
49
50
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
51
52
    request_timeout_secs: u64,
    // PD mode flag
53
54
    pd_disaggregation: bool,
    // PD-specific fields (only used when pd_disaggregation is true)
55
56
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
57
58
59
60
61
}

#[pymethods]
impl Router {
    #[new]
62
63
64
65
66
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
67
        worker_startup_timeout_secs = 300,
68
        worker_startup_check_interval = 10,
69
        cache_threshold = 0.50,
70
71
        balance_abs_threshold = 32,
        balance_rel_threshold = 1.0001,
72
        eviction_interval_secs = 60,
73
        max_tree_size = 2usize.pow(24),
74
        max_payload_size = 256 * 1024 * 1024,  // 256MB default for large batches
75
76
        verbose = false,
        log_dir = None,
77
78
79
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
80
        service_discovery_namespace = None,
81
82
83
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
84
        prometheus_port = None,
85
86
        prometheus_host = None,
        request_timeout_secs = 600,  // Add configurable request timeout
87
        pd_disaggregation = false,  // New flag for PD mode
88
89
        prefill_urls = None,
        decode_urls = None
90
91
92
93
94
95
    ))]
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
96
        worker_startup_timeout_secs: u64,
97
        worker_startup_check_interval: u64,
98
        cache_threshold: f32,
99
100
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
101
102
        eviction_interval_secs: u64,
        max_tree_size: usize,
103
        max_payload_size: usize,
104
        verbose: bool,
105
        log_dir: Option<String>,
106
107
108
109
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
110
111
112
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
113
114
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
115
        request_timeout_secs: u64,
116
        pd_disaggregation: bool,
117
118
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
119
120
    ) -> PyResult<Self> {
        Ok(Router {
121
122
123
            host,
            port,
            worker_urls,
124
            policy,
125
            worker_startup_timeout_secs,
126
            worker_startup_check_interval,
127
            cache_threshold,
128
129
            balance_abs_threshold,
            balance_rel_threshold,
130
131
            eviction_interval_secs,
            max_tree_size,
132
            max_payload_size,
133
            verbose,
134
            log_dir,
135
136
137
138
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
139
140
141
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
142
143
            prometheus_port,
            prometheus_host,
144
            request_timeout_secs,
145
            pd_disaggregation,
146
147
            prefill_urls,
            decode_urls,
148
        })
149
150
151
    }

    fn start(&self) -> PyResult<()> {
152
        let policy_config = if self.pd_disaggregation {
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
            // PD mode - map PolicyType to PDSelectionPolicy
            let pd_selection_policy = match &self.policy {
                PolicyType::Random => pd_types::PDSelectionPolicy::Random,
                PolicyType::PowerOfTwo => pd_types::PDSelectionPolicy::PowerOfTwo,
                PolicyType::CacheAware => pd_types::PDSelectionPolicy::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                },
                PolicyType::RoundRobin => {
                    return Err(pyo3::exceptions::PyValueError::new_err(
                        "RoundRobin policy is not supported in PD disaggregated mode",
                    ));
                }
            };

            let prefill_urls = self.prefill_urls.as_ref().ok_or_else(|| {
                pyo3::exceptions::PyValueError::new_err(
                    "PD disaggregated mode requires prefill_urls",
                )
            })?;
            let decode_urls = self.decode_urls.as_ref().ok_or_else(|| {
                pyo3::exceptions::PyValueError::new_err(
                    "PD disaggregated mode requires decode_urls",
                )
            })?;

            router::PolicyConfig::PrefillDecodeConfig {
                selection_policy: pd_selection_policy,
                prefill_urls: prefill_urls.clone(),
                decode_urls: decode_urls.clone(),
184
                timeout_secs: self.worker_startup_timeout_secs,
185
                interval_secs: self.worker_startup_check_interval,
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
            }
        } else {
            // Regular mode
            match &self.policy {
                PolicyType::Random => router::PolicyConfig::RandomConfig {
                    timeout_secs: self.worker_startup_timeout_secs,
                    interval_secs: self.worker_startup_check_interval,
                },
                PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
                    timeout_secs: self.worker_startup_timeout_secs,
                    interval_secs: self.worker_startup_check_interval,
                },
                PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
                    timeout_secs: self.worker_startup_timeout_secs,
                    interval_secs: self.worker_startup_check_interval,
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => {
                    return Err(pyo3::exceptions::PyValueError::new_err(
                        "PowerOfTwo policy is only supported in PD disaggregated mode",
                    ));
                }
            }
213
        };
214

215
216
217
218
219
220
221
222
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
223
224
225
226
227
                // PD mode configuration
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
228
229
230
231
232
            })
        } else {
            None
        };

233
234
235
236
237
238
239
240
241
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

242
        actix_web::rt::System::new().block_on(async move {
243
244
245
246
247
248
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
                worker_urls: self.worker_urls.clone(),
                policy_config,
                verbose: self.verbose,
249
                max_payload_size: self.max_payload_size,
250
                log_dir: self.log_dir.clone(),
251
                service_discovery_config,
252
                prometheus_config,
253
                request_timeout_secs: self.request_timeout_secs,
254
255
            })
            .await
256
257
258
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
            Ok(())
        })
259
260
261
262
    }
}

#[pymodule]
263
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
264
    m.add_class::<PolicyType>()?;
265
266
    m.add_class::<Router>()?;
    Ok(())
267
}