lib.rs 16.2 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5
pub mod core;
6
7
#[cfg(feature = "grpc-client")]
pub mod grpc;
8
pub mod metrics;
9
pub mod middleware;
10
pub mod policies;
11
pub mod protocols;
12
pub mod reasoning_parser;
13
pub mod routers;
14
pub mod server;
15
pub mod service_discovery;
16
pub mod tokenizer;
17
pub mod tool_parser;
18
pub mod tree;
19
use crate::metrics::PrometheusConfig;
20

21
#[pyclass(eq)]
22
#[derive(Clone, PartialEq, Debug)]
23
24
25
pub enum PolicyType {
    Random,
    RoundRobin,
26
    CacheAware,
27
    PowerOfTwo,
28
29
}

30
#[pyclass]
31
#[derive(Debug, Clone, PartialEq)]
32
33
34
35
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
36
    policy: PolicyType,
37
    worker_startup_timeout_secs: u64,
38
    worker_startup_check_interval: u64,
39
    cache_threshold: f32,
40
41
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
42
43
    eviction_interval_secs: u64,
    max_tree_size: usize,
44
    max_payload_size: usize,
45
46
    dp_aware: bool,
    api_key: Option<String>,
47
    log_dir: Option<String>,
48
    log_level: Option<String>,
49
50
51
52
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
53
54
55
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
56
57
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
58
    request_timeout_secs: u64,
59
    request_id_headers: Option<Vec<String>>,
60
    pd_disaggregation: bool,
61
62
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
63
64
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
65
66
    max_concurrent_requests: usize,
    cors_allowed_origins: Vec<String>,
67
68
69
70
71
72
73
74
75
76
77
78
79
    // Retry configuration
    retry_max_retries: u32,
    retry_initial_backoff_ms: u64,
    retry_max_backoff_ms: u64,
    retry_backoff_multiplier: f32,
    retry_jitter_factor: f32,
    disable_retries: bool,
    // Circuit breaker configuration
    cb_failure_threshold: u32,
    cb_success_threshold: u32,
    cb_timeout_duration_secs: u64,
    cb_window_duration_secs: u64,
    disable_circuit_breaker: bool,
80
81
82
83
84
85
    // Health check configuration
    health_failure_threshold: u32,
    health_success_threshold: u32,
    health_check_timeout_secs: u64,
    health_check_interval_secs: u64,
    health_check_endpoint: String,
86
87
    // IGW (Inference Gateway) configuration
    enable_igw: bool,
88
89
90
    queue_size: usize,
    queue_timeout_secs: u64,
    rate_limit_tokens_per_second: Option<usize>,
91
92
}

93
94
95
96
97
98
99
impl Router {
    /// Convert PyO3 Router to RouterConfig
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        // Convert policy helper function
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
                    load_check_interval_secs: 5, // Default value
                },
            }
        };

118
        // Determine routing mode
