lib.rs 22.2 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5

6
pub mod core;
7
pub mod data_connector;
8
#[cfg(feature = "grpc-client")]
9
pub mod grpc_client;
10
pub mod mcp;
11
pub mod metrics;
12
pub mod middleware;
13
pub mod policies;
14
pub mod protocols;
15
pub mod reasoning_parser;
16
pub mod routers;
17
pub mod server;
18
pub mod service_discovery;
19
pub mod tokenizer;
20
pub mod tool_parser;
21
pub mod tree;
22
use crate::metrics::PrometheusConfig;
23

24
#[pyclass(eq)]
25
#[derive(Clone, PartialEq, Debug)]
26
27
28
pub enum PolicyType {
    Random,
    RoundRobin,
29
    CacheAware,
30
    PowerOfTwo,
31
32
}

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#[pyclass(eq)]
#[derive(Clone, PartialEq, Debug)]
pub enum BackendType {
    Sglang,
    Openai,
}

#[pyclass(eq)]
#[derive(Clone, PartialEq, Debug)]
pub enum HistoryBackendType {
    Memory,
    None,
    Oracle,
}

#[pyclass]
#[derive(Clone, PartialEq)]
pub struct PyOracleConfig {
    #[pyo3(get, set)]
    pub wallet_path: Option<String>,
    #[pyo3(get, set)]
    pub connect_descriptor: Option<String>,
    #[pyo3(get, set)]
    pub username: Option<String>,
    #[pyo3(get, set)]
    pub password: Option<String>,
    #[pyo3(get, set)]
    pub pool_min: usize,
    #[pyo3(get, set)]
    pub pool_max: usize,
    #[pyo3(get, set)]
    pub pool_timeout_secs: u64,
}

impl std::fmt::Debug for PyOracleConfig {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("PyOracleConfig")
            .field("wallet_path", &self.wallet_path)
            .field("connect_descriptor", &"<redacted>")
            .field("username", &self.username)
            .field("password", &"<redacted>")
            .field("pool_min", &self.pool_min)
            .field("pool_max", &self.pool_max)
            .field("pool_timeout_secs", &self.pool_timeout_secs)
            .finish()
    }
}

#[pymethods]
impl PyOracleConfig {
    #[new]
    #[pyo3(signature = (
        password = None,
        username = None,
        connect_descriptor = None,
        wallet_path = None,
        pool_min = 1,
        pool_max = 16,
        pool_timeout_secs = 30,
    ))]
    fn new(
        password: Option<String>,
        username: Option<String>,
        connect_descriptor: Option<String>,
        wallet_path: Option<String>,
        pool_min: usize,
        pool_max: usize,
        pool_timeout_secs: u64,
    ) -> PyResult<Self> {
        if pool_min == 0 {
            return Err(pyo3::exceptions::PyValueError::new_err(
                "pool_min must be at least 1",
            ));
        }
        if pool_max < pool_min {
            return Err(pyo3::exceptions::PyValueError::new_err(
                "pool_max must be >= pool_min",
            ));
        }

        Ok(PyOracleConfig {
            wallet_path,
            connect_descriptor,
            username,
            password,
            pool_min,
            pool_max,
            pool_timeout_secs,
        })
    }
}

impl PyOracleConfig {
    fn to_config_oracle(&self) -> config::OracleConfig {
        // Simple conversion - validation happens later in validate_oracle()
        config::OracleConfig {
            wallet_path: self.wallet_path.clone(),
            connect_descriptor: self.connect_descriptor.clone().unwrap_or_default(),
            username: self.username.clone().unwrap_or_default(),
            password: self.password.clone().unwrap_or_default(),
            pool_min: self.pool_min,
            pool_max: self.pool_max,
            pool_timeout_secs: self.pool_timeout_secs,
        }
    }
}

140
#[pyclass]
141
#[derive(Debug, Clone, PartialEq)]
142
143
144
145
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
146
    policy: PolicyType,
147
    worker_startup_timeout_secs: u64,
148
    worker_startup_check_interval: u64,
149
    cache_threshold: f32,
150
151
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
152
153
    eviction_interval_secs: u64,
    max_tree_size: usize,
154
    max_payload_size: usize,
155
156
    dp_aware: bool,
    api_key: Option<String>,
157
    log_dir: Option<String>,
158
    log_level: Option<String>,
159
160
161
162
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
163
164
165
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
166
167
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
168
    request_timeout_secs: u64,
169
    request_id_headers: Option<Vec<String>>,
170
    pd_disaggregation: bool,
171
172
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
173
174
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
175
    max_concurrent_requests: i32,
176
    cors_allowed_origins: Vec<String>,
