lib.rs 15.1 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5
pub mod core;
6
pub mod metrics;
7
pub mod middleware;
8
pub mod openai_api_types;
9
10
pub mod policies;
pub mod routers;
11
pub mod server;
12
pub mod service_discovery;
13
pub mod tree;
14
use crate::metrics::PrometheusConfig;
15

16
#[pyclass(eq)]
17
#[derive(Clone, PartialEq, Debug)]
18
19
20
pub enum PolicyType {
    Random,
    RoundRobin,
21
    CacheAware,
22
    PowerOfTwo,
23
24
}

25
#[pyclass]
26
#[derive(Debug, Clone, PartialEq)]
27
28
29
30
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
31
    policy: PolicyType,
32
    worker_startup_timeout_secs: u64,
33
    worker_startup_check_interval: u64,
34
    cache_threshold: f32,
35
36
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
37
38
    eviction_interval_secs: u64,
    max_tree_size: usize,
39
    max_payload_size: usize,
40
41
    dp_aware: bool,
    api_key: Option<String>,
42
    log_dir: Option<String>,
43
    log_level: Option<String>,
44
45
46
47
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
48
49
50
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
51
52
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
53
    request_timeout_secs: u64,
54
    request_id_headers: Option<Vec<String>>,
55
    pd_disaggregation: bool,
56
57
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
58
59
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
60
61
    max_concurrent_requests: usize,
    cors_allowed_origins: Vec<String>,
62
63
64
65
66
67
68
69
70
71
72
73
74
    // Retry configuration
    retry_max_retries: u32,
    retry_initial_backoff_ms: u64,
    retry_max_backoff_ms: u64,
    retry_backoff_multiplier: f32,
    retry_jitter_factor: f32,
    disable_retries: bool,
    // Circuit breaker configuration
    cb_failure_threshold: u32,
    cb_success_threshold: u32,
    cb_timeout_duration_secs: u64,
    cb_window_duration_secs: u64,
    disable_circuit_breaker: bool,
75
76
77
78
79
80
    // Health check configuration
    health_failure_threshold: u32,
    health_success_threshold: u32,
    health_check_timeout_secs: u64,
    health_check_interval_secs: u64,
    health_check_endpoint: String,
81
82
}

83
84
85
86
87
88
89
impl Router {
    /// Convert PyO3 Router to RouterConfig
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        // Convert policy helper function
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
                    load_check_interval_secs: 5, // Default value
                },
            }
        };

108
109
110
111
112
        // Determine routing mode
        let mode = if self.pd_disaggregation {
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
113
114
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
115
116
117
118
119
120
121
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

122
123
        // Convert main policy
        let policy = convert_policy(&self.policy);
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

        // Service discovery configuration
        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        // Metrics configuration
        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
159
160
            dp_aware: self.dp_aware,
            api_key: self.api_key.clone(),
161
162
163
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
164
            log_level: self.log_level.clone(),
165
            request_id_headers: self.request_id_headers.clone(),
166
167
            max_concurrent_requests: self.max_concurrent_requests,
            cors_allowed_origins: self.cors_allowed_origins.clone(),
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
            retry: config::RetryConfig {
                max_retries: self.retry_max_retries,
                initial_backoff_ms: self.retry_initial_backoff_ms,
                max_backoff_ms: self.retry_max_backoff_ms,
                backoff_multiplier: self.retry_backoff_multiplier,
                jitter_factor: self.retry_jitter_factor,
            },
            circuit_breaker: config::CircuitBreakerConfig {
                failure_threshold: self.cb_failure_threshold,
                success_threshold: self.cb_success_threshold,
                timeout_duration_secs: self.cb_timeout_duration_secs,
                window_duration_secs: self.cb_window_duration_secs,
            },
            disable_retries: false,
            disable_circuit_breaker: false,
183
184
185
186
187
188
189
            health_check: config::HealthCheckConfig {
                failure_threshold: self.health_failure_threshold,
                success_threshold: self.health_success_threshold,
                timeout_secs: self.health_check_timeout_secs,
                check_interval_secs: self.health_check_interval_secs,
                endpoint: self.health_check_endpoint.clone(),
            },
190
191
192
193
        })
    }
}

