lib.rs 21.2 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5

6
pub mod core;
7
pub mod data_connector;
8
#[cfg(feature = "grpc-client")]
9
pub mod grpc_client;
10
pub mod mcp;
11
pub mod metrics;
12
pub mod middleware;
13
pub mod policies;
14
pub mod protocols;
15
pub mod reasoning_parser;
16
pub mod routers;
17
pub mod server;
18
pub mod service_discovery;
19
pub mod tokenizer;
20
pub mod tool_parser;
21
pub mod tree;
22
use crate::metrics::PrometheusConfig;
23

24
#[pyclass(eq)]
25
#[derive(Clone, PartialEq, Debug)]
26
27
28
pub enum PolicyType {
    Random,
    RoundRobin,
29
    CacheAware,
30
    PowerOfTwo,
31
32
}

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#[pyclass(eq)]
#[derive(Clone, PartialEq, Debug)]
pub enum BackendType {
    Sglang,
    Openai,
}

#[pyclass(eq)]
#[derive(Clone, PartialEq, Debug)]
pub enum HistoryBackendType {
    Memory,
    None,
    Oracle,
}

#[pyclass]
#[derive(Clone, PartialEq)]
pub struct PyOracleConfig {
    #[pyo3(get, set)]
    pub wallet_path: Option<String>,
    #[pyo3(get, set)]
    pub connect_descriptor: Option<String>,
    #[pyo3(get, set)]
    pub username: Option<String>,
    #[pyo3(get, set)]
    pub password: Option<String>,
    #[pyo3(get, set)]
    pub pool_min: usize,
    #[pyo3(get, set)]
    pub pool_max: usize,
    #[pyo3(get, set)]
    pub pool_timeout_secs: u64,
}

impl std::fmt::Debug for PyOracleConfig {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("PyOracleConfig")
            .field("wallet_path", &self.wallet_path)
            .field("connect_descriptor", &"<redacted>")
            .field("username", &self.username)
            .field("password", &"<redacted>")
            .field("pool_min", &self.pool_min)
            .field("pool_max", &self.pool_max)
            .field("pool_timeout_secs", &self.pool_timeout_secs)
            .finish()
    }
}

#[pymethods]
impl PyOracleConfig {
    #[new]
    #[pyo3(signature = (
        password = None,
        username = None,
        connect_descriptor = None,
        wallet_path = None,
        pool_min = 1,
        pool_max = 16,
        pool_timeout_secs = 30,
    ))]
    fn new(
        password: Option<String>,
        username: Option<String>,
        connect_descriptor: Option<String>,
        wallet_path: Option<String>,
        pool_min: usize,
        pool_max: usize,
        pool_timeout_secs: u64,
    ) -> PyResult<Self> {
        if pool_min == 0 {
            return Err(pyo3::exceptions::PyValueError::new_err(
                "pool_min must be at least 1",
            ));
        }
        if pool_max < pool_min {
            return Err(pyo3::exceptions::PyValueError::new_err(
                "pool_max must be >= pool_min",
            ));
        }

        Ok(PyOracleConfig {
            wallet_path,
            connect_descriptor,
            username,
            password,
            pool_min,
            pool_max,
            pool_timeout_secs,
        })
    }
}

impl PyOracleConfig {
    fn to_config_oracle(&self) -> config::OracleConfig {
        // Simple conversion - validation happens later in validate_oracle()
        config::OracleConfig {
            wallet_path: self.wallet_path.clone(),
            connect_descriptor: self.connect_descriptor.clone().unwrap_or_default(),
            username: self.username.clone().unwrap_or_default(),
            password: self.password.clone().unwrap_or_default(),
            pool_min: self.pool_min,
            pool_max: self.pool_max,
            pool_timeout_secs: self.pool_timeout_secs,
        }
    }
}

140
#[pyclass]
141
#[derive(Debug, Clone, PartialEq)]
142
143
144
145
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
146
    policy: PolicyType,
147
    worker_startup_timeout_secs: u64,
148
    worker_startup_check_interval: u64,
149
    cache_threshold: f32,
150
151
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
152
153
    eviction_interval_secs: u64,
    max_tree_size: usize,
154
    max_payload_size: usize,
155
156
    dp_aware: bool,
    api_key: Option<String>,
157
    log_dir: Option<String>,
158
    log_level: Option<String>,
