lib.rs 16.3 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5
pub mod core;
6
7
#[cfg(feature = "grpc-client")]
pub mod grpc;
8
pub mod mcp;
9
pub mod metrics;
10
pub mod middleware;
11
pub mod policies;
12
pub mod protocols;
13
pub mod reasoning_parser;
14
pub mod routers;
15
pub mod server;
16
pub mod service_discovery;
17
pub mod tokenizer;
18
pub mod tool_parser;
19
pub mod tree;
20
use crate::metrics::PrometheusConfig;
21

22
#[pyclass(eq)]
23
#[derive(Clone, PartialEq, Debug)]
24
25
26
pub enum PolicyType {
    Random,
    RoundRobin,
27
    CacheAware,
28
    PowerOfTwo,
29
30
}

31
#[pyclass]
32
#[derive(Debug, Clone, PartialEq)]
33
34
35
36
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
37
    policy: PolicyType,
38
    worker_startup_timeout_secs: u64,
39
    worker_startup_check_interval: u64,
40
    cache_threshold: f32,
41
42
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
43
44
    eviction_interval_secs: u64,
    max_tree_size: usize,
45
    max_payload_size: usize,
46
47
    dp_aware: bool,
    api_key: Option<String>,
48
    log_dir: Option<String>,
49
    log_level: Option<String>,
50
51
52
53
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
54
55
56
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
57
58
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
59
    request_timeout_secs: u64,
60
    request_id_headers: Option<Vec<String>>,
61
    pd_disaggregation: bool,
62
63
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
64
65
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
66
67
    max_concurrent_requests: usize,
    cors_allowed_origins: Vec<String>,
68
69
70
71
72
73
74
75
76
77
78
79
80
    // Retry configuration
    retry_max_retries: u32,
    retry_initial_backoff_ms: u64,
    retry_max_backoff_ms: u64,
    retry_backoff_multiplier: f32,
    retry_jitter_factor: f32,
    disable_retries: bool,
    // Circuit breaker configuration
    cb_failure_threshold: u32,
    cb_success_threshold: u32,
    cb_timeout_duration_secs: u64,
    cb_window_duration_secs: u64,
    disable_circuit_breaker: bool,
81
82
83
84
85
86
    // Health check configuration
    health_failure_threshold: u32,
    health_success_threshold: u32,
    health_check_timeout_secs: u64,
    health_check_interval_secs: u64,
    health_check_endpoint: String,
87
88
    // IGW (Inference Gateway) configuration
    enable_igw: bool,
89
90
91
    queue_size: usize,
    queue_timeout_secs: u64,
    rate_limit_tokens_per_second: Option<usize>,
92
93
}

94
95
96
97
98
99
100
impl Router {
    /// Convert PyO3 Router to RouterConfig
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        // Convert policy helper function
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
                    load_check_interval_secs: 5, // Default value
                },
            }
        };

119
        // Determine routing mode
