lib.rs 15.2 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5
pub mod core;
6
pub mod metrics;
7
pub mod middleware;
8
pub mod openai_api_types;
9
pub mod policies;
10
pub mod reasoning_parser;
11
pub mod routers;
12
pub mod server;
13
pub mod service_discovery;
14
pub mod tokenizer;
15
pub mod tree;
16
use crate::metrics::PrometheusConfig;
17

18
#[pyclass(eq)]
19
#[derive(Clone, PartialEq, Debug)]
20
21
22
pub enum PolicyType {
    Random,
    RoundRobin,
23
    CacheAware,
24
    PowerOfTwo,
25
26
}

27
#[pyclass]
28
#[derive(Debug, Clone, PartialEq)]
29
30
31
32
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
33
    policy: PolicyType,
34
    worker_startup_timeout_secs: u64,
35
    worker_startup_check_interval: u64,
36
    cache_threshold: f32,
37
38
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
39
40
    eviction_interval_secs: u64,
    max_tree_size: usize,
41
    max_payload_size: usize,
42
43
    dp_aware: bool,
    api_key: Option<String>,
44
    log_dir: Option<String>,
45
    log_level: Option<String>,
46
47
48
49
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
50
51
52
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
53
54
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
55
    request_timeout_secs: u64,
56
    request_id_headers: Option<Vec<String>>,
57
    pd_disaggregation: bool,
58
59
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
60
61
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
62
63
    max_concurrent_requests: usize,
    cors_allowed_origins: Vec<String>,
64
65
66
67
68
69
70
71
72
73
74
75
76
    // Retry configuration
    retry_max_retries: u32,
    retry_initial_backoff_ms: u64,
    retry_max_backoff_ms: u64,
    retry_backoff_multiplier: f32,
    retry_jitter_factor: f32,
    disable_retries: bool,
    // Circuit breaker configuration
    cb_failure_threshold: u32,
    cb_success_threshold: u32,
    cb_timeout_duration_secs: u64,
    cb_window_duration_secs: u64,
    disable_circuit_breaker: bool,
77
78
79
80
81
82
    // Health check configuration
    health_failure_threshold: u32,
    health_success_threshold: u32,
    health_check_timeout_secs: u64,
    health_check_interval_secs: u64,
    health_check_endpoint: String,
83
84
}

85
86
87
88
89
90
91
impl Router {
    /// Convert PyO3 Router to RouterConfig
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        // Convert policy helper function
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
                    load_check_interval_secs: 5, // Default value
                },
            }
        };

110
111
112
113
114
        // Determine routing mode
        let mode = if self.pd_disaggregation {
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
115
116
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
117
118
119
120
121
122
123
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

124
125
        // Convert main policy
        let policy = convert_policy(&self.policy);
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160

        // Service discovery configuration
        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        // Metrics configuration
        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
161
162
            dp_aware: self.dp_aware,
            api_key: self.api_key.clone(),
163
164
165
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
166
            log_level: self.log_level.clone(),
167
            request_id_headers: self.request_id_headers.clone(),
168
169
            max_concurrent_requests: self.max_concurrent_requests,
            cors_allowed_origins: self.cors_allowed_origins.clone(),
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
            retry: config::RetryConfig {
                max_retries: self.retry_max_retries,
                initial_backoff_ms: self.retry_initial_backoff_ms,
                max_backoff_ms: self.retry_max_backoff_ms,
                backoff_multiplier: self.retry_backoff_multiplier,
                jitter_factor: self.retry_jitter_factor,
            },
            circuit_breaker: config::CircuitBreakerConfig {
                failure_threshold: self.cb_failure_threshold,
                success_threshold: self.cb_success_threshold,
                timeout_duration_secs: self.cb_timeout_duration_secs,
                window_duration_secs: self.cb_window_duration_secs,
            },
            disable_retries: false,
            disable_circuit_breaker: false,
185
186
187
188
189
190
191
            health_check: config::HealthCheckConfig {
                failure_threshold: self.health_failure_threshold,
                success_threshold: self.health_success_threshold,
                timeout_secs: self.health_check_timeout_secs,
                check_interval_secs: self.health_check_interval_secs,
                endpoint: self.health_check_endpoint.clone(),
            },
192
193
194
195
        })
    }
}