159
160
161
162
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
163
164
165
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
166
167
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
168
    request_timeout_secs: u64,
169
    request_id_headers: Option<Vec<String>>,
170
    pd_disaggregation: bool,
171
172
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
173
174
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
175
    max_concurrent_requests: i32,
176
    cors_allowed_origins: Vec<String>,
177
178
179
180
181
182
183
184
185
186
187
    retry_max_retries: u32,
    retry_initial_backoff_ms: u64,
    retry_max_backoff_ms: u64,
    retry_backoff_multiplier: f32,
    retry_jitter_factor: f32,
    disable_retries: bool,
    cb_failure_threshold: u32,
    cb_success_threshold: u32,
    cb_timeout_duration_secs: u64,
    cb_window_duration_secs: u64,
    disable_circuit_breaker: bool,
188
189
190
191
192
    health_failure_threshold: u32,
    health_success_threshold: u32,
    health_check_timeout_secs: u64,
    health_check_interval_secs: u64,
    health_check_endpoint: String,
193
    enable_igw: bool,
194
195
    queue_size: usize,
    queue_timeout_secs: u64,
196
    rate_limit_tokens_per_second: Option<i32>,
197
198
199
    connection_mode: config::ConnectionMode,
    model_path: Option<String>,
    tokenizer_path: Option<String>,
200
    chat_template: Option<String>,
201
202
    reasoning_parser: Option<String>,
    tool_call_parser: Option<String>,
203
204
205
    backend: BackendType,
    history_backend: HistoryBackendType,
    oracle_config: Option<PyOracleConfig>,
206
207
}

208
impl Router {
209
210
211
212
213
214
215
216
217
218
    /// Determine connection mode from worker URLs
    fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode {
        for url in worker_urls {
            if url.starts_with("grpc://") || url.starts_with("grpcs://") {
                return config::ConnectionMode::Grpc;
            }
        }
        config::ConnectionMode::Http
    }

219
220
221
222
223
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

224
225
226
227
228
229
230
231
232
233
234
235
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
236
                    load_check_interval_secs: 5,
237
238
239
240
                },
            }
        };

241
242
243
244
        let mode = if self.enable_igw {
            RoutingMode::Regular {
                worker_urls: vec![],
            }
245
246
247
248
        } else if matches!(self.backend, BackendType::Openai) {
            RoutingMode::OpenAI {
                worker_urls: self.worker_urls.clone(),
            }
249
        } else if self.pd_disaggregation {
250
251
252
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
253
254
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
255
256
257
258
259
260
261
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

262
        let policy = convert_policy(&self.policy);
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286

        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

287
288
289
290
291
292
293
294
295
296
297
298
299
300
        let history_backend = match self.history_backend {
            HistoryBackendType::Memory => config::HistoryBackend::Memory,
            HistoryBackendType::None => config::HistoryBackend::None,
            HistoryBackendType::Oracle => config::HistoryBackend::Oracle,
        };

        let oracle = if matches!(self.history_backend, HistoryBackendType::Oracle) {
            self.oracle_config
                .as_ref()
                .map(|cfg| cfg.to_config_oracle())
        } else {
            None
        };

301
302
303
304
305
        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
306
            connection_mode: self.connection_mode.clone(),
307
308
309
310
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
311
312
            dp_aware: self.dp_aware,
            api_key: self.api_key.clone(),
313
314
315
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
316
            log_level: self.log_level.clone(),
317
            request_id_headers: self.request_id_headers.clone(),
318
            max_concurrent_requests: self.max_concurrent_requests,
319
320
321
            queue_size: self.queue_size,
            queue_timeout_secs: self.queue_timeout_secs,
            rate_limit_tokens_per_second: self.rate_limit_tokens_per_second,
322
            cors_allowed_origins: self.cors_allowed_origins.clone(),
323
324
325
326
327
328
329
330
331
332
333
334
335
            retry: config::RetryConfig {
                max_retries: self.retry_max_retries,
                initial_backoff_ms: self.retry_initial_backoff_ms,
                max_backoff_ms: self.retry_max_backoff_ms,
                backoff_multiplier: self.retry_backoff_multiplier,
                jitter_factor: self.retry_jitter_factor,
            },
            circuit_breaker: config::CircuitBreakerConfig {
                failure_threshold: self.cb_failure_threshold,
                success_threshold: self.cb_success_threshold,
                timeout_duration_secs: self.cb_timeout_duration_secs,
                window_duration_secs: self.cb_window_duration_secs,
            },
336
337
            disable_retries: self.disable_retries,
            disable_circuit_breaker: self.disable_circuit_breaker,
338
339
340
341
342
343
344
            health_check: config::HealthCheckConfig {
                failure_threshold: self.health_failure_threshold,
                success_threshold: self.health_success_threshold,
                timeout_secs: self.health_check_timeout_secs,
                check_interval_secs: self.health_check_interval_secs,
                endpoint: self.health_check_endpoint.clone(),
            },
345
            enable_igw: self.enable_igw,
346
347
            model_path: self.model_path.clone(),
            tokenizer_path: self.tokenizer_path.clone(),
348
            chat_template: self.chat_template.clone(),
349
350
            history_backend,
            oracle,
351
352
            reasoning_parser: self.reasoning_parser.clone(),
            tool_call_parser: self.tool_call_parser.clone(),
353
354
355
356
        })
    }
}

