lib.rs 15.1 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5
pub mod core;
6
pub mod metrics;
7
pub mod middleware;
8
pub mod openai_api_types;
9
10
pub mod policies;
pub mod routers;
11
pub mod server;
12
pub mod service_discovery;
13
pub mod tokenizer;
14
pub mod tree;
15
use crate::metrics::PrometheusConfig;
16

17
#[pyclass(eq)]
18
#[derive(Clone, PartialEq, Debug)]
19
20
21
pub enum PolicyType {
    Random,
    RoundRobin,
22
    CacheAware,
23
    PowerOfTwo,
24
25
}

26
#[pyclass]
27
#[derive(Debug, Clone, PartialEq)]
28
29
30
31
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
32
    policy: PolicyType,
33
    worker_startup_timeout_secs: u64,
34
    worker_startup_check_interval: u64,
35
    cache_threshold: f32,
36
37
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
38
39
    eviction_interval_secs: u64,
    max_tree_size: usize,
40
    max_payload_size: usize,
41
42
    dp_aware: bool,
    api_key: Option<String>,
43
    log_dir: Option<String>,
44
    log_level: Option<String>,
45
46
47
48
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
49
50
51
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
52
53
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
54
    request_timeout_secs: u64,
55
    request_id_headers: Option<Vec<String>>,
56
    pd_disaggregation: bool,
57
58
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
59
60
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
61
62
    max_concurrent_requests: usize,
    cors_allowed_origins: Vec<String>,
63
64
65
66
67
68
69
70
71
72
73
74
75
    // Retry configuration
    retry_max_retries: u32,
    retry_initial_backoff_ms: u64,
    retry_max_backoff_ms: u64,
    retry_backoff_multiplier: f32,
    retry_jitter_factor: f32,
    disable_retries: bool,
    // Circuit breaker configuration
    cb_failure_threshold: u32,
    cb_success_threshold: u32,
    cb_timeout_duration_secs: u64,
    cb_window_duration_secs: u64,
    disable_circuit_breaker: bool,
76
77
78
79
80
81
    // Health check configuration
    health_failure_threshold: u32,
    health_success_threshold: u32,
    health_check_timeout_secs: u64,
    health_check_interval_secs: u64,
    health_check_endpoint: String,
82
83
}

84
85
86
87
88
89
90
impl Router {
    /// Convert PyO3 Router to RouterConfig
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        // Convert policy helper function
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
                    load_check_interval_secs: 5, // Default value
                },
            }
        };

109
110
111
112
113
        // Determine routing mode
        let mode = if self.pd_disaggregation {
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
114
115
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
116
117
118
119
120
121
122
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

123
124
        // Convert main policy
        let policy = convert_policy(&self.policy);
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159

        // Service discovery configuration
        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        // Metrics configuration
        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
160
161
            dp_aware: self.dp_aware,
            api_key: self.api_key.clone(),
162
163
164
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
165
            log_level: self.log_level.clone(),
166
            request_id_headers: self.request_id_headers.clone(),
167
168
            max_concurrent_requests: self.max_concurrent_requests,
            cors_allowed_origins: self.cors_allowed_origins.clone(),
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
            retry: config::RetryConfig {
                max_retries: self.retry_max_retries,
                initial_backoff_ms: self.retry_initial_backoff_ms,
                max_backoff_ms: self.retry_max_backoff_ms,
                backoff_multiplier: self.retry_backoff_multiplier,
                jitter_factor: self.retry_jitter_factor,
            },
            circuit_breaker: config::CircuitBreakerConfig {
                failure_threshold: self.cb_failure_threshold,
                success_threshold: self.cb_success_threshold,
                timeout_duration_secs: self.cb_timeout_duration_secs,
                window_duration_secs: self.cb_window_duration_secs,
            },
            disable_retries: false,
            disable_circuit_breaker: false,
184
185
186
187
188
189
190
            health_check: config::HealthCheckConfig {
                failure_threshold: self.health_failure_threshold,
                success_threshold: self.health_success_threshold,
                timeout_secs: self.health_check_timeout_secs,
                check_interval_secs: self.health_check_interval_secs,
                endpoint: self.health_check_endpoint.clone(),
            },
191
192
193
194
        })
    }
}