177
178
179
180
181
182
183
184
185
186
187
    retry_max_retries: u32,
    retry_initial_backoff_ms: u64,
    retry_max_backoff_ms: u64,
    retry_backoff_multiplier: f32,
    retry_jitter_factor: f32,
    disable_retries: bool,
    cb_failure_threshold: u32,
    cb_success_threshold: u32,
    cb_timeout_duration_secs: u64,
    cb_window_duration_secs: u64,
    disable_circuit_breaker: bool,
188
189
190
191
192
    health_failure_threshold: u32,
    health_success_threshold: u32,
    health_check_timeout_secs: u64,
    health_check_interval_secs: u64,
    health_check_endpoint: String,
193
    enable_igw: bool,
194
195
    queue_size: usize,
    queue_timeout_secs: u64,
196
    rate_limit_tokens_per_second: Option<i32>,
197
198
199
    connection_mode: config::ConnectionMode,
    model_path: Option<String>,
    tokenizer_path: Option<String>,
200
    chat_template: Option<String>,
201
202
203
204
    tokenizer_cache_enable_l0: bool,
    tokenizer_cache_l0_max_entries: usize,
    tokenizer_cache_enable_l1: bool,
    tokenizer_cache_l1_max_memory: usize,
205
206
    reasoning_parser: Option<String>,
    tool_call_parser: Option<String>,
207
208
209
    backend: BackendType,
    history_backend: HistoryBackendType,
    oracle_config: Option<PyOracleConfig>,
210
211
}

212
impl Router {
213
214
215
216
217
218
219
220
221
222
    /// Determine connection mode from worker URLs
    fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode {
        for url in worker_urls {
            if url.starts_with("grpc://") || url.starts_with("grpcs://") {
                return config::ConnectionMode::Grpc;
            }
        }
        config::ConnectionMode::Http
    }

223
224
225
226
227
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

228
229
230
231
232
233
234
235
236
237
238
239
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
240
                    load_check_interval_secs: 5,
241
242
243
244
                },
            }
        };

245
246
247
248
        let mode = if self.enable_igw {
            RoutingMode::Regular {
                worker_urls: vec![],
            }
249
250
251
252
        } else if matches!(self.backend, BackendType::Openai) {
            RoutingMode::OpenAI {
                worker_urls: self.worker_urls.clone(),
            }
253
        } else if self.pd_disaggregation {
254
255
256
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
257
258
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
259
260
261
262
263
264
265
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

266
        let policy = convert_policy(&self.policy);
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290

        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

291
292
293
294
295
296
297
298
299
300
301
302
303
304
        let history_backend = match self.history_backend {
            HistoryBackendType::Memory => config::HistoryBackend::Memory,
            HistoryBackendType::None => config::HistoryBackend::None,
            HistoryBackendType::Oracle => config::HistoryBackend::Oracle,
        };

        let oracle = if matches!(self.history_backend, HistoryBackendType::Oracle) {
            self.oracle_config
                .as_ref()
                .map(|cfg| cfg.to_config_oracle())
        } else {
            None
        };

305
306
307
308
309
        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
310
            connection_mode: self.connection_mode.clone(),
311
312
313
314
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
315
316
            dp_aware: self.dp_aware,
            api_key: self.api_key.clone(),
317
318
319
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
320
            log_level: self.log_level.clone(),
321
            request_id_headers: self.request_id_headers.clone(),
322
            max_concurrent_requests: self.max_concurrent_requests,
323
324
325
            queue_size: self.queue_size,
            queue_timeout_secs: self.queue_timeout_secs,
            rate_limit_tokens_per_second: self.rate_limit_tokens_per_second,
326
            cors_allowed_origins: self.cors_allowed_origins.clone(),
327
328
329
330
331
332
333
334
335
336
337
338
339
            retry: config::RetryConfig {
                max_retries: self.retry_max_retries,
                initial_backoff_ms: self.retry_initial_backoff_ms,
                max_backoff_ms: self.retry_max_backoff_ms,
                backoff_multiplier: self.retry_backoff_multiplier,
                jitter_factor: self.retry_jitter_factor,
            },
            circuit_breaker: config::CircuitBreakerConfig {
                failure_threshold: self.cb_failure_threshold,
                success_threshold: self.cb_success_threshold,
                timeout_duration_secs: self.cb_timeout_duration_secs,
                window_duration_secs: self.cb_window_duration_secs,
            },
340
341
            disable_retries: self.disable_retries,
            disable_circuit_breaker: self.disable_circuit_breaker,
342
343
344
345
346
347
348
            health_check: config::HealthCheckConfig {
                failure_threshold: self.health_failure_threshold,
                success_threshold: self.health_success_threshold,
                timeout_secs: self.health_check_timeout_secs,
                check_interval_secs: self.health_check_interval_secs,
                endpoint: self.health_check_endpoint.clone(),
            },
349
            enable_igw: self.enable_igw,
350
351
            model_path: self.model_path.clone(),
            tokenizer_path: self.tokenizer_path.clone(),
352
            chat_template: self.chat_template.clone(),
353
354
            history_backend,
            oracle,
355
356
            reasoning_parser: self.reasoning_parser.clone(),
            tool_call_parser: self.tool_call_parser.clone(),
357
358
359
360
361
362
            tokenizer_cache: config::TokenizerCacheConfig {
                enable_l0: self.tokenizer_cache_enable_l0,
                l0_max_entries: self.tokenizer_cache_l0_max_entries,
                enable_l1: self.tokenizer_cache_enable_l1,
                l1_max_memory: self.tokenizer_cache_l1_max_memory,
            },
363
364
365
366
        })
    }
}

367
368
369
#[pymethods]
impl Router {
    #[new]
370
371
372
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
373
        host = String::from("0.0.0.0"),
374
        port = 3001,
375
376
377
378
379
380
381
        worker_startup_timeout_secs = 600,
        worker_startup_check_interval = 30,
        cache_threshold = 0.3,
        balance_abs_threshold = 64,
        balance_rel_threshold = 1.5,
        eviction_interval_secs = 120,
        max_tree_size = 2usize.pow(26),
382
        max_payload_size = 512 * 1024 * 1024,
383
384
        dp_aware = false,
        api_key = None,
385
        log_dir = None,
386
        log_level = None,
387
388
389
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
390
        service_discovery_namespace = None,
391
392
393
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
394
        prometheus_port = None,
395
        prometheus_host = None,
396
397
398
        request_timeout_secs = 1800,
        request_id_headers = None,
        pd_disaggregation = false,
399
        prefill_urls = None,
400
401
        decode_urls = None,
        prefill_policy = None,
402
        decode_policy = None,
403
        max_concurrent_requests = -1,
404
        cors_allowed_origins = vec![],
405
406
407
408
409
        retry_max_retries = 5,
        retry_initial_backoff_ms = 50,
        retry_max_backoff_ms = 30_000,
        retry_backoff_multiplier = 1.5,
        retry_jitter_factor = 0.2,
410
        disable_retries = false,
411
412
413
414
        cb_failure_threshold = 10,
        cb_success_threshold = 3,
        cb_timeout_duration_secs = 60,
        cb_window_duration_secs = 120,
415
        disable_circuit_breaker = false,
416
417
418
419
420
        health_failure_threshold = 3,
        health_success_threshold = 2,
        health_check_timeout_secs = 5,
        health_check_interval_secs = 60,
        health_check_endpoint = String::from("/health"),
421
        enable_igw = false,
422
423
424
        queue_size = 100,
        queue_timeout_secs = 60,
        rate_limit_tokens_per_second = None,
425
426
        model_path = None,
        tokenizer_path = None,
427
        chat_template = None,
428
429
430
431
        tokenizer_cache_enable_l0 = false,
        tokenizer_cache_l0_max_entries = 10000,
        tokenizer_cache_enable_l1 = false,
        tokenizer_cache_l1_max_memory = 52428800,
432
433
        reasoning_parser = None,
        tool_call_parser = None,
434
435
436
        backend = BackendType::Sglang,
        history_backend = HistoryBackendType::Memory,
        oracle_config = None,
437
    ))]
438
    #[allow(clippy::too_many_arguments)]
439
440
441
442
443
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
444
        worker_startup_timeout_secs: u64,
445
        worker_startup_check_interval: u64,
446
        cache_threshold: f32,
447
448
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
449
450
        eviction_interval_secs: u64,
        max_tree_size: usize,
451
        max_payload_size: usize,
452
453
        dp_aware: bool,
        api_key: Option<String>,
454
        log_dir: Option<String>,
455
        log_level: Option<String>,
456
457
458
459
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
460
461
462
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
463
464
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
465
        request_timeout_secs: u64,
466
        request_id_headers: Option<Vec<String>>,
467
        pd_disaggregation: bool,
468
469
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
470
471
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
472
        max_concurrent_requests: i32,
473
        cors_allowed_origins: Vec<String>,
474
475
476
477
478
479
480
481
482
483
484
        retry_max_retries: u32,
        retry_initial_backoff_ms: u64,
        retry_max_backoff_ms: u64,
        retry_backoff_multiplier: f32,
        retry_jitter_factor: f32,
        disable_retries: bool,
        cb_failure_threshold: u32,
        cb_success_threshold: u32,
        cb_timeout_duration_secs: u64,
        cb_window_duration_secs: u64,
        disable_circuit_breaker: bool,
485
486
487
488
489
        health_failure_threshold: u32,
        health_success_threshold: u32,
        health_check_timeout_secs: u64,
        health_check_interval_secs: u64,
        health_check_endpoint: String,
490
        enable_igw: bool,
491
492
        queue_size: usize,
        queue_timeout_secs: u64,
493
        rate_limit_tokens_per_second: Option<i32>,
494
495
        model_path: Option<String>,
        tokenizer_path: Option<String>,
496
        chat_template: Option<String>,
497
498
499
500
        tokenizer_cache_enable_l0: bool,
        tokenizer_cache_l0_max_entries: usize,
        tokenizer_cache_enable_l1: bool,
        tokenizer_cache_l1_max_memory: usize,
501
502
        reasoning_parser: Option<String>,
        tool_call_parser: Option<String>,
503
504
505
        backend: BackendType,
        history_backend: HistoryBackendType,
        oracle_config: Option<PyOracleConfig>,
506
    ) -> PyResult<Self> {
507
508
509
510
511
512
513
514
515
516
517
518
519
520
        let mut all_urls = worker_urls.clone();

        if let Some(ref prefill_urls) = prefill_urls {
            for (url, _) in prefill_urls {
                all_urls.push(url.clone());
            }
        }

        if let Some(ref decode_urls) = decode_urls {
            all_urls.extend(decode_urls.clone());
        }

        let connection_mode = Self::determine_connection_mode(&all_urls);

521
        Ok(Router {
522
523
524
            host,
            port,
            worker_urls,
525
            policy,
526
            worker_startup_timeout_secs,
527
            worker_startup_check_interval,
528
            cache_threshold,
529
530
            balance_abs_threshold,
            balance_rel_threshold,
531
532
            eviction_interval_secs,
            max_tree_size,
533
            max_payload_size,
534
535
            dp_aware,
            api_key,
536
            log_dir,
537
            log_level,
538
539
540
541
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
542
543
544
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
545
546
            prometheus_port,
            prometheus_host,
547
            request_timeout_secs,
548
            request_id_headers,
549
            pd_disaggregation,
550
551
            prefill_urls,
            decode_urls,
552
553
            prefill_policy,
            decode_policy,
554
555
            max_concurrent_requests,
            cors_allowed_origins,
556
557
558
559
560
561
562
563
564
565
566
            retry_max_retries,
            retry_initial_backoff_ms,
            retry_max_backoff_ms,
            retry_backoff_multiplier,
            retry_jitter_factor,
            disable_retries,
            cb_failure_threshold,
            cb_success_threshold,
            cb_timeout_duration_secs,
            cb_window_duration_secs,
            disable_circuit_breaker,
567
568
569
570
571
            health_failure_threshold,
            health_success_threshold,
            health_check_timeout_secs,
            health_check_interval_secs,
            health_check_endpoint,
572
            enable_igw,
573
574
575
            queue_size,
            queue_timeout_secs,
            rate_limit_tokens_per_second,
576
577
578
            connection_mode,
            model_path,
            tokenizer_path,
579
            chat_template,
580
581
582
583
            tokenizer_cache_enable_l0,
            tokenizer_cache_l0_max_entries,
            tokenizer_cache_enable_l1,
            tokenizer_cache_l1_max_memory,
584
585
            reasoning_parser,
            tool_call_parser,
586
587
588
            backend,
            history_backend,
            oracle_config,
589
        })
590
591
592
    }

    fn start(&self) -> PyResult<()> {
593
594
595
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
596

597
598
599
600
601
602
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
603

604
605
606
607
608
609
610
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
611
612
613
614
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
615
616
617
618
619
            })
        } else {
            None
        };

620
621
622
623
624
625
626
627
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

628
629
630
631
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        runtime.block_on(async move {
632
633
634
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
635
                router_config,
636
                max_payload_size: self.max_payload_size,
637
                log_dir: self.log_dir.clone(),
638
                log_level: self.log_level.clone(),
639
                service_discovery_config,
640
                prometheus_config,
641
                request_timeout_secs: self.request_timeout_secs,
642
                request_id_headers: self.request_id_headers.clone(),
643
644
            })
            .await
645
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
646
        })
647
648
649
650
    }
}

#[pymodule]
651
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
652
    m.add_class::<PolicyType>()?;
653
654
655
    m.add_class::<BackendType>()?;
    m.add_class::<HistoryBackendType>()?;
    m.add_class::<PyOracleConfig>()?;
656
657
    m.add_class::<Router>()?;
    Ok(())
658
}