lib.rs 23.2 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod app_context;
3
pub mod config;
4
pub mod logging;
5
use std::collections::HashMap;
6

7
pub mod core;
8
pub mod data_connector;
9
#[cfg(feature = "grpc-client")]
10
pub mod grpc_client;
11
pub mod mcp;
12
pub mod metrics;
13
pub mod middleware;
14
pub mod policies;
15
pub mod protocols;
16
pub mod reasoning_parser;
17
pub mod routers;
18
pub mod server;
19
pub mod service_discovery;
20
pub mod tokenizer;
21
pub mod tool_parser;
22
pub mod tree;
23
use crate::metrics::PrometheusConfig;
24

25
#[pyclass(eq)]
26
#[derive(Clone, PartialEq, Debug)]
27
28
29
pub enum PolicyType {
    Random,
    RoundRobin,
30
    CacheAware,
31
    PowerOfTwo,
32
33
}

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#[pyclass(eq)]
#[derive(Clone, PartialEq, Debug)]
pub enum BackendType {
    Sglang,
    Openai,
}

#[pyclass(eq)]
#[derive(Clone, PartialEq, Debug)]
pub enum HistoryBackendType {
    Memory,
    None,
    Oracle,
}

#[pyclass]
#[derive(Clone, PartialEq)]
pub struct PyOracleConfig {
    #[pyo3(get, set)]
    pub wallet_path: Option<String>,
    #[pyo3(get, set)]
    pub connect_descriptor: Option<String>,
    #[pyo3(get, set)]
    pub username: Option<String>,
    #[pyo3(get, set)]
    pub password: Option<String>,
    #[pyo3(get, set)]
    pub pool_min: usize,
    #[pyo3(get, set)]
    pub pool_max: usize,
    #[pyo3(get, set)]
    pub pool_timeout_secs: u64,
}

impl std::fmt::Debug for PyOracleConfig {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("PyOracleConfig")
            .field("wallet_path", &self.wallet_path)
            .field("connect_descriptor", &"<redacted>")
            .field("username", &self.username)
            .field("password", &"<redacted>")
            .field("pool_min", &self.pool_min)
            .field("pool_max", &self.pool_max)
            .field("pool_timeout_secs", &self.pool_timeout_secs)
            .finish()
    }
}

#[pymethods]
impl PyOracleConfig {
    #[new]
    #[pyo3(signature = (
        password = None,
        username = None,
        connect_descriptor = None,
        wallet_path = None,
        pool_min = 1,
        pool_max = 16,
        pool_timeout_secs = 30,
    ))]
    fn new(
        password: Option<String>,
        username: Option<String>,
        connect_descriptor: Option<String>,
        wallet_path: Option<String>,
        pool_min: usize,
        pool_max: usize,
        pool_timeout_secs: u64,
    ) -> PyResult<Self> {
        if pool_min == 0 {
            return Err(pyo3::exceptions::PyValueError::new_err(
                "pool_min must be at least 1",
            ));
        }
        if pool_max < pool_min {
            return Err(pyo3::exceptions::PyValueError::new_err(
                "pool_max must be >= pool_min",
            ));
        }

        Ok(PyOracleConfig {
            wallet_path,
            connect_descriptor,
            username,
            password,
            pool_min,
            pool_max,
            pool_timeout_secs,
        })
    }
}

impl PyOracleConfig {
    fn to_config_oracle(&self) -> config::OracleConfig {
        // Simple conversion - validation happens later in validate_oracle()
        config::OracleConfig {
            wallet_path: self.wallet_path.clone(),
            connect_descriptor: self.connect_descriptor.clone().unwrap_or_default(),
            username: self.username.clone().unwrap_or_default(),
            password: self.password.clone().unwrap_or_default(),
            pool_min: self.pool_min,
            pool_max: self.pool_max,
            pool_timeout_secs: self.pool_timeout_secs,
        }
    }
}

141
#[pyclass]
142
#[derive(Debug, Clone, PartialEq)]
143
144
145
146
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
147
    policy: PolicyType,
148
    worker_startup_timeout_secs: u64,