120
121
122
123
124
125
        let mode = if self.enable_igw {
            // IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
            RoutingMode::Regular {
                worker_urls: vec![],
            }
        } else if self.pd_disaggregation {
126
127
128
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
129
130
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
131
132
133
134
135
136
137
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

138
139
        // Convert main policy
        let policy = convert_policy(&self.policy);
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174

        // Service discovery configuration
        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        // Metrics configuration
        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
175
176
            dp_aware: self.dp_aware,
            api_key: self.api_key.clone(),
177
178
179
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
180
            log_level: self.log_level.clone(),
181
            request_id_headers: self.request_id_headers.clone(),
182
            max_concurrent_requests: self.max_concurrent_requests,
183
184
185
            queue_size: self.queue_size,
            queue_timeout_secs: self.queue_timeout_secs,
            rate_limit_tokens_per_second: self.rate_limit_tokens_per_second,
186
            cors_allowed_origins: self.cors_allowed_origins.clone(),
187
188
189
190
191
192
193
194
195
196
197
198
199
            retry: config::RetryConfig {
                max_retries: self.retry_max_retries,
                initial_backoff_ms: self.retry_initial_backoff_ms,
                max_backoff_ms: self.retry_max_backoff_ms,
                backoff_multiplier: self.retry_backoff_multiplier,
                jitter_factor: self.retry_jitter_factor,
            },
            circuit_breaker: config::CircuitBreakerConfig {
                failure_threshold: self.cb_failure_threshold,
                success_threshold: self.cb_success_threshold,
                timeout_duration_secs: self.cb_timeout_duration_secs,
                window_duration_secs: self.cb_window_duration_secs,
            },
200
201
            disable_retries: self.disable_retries,
            disable_circuit_breaker: self.disable_circuit_breaker,
202
203
204
205
206
207
208
            health_check: config::HealthCheckConfig {
                failure_threshold: self.health_failure_threshold,
                success_threshold: self.health_success_threshold,
                timeout_secs: self.health_check_timeout_secs,
                check_interval_secs: self.health_check_interval_secs,
                endpoint: self.health_check_endpoint.clone(),
            },
209
            enable_igw: self.enable_igw,
210
211
212
213
        })
    }
}

214
215
216
#[pymethods]
impl Router {
    #[new]
217
218
219
220
221
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
222
223
224
225
226
227
228
229
        worker_startup_timeout_secs = 600,
        worker_startup_check_interval = 30,
        cache_threshold = 0.3,
        balance_abs_threshold = 64,
        balance_rel_threshold = 1.5,
        eviction_interval_secs = 120,
        max_tree_size = 2usize.pow(26),
        max_payload_size = 512 * 1024 * 1024,  // 512MB default for large batches
230
231
        dp_aware = false,
        api_key = None,
232
        log_dir = None,
233
        log_level = None,
234
235
236
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
237
        service_discovery_namespace = None,
238
239
240
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
241
        prometheus_port = None,
242
        prometheus_host = None,
243
        request_timeout_secs = 1800,  // Add configurable request timeout
244
        request_id_headers = None,  // Custom request ID headers
245
        pd_disaggregation = false,  // New flag for PD mode
246
        prefill_urls = None,
247
248
        decode_urls = None,
        prefill_policy = None,
249
        decode_policy = None,
250
        max_concurrent_requests = 256,
251
252
        cors_allowed_origins = vec![],
        // Retry defaults
253
254
255
256
257
        retry_max_retries = 5,
        retry_initial_backoff_ms = 50,
        retry_max_backoff_ms = 30_000,
        retry_backoff_multiplier = 1.5,
        retry_jitter_factor = 0.2,
258
259
        disable_retries = false,
        // Circuit breaker defaults
260
261
262
263
        cb_failure_threshold = 10,
        cb_success_threshold = 3,
        cb_timeout_duration_secs = 60,
        cb_window_duration_secs = 120,
264
        disable_circuit_breaker = false,
265
266
267
268
269
270
        // Health check defaults
        health_failure_threshold = 3,
        health_success_threshold = 2,
        health_check_timeout_secs = 5,
        health_check_interval_secs = 60,
        health_check_endpoint = String::from("/health"),
271
272
        // IGW defaults
        enable_igw = false,
273
274
275
        queue_size = 100,
        queue_timeout_secs = 60,
        rate_limit_tokens_per_second = None,
276
    ))]
277
    #[allow(clippy::too_many_arguments)]
278
279
280
281
282
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
283
        worker_startup_timeout_secs: u64,
284
        worker_startup_check_interval: u64,
285
        cache_threshold: f32,
286
287
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
288
289
        eviction_interval_secs: u64,
        max_tree_size: usize,
290
        max_payload_size: usize,
291
292
        dp_aware: bool,
        api_key: Option<String>,
293
        log_dir: Option<String>,
294
        log_level: Option<String>,