357
358
359
#[pymethods]
impl Router {
    #[new]
360
361
362
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
363
        host = String::from("0.0.0.0"),
364
        port = 3001,
365
366
367
368
369
370
371
        worker_startup_timeout_secs = 600,
        worker_startup_check_interval = 30,
        cache_threshold = 0.3,
        balance_abs_threshold = 64,
        balance_rel_threshold = 1.5,
        eviction_interval_secs = 120,
        max_tree_size = 2usize.pow(26),
372
        max_payload_size = 512 * 1024 * 1024,
373
374
        dp_aware = false,
        api_key = None,
375
        log_dir = None,
376
        log_level = None,
377
378
379
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
380
        service_discovery_namespace = None,
381
382
383
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
384
        prometheus_port = None,
385
        prometheus_host = None,
386
387
388
        request_timeout_secs = 1800,
        request_id_headers = None,
        pd_disaggregation = false,
389
        prefill_urls = None,
390
391
        decode_urls = None,
        prefill_policy = None,
392
        decode_policy = None,
393
        max_concurrent_requests = -1,
394
        cors_allowed_origins = vec![],
395
396
397
398
399
        retry_max_retries = 5,
        retry_initial_backoff_ms = 50,
        retry_max_backoff_ms = 30_000,
        retry_backoff_multiplier = 1.5,
        retry_jitter_factor = 0.2,
400
        disable_retries = false,
401
402
403
404
        cb_failure_threshold = 10,
        cb_success_threshold = 3,
        cb_timeout_duration_secs = 60,
        cb_window_duration_secs = 120,
405
        disable_circuit_breaker = false,
406
407
408
409
410
        health_failure_threshold = 3,
        health_success_threshold = 2,
        health_check_timeout_secs = 5,
        health_check_interval_secs = 60,
        health_check_endpoint = String::from("/health"),
411
        enable_igw = false,
412
413
414
        queue_size = 100,
        queue_timeout_secs = 60,
        rate_limit_tokens_per_second = None,
415
416
        model_path = None,
        tokenizer_path = None,
417
        chat_template = None,
418
419
        reasoning_parser = None,
        tool_call_parser = None,
420
421
422
        backend = BackendType::Sglang,
        history_backend = HistoryBackendType::Memory,
        oracle_config = None,
423
    ))]
424
    #[allow(clippy::too_many_arguments)]
425
426
427
428
429
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
430
        worker_startup_timeout_secs: u64,
431
        worker_startup_check_interval: u64,
432
        cache_threshold: f32,
433
434
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
435
436
        eviction_interval_secs: u64,
        max_tree_size: usize,
437
        max_payload_size: usize,
438
439
        dp_aware: bool,
        api_key: Option<String>,
440
        log_dir: Option<String>,
441
        log_level: Option<String>,
442
443
444
445
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
446
447
448
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
449
450
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
451
        request_timeout_secs: u64,
452
        request_id_headers: Option<Vec<String>>,
453
        pd_disaggregation: bool,
454
455
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
456
457
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
458
        max_concurrent_requests: i32,
459
        cors_allowed_origins: Vec<String>,