149
    worker_startup_check_interval: u64,
150
    cache_threshold: f32,
151
152
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
153
154
    eviction_interval_secs: u64,
    max_tree_size: usize,
155
    max_payload_size: usize,
156
157
    dp_aware: bool,
    api_key: Option<String>,
158
    log_dir: Option<String>,
159
    log_level: Option<String>,
160
161
162
163
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
164
165
166
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
167
168
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
169
    request_timeout_secs: u64,
170
    request_id_headers: Option<Vec<String>>,
171
    pd_disaggregation: bool,
172
173
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
174
175
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
176
    max_concurrent_requests: i32,
177
    cors_allowed_origins: Vec<String>,
178
179
180
181
182
183
184
185
186
187
188
    retry_max_retries: u32,
    retry_initial_backoff_ms: u64,
    retry_max_backoff_ms: u64,
    retry_backoff_multiplier: f32,
    retry_jitter_factor: f32,
    disable_retries: bool,
    cb_failure_threshold: u32,
    cb_success_threshold: u32,
    cb_timeout_duration_secs: u64,
    cb_window_duration_secs: u64,
    disable_circuit_breaker: bool,
189
190
191
192
193
    health_failure_threshold: u32,
    health_success_threshold: u32,
    health_check_timeout_secs: u64,
    health_check_interval_secs: u64,
    health_check_endpoint: String,
194
    enable_igw: bool,
195
196
    queue_size: usize,
    queue_timeout_secs: u64,
197
    rate_limit_tokens_per_second: Option<i32>,
198
    connection_mode: core::ConnectionMode,
199
200
    model_path: Option<String>,
    tokenizer_path: Option<String>,
201
    chat_template: Option<String>,
202
203
204
205
    tokenizer_cache_enable_l0: bool,
    tokenizer_cache_l0_max_entries: usize,
    tokenizer_cache_enable_l1: bool,
    tokenizer_cache_l1_max_memory: usize,
206
207
    reasoning_parser: Option<String>,
    tool_call_parser: Option<String>,
208
    mcp_config_path: Option<String>,
209
210
211
    backend: BackendType,
    history_backend: HistoryBackendType,
    oracle_config: Option<PyOracleConfig>,
212
213
214
    client_cert_path: Option<String>,
    client_key_path: Option<String>,
    ca_cert_paths: Vec<String>,
215
216
}

217
impl Router {
218
    /// Determine connection mode from worker URLs
219
    fn determine_connection_mode(worker_urls: &[String]) -> core::ConnectionMode {
220
221
        for url in worker_urls {
            if url.starts_with("grpc://") || url.starts_with("grpcs://") {
222
                return core::ConnectionMode::Grpc { port: None };
223
224
            }
        }
225
        core::ConnectionMode::Http
226
227
    }

228
229
230
231
232
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

233
234
235
236
237
238
239
240
241
242
243
244
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
245
                    load_check_interval_secs: 5,
246
247
248
249
                },
            }
        };

