lib.rs 15.2 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5
pub mod core;
6
7
#[cfg(feature = "grpc-client")]
pub mod grpc;
8
pub mod metrics;
9
pub mod middleware;
10
pub mod policies;
11
pub mod protocols;
12
pub mod reasoning_parser;
13
pub mod routers;
14
pub mod server;
15
pub mod service_discovery;
16
pub mod tokenizer;
17
pub mod tree;
18
use crate::metrics::PrometheusConfig;
19

20
#[pyclass(eq)]
21
#[derive(Clone, PartialEq, Debug)]
22
23
24
pub enum PolicyType {
    Random,
    RoundRobin,
25
    CacheAware,
26
    PowerOfTwo,
27
28
}

29
#[pyclass]
30
#[derive(Debug, Clone, PartialEq)]
31
32
33
34
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
35
    policy: PolicyType,
36
    worker_startup_timeout_secs: u64,
37
    worker_startup_check_interval: u64,
38
    cache_threshold: f32,
39
40
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
41
42
    eviction_interval_secs: u64,
    max_tree_size: usize,
43
    max_payload_size: usize,
44
45
    dp_aware: bool,
    api_key: Option<String>,
46
    log_dir: Option<String>,
47
    log_level: Option<String>,
48
49
50
51
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
52
53
54
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
55
56
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
57
    request_timeout_secs: u64,
58
    request_id_headers: Option<Vec<String>>,
59
    pd_disaggregation: bool,
60
61
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
62
63
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
64
65
    max_concurrent_requests: usize,
    cors_allowed_origins: Vec<String>,
66
67
68
69
70
71
72
73
74
75
76
77
78
    // Retry configuration
    retry_max_retries: u32,
    retry_initial_backoff_ms: u64,
    retry_max_backoff_ms: u64,
    retry_backoff_multiplier: f32,
    retry_jitter_factor: f32,
    disable_retries: bool,
    // Circuit breaker configuration
    cb_failure_threshold: u32,
    cb_success_threshold: u32,
    cb_timeout_duration_secs: u64,
    cb_window_duration_secs: u64,
    disable_circuit_breaker: bool,
79
80
81
82
83
84
    // Health check configuration
    health_failure_threshold: u32,
    health_success_threshold: u32,
    health_check_timeout_secs: u64,
    health_check_interval_secs: u64,
    health_check_endpoint: String,
85
86
}

87
88
89
90
91
92
93
impl Router {
    /// Convert PyO3 Router to RouterConfig
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        // Convert policy helper function
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
                    load_check_interval_secs: 5, // Default value
                },
            }
        };

112
113
114
115
116
        // Determine routing mode
        let mode = if self.pd_disaggregation {
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
117
118
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
119
120
121
122
123
124
125
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

126
127
        // Convert main policy
        let policy = convert_policy(&self.policy);
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162

        // Service discovery configuration
        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        // Metrics configuration
        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
163
164
            dp_aware: self.dp_aware,
            api_key: self.api_key.clone(),
165
166
167
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
168
            log_level: self.log_level.clone(),
169
            request_id_headers: self.request_id_headers.clone(),
170
171
            max_concurrent_requests: self.max_concurrent_requests,
            cors_allowed_origins: self.cors_allowed_origins.clone(),
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
            retry: config::RetryConfig {
                max_retries: self.retry_max_retries,
                initial_backoff_ms: self.retry_initial_backoff_ms,
                max_backoff_ms: self.retry_max_backoff_ms,
                backoff_multiplier: self.retry_backoff_multiplier,
                jitter_factor: self.retry_jitter_factor,
            },
            circuit_breaker: config::CircuitBreakerConfig {
                failure_threshold: self.cb_failure_threshold,
                success_threshold: self.cb_success_threshold,
                timeout_duration_secs: self.cb_timeout_duration_secs,
                window_duration_secs: self.cb_window_duration_secs,
            },
            disable_retries: false,
            disable_circuit_breaker: false,
187
188
189
190
191
192
193
            health_check: config::HealthCheckConfig {
                failure_threshold: self.health_failure_threshold,
                success_threshold: self.health_success_threshold,
                timeout_secs: self.health_check_timeout_secs,
                check_interval_secs: self.health_check_interval_secs,
                endpoint: self.health_check_endpoint.clone(),
            },
194
195
196
197
        })
    }
}