194
195
196
#[pymethods]
impl Router {
    #[new]
197
198
199
200
201
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
202
203
204
205
206
207
208
209
        worker_startup_timeout_secs = 600,
        worker_startup_check_interval = 30,
        cache_threshold = 0.3,
        balance_abs_threshold = 64,
        balance_rel_threshold = 1.5,
        eviction_interval_secs = 120,
        max_tree_size = 2usize.pow(26),
        max_payload_size = 512 * 1024 * 1024,  // 512MB default for large batches
210
211
        dp_aware = false,
        api_key = None,
212
        log_dir = None,
213
        log_level = None,
214
215
216
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
217
        service_discovery_namespace = None,
218
219
220
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
221
        prometheus_port = None,
222
        prometheus_host = None,
223
        request_timeout_secs = 1800,  // Add configurable request timeout
224
        request_id_headers = None,  // Custom request ID headers
225
        pd_disaggregation = false,  // New flag for PD mode
226
        prefill_urls = None,
227
228
        decode_urls = None,
        prefill_policy = None,
229
        decode_policy = None,
230
        max_concurrent_requests = 256,
231
232
        cors_allowed_origins = vec![],
        // Retry defaults
233
234
235
236
237
        retry_max_retries = 5,
        retry_initial_backoff_ms = 50,
        retry_max_backoff_ms = 30_000,
        retry_backoff_multiplier = 1.5,
        retry_jitter_factor = 0.2,
238
239
        disable_retries = false,
        // Circuit breaker defaults
240
241
242
243
        cb_failure_threshold = 10,
        cb_success_threshold = 3,
        cb_timeout_duration_secs = 60,
        cb_window_duration_secs = 120,
244
        disable_circuit_breaker = false,
245
246
247
248
249
250
        // Health check defaults
        health_failure_threshold = 3,
        health_success_threshold = 2,
        health_check_timeout_secs = 5,
        health_check_interval_secs = 60,
        health_check_endpoint = String::from("/health"),
251
252
253
254
255
256
    ))]
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
257
        worker_startup_timeout_secs: u64,
258
        worker_startup_check_interval: u64,
259
        cache_threshold: f32,
260
261
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
262
263
        eviction_interval_secs: u64,
        max_tree_size: usize,
264
        max_payload_size: usize,
265
266
        dp_aware: bool,
        api_key: Option<String>,
267
        log_dir: Option<String>,
268
        log_level: Option<String>,
269
270
271
272
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
273
274
275
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
276
277
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
278
        request_timeout_secs: u64,
279
        request_id_headers: Option<Vec<String>>,
280
        pd_disaggregation: bool,
281
282
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
283
284
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
285
286
        max_concurrent_requests: usize,
        cors_allowed_origins: Vec<String>,
287
288
289
290
291
292
293
294
295
296
297
        retry_max_retries: u32,
        retry_initial_backoff_ms: u64,
        retry_max_backoff_ms: u64,
        retry_backoff_multiplier: f32,
        retry_jitter_factor: f32,
        disable_retries: bool,
        cb_failure_threshold: u32,
        cb_success_threshold: u32,
        cb_timeout_duration_secs: u64,
        cb_window_duration_secs: u64,
        disable_circuit_breaker: bool,
298
299
300
301
302
        health_failure_threshold: u32,
        health_success_threshold: u32,
        health_check_timeout_secs: u64,
        health_check_interval_secs: u64,
        health_check_endpoint: String,
303
304
    ) -> PyResult<Self> {
        Ok(Router {
305
306
307
            host,
            port,
            worker_urls,
308
            policy,
309
            worker_startup_timeout_secs,
310
            worker_startup_check_interval,
311
            cache_threshold,
312
313
            balance_abs_threshold,
            balance_rel_threshold,
314
315
            eviction_interval_secs,
            max_tree_size,
316
            max_payload_size,
317
318
            dp_aware,
            api_key,
319
            log_dir,
320
            log_level,
321
322
323
324
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
325
326
327
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
328
329
            prometheus_port,
            prometheus_host,
330
            request_timeout_secs,
331
            request_id_headers,
332
            pd_disaggregation,
333
334
            prefill_urls,
            decode_urls,
335
336
            prefill_policy,
            decode_policy,
337
338
            max_concurrent_requests,
            cors_allowed_origins,
339
340
341
342
343
344
345
346
347
348
349
            retry_max_retries,
            retry_initial_backoff_ms,
            retry_max_backoff_ms,
            retry_backoff_multiplier,
            retry_jitter_factor,
            disable_retries,
            cb_failure_threshold,
            cb_success_threshold,
            cb_timeout_duration_secs,
            cb_window_duration_secs,
            disable_circuit_breaker,
350
351
352
353
354
            health_failure_threshold,
            health_success_threshold,
            health_check_timeout_secs,
            health_check_interval_secs,
            health_check_endpoint,
355
        })
356
357
358
    }

    fn start(&self) -> PyResult<()> {
359
360
361
362
        // Convert to RouterConfig and validate
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
363

364
365
366
367
368
369
370
        // Validate the configuration
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
371

372
373
374
375
376
377
378
379
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
380
381
382
383
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
384
385
386
387
388
            })
        } else {
            None
        };

389
390
391
392
393
394
395
396
397
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

398
399
400
401
402
403
        // Use tokio runtime instead of actix-web System for better compatibility
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        // Block on the async startup function
        runtime.block_on(async move {
404
405
406
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
407
                router_config,
408
                max_payload_size: self.max_payload_size,
409
                log_dir: self.log_dir.clone(),
410
                log_level: self.log_level.clone(),
411
                service_discovery_config,
412
                prometheus_config,
413
                request_timeout_secs: self.request_timeout_secs,
414
                request_id_headers: self.request_id_headers.clone(),
415
416
            })
            .await
417
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
418
        })
419
420
421
422
    }
}

#[pymodule]
423
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
424
    m.add_class::<PolicyType>()?;
425
426
    m.add_class::<Router>()?;
    Ok(())
427
}