250
251
252
253
        let mode = if self.enable_igw {
            RoutingMode::Regular {
                worker_urls: vec![],
            }
254
255
256
257
        } else if matches!(self.backend, BackendType::Openai) {
            RoutingMode::OpenAI {
                worker_urls: self.worker_urls.clone(),
            }
258
        } else if self.pd_disaggregation {
259
260
261
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
262
263
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
264
265
266
267
268
269
270
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

271
        let policy = convert_policy(&self.policy);
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295

        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

296
297
298
299
300
301
302
303
304
305
306
307
308
309
        let history_backend = match self.history_backend {
            HistoryBackendType::Memory => config::HistoryBackend::Memory,
            HistoryBackendType::None => config::HistoryBackend::None,
            HistoryBackendType::Oracle => config::HistoryBackend::Oracle,
        };

        let oracle = if matches!(self.history_backend, HistoryBackendType::Oracle) {
            self.oracle_config
                .as_ref()
                .map(|cfg| cfg.to_config_oracle())
        } else {
            None
        };

310
        config::RouterConfig::builder()
311
312
313
314
315
316
317
318
319
320
321
322
323
324
            .mode(mode)
            .policy(policy)
            .host(&self.host)
            .port(self.port)
            .connection_mode(self.connection_mode.clone())
            .max_payload_size(self.max_payload_size)
            .request_timeout_secs(self.request_timeout_secs)
            .worker_startup_timeout_secs(self.worker_startup_timeout_secs)
            .worker_startup_check_interval_secs(self.worker_startup_check_interval)
            .max_concurrent_requests(self.max_concurrent_requests)
            .queue_size(self.queue_size)
            .queue_timeout_secs(self.queue_timeout_secs)
            .cors_allowed_origins(self.cors_allowed_origins.clone())
            .retry_config(config::RetryConfig {
325
326
327
328
329
                max_retries: self.retry_max_retries,
                initial_backoff_ms: self.retry_initial_backoff_ms,
                max_backoff_ms: self.retry_max_backoff_ms,
                backoff_multiplier: self.retry_backoff_multiplier,
                jitter_factor: self.retry_jitter_factor,
330
331
            })
            .circuit_breaker_config(config::CircuitBreakerConfig {
332
333
334
335
                failure_threshold: self.cb_failure_threshold,
                success_threshold: self.cb_success_threshold,
                timeout_duration_secs: self.cb_timeout_duration_secs,
                window_duration_secs: self.cb_window_duration_secs,
336
337
            })
            .health_check_config(config::HealthCheckConfig {
338
339
340
341
342
                failure_threshold: self.health_failure_threshold,
                success_threshold: self.health_success_threshold,
                timeout_secs: self.health_check_timeout_secs,
                check_interval_secs: self.health_check_interval_secs,
                endpoint: self.health_check_endpoint.clone(),
343
344
            })
            .tokenizer_cache(config::TokenizerCacheConfig {
345
346
347
348
                enable_l0: self.tokenizer_cache_enable_l0,
                l0_max_entries: self.tokenizer_cache_l0_max_entries,
                enable_l1: self.tokenizer_cache_enable_l1,
                l1_max_memory: self.tokenizer_cache_l1_max_memory,
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
            })
            .history_backend(history_backend)
            .maybe_api_key(self.api_key.as_ref())
            .maybe_discovery(discovery)
            .maybe_metrics(metrics)
            .maybe_log_dir(self.log_dir.as_ref())
            .maybe_log_level(self.log_level.as_ref())
            .maybe_request_id_headers(self.request_id_headers.clone())
            .maybe_rate_limit_tokens_per_second(self.rate_limit_tokens_per_second)
            .maybe_model_path(self.model_path.as_ref())
            .maybe_tokenizer_path(self.tokenizer_path.as_ref())
            .maybe_chat_template(self.chat_template.as_ref())
            .maybe_oracle(oracle)
            .maybe_reasoning_parser(self.reasoning_parser.as_ref())
            .maybe_tool_call_parser(self.tool_call_parser.as_ref())
364
            .maybe_mcp_config_path(self.mcp_config_path.as_ref())
365
366
367
            .dp_aware(self.dp_aware)
            .retries(!self.disable_retries)
            .circuit_breaker(!self.disable_circuit_breaker)
368
369
370
371
372
373
374
            .igw(self.enable_igw)
            .maybe_client_cert_and_key(
                self.client_cert_path.as_ref(),
                self.client_key_path.as_ref(),
            )
            .add_ca_certificates(self.ca_cert_paths.clone())
            .build()
375
376
377
    }
}

