lib.rs 15.1 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5
pub mod core;
6
pub mod metrics;
7
pub mod middleware;
8
pub mod openai_api_types;
9
10
pub mod policies;
pub mod routers;
11
pub mod server;
12
pub mod service_discovery;
13
pub mod tree;
14
use crate::metrics::PrometheusConfig;
15

16
#[pyclass(eq)]
17
#[derive(Clone, PartialEq, Debug)]
18
19
20
pub enum PolicyType {
    Random,
    RoundRobin,
21
    CacheAware,
22
    PowerOfTwo,
23
24
}

25
#[pyclass]
26
#[derive(Debug, Clone, PartialEq)]
27
28
29
30
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
31
    policy: PolicyType,
32
    worker_startup_timeout_secs: u64,
33
    worker_startup_check_interval: u64,
34
    cache_threshold: f32,
35
36
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
37
38
    eviction_interval_secs: u64,
    max_tree_size: usize,
39
    max_payload_size: usize,
40
41
    dp_aware: bool,
    api_key: Option<String>,
42
    log_dir: Option<String>,
43
    log_level: Option<String>,
44
45
46
47
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
48
49
50
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
51
52
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
53
    request_timeout_secs: u64,
54
    request_id_headers: Option<Vec<String>>,
55
    pd_disaggregation: bool,
56
57
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
58
59
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
60
61
    max_concurrent_requests: usize,
    cors_allowed_origins: Vec<String>,
62
63
64
65
66
67
68
69
70
71
72
73
74
    // Retry configuration
    retry_max_retries: u32,
    retry_initial_backoff_ms: u64,
    retry_max_backoff_ms: u64,
    retry_backoff_multiplier: f32,
    retry_jitter_factor: f32,
    disable_retries: bool,
    // Circuit breaker configuration
    cb_failure_threshold: u32,
    cb_success_threshold: u32,
    cb_timeout_duration_secs: u64,
    cb_window_duration_secs: u64,
    disable_circuit_breaker: bool,
75
76
77
78
79
80
    // Health check configuration
    health_failure_threshold: u32,
    health_success_threshold: u32,
    health_check_timeout_secs: u64,
    health_check_interval_secs: u64,
    health_check_endpoint: String,
81
82
}

83
84
85
86
87
88
89
impl Router {
    /// Convert PyO3 Router to RouterConfig
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        // Convert policy helper function
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
                    load_check_interval_secs: 5, // Default value
                },
            }
        };

108
109
110
111
112
        // Determine routing mode
        let mode = if self.pd_disaggregation {
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
113
114
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
115
116
117
118
119
120
121
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

122
123
        // Convert main policy
        let policy = convert_policy(&self.policy);
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

        // Service discovery configuration
        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        // Metrics configuration
        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
159
160
            dp_aware: self.dp_aware,
            api_key: self.api_key.clone(),
161
162
163
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
164
            log_level: self.log_level.clone(),
165
            request_id_headers: self.request_id_headers.clone(),
166
167
            max_concurrent_requests: self.max_concurrent_requests,
            cors_allowed_origins: self.cors_allowed_origins.clone(),
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
            retry: config::RetryConfig {
                max_retries: self.retry_max_retries,
                initial_backoff_ms: self.retry_initial_backoff_ms,
                max_backoff_ms: self.retry_max_backoff_ms,
                backoff_multiplier: self.retry_backoff_multiplier,
                jitter_factor: self.retry_jitter_factor,
            },
            circuit_breaker: config::CircuitBreakerConfig {
                failure_threshold: self.cb_failure_threshold,
                success_threshold: self.cb_success_threshold,
                timeout_duration_secs: self.cb_timeout_duration_secs,
                window_duration_secs: self.cb_window_duration_secs,
            },
            disable_retries: false,
            disable_circuit_breaker: false,
183
184
185
186
187
188
189
            health_check: config::HealthCheckConfig {
                failure_threshold: self.health_failure_threshold,
                success_threshold: self.health_success_threshold,
                timeout_secs: self.health_check_timeout_secs,
                check_interval_secs: self.health_check_interval_secs,
                endpoint: self.health_check_endpoint.clone(),
            },
190
191
192
193
        })
    }
}