460
461
462
463
464
465
466
467
468
469
470
        retry_max_retries: u32,
        retry_initial_backoff_ms: u64,
        retry_max_backoff_ms: u64,
        retry_backoff_multiplier: f32,
        retry_jitter_factor: f32,
        disable_retries: bool,
        cb_failure_threshold: u32,
        cb_success_threshold: u32,
        cb_timeout_duration_secs: u64,
        cb_window_duration_secs: u64,
        disable_circuit_breaker: bool,
471
472
473
474
475
        health_failure_threshold: u32,
        health_success_threshold: u32,
        health_check_timeout_secs: u64,
        health_check_interval_secs: u64,
        health_check_endpoint: String,
476
        enable_igw: bool,
477
478
        queue_size: usize,
        queue_timeout_secs: u64,
479
        rate_limit_tokens_per_second: Option<i32>,
480
481
        model_path: Option<String>,
        tokenizer_path: Option<String>,
482
        chat_template: Option<String>,
483
484
        reasoning_parser: Option<String>,
        tool_call_parser: Option<String>,
485
486
487
        backend: BackendType,
        history_backend: HistoryBackendType,
        oracle_config: Option<PyOracleConfig>,
488
    ) -> PyResult<Self> {
489
490
491
492
493
494
495
496
497
498
499
500
501
502
        let mut all_urls = worker_urls.clone();

        if let Some(ref prefill_urls) = prefill_urls {
            for (url, _) in prefill_urls {
                all_urls.push(url.clone());
            }
        }

        if let Some(ref decode_urls) = decode_urls {
            all_urls.extend(decode_urls.clone());
        }

        let connection_mode = Self::determine_connection_mode(&all_urls);

503
        Ok(Router {
504
505
506
            host,
            port,
            worker_urls,
507
            policy,
508
            worker_startup_timeout_secs,
509
            worker_startup_check_interval,
510
            cache_threshold,
511
512
            balance_abs_threshold,
            balance_rel_threshold,
513
514
            eviction_interval_secs,
            max_tree_size,
515
            max_payload_size,
516
517
            dp_aware,
            api_key,
518
            log_dir,
519
            log_level,
520
521
522
523
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
524
525
526
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
527
528
            prometheus_port,
            prometheus_host,
529
            request_timeout_secs,
530
            request_id_headers,
531
            pd_disaggregation,
532
533
            prefill_urls,
            decode_urls,
534
535
            prefill_policy,
            decode_policy,
536
537
            max_concurrent_requests,
            cors_allowed_origins,
538
539
540
541
542
543
544
545
546
547
548
            retry_max_retries,
            retry_initial_backoff_ms,
            retry_max_backoff_ms,
            retry_backoff_multiplier,
            retry_jitter_factor,
            disable_retries,
            cb_failure_threshold,
            cb_success_threshold,
            cb_timeout_duration_secs,
            cb_window_duration_secs,
            disable_circuit_breaker,
549
550
551
552
553
            health_failure_threshold,
            health_success_threshold,
            health_check_timeout_secs,
            health_check_interval_secs,
            health_check_endpoint,
554
            enable_igw,
555
556
557
            queue_size,
            queue_timeout_secs,
            rate_limit_tokens_per_second,
558
559
560
            connection_mode,
            model_path,
            tokenizer_path,
561
            chat_template,
562
563
            reasoning_parser,
            tool_call_parser,
564
565
566
            backend,
            history_backend,
            oracle_config,
567
        })
568
569
570
    }

    fn start(&self) -> PyResult<()> {
571
572
573
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
574

575
576
577
578
579
580
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
581

582
583
584
585
586
587
588
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
589
590
591
592
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
593
594
595
596
597
            })
        } else {
            None
        };

598
599
600
601
602
603
604
605
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

606
607
608
609
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        runtime.block_on(async move {
610
611
612
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
613
                router_config,
614
                max_payload_size: self.max_payload_size,
615
                log_dir: self.log_dir.clone(),
616
                log_level: self.log_level.clone(),
617
                service_discovery_config,
618
                prometheus_config,
619
                request_timeout_secs: self.request_timeout_secs,
620
                request_id_headers: self.request_id_headers.clone(),
621
622
            })
            .await
623
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
624
        })
625
626
627
628
    }
}

#[pymodule]
629
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
630
    m.add_class::<PolicyType>()?;
631
632
633
    m.add_class::<BackendType>()?;
    m.add_class::<HistoryBackendType>()?;
    m.add_class::<PyOracleConfig>()?;
634
635
    m.add_class::<Router>()?;
    Ok(())
636
}