378
379
380
#[pymethods]
impl Router {
    #[new]
381
382
383
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
384
        host = String::from("0.0.0.0"),
385
        port = 3001,
386
387
388
389
390
391
392
        worker_startup_timeout_secs = 600,
        worker_startup_check_interval = 30,
        cache_threshold = 0.3,
        balance_abs_threshold = 64,
        balance_rel_threshold = 1.5,
        eviction_interval_secs = 120,
        max_tree_size = 2usize.pow(26),
393
        max_payload_size = 512 * 1024 * 1024,
394
395
        dp_aware = false,
        api_key = None,
396
        log_dir = None,
397
        log_level = None,
398
399
400
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
401
        service_discovery_namespace = None,
402
403
404
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
405
        prometheus_port = None,
406
        prometheus_host = None,
407
408
409
        request_timeout_secs = 1800,
        request_id_headers = None,
        pd_disaggregation = false,
410
        prefill_urls = None,
411
412
        decode_urls = None,
        prefill_policy = None,
413
        decode_policy = None,
414
        max_concurrent_requests = -1,
415
        cors_allowed_origins = vec![],
416
417
418
419
420
        retry_max_retries = 5,
        retry_initial_backoff_ms = 50,
        retry_max_backoff_ms = 30_000,
        retry_backoff_multiplier = 1.5,
        retry_jitter_factor = 0.2,
421
        disable_retries = false,
422
423
424
425
        cb_failure_threshold = 10,
        cb_success_threshold = 3,
        cb_timeout_duration_secs = 60,
        cb_window_duration_secs = 120,
426
        disable_circuit_breaker = false,
427
428
429
430
431
        health_failure_threshold = 3,
        health_success_threshold = 2,
        health_check_timeout_secs = 5,
        health_check_interval_secs = 60,
        health_check_endpoint = String::from("/health"),
432
        enable_igw = false,
433
434
435
        queue_size = 100,
        queue_timeout_secs = 60,
        rate_limit_tokens_per_second = None,
436
437
        model_path = None,
        tokenizer_path = None,
438
        chat_template = None,
439
440
441
442
        tokenizer_cache_enable_l0 = false,
        tokenizer_cache_l0_max_entries = 10000,
        tokenizer_cache_enable_l1 = false,
        tokenizer_cache_l1_max_memory = 52428800,
443
444
        reasoning_parser = None,
        tool_call_parser = None,
445
        mcp_config_path = None,
446
447
448
        backend = BackendType::Sglang,
        history_backend = HistoryBackendType::Memory,
        oracle_config = None,
449
450
451
        client_cert_path = None,
        client_key_path = None,
        ca_cert_paths = vec![],
452
    ))]
453
    #[allow(clippy::too_many_arguments)]
454
455
456
457
458
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
459
        worker_startup_timeout_secs: u64,
460
        worker_startup_check_interval: u64,
461
        cache_threshold: f32,
462
463
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
464
465
        eviction_interval_secs: u64,
        max_tree_size: usize,
466
        max_payload_size: usize,
467
468
        dp_aware: bool,
        api_key: Option<String>,
469
        log_dir: Option<String>,
470
        log_level: Option<String>,
471
472
473
474
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
475
476
477
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
478
479
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
480
        request_timeout_secs: u64,
481
        request_id_headers: Option<Vec<String>>,
482
        pd_disaggregation: bool,
483
484
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
485
486
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
487
        max_concurrent_requests: i32,
488
        cors_allowed_origins: Vec<String>,
489
490
491
492
493
494
495
496
497
498
499
        retry_max_retries: u32,
        retry_initial_backoff_ms: u64,
        retry_max_backoff_ms: u64,
        retry_backoff_multiplier: f32,
        retry_jitter_factor: f32,
        disable_retries: bool,
        cb_failure_threshold: u32,
        cb_success_threshold: u32,
        cb_timeout_duration_secs: u64,
        cb_window_duration_secs: u64,
        disable_circuit_breaker: bool,
500
501
502
503
504
        health_failure_threshold: u32,
        health_success_threshold: u32,
        health_check_timeout_secs: u64,
        health_check_interval_secs: u64,
        health_check_endpoint: String,
505
        enable_igw: bool,
506
507
        queue_size: usize,
        queue_timeout_secs: u64,
508
        rate_limit_tokens_per_second: Option<i32>,
509
510
        model_path: Option<String>,
        tokenizer_path: Option<String>,
511
        chat_template: Option<String>,
512
513
514
515
        tokenizer_cache_enable_l0: bool,
        tokenizer_cache_l0_max_entries: usize,
        tokenizer_cache_enable_l1: bool,
        tokenizer_cache_l1_max_memory: usize,