295
296
297
298
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
299
300
301
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
302
303
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
304
        request_timeout_secs: u64,
305
        request_id_headers: Option<Vec<String>>,
306
        pd_disaggregation: bool,
307
308
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
309
310
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
311
312
        max_concurrent_requests: usize,
        cors_allowed_origins: Vec<String>,
313
314
315
316
317
318
319
320
321
322
323
        retry_max_retries: u32,
        retry_initial_backoff_ms: u64,
        retry_max_backoff_ms: u64,
        retry_backoff_multiplier: f32,
        retry_jitter_factor: f32,
        disable_retries: bool,
        cb_failure_threshold: u32,
        cb_success_threshold: u32,
        cb_timeout_duration_secs: u64,
        cb_window_duration_secs: u64,
        disable_circuit_breaker: bool,
324
325
326
327
328
        health_failure_threshold: u32,
        health_success_threshold: u32,
        health_check_timeout_secs: u64,
        health_check_interval_secs: u64,
        health_check_endpoint: String,
329
        enable_igw: bool,
330
331
332
        queue_size: usize,
        queue_timeout_secs: u64,
        rate_limit_tokens_per_second: Option<usize>,
333
334
    ) -> PyResult<Self> {
        Ok(Router {
335
336
337
            host,
            port,
            worker_urls,
338
            policy,
339
            worker_startup_timeout_secs,
340
            worker_startup_check_interval,
341
            cache_threshold,
342
343
            balance_abs_threshold,
            balance_rel_threshold,
344
345
            eviction_interval_secs,
            max_tree_size,
346
            max_payload_size,
347
348
            dp_aware,
            api_key,
349
            log_dir,
350
            log_level,
351
352
353
354
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
355
356
357
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
358
359
            prometheus_port,
            prometheus_host,
360
            request_timeout_secs,
361
            request_id_headers,
362
            pd_disaggregation,
363
364
            prefill_urls,
            decode_urls,
365
366
            prefill_policy,
            decode_policy,
367
368
            max_concurrent_requests,
            cors_allowed_origins,
369
370
371
372
373
374
375
376
377
378
379
            retry_max_retries,
            retry_initial_backoff_ms,
            retry_max_backoff_ms,
            retry_backoff_multiplier,
            retry_jitter_factor,
            disable_retries,
            cb_failure_threshold,
            cb_success_threshold,
            cb_timeout_duration_secs,
            cb_window_duration_secs,
            disable_circuit_breaker,
380
381
382
383
384
            health_failure_threshold,
            health_success_threshold,
            health_check_timeout_secs,
            health_check_interval_secs,
            health_check_endpoint,
385
            enable_igw,
386
387
388
            queue_size,
            queue_timeout_secs,
            rate_limit_tokens_per_second,
389
        })
390
391
392
    }

    fn start(&self) -> PyResult<()> {
393
394
395
396
        // Convert to RouterConfig and validate
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
397

398
399
400
401
402
403
404
        // Validate the configuration
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
405

406
407
408
409
410
411
412
413
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
414
415
416
417
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
418
419
420
421
422
            })
        } else {
            None
        };

423
424
425
426
427
428
429
430
431
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

432
433
434
435
436
437
        // Use tokio runtime instead of actix-web System for better compatibility
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        // Block on the async startup function
        runtime.block_on(async move {
438
439
440
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
441
                router_config,
442
                max_payload_size: self.max_payload_size,
443
                log_dir: self.log_dir.clone(),
444
                log_level: self.log_level.clone(),
445
                service_discovery_config,
446
                prometheus_config,
447
                request_timeout_secs: self.request_timeout_secs,
448
                request_id_headers: self.request_id_headers.clone(),
449
450
            })
            .await
451
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
452
        })
453
454
455
456
    }
}

#[pymodule]
457
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
458
    m.add_class::<PolicyType>()?;
459
460
    m.add_class::<Router>()?;
    Ok(())
461
}