194
195
196
#[pymethods]
impl Router {
    #[new]
197
198
199
200
201
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
202
203
204
205
206
207
208
209
        worker_startup_timeout_secs = 600,
        worker_startup_check_interval = 30,
        cache_threshold = 0.3,
        balance_abs_threshold = 64,
        balance_rel_threshold = 1.5,
        eviction_interval_secs = 120,
        max_tree_size = 2usize.pow(26),
        max_payload_size = 512 * 1024 * 1024,  // 512MB default for large batches
210
211
        dp_aware = false,
        api_key = None,
212
        log_dir = None,
213
        log_level = None,
214
215
216
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
217
        service_discovery_namespace = None,
218
219
220
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
221
        prometheus_port = None,
222
        prometheus_host = None,
223
        request_timeout_secs = 1800,  // Add configurable request timeout
224
        request_id_headers = None,  // Custom request ID headers
225
        pd_disaggregation = false,  // New flag for PD mode
226
        prefill_urls = None,
227
228
        decode_urls = None,
        prefill_policy = None,
229
        decode_policy = None,
230
        max_concurrent_requests = 256,
231
232
        cors_allowed_origins = vec![],
        // Retry defaults
233
234
235
236
237
        retry_max_retries = 5,
        retry_initial_backoff_ms = 50,
        retry_max_backoff_ms = 30_000,
        retry_backoff_multiplier = 1.5,
        retry_jitter_factor = 0.2,
238
239
        disable_retries = false,
        // Circuit breaker defaults
240
241
242
243
        cb_failure_threshold = 10,
        cb_success_threshold = 3,
        cb_timeout_duration_secs = 60,
        cb_window_duration_secs = 120,
244
        disable_circuit_breaker = false,
245
246
247
248
249
250
        // Health check defaults
        health_failure_threshold = 3,
        health_success_threshold = 2,
        health_check_timeout_secs = 5,
        health_check_interval_secs = 60,
        health_check_endpoint = String::from("/health"),
251
    ))]
252
    #[allow(clippy::too_many_arguments)]
253
254
255
256
257
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
258
        worker_startup_timeout_secs: u64,
259
        worker_startup_check_interval: u64,
260
        cache_threshold: f32,
261
262
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
263
264
        eviction_interval_secs: u64,
        max_tree_size: usize,
265
        max_payload_size: usize,
266
267
        dp_aware: bool,
        api_key: Option<String>,
268
        log_dir: Option<String>,
269
        log_level: Option<String>,
270
271
272
273
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
274
275
276
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
277
278
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
279
        request_timeout_secs: u64,
280
        request_id_headers: Option<Vec<String>>,
281
        pd_disaggregation: bool,
282
283
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
284
285
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
286
287
        max_concurrent_requests: usize,
        cors_allowed_origins: Vec<String>,
288
289
290
291
292
293
294
295
296
297
298
        retry_max_retries: u32,
        retry_initial_backoff_ms: u64,
        retry_max_backoff_ms: u64,
        retry_backoff_multiplier: f32,
        retry_jitter_factor: f32,
        disable_retries: bool,
        cb_failure_threshold: u32,
        cb_success_threshold: u32,
        cb_timeout_duration_secs: u64,
        cb_window_duration_secs: u64,
        disable_circuit_breaker: bool,
299
300
301
302
303
        health_failure_threshold: u32,
        health_success_threshold: u32,
        health_check_timeout_secs: u64,
        health_check_interval_secs: u64,
        health_check_endpoint: String,
304
305
    ) -> PyResult<Self> {
        Ok(Router {
306
307
308
            host,
            port,
            worker_urls,
309
            policy,
310
            worker_startup_timeout_secs,
311
            worker_startup_check_interval,
312
            cache_threshold,
313
314
            balance_abs_threshold,
            balance_rel_threshold,
315
316
            eviction_interval_secs,
            max_tree_size,
317
            max_payload_size,
318
319
            dp_aware,
            api_key,
320
            log_dir,
321
            log_level,
322
323
324
325
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
326
327
328
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
329
330
            prometheus_port,
            prometheus_host,
331
            request_timeout_secs,
332
            request_id_headers,
333
            pd_disaggregation,
334
335
            prefill_urls,
            decode_urls,
336
337
            prefill_policy,
            decode_policy,
338
339
            max_concurrent_requests,
            cors_allowed_origins,
340
341
342
343
344
345
346
347
348
349
350
            retry_max_retries,
            retry_initial_backoff_ms,
            retry_max_backoff_ms,
            retry_backoff_multiplier,
            retry_jitter_factor,
            disable_retries,
            cb_failure_threshold,
            cb_success_threshold,
            cb_timeout_duration_secs,
            cb_window_duration_secs,
            disable_circuit_breaker,
351
352
353
354
355
            health_failure_threshold,
            health_success_threshold,
            health_check_timeout_secs,
            health_check_interval_secs,
            health_check_endpoint,
356
        })
357
358
359
    }

    fn start(&self) -> PyResult<()> {
360
361
362
363
        // Convert to RouterConfig and validate
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
364

365
366
367
368
369
370
371
        // Validate the configuration
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
372

373
374
375
376
377
378
379
380
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
381
382
383
384
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
385
386
387
388
389
            })
        } else {
            None
        };

390
391
392
393
394
395
396
397
398
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

399
400
401
402
403
404
        // Use tokio runtime instead of actix-web System for better compatibility
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        // Block on the async startup function
        runtime.block_on(async move {
405
406
407
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
408
                router_config,
409
                max_payload_size: self.max_payload_size,
410
                log_dir: self.log_dir.clone(),
411
                log_level: self.log_level.clone(),
412
                service_discovery_config,
413
                prometheus_config,
414
                request_timeout_secs: self.request_timeout_secs,
415
                request_id_headers: self.request_id_headers.clone(),
416
417
            })
            .await
418
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
419
        })
420
421
422
423
    }
}

#[pymodule]
424
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
425
    m.add_class::<PolicyType>()?;
426
427
    m.add_class::<Router>()?;
    Ok(())
428
}