196
197
198
#[pymethods]
impl Router {
    #[new]
199
200
201
202
203
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
204
205
206
207
208
209
210
211
        worker_startup_timeout_secs = 600,
        worker_startup_check_interval = 30,
        cache_threshold = 0.3,
        balance_abs_threshold = 64,
        balance_rel_threshold = 1.5,
        eviction_interval_secs = 120,
        max_tree_size = 2usize.pow(26),
        max_payload_size = 512 * 1024 * 1024,  // 512MB default for large batches
212
213
        dp_aware = false,
        api_key = None,
214
        log_dir = None,
215
        log_level = None,
216
217
218
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
219
        service_discovery_namespace = None,
220
221
222
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
223
        prometheus_port = None,
224
        prometheus_host = None,
225
        request_timeout_secs = 1800,  // Add configurable request timeout
226
        request_id_headers = None,  // Custom request ID headers
227
        pd_disaggregation = false,  // New flag for PD mode
228
        prefill_urls = None,
229
230
        decode_urls = None,
        prefill_policy = None,
231
        decode_policy = None,
232
        max_concurrent_requests = 256,
233
234
        cors_allowed_origins = vec![],
        // Retry defaults
235
236
237
238
239
        retry_max_retries = 5,
        retry_initial_backoff_ms = 50,
        retry_max_backoff_ms = 30_000,
        retry_backoff_multiplier = 1.5,
        retry_jitter_factor = 0.2,
240
241
        disable_retries = false,
        // Circuit breaker defaults
242
243
244
245
        cb_failure_threshold = 10,
        cb_success_threshold = 3,
        cb_timeout_duration_secs = 60,
        cb_window_duration_secs = 120,
246
        disable_circuit_breaker = false,
247
248
249
250
251
252
        // Health check defaults
        health_failure_threshold = 3,
        health_success_threshold = 2,
        health_check_timeout_secs = 5,
        health_check_interval_secs = 60,
        health_check_endpoint = String::from("/health"),
253
    ))]
254
    #[allow(clippy::too_many_arguments)]
255
256
257
258
259
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
260
        worker_startup_timeout_secs: u64,
261
        worker_startup_check_interval: u64,
262
        cache_threshold: f32,
263
264
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
265
266
        eviction_interval_secs: u64,
        max_tree_size: usize,
267
        max_payload_size: usize,
268
269
        dp_aware: bool,
        api_key: Option<String>,
270
        log_dir: Option<String>,
271
        log_level: Option<String>,
272
273
274
275
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
276
277
278
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
279
280
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
281
        request_timeout_secs: u64,
282
        request_id_headers: Option<Vec<String>>,
283
        pd_disaggregation: bool,
284
285
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
286
287
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
288
289
        max_concurrent_requests: usize,
        cors_allowed_origins: Vec<String>,
290
291
292
293
294
295
296
297
298
299
300
        retry_max_retries: u32,
        retry_initial_backoff_ms: u64,
        retry_max_backoff_ms: u64,
        retry_backoff_multiplier: f32,
        retry_jitter_factor: f32,
        disable_retries: bool,
        cb_failure_threshold: u32,
        cb_success_threshold: u32,
        cb_timeout_duration_secs: u64,
        cb_window_duration_secs: u64,
        disable_circuit_breaker: bool,
301
302
303
304
305
        health_failure_threshold: u32,
        health_success_threshold: u32,
        health_check_timeout_secs: u64,
        health_check_interval_secs: u64,
        health_check_endpoint: String,
306
307
    ) -> PyResult<Self> {
        Ok(Router {
308
309
310
            host,
            port,
            worker_urls,
311
            policy,
312
            worker_startup_timeout_secs,
313
            worker_startup_check_interval,
314
            cache_threshold,
315
316
            balance_abs_threshold,
            balance_rel_threshold,
317
318
            eviction_interval_secs,
            max_tree_size,
319
            max_payload_size,
320
321
            dp_aware,
            api_key,
322
            log_dir,
323
            log_level,
324
325
326
327
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
328
329
330
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
331
332
            prometheus_port,
            prometheus_host,
333
            request_timeout_secs,
334
            request_id_headers,
335
            pd_disaggregation,
336
337
            prefill_urls,
            decode_urls,
338
339
            prefill_policy,
            decode_policy,
340
341
            max_concurrent_requests,
            cors_allowed_origins,
342
343
344
345
346
347
348
349
350
351
352
            retry_max_retries,
            retry_initial_backoff_ms,
            retry_max_backoff_ms,
            retry_backoff_multiplier,
            retry_jitter_factor,
            disable_retries,
            cb_failure_threshold,
            cb_success_threshold,
            cb_timeout_duration_secs,
            cb_window_duration_secs,
            disable_circuit_breaker,
353
354
355
356
357
            health_failure_threshold,
            health_success_threshold,
            health_check_timeout_secs,
            health_check_interval_secs,
            health_check_endpoint,
358
        })
359
360
361
    }

    fn start(&self) -> PyResult<()> {
362
363
364
365
        // Convert to RouterConfig and validate
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
366

367
368
369
370
371
372
373
        // Validate the configuration
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
374

375
376
377
378
379
380
381
382
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
383
384
385
386
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
387
388
389
390
391
            })
        } else {
            None
        };

392
393
394
395
396
397
398
399
400
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

401
402
403
404
405
406
        // Use tokio runtime instead of actix-web System for better compatibility
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        // Block on the async startup function
        runtime.block_on(async move {
407
408
409
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
410
                router_config,
411
                max_payload_size: self.max_payload_size,
412
                log_dir: self.log_dir.clone(),
413
                log_level: self.log_level.clone(),
414
                service_discovery_config,
415
                prometheus_config,
416
                request_timeout_secs: self.request_timeout_secs,
417
                request_id_headers: self.request_id_headers.clone(),
418
419
            })
            .await
420
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
421
        })
422
423
424
425
    }
}

#[pymodule]
426
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
427
    m.add_class::<PolicyType>()?;
428
429
    m.add_class::<Router>()?;
    Ok(())
430
}