195
196
197
#[pymethods]
impl Router {
    #[new]
198
199
200
201
202
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
203
204
205
206
207
208
209
210
        worker_startup_timeout_secs = 600,
        worker_startup_check_interval = 30,
        cache_threshold = 0.3,
        balance_abs_threshold = 64,
        balance_rel_threshold = 1.5,
        eviction_interval_secs = 120,
        max_tree_size = 2usize.pow(26),
        max_payload_size = 512 * 1024 * 1024,  // 512MB default for large batches
211
212
        dp_aware = false,
        api_key = None,
213
        log_dir = None,
214
        log_level = None,
215
216
217
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
218
        service_discovery_namespace = None,
219
220
221
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
222
        prometheus_port = None,
223
        prometheus_host = None,
224
        request_timeout_secs = 1800,  // Add configurable request timeout
225
        request_id_headers = None,  // Custom request ID headers
226
        pd_disaggregation = false,  // New flag for PD mode
227
        prefill_urls = None,
228
229
        decode_urls = None,
        prefill_policy = None,
230
        decode_policy = None,
231
        max_concurrent_requests = 256,
232
233
        cors_allowed_origins = vec![],
        // Retry defaults
234
235
236
237
238
        retry_max_retries = 5,
        retry_initial_backoff_ms = 50,
        retry_max_backoff_ms = 30_000,
        retry_backoff_multiplier = 1.5,
        retry_jitter_factor = 0.2,
239
240
        disable_retries = false,
        // Circuit breaker defaults
241
242
243
244
        cb_failure_threshold = 10,
        cb_success_threshold = 3,
        cb_timeout_duration_secs = 60,
        cb_window_duration_secs = 120,
245
        disable_circuit_breaker = false,
246
247
248
249
250
251
        // Health check defaults
        health_failure_threshold = 3,
        health_success_threshold = 2,
        health_check_timeout_secs = 5,
        health_check_interval_secs = 60,
        health_check_endpoint = String::from("/health"),
252
    ))]
253
    #[allow(clippy::too_many_arguments)]
254
255
256
257
258
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
259
        worker_startup_timeout_secs: u64,
260
        worker_startup_check_interval: u64,
261
        cache_threshold: f32,
262
263
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
264
265
        eviction_interval_secs: u64,
        max_tree_size: usize,
266
        max_payload_size: usize,
267
268
        dp_aware: bool,
        api_key: Option<String>,
269
        log_dir: Option<String>,
270
        log_level: Option<String>,
271
272
273
274
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
275
276
277
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
278
279
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
280
        request_timeout_secs: u64,
281
        request_id_headers: Option<Vec<String>>,
282
        pd_disaggregation: bool,
283
284
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
285
286
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
287
288
        max_concurrent_requests: usize,
        cors_allowed_origins: Vec<String>,
289
290
291
292
293
294
295
296
297
298
299
        retry_max_retries: u32,
        retry_initial_backoff_ms: u64,
        retry_max_backoff_ms: u64,
        retry_backoff_multiplier: f32,
        retry_jitter_factor: f32,
        disable_retries: bool,
        cb_failure_threshold: u32,
        cb_success_threshold: u32,
        cb_timeout_duration_secs: u64,
        cb_window_duration_secs: u64,
        disable_circuit_breaker: bool,
300
301
302
303
304
        health_failure_threshold: u32,
        health_success_threshold: u32,
        health_check_timeout_secs: u64,
        health_check_interval_secs: u64,
        health_check_endpoint: String,
305
306
    ) -> PyResult<Self> {
        Ok(Router {
307
308
309
            host,
            port,
            worker_urls,
310
            policy,
311
            worker_startup_timeout_secs,
312
            worker_startup_check_interval,
313
            cache_threshold,
314
315
            balance_abs_threshold,
            balance_rel_threshold,
316
317
            eviction_interval_secs,
            max_tree_size,
318
            max_payload_size,
319
320
            dp_aware,
            api_key,
321
            log_dir,
322
            log_level,
323
324
325
326
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
327
328
329
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
330
331
            prometheus_port,
            prometheus_host,
332
            request_timeout_secs,
333
            request_id_headers,
334
            pd_disaggregation,
335
336
            prefill_urls,
            decode_urls,
337
338
            prefill_policy,
            decode_policy,
339
340
            max_concurrent_requests,
            cors_allowed_origins,
341
342
343
344
345
346
347
348
349
350
351
            retry_max_retries,
            retry_initial_backoff_ms,
            retry_max_backoff_ms,
            retry_backoff_multiplier,
            retry_jitter_factor,
            disable_retries,
            cb_failure_threshold,
            cb_success_threshold,
            cb_timeout_duration_secs,
            cb_window_duration_secs,
            disable_circuit_breaker,
352
353
354
355
356
            health_failure_threshold,
            health_success_threshold,
            health_check_timeout_secs,
            health_check_interval_secs,
            health_check_endpoint,
357
        })
358
359
360
    }

    fn start(&self) -> PyResult<()> {
361
362
363
364
        // Convert to RouterConfig and validate
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
365

366
367
368
369
370
371
372
        // Validate the configuration
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
373

374
375
376
377
378
379
380
381
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
382
383
384
385
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
386
387
388
389
390
            })
        } else {
            None
        };

391
392
393
394
395
396
397
398
399
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

400
401
402
403
404
405
        // Use tokio runtime instead of actix-web System for better compatibility
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        // Block on the async startup function
        runtime.block_on(async move {
406
407
408
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
409
                router_config,
410
                max_payload_size: self.max_payload_size,
411
                log_dir: self.log_dir.clone(),
412
                log_level: self.log_level.clone(),
413
                service_discovery_config,
414
                prometheus_config,
415
                request_timeout_secs: self.request_timeout_secs,
416
                request_id_headers: self.request_id_headers.clone(),
417
418
            })
            .await
419
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
420
        })
421
422
423
424
    }
}

#[pymodule]
425
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
426
    m.add_class::<PolicyType>()?;
427
428
    m.add_class::<Router>()?;
    Ok(())
429
}