lib.rs 11.6 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5
pub mod core;
6
pub mod metrics;
7
pub mod middleware;
8
pub mod openai_api_types;
9
10
pub mod policies;
pub mod routers;
11
pub mod server;
12
pub mod service_discovery;
13
pub mod tree;
14
use crate::metrics::PrometheusConfig;
15

16
#[pyclass(eq)]
17
#[derive(Clone, PartialEq, Debug)]
18
19
20
pub enum PolicyType {
    Random,
    RoundRobin,
21
    CacheAware,
22
    PowerOfTwo,
23
24
}

25
#[pyclass]
26
#[derive(Debug, Clone, PartialEq)]
27
28
29
30
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
31
    policy: PolicyType,
32
    worker_startup_timeout_secs: u64,
33
    worker_startup_check_interval: u64,
34
    cache_threshold: f32,
35
36
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
37
38
    eviction_interval_secs: u64,
    max_tree_size: usize,
39
    max_payload_size: usize,
40
41
    dp_aware: bool,
    api_key: Option<String>,
42
    log_dir: Option<String>,
43
    log_level: Option<String>,
44
45
46
47
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
48
49
50
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
51
52
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
53
    request_timeout_secs: u64,
54
    request_id_headers: Option<Vec<String>>,
55
    pd_disaggregation: bool,
56
57
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
58
59
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
60
61
    max_concurrent_requests: usize,
    cors_allowed_origins: Vec<String>,
62
63
}

64
65
66
67
68
69
70
impl Router {
    /// Convert PyO3 Router to RouterConfig
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        // Convert policy helper function
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
                    load_check_interval_secs: 5, // Default value
                },
            }
        };

89
90
91
92
93
        // Determine routing mode
        let mode = if self.pd_disaggregation {
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
94
95
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
96
97
98
99
100
101
102
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

103
104
        // Convert main policy
        let policy = convert_policy(&self.policy);
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139

        // Service discovery configuration
        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        // Metrics configuration
        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
140
141
            dp_aware: self.dp_aware,
            api_key: self.api_key.clone(),
142
143
144
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
145
            log_level: self.log_level.clone(),
146
            request_id_headers: self.request_id_headers.clone(),
147
148
            max_concurrent_requests: self.max_concurrent_requests,
            cors_allowed_origins: self.cors_allowed_origins.clone(),
149
            retry: config::RetryConfig::default(),
150
            circuit_breaker: config::CircuitBreakerConfig::default(),
151
152
153
154
        })
    }
}

155
156
157
#[pymethods]
impl Router {
    #[new]
158
159
160
161
162
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
163
        worker_startup_timeout_secs = 300,
164
        worker_startup_check_interval = 10,
165
        cache_threshold = 0.50,
166
167
        balance_abs_threshold = 32,
        balance_rel_threshold = 1.0001,
168
        eviction_interval_secs = 60,
169
        max_tree_size = 2usize.pow(24),
170
        max_payload_size = 256 * 1024 * 1024,  // 256MB default for large batches
171
172
        dp_aware = false,
        api_key = None,
173
        log_dir = None,
174
        log_level = None,
175
176
177
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
178
        service_discovery_namespace = None,
179
180
181
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
182
        prometheus_port = None,
183
184
        prometheus_host = None,
        request_timeout_secs = 600,  // Add configurable request timeout
185
        request_id_headers = None,  // Custom request ID headers
186
        pd_disaggregation = false,  // New flag for PD mode
187
        prefill_urls = None,
188
189
        decode_urls = None,
        prefill_policy = None,
190
191
192
        decode_policy = None,
        max_concurrent_requests = 64,
        cors_allowed_origins = vec![]
193
194
195
196
197
198
    ))]
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
199
        worker_startup_timeout_secs: u64,
200
        worker_startup_check_interval: u64,
201
        cache_threshold: f32,
202
203
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
204
205
        eviction_interval_secs: u64,
        max_tree_size: usize,
206
        max_payload_size: usize,
207
208
        dp_aware: bool,
        api_key: Option<String>,
209
        log_dir: Option<String>,
210
        log_level: Option<String>,
211
212
213
214
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
215
216
217
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
218
219
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
220
        request_timeout_secs: u64,
221
        request_id_headers: Option<Vec<String>>,
222
        pd_disaggregation: bool,
223
224
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
225
226
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
227
228
        max_concurrent_requests: usize,
        cors_allowed_origins: Vec<String>,
229
230
    ) -> PyResult<Self> {
        Ok(Router {
231
232
233
            host,
            port,
            worker_urls,
234
            policy,
235
            worker_startup_timeout_secs,
236
            worker_startup_check_interval,
237
            cache_threshold,
238
239
            balance_abs_threshold,
            balance_rel_threshold,
240
241
            eviction_interval_secs,
            max_tree_size,
242
            max_payload_size,
243
244
            dp_aware,
            api_key,
245
            log_dir,
246
            log_level,
247
248
249
250
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
251
252
253
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
254
255
            prometheus_port,
            prometheus_host,
256
            request_timeout_secs,
257
            request_id_headers,
258
            pd_disaggregation,
259
260
            prefill_urls,
            decode_urls,
261
262
            prefill_policy,
            decode_policy,
263
264
            max_concurrent_requests,
            cors_allowed_origins,
265
        })
266
267
268
    }

    fn start(&self) -> PyResult<()> {
269
270
271
272
        // Convert to RouterConfig and validate
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
273

274
275
276
277
278
279
280
        // Validate the configuration
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
281

282
283
284
285
286
287
288
289
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
290
291
292
293
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
294
295
296
297
298
            })
        } else {
            None
        };

299
300
301
302
303
304
305
306
307
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

308
309
310
311
312
313
        // Use tokio runtime instead of actix-web System for better compatibility
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        // Block on the async startup function
        runtime.block_on(async move {
314
315
316
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
317
                router_config,
318
                max_payload_size: self.max_payload_size,
319
                log_dir: self.log_dir.clone(),
320
                log_level: self.log_level.clone(),
321
                service_discovery_config,
322
                prometheus_config,
323
                request_timeout_secs: self.request_timeout_secs,
324
                request_id_headers: self.request_id_headers.clone(),
325
326
            })
            .await
327
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
328
        })
329
330
331
332
    }
}

#[pymodule]
333
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
334
    m.add_class::<PolicyType>()?;
335
336
    m.add_class::<Router>()?;
    Ok(())
337
}