119
120
121
122
123
124
        let mode = if self.enable_igw {
            // IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
            RoutingMode::Regular {
                worker_urls: vec![],
            }
        } else if self.pd_disaggregation {
125
126
127
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
128
129
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
130
131
132
133
134
135
136
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

137
138
        // Convert main policy
        let policy = convert_policy(&self.policy);
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173

        // Service discovery configuration
        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        // Metrics configuration
        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
174
175
            dp_aware: self.dp_aware,
            api_key: self.api_key.clone(),
176
177
178
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
179
            log_level: self.log_level.clone(),
180
            request_id_headers: self.request_id_headers.clone(),
181
            max_concurrent_requests: self.max_concurrent_requests,
182
183
184
            queue_size: self.queue_size,
            queue_timeout_secs: self.queue_timeout_secs,
            rate_limit_tokens_per_second: self.rate_limit_tokens_per_second,
185
            cors_allowed_origins: self.cors_allowed_origins.clone(),
186
187
188
189
190
191
192
193
194
195
196
197
198
            retry: config::RetryConfig {
                max_retries: self.retry_max_retries,
                initial_backoff_ms: self.retry_initial_backoff_ms,
                max_backoff_ms: self.retry_max_backoff_ms,
                backoff_multiplier: self.retry_backoff_multiplier,
                jitter_factor: self.retry_jitter_factor,
            },
            circuit_breaker: config::CircuitBreakerConfig {
                failure_threshold: self.cb_failure_threshold,
                success_threshold: self.cb_success_threshold,
                timeout_duration_secs: self.cb_timeout_duration_secs,
                window_duration_secs: self.cb_window_duration_secs,
            },
199
200
            disable_retries: self.disable_retries,
            disable_circuit_breaker: self.disable_circuit_breaker,
201
202
203
204
205
206
207
            health_check: config::HealthCheckConfig {
                failure_threshold: self.health_failure_threshold,
                success_threshold: self.health_success_threshold,
                timeout_secs: self.health_check_timeout_secs,
                check_interval_secs: self.health_check_interval_secs,
                endpoint: self.health_check_endpoint.clone(),
            },
208
            enable_igw: self.enable_igw,
209
210
211
212
        })
    }
}

213
214
215
#[pymethods]
impl Router {
    #[new]
216
217
218
219
220
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
221
222
223
224
225
226
227
228
        worker_startup_timeout_secs = 600,
        worker_startup_check_interval = 30,
        cache_threshold = 0.3,
        balance_abs_threshold = 64,
        balance_rel_threshold = 1.5,
        eviction_interval_secs = 120,
        max_tree_size = 2usize.pow(26),
        max_payload_size = 512 * 1024 * 1024,  // 512MB default for large batches
229
230
        dp_aware = false,
        api_key = None,
231
        log_dir = None,
232
        log_level = None,
233
234
235
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
236
        service_discovery_namespace = None,
237
238
239
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
240
        prometheus_port = None,
241
        prometheus_host = None,
242
        request_timeout_secs = 1800,  // Add configurable request timeout
243
        request_id_headers = None,  // Custom request ID headers
244
        pd_disaggregation = false,  // New flag for PD mode
245
        prefill_urls = None,
246
247
        decode_urls = None,
        prefill_policy = None,
248
        decode_policy = None,
249
        max_concurrent_requests = 256,
250
251
        cors_allowed_origins = vec![],
        // Retry defaults
252
253
254
255
256
        retry_max_retries = 5,
        retry_initial_backoff_ms = 50,
        retry_max_backoff_ms = 30_000,
        retry_backoff_multiplier = 1.5,
        retry_jitter_factor = 0.2,
257
258
        disable_retries = false,
        // Circuit breaker defaults
259
260
261
262
        cb_failure_threshold = 10,
        cb_success_threshold = 3,
        cb_timeout_duration_secs = 60,
        cb_window_duration_secs = 120,
263
        disable_circuit_breaker = false,
264
265
266
267
268
269
        // Health check defaults
        health_failure_threshold = 3,
        health_success_threshold = 2,
        health_check_timeout_secs = 5,
        health_check_interval_secs = 60,
        health_check_endpoint = String::from("/health"),
270
271
        // IGW defaults
        enable_igw = false,
272
273
274
        queue_size = 100,
        queue_timeout_secs = 60,
        rate_limit_tokens_per_second = None,
275
    ))]
276
    #[allow(clippy::too_many_arguments)]
277
278
279
280
281
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
282
        worker_startup_timeout_secs: u64,
283
        worker_startup_check_interval: u64,
284
        cache_threshold: f32,
285
286
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
287
288
        eviction_interval_secs: u64,
        max_tree_size: usize,
289
        max_payload_size: usize,
290
291
        dp_aware: bool,
        api_key: Option<String>,
292
        log_dir: Option<String>,
293
        log_level: Option<String>,