516
517
        reasoning_parser: Option<String>,
        tool_call_parser: Option<String>,
518
        mcp_config_path: Option<String>,
519
520
521
        backend: BackendType,
        history_backend: HistoryBackendType,
        oracle_config: Option<PyOracleConfig>,
522
523
524
        client_cert_path: Option<String>,
        client_key_path: Option<String>,
        ca_cert_paths: Vec<String>,
525
    ) -> PyResult<Self> {
526
527
528
529
530
531
532
533
534
535
536
537
538
539
        let mut all_urls = worker_urls.clone();

        if let Some(ref prefill_urls) = prefill_urls {
            for (url, _) in prefill_urls {
                all_urls.push(url.clone());
            }
        }

        if let Some(ref decode_urls) = decode_urls {
            all_urls.extend(decode_urls.clone());
        }

        let connection_mode = Self::determine_connection_mode(&all_urls);

540
        Ok(Router {
541
542
543
            host,
            port,
            worker_urls,
544
            policy,
545
            worker_startup_timeout_secs,
546
            worker_startup_check_interval,
547
            cache_threshold,
548
549
            balance_abs_threshold,
            balance_rel_threshold,
550
551
            eviction_interval_secs,
            max_tree_size,
552
            max_payload_size,
553
554
            dp_aware,
            api_key,
555
            log_dir,
556
            log_level,
557
558
559
560
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
561
562
563
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
564
565
            prometheus_port,
            prometheus_host,
566
            request_timeout_secs,
567
            request_id_headers,
568
            pd_disaggregation,
569
570
            prefill_urls,
            decode_urls,
571
572
            prefill_policy,
            decode_policy,
573
574
            max_concurrent_requests,
            cors_allowed_origins,
575
576
577
578
579
580
581
582
583
584
585
            retry_max_retries,
            retry_initial_backoff_ms,
            retry_max_backoff_ms,
            retry_backoff_multiplier,
            retry_jitter_factor,
            disable_retries,
            cb_failure_threshold,
            cb_success_threshold,
            cb_timeout_duration_secs,
            cb_window_duration_secs,
            disable_circuit_breaker,
586
587
588
589
590
            health_failure_threshold,
            health_success_threshold,
            health_check_timeout_secs,
            health_check_interval_secs,
            health_check_endpoint,
591
            enable_igw,
592
593
594
            queue_size,
            queue_timeout_secs,
            rate_limit_tokens_per_second,
595
596
597
            connection_mode,
            model_path,
            tokenizer_path,
598
            chat_template,
599
600
601
602
            tokenizer_cache_enable_l0,
            tokenizer_cache_l0_max_entries,
            tokenizer_cache_enable_l1,
            tokenizer_cache_l1_max_memory,
603
604
            reasoning_parser,
            tool_call_parser,
605
            mcp_config_path,
606
607
608
            backend,
            history_backend,
            oracle_config,
609
610
611
            client_cert_path,
            client_key_path,
            ca_cert_paths,
612
        })
613
614
615
    }

    fn start(&self) -> PyResult<()> {
616
617
618
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
619

620
621
622
623
624
625
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
626

627
628
629
630
631
632
633
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
634
635
636
637
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
638
639
640
641
642
            })
        } else {
            None
        };

643
644
645
646
647
648
649
650
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

651
652
653
654
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        runtime.block_on(async move {
655
656
657
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
658
                router_config,
659
                max_payload_size: self.max_payload_size,
660
                log_dir: self.log_dir.clone(),
661
                log_level: self.log_level.clone(),
662
                service_discovery_config,
663
                prometheus_config,
664
                request_timeout_secs: self.request_timeout_secs,
665
                request_id_headers: self.request_id_headers.clone(),
666
667
            })
            .await
668
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
669
        })
670
671
672
673
    }
}

#[pymodule]
674
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
675
    m.add_class::<PolicyType>()?;
676
677
678
    m.add_class::<BackendType>()?;
    m.add_class::<HistoryBackendType>()?;
    m.add_class::<PyOracleConfig>()?;
679
680
    m.add_class::<Router>()?;
    Ok(())
681
}