198
199
200
#[pymethods]
impl Router {
    #[new]
201
202
203
204
205
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
206
207
208
209
210
211
212
213
        worker_startup_timeout_secs = 600,
        worker_startup_check_interval = 30,
        cache_threshold = 0.3,
        balance_abs_threshold = 64,
        balance_rel_threshold = 1.5,
        eviction_interval_secs = 120,
        max_tree_size = 2usize.pow(26),
        max_payload_size = 512 * 1024 * 1024,  // 512MB default for large batches
214
215
        dp_aware = false,
        api_key = None,
216
        log_dir = None,
217
        log_level = None,
218
219
220
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
221
        service_discovery_namespace = None,
222
223
224
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
225
        prometheus_port = None,
226
        prometheus_host = None,
227
        request_timeout_secs = 1800,  // Add configurable request timeout
228
        request_id_headers = None,  // Custom request ID headers
229
        pd_disaggregation = false,  // New flag for PD mode
230
        prefill_urls = None,
231
232
        decode_urls = None,
        prefill_policy = None,
233
        decode_policy = None,
234
        max_concurrent_requests = 256,
235
236
        cors_allowed_origins = vec![],
        // Retry defaults
237
238
239
240
241
        retry_max_retries = 5,
        retry_initial_backoff_ms = 50,
        retry_max_backoff_ms = 30_000,
        retry_backoff_multiplier = 1.5,
        retry_jitter_factor = 0.2,
242
243
        disable_retries = false,
        // Circuit breaker defaults
244
245
246
247
        cb_failure_threshold = 10,
        cb_success_threshold = 3,
        cb_timeout_duration_secs = 60,
        cb_window_duration_secs = 120,
248
        disable_circuit_breaker = false,
249
250
251
252
253
254
        // Health check defaults
        health_failure_threshold = 3,
        health_success_threshold = 2,
        health_check_timeout_secs = 5,
        health_check_interval_secs = 60,
        health_check_endpoint = String::from("/health"),
255
    ))]
256
    #[allow(clippy::too_many_arguments)]
257
258
259
260
261
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
262
        worker_startup_timeout_secs: u64,
263
        worker_startup_check_interval: u64,
264
        cache_threshold: f32,
265
266
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
267
268
        eviction_interval_secs: u64,
        max_tree_size: usize,
269
        max_payload_size: usize,
270
271
        dp_aware: bool,
        api_key: Option<String>,
272
        log_dir: Option<String>,
273
        log_level: Option<String>,
274
275
276
277
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
278
279
280
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
281
282
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
283
        request_timeout_secs: u64,
284
        request_id_headers: Option<Vec<String>>,
285
        pd_disaggregation: bool,
286
287
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
288
289
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
290
291
        max_concurrent_requests: usize,
        cors_allowed_origins: Vec<String>,
292
293
294
295
296
297
298
299
300
301
302
        retry_max_retries: u32,
        retry_initial_backoff_ms: u64,
        retry_max_backoff_ms: u64,
        retry_backoff_multiplier: f32,
        retry_jitter_factor: f32,
        disable_retries: bool,
        cb_failure_threshold: u32,
        cb_success_threshold: u32,
        cb_timeout_duration_secs: u64,
        cb_window_duration_secs: u64,
        disable_circuit_breaker: bool,
303
304
305
306
307
        health_failure_threshold: u32,
        health_success_threshold: u32,
        health_check_timeout_secs: u64,
        health_check_interval_secs: u64,
        health_check_endpoint: String,
308
309
    ) -> PyResult<Self> {
        Ok(Router {
310
311
312
            host,
            port,
            worker_urls,
313
            policy,
314
            worker_startup_timeout_secs,
315
            worker_startup_check_interval,
316
            cache_threshold,
317
318
            balance_abs_threshold,
            balance_rel_threshold,
319
320
            eviction_interval_secs,
            max_tree_size,
321
            max_payload_size,
322
323
            dp_aware,
            api_key,
324
            log_dir,
325
            log_level,
326
327
328
329
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
330
331
332
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
333
334
            prometheus_port,
            prometheus_host,
335
            request_timeout_secs,
336
            request_id_headers,
337
            pd_disaggregation,
338
339
            prefill_urls,
            decode_urls,
340
341
            prefill_policy,
            decode_policy,
342
343
            max_concurrent_requests,
            cors_allowed_origins,
344
345
346
347
348
349
350
351
352
353
354
            retry_max_retries,
            retry_initial_backoff_ms,
            retry_max_backoff_ms,
            retry_backoff_multiplier,
            retry_jitter_factor,
            disable_retries,
            cb_failure_threshold,
            cb_success_threshold,
            cb_timeout_duration_secs,
            cb_window_duration_secs,
            disable_circuit_breaker,
355
356
357
358
359
            health_failure_threshold,
            health_success_threshold,
            health_check_timeout_secs,
            health_check_interval_secs,
            health_check_endpoint,
360
        })
361
362
363
    }

    fn start(&self) -> PyResult<()> {
364
365
366
367
        // Convert to RouterConfig and validate
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
368

369
370
371
372
373
374
375
        // Validate the configuration
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
376

377
378
379
380
381
382
383
384
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
385
386
387
388
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
389
390
391
392
393
            })
        } else {
            None
        };

394
395
396
397
398
399
400
401
402
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

403
404
405
406
407
408
        // Use tokio runtime instead of actix-web System for better compatibility
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        // Block on the async startup function
        runtime.block_on(async move {
409
410
411
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
412
                router_config,
413
                max_payload_size: self.max_payload_size,
414
                log_dir: self.log_dir.clone(),
415
                log_level: self.log_level.clone(),
416
                service_discovery_config,
417
                prometheus_config,
418
                request_timeout_secs: self.request_timeout_secs,
419
                request_id_headers: self.request_id_headers.clone(),
420
421
            })
            .await
422
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
423
        })
424
425
426
427
    }
}

#[pymodule]
428
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
429
    m.add_class::<PolicyType>()?;
430
431
    m.add_class::<Router>()?;
    Ok(())
432
}