294
295
296
297
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
298
299
300
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
301
302
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
303
        request_timeout_secs: u64,
304
        request_id_headers: Option<Vec<String>>,
305
        pd_disaggregation: bool,
306
307
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
308
309
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
310
311
        max_concurrent_requests: usize,
        cors_allowed_origins: Vec<String>,
312
313
314
315
316
317
318
319
320
321
322
        retry_max_retries: u32,
        retry_initial_backoff_ms: u64,
        retry_max_backoff_ms: u64,
        retry_backoff_multiplier: f32,
        retry_jitter_factor: f32,
        disable_retries: bool,
        cb_failure_threshold: u32,
        cb_success_threshold: u32,
        cb_timeout_duration_secs: u64,
        cb_window_duration_secs: u64,
        disable_circuit_breaker: bool,
323
324
325
326
327
        health_failure_threshold: u32,
        health_success_threshold: u32,
        health_check_timeout_secs: u64,
        health_check_interval_secs: u64,
        health_check_endpoint: String,
328
        enable_igw: bool,
329
330
331
        queue_size: usize,
        queue_timeout_secs: u64,
        rate_limit_tokens_per_second: Option<usize>,
332
333
    ) -> PyResult<Self> {
        Ok(Router {
334
335
336
            host,
            port,
            worker_urls,
337
            policy,
338
            worker_startup_timeout_secs,
339
            worker_startup_check_interval,
340
            cache_threshold,
341
342
            balance_abs_threshold,
            balance_rel_threshold,
343
344
            eviction_interval_secs,
            max_tree_size,
345
            max_payload_size,
346
347
            dp_aware,
            api_key,
348
            log_dir,
349
            log_level,
350
351
352
353
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
354
355
356
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
357
358
            prometheus_port,
            prometheus_host,
359
            request_timeout_secs,
360
            request_id_headers,
361
            pd_disaggregation,
362
363
            prefill_urls,
            decode_urls,
364
365
            prefill_policy,
            decode_policy,
366
367
            max_concurrent_requests,
            cors_allowed_origins,
368
369
370
371
372
373
374
375
376
377
378
            retry_max_retries,
            retry_initial_backoff_ms,
            retry_max_backoff_ms,
            retry_backoff_multiplier,
            retry_jitter_factor,
            disable_retries,
            cb_failure_threshold,
            cb_success_threshold,
            cb_timeout_duration_secs,
            cb_window_duration_secs,
            disable_circuit_breaker,
379
380
381
382
383
            health_failure_threshold,
            health_success_threshold,
            health_check_timeout_secs,
            health_check_interval_secs,
            health_check_endpoint,
384
            enable_igw,
385
386
387
            queue_size,
            queue_timeout_secs,
            rate_limit_tokens_per_second,
388
        })
389
390
391
    }

    fn start(&self) -> PyResult<()> {
392
393
394
395
        // Convert to RouterConfig and validate
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
396

397
398
399
400
401
402
403
        // Validate the configuration
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
404

405
406
407
408
409
410
411
412
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
413
414
415
416
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
417
418
419
420
421
            })
        } else {
            None
        };

422
423
424
425
426
427
428
429
430
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

431
432
433
434
435
436
        // Use tokio runtime instead of actix-web System for better compatibility
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        // Block on the async startup function
        runtime.block_on(async move {
437
438
439
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
440
                router_config,
441
                max_payload_size: self.max_payload_size,
442
                log_dir: self.log_dir.clone(),
443
                log_level: self.log_level.clone(),
444
                service_discovery_config,
445
                prometheus_config,
446
                request_timeout_secs: self.request_timeout_secs,
447
                request_id_headers: self.request_id_headers.clone(),
448
449
            })
            .await
450
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
451
        })
452
453
454
455
    }
}

#[pymodule]
456
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
457
    m.add_class::<PolicyType>()?;
458
459
    m.add_class::<Router>()?;
    Ok(())
460
}