lib.rs 11.7 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5
pub mod core;
6
pub mod metrics;
7
pub mod middleware;
8
pub mod openai_api_types;
9
10
pub mod policies;
pub mod routers;
11
pub mod server;
12
pub mod service_discovery;
13
pub mod tree;
14
use crate::metrics::PrometheusConfig;
15

16
#[pyclass(eq)]
17
#[derive(Clone, PartialEq, Debug)]
18
19
20
pub enum PolicyType {
    Random,
    RoundRobin,
21
    CacheAware,
22
    PowerOfTwo, // Moved from PD-specific, now shared
23
24
}

25
#[pyclass]
26
#[derive(Debug, Clone, PartialEq)]
27
28
29
30
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
31
    policy: PolicyType,
32
    worker_startup_timeout_secs: u64,
33
    worker_startup_check_interval: u64,
34
    cache_threshold: f32,
35
36
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
37
38
    eviction_interval_secs: u64,
    max_tree_size: usize,
39
    max_payload_size: usize,
40
41
    dp_aware: bool,
    api_key: Option<String>,
42
    log_dir: Option<String>,
43
    log_level: Option<String>,
44
45
46
47
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
48
49
50
51
    // PD service discovery fields
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
52
53
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
54
    request_timeout_secs: u64,
55
    request_id_headers: Option<Vec<String>>,
56
    // PD mode flag
57
58
    pd_disaggregation: bool,
    // PD-specific fields (only used when pd_disaggregation is true)
59
60
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
61
62
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
63
64
65
    // Additional server config fields
    max_concurrent_requests: usize,
    cors_allowed_origins: Vec<String>,
66
67
}

68
69
70
71
72
73
74
impl Router {
    /// Convert PyO3 Router to RouterConfig
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        // Convert policy helper function
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
                    load_check_interval_secs: 5, // Default value
                },
            }
        };

93
94
95
96
97
        // Determine routing mode
        let mode = if self.pd_disaggregation {
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
98
99
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
100
101
102
103
104
105
106
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

107
108
        // Convert main policy
        let policy = convert_policy(&self.policy);
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143

        // Service discovery configuration
        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        // Metrics configuration
        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
144
145
            dp_aware: self.dp_aware,
            api_key: self.api_key.clone(),
146
147
148
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
149
            log_level: self.log_level.clone(),
150
            request_id_headers: self.request_id_headers.clone(),
151
152
            max_concurrent_requests: self.max_concurrent_requests,
            cors_allowed_origins: self.cors_allowed_origins.clone(),
153
154
155
156
        })
    }
}

157
158
159
#[pymethods]
impl Router {
    #[new]
160
161
162
163
164
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
165
        worker_startup_timeout_secs = 300,
166
        worker_startup_check_interval = 10,
167
        cache_threshold = 0.50,
168
169
        balance_abs_threshold = 32,
        balance_rel_threshold = 1.0001,
170
        eviction_interval_secs = 60,
171
        max_tree_size = 2usize.pow(24),
172
        max_payload_size = 256 * 1024 * 1024,  // 256MB default for large batches
173
174
        dp_aware = false,
        api_key = None,
175
        log_dir = None,
176
        log_level = None,
177
178
179
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
180
        service_discovery_namespace = None,
181
182
183
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
184
        prometheus_port = None,
185
186
        prometheus_host = None,
        request_timeout_secs = 600,  // Add configurable request timeout
187
        request_id_headers = None,  // Custom request ID headers
188
        pd_disaggregation = false,  // New flag for PD mode
189
        prefill_urls = None,
190
191
        decode_urls = None,
        prefill_policy = None,
192
193
194
        decode_policy = None,
        max_concurrent_requests = 64,
        cors_allowed_origins = vec![]
195
196
197
198
199
200
    ))]
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
201
        worker_startup_timeout_secs: u64,
202
        worker_startup_check_interval: u64,
203
        cache_threshold: f32,
204
205
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
206
207
        eviction_interval_secs: u64,
        max_tree_size: usize,
208
        max_payload_size: usize,
209
210
        dp_aware: bool,
        api_key: Option<String>,
211
        log_dir: Option<String>,
212
        log_level: Option<String>,
213
214
215
216
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
217
218
219
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
220
221
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
222
        request_timeout_secs: u64,
223
        request_id_headers: Option<Vec<String>>,
224
        pd_disaggregation: bool,
225
226
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
227
228
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
229
230
        max_concurrent_requests: usize,
        cors_allowed_origins: Vec<String>,
231
232
    ) -> PyResult<Self> {
        Ok(Router {
233
234
235
            host,
            port,
            worker_urls,
236
            policy,
237
            worker_startup_timeout_secs,
238
            worker_startup_check_interval,
239
            cache_threshold,
240
241
            balance_abs_threshold,
            balance_rel_threshold,
242
243
            eviction_interval_secs,
            max_tree_size,
244
            max_payload_size,
245
246
            dp_aware,
            api_key,
247
            log_dir,
248
            log_level,
249
250
251
252
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
253
254
255
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
256
257
            prometheus_port,
            prometheus_host,
258
            request_timeout_secs,
259
            request_id_headers,
260
            pd_disaggregation,
261
262
            prefill_urls,
            decode_urls,
263
264
            prefill_policy,
            decode_policy,
265
266
            max_concurrent_requests,
            cors_allowed_origins,
267
        })
268
269
270
    }

    fn start(&self) -> PyResult<()> {
271
272
273
274
        // Convert to RouterConfig and validate
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
275

276
277
278
279
280
281
282
        // Validate the configuration
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
283

284
285
286
287
288
289
290
291
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
292
293
294
295
296
                // PD mode configuration
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
297
298
299
300
301
            })
        } else {
            None
        };

302
303
304
305
306
307
308
309
310
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

311
312
313
314
315
316
        // Use tokio runtime instead of actix-web System for better compatibility
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        // Block on the async startup function
        runtime.block_on(async move {
317
318
319
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
320
                router_config,
321
                max_payload_size: self.max_payload_size,
322
                log_dir: self.log_dir.clone(),
323
                log_level: self.log_level.clone(),
324
                service_discovery_config,
325
                prometheus_config,
326
                request_timeout_secs: self.request_timeout_secs,
327
                request_id_headers: self.request_id_headers.clone(),
328
329
            })
            .await
330
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
331
        })
332
333
334
335
    }
}

#[pymodule]
336
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
337
    m.add_class::<PolicyType>()?;
338
339
    m.add_class::<Router>()?;
    Ok(())
340
}