lib.rs 15.6 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5
pub mod core;
6
7
#[cfg(feature = "grpc-client")]
pub mod grpc;
8
pub mod metrics;
9
pub mod middleware;
10
pub mod policies;
11
pub mod protocols;
12
pub mod reasoning_parser;
13
pub mod routers;
14
pub mod server;
15
pub mod service_discovery;
16
pub mod tokenizer;
17
pub mod tool_parser;
18
pub mod tree;
19
use crate::metrics::PrometheusConfig;
20

21
#[pyclass(eq)]
22
#[derive(Clone, PartialEq, Debug)]
23
24
25
pub enum PolicyType {
    Random,
    RoundRobin,
26
    CacheAware,
27
    PowerOfTwo,
28
29
}

30
#[pyclass]
31
#[derive(Debug, Clone, PartialEq)]
32
33
34
35
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
36
    policy: PolicyType,
37
    worker_startup_timeout_secs: u64,
38
    worker_startup_check_interval: u64,
39
    cache_threshold: f32,
40
41
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
42
43
    eviction_interval_secs: u64,
    max_tree_size: usize,
44
    max_payload_size: usize,
45
46
    dp_aware: bool,
    api_key: Option<String>,
47
    log_dir: Option<String>,
48
    log_level: Option<String>,
49
50
51
52
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
53
54
55
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
56
57
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
58
    request_timeout_secs: u64,
59
    request_id_headers: Option<Vec<String>>,
60
    pd_disaggregation: bool,
61
62
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
63
64
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
65
66
    max_concurrent_requests: usize,
    cors_allowed_origins: Vec<String>,
67
68
69
70
71
72
73
74
75
76
77
78
79
    // Retry configuration
    retry_max_retries: u32,
    retry_initial_backoff_ms: u64,
    retry_max_backoff_ms: u64,
    retry_backoff_multiplier: f32,
    retry_jitter_factor: f32,
    disable_retries: bool,
    // Circuit breaker configuration
    cb_failure_threshold: u32,
    cb_success_threshold: u32,
    cb_timeout_duration_secs: u64,
    cb_window_duration_secs: u64,
    disable_circuit_breaker: bool,
80
81
82
83
84
85
    // Health check configuration
    health_failure_threshold: u32,
    health_success_threshold: u32,
    health_check_timeout_secs: u64,
    health_check_interval_secs: u64,
    health_check_endpoint: String,
86
87
    // IGW (Inference Gateway) configuration
    enable_igw: bool,
88
89
}

90
91
92
93
94
95
96
impl Router {
    /// Convert PyO3 Router to RouterConfig
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        // Convert policy helper function
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
                    load_check_interval_secs: 5, // Default value
                },
            }
        };

115
        // Determine routing mode
116
117
118
119
120
121
        let mode = if self.enable_igw {
            // IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
            RoutingMode::Regular {
                worker_urls: vec![],
            }
        } else if self.pd_disaggregation {
122
123
124
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
125
126
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
127
128
129
130
131
132
133
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

134
135
        // Convert main policy
        let policy = convert_policy(&self.policy);
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170

        // Service discovery configuration
        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        // Metrics configuration
        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
171
172
            dp_aware: self.dp_aware,
            api_key: self.api_key.clone(),
173
174
175
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
176
            log_level: self.log_level.clone(),
177
            request_id_headers: self.request_id_headers.clone(),
178
179
            max_concurrent_requests: self.max_concurrent_requests,
            cors_allowed_origins: self.cors_allowed_origins.clone(),
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
            retry: config::RetryConfig {
                max_retries: self.retry_max_retries,
                initial_backoff_ms: self.retry_initial_backoff_ms,
                max_backoff_ms: self.retry_max_backoff_ms,
                backoff_multiplier: self.retry_backoff_multiplier,
                jitter_factor: self.retry_jitter_factor,
            },
            circuit_breaker: config::CircuitBreakerConfig {
                failure_threshold: self.cb_failure_threshold,
                success_threshold: self.cb_success_threshold,
                timeout_duration_secs: self.cb_timeout_duration_secs,
                window_duration_secs: self.cb_window_duration_secs,
            },
            disable_retries: false,
            disable_circuit_breaker: false,
195
196
197
198
199
200
201
            health_check: config::HealthCheckConfig {
                failure_threshold: self.health_failure_threshold,
                success_threshold: self.health_success_threshold,
                timeout_secs: self.health_check_timeout_secs,
                check_interval_secs: self.health_check_interval_secs,
                endpoint: self.health_check_endpoint.clone(),
            },
202
            enable_igw: self.enable_igw,
203
204
205
206
        })
    }
}

207
208
209
#[pymethods]
impl Router {
    #[new]
210
211
212
213
214
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
215
216
217
218
219
220
221
222
        worker_startup_timeout_secs = 600,
        worker_startup_check_interval = 30,
        cache_threshold = 0.3,
        balance_abs_threshold = 64,
        balance_rel_threshold = 1.5,
        eviction_interval_secs = 120,
        max_tree_size = 2usize.pow(26),
        max_payload_size = 512 * 1024 * 1024,  // 512MB default for large batches
223
224
        dp_aware = false,
        api_key = None,
225
        log_dir = None,
226
        log_level = None,
227
228
229
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
230
        service_discovery_namespace = None,
231
232
233
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
234
        prometheus_port = None,
235
        prometheus_host = None,
236
        request_timeout_secs = 1800,  // Add configurable request timeout
237
        request_id_headers = None,  // Custom request ID headers
238
        pd_disaggregation = false,  // New flag for PD mode
239
        prefill_urls = None,
240
241
        decode_urls = None,
        prefill_policy = None,
242
        decode_policy = None,
243
        max_concurrent_requests = 256,
244
245
        cors_allowed_origins = vec![],
        // Retry defaults
246
247
248
249
250
        retry_max_retries = 5,
        retry_initial_backoff_ms = 50,
        retry_max_backoff_ms = 30_000,
        retry_backoff_multiplier = 1.5,
        retry_jitter_factor = 0.2,
251
252
        disable_retries = false,
        // Circuit breaker defaults
253
254
255
256
        cb_failure_threshold = 10,
        cb_success_threshold = 3,
        cb_timeout_duration_secs = 60,
        cb_window_duration_secs = 120,
257
        disable_circuit_breaker = false,
258
259
260
261
262
263
        // Health check defaults
        health_failure_threshold = 3,
        health_success_threshold = 2,
        health_check_timeout_secs = 5,
        health_check_interval_secs = 60,
        health_check_endpoint = String::from("/health"),
264
265
        // IGW defaults
        enable_igw = false,
266
    ))]
267
    #[allow(clippy::too_many_arguments)]
268
269
270
271
272
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
273
        worker_startup_timeout_secs: u64,
274
        worker_startup_check_interval: u64,
275
        cache_threshold: f32,
276
277
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
278
279
        eviction_interval_secs: u64,
        max_tree_size: usize,
280
        max_payload_size: usize,
281
282
        dp_aware: bool,
        api_key: Option<String>,
283
        log_dir: Option<String>,
284
        log_level: Option<String>,
285
286
287
288
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
289
290
291
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
292
293
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
294
        request_timeout_secs: u64,
295
        request_id_headers: Option<Vec<String>>,
296
        pd_disaggregation: bool,
297
298
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
299
300
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
301
302
        max_concurrent_requests: usize,
        cors_allowed_origins: Vec<String>,
303
304
305
306
307
308
309
310
311
312
313
        retry_max_retries: u32,
        retry_initial_backoff_ms: u64,
        retry_max_backoff_ms: u64,
        retry_backoff_multiplier: f32,
        retry_jitter_factor: f32,
        disable_retries: bool,
        cb_failure_threshold: u32,
        cb_success_threshold: u32,
        cb_timeout_duration_secs: u64,
        cb_window_duration_secs: u64,
        disable_circuit_breaker: bool,
314
315
316
317
318
        health_failure_threshold: u32,
        health_success_threshold: u32,
        health_check_timeout_secs: u64,
        health_check_interval_secs: u64,
        health_check_endpoint: String,
319
        enable_igw: bool,
320
321
    ) -> PyResult<Self> {
        Ok(Router {
322
323
324
            host,
            port,
            worker_urls,
325
            policy,
326
            worker_startup_timeout_secs,
327
            worker_startup_check_interval,
328
            cache_threshold,
329
330
            balance_abs_threshold,
            balance_rel_threshold,
331
332
            eviction_interval_secs,
            max_tree_size,
333
            max_payload_size,
334
335
            dp_aware,
            api_key,
336
            log_dir,
337
            log_level,
338
339
340
341
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
342
343
344
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
345
346
            prometheus_port,
            prometheus_host,
347
            request_timeout_secs,
348
            request_id_headers,
349
            pd_disaggregation,
350
351
            prefill_urls,
            decode_urls,
352
353
            prefill_policy,
            decode_policy,
354
355
            max_concurrent_requests,
            cors_allowed_origins,
356
357
358
359
360
361
362
363
364
365
366
            retry_max_retries,
            retry_initial_backoff_ms,
            retry_max_backoff_ms,
            retry_backoff_multiplier,
            retry_jitter_factor,
            disable_retries,
            cb_failure_threshold,
            cb_success_threshold,
            cb_timeout_duration_secs,
            cb_window_duration_secs,
            disable_circuit_breaker,
367
368
369
370
371
            health_failure_threshold,
            health_success_threshold,
            health_check_timeout_secs,
            health_check_interval_secs,
            health_check_endpoint,
372
            enable_igw,
373
        })
374
375
376
    }

    fn start(&self) -> PyResult<()> {
377
378
379
380
        // Convert to RouterConfig and validate
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
381

382
383
384
385
386
387
388
        // Validate the configuration
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
389

390
391
392
393
394
395
396
397
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
398
399
400
401
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
402
403
404
405
406
            })
        } else {
            None
        };

407
408
409
410
411
412
413
414
415
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

416
417
418
419
420
421
        // Use tokio runtime instead of actix-web System for better compatibility
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        // Block on the async startup function
        runtime.block_on(async move {
422
423
424
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
425
                router_config,
426
                max_payload_size: self.max_payload_size,
427
                log_dir: self.log_dir.clone(),
428
                log_level: self.log_level.clone(),
429
                service_discovery_config,
430
                prometheus_config,
431
                request_timeout_secs: self.request_timeout_secs,
432
                request_id_headers: self.request_id_headers.clone(),
433
434
            })
            .await
435
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
436
        })
437
438
439
440
    }
}

#[pymodule]
441
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
442
    m.add_class::<PolicyType>()?;
443
444
    m.add_class::<Router>()?;
    Ok(())
445
}