lib.rs 23 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod app_context;
3
pub mod config;
4
pub mod logging;
5
use std::collections::HashMap;
6

7
pub mod core;
8
pub mod data_connector;
9
#[cfg(feature = "grpc-client")]
10
pub mod grpc_client;
11
pub mod mcp;
12
pub mod metrics;
13
pub mod middleware;
14
pub mod policies;
15
pub mod protocols;
16
pub mod reasoning_parser;
17
pub mod routers;
18
pub mod server;
19
pub mod service_discovery;
20
pub mod tokenizer;
21
pub mod tool_parser;
22
pub mod tree;
23
use crate::metrics::PrometheusConfig;
24

25
#[pyclass(eq)]
26
#[derive(Clone, PartialEq, Debug)]
27
28
29
pub enum PolicyType {
    Random,
    RoundRobin,
30
    CacheAware,
31
    PowerOfTwo,
32
33
}

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#[pyclass(eq)]
#[derive(Clone, PartialEq, Debug)]
pub enum BackendType {
    Sglang,
    Openai,
}

#[pyclass(eq)]
#[derive(Clone, PartialEq, Debug)]
pub enum HistoryBackendType {
    Memory,
    None,
    Oracle,
}

#[pyclass]
#[derive(Clone, PartialEq)]
pub struct PyOracleConfig {
    #[pyo3(get, set)]
    pub wallet_path: Option<String>,
    #[pyo3(get, set)]
    pub connect_descriptor: Option<String>,
    #[pyo3(get, set)]
    pub username: Option<String>,
    #[pyo3(get, set)]
    pub password: Option<String>,
    #[pyo3(get, set)]
    pub pool_min: usize,
    #[pyo3(get, set)]
    pub pool_max: usize,
    #[pyo3(get, set)]
    pub pool_timeout_secs: u64,
}

impl std::fmt::Debug for PyOracleConfig {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("PyOracleConfig")
            .field("wallet_path", &self.wallet_path)
            .field("connect_descriptor", &"<redacted>")
            .field("username", &self.username)
            .field("password", &"<redacted>")
            .field("pool_min", &self.pool_min)
            .field("pool_max", &self.pool_max)
            .field("pool_timeout_secs", &self.pool_timeout_secs)
            .finish()
    }
}

#[pymethods]
impl PyOracleConfig {
    #[new]
    #[pyo3(signature = (
        password = None,
        username = None,
        connect_descriptor = None,
        wallet_path = None,
        pool_min = 1,
        pool_max = 16,
        pool_timeout_secs = 30,
    ))]
    fn new(
        password: Option<String>,
        username: Option<String>,
        connect_descriptor: Option<String>,
        wallet_path: Option<String>,
        pool_min: usize,
        pool_max: usize,
        pool_timeout_secs: u64,
    ) -> PyResult<Self> {
        if pool_min == 0 {
            return Err(pyo3::exceptions::PyValueError::new_err(
                "pool_min must be at least 1",
            ));
        }
        if pool_max < pool_min {
            return Err(pyo3::exceptions::PyValueError::new_err(
                "pool_max must be >= pool_min",
            ));
        }

        Ok(PyOracleConfig {
            wallet_path,
            connect_descriptor,
            username,
            password,
            pool_min,
            pool_max,
            pool_timeout_secs,
        })
    }
}

impl PyOracleConfig {
    fn to_config_oracle(&self) -> config::OracleConfig {
        // Simple conversion - validation happens later in validate_oracle()
        config::OracleConfig {
            wallet_path: self.wallet_path.clone(),
            connect_descriptor: self.connect_descriptor.clone().unwrap_or_default(),
            username: self.username.clone().unwrap_or_default(),
            password: self.password.clone().unwrap_or_default(),
            pool_min: self.pool_min,
            pool_max: self.pool_max,
            pool_timeout_secs: self.pool_timeout_secs,
        }
    }
}

141
#[pyclass]
142
#[derive(Debug, Clone, PartialEq)]
143
144
145
146
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
147
    policy: PolicyType,
148
    worker_startup_timeout_secs: u64,
149
    worker_startup_check_interval: u64,
150
    cache_threshold: f32,
151
152
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
153
154
    eviction_interval_secs: u64,
    max_tree_size: usize,
155
    max_payload_size: usize,
156
157
    dp_aware: bool,
    api_key: Option<String>,
158
    log_dir: Option<String>,
159
    log_level: Option<String>,
160
161
162
163
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
164
165
166
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
167
168
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
169
    request_timeout_secs: u64,
170
    request_id_headers: Option<Vec<String>>,
171
    pd_disaggregation: bool,
172
173
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
174
175
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
176
    max_concurrent_requests: i32,
177
    cors_allowed_origins: Vec<String>,
178
179
180
181
182
183
184
185
186
187
188
    retry_max_retries: u32,
    retry_initial_backoff_ms: u64,
    retry_max_backoff_ms: u64,
    retry_backoff_multiplier: f32,
    retry_jitter_factor: f32,
    disable_retries: bool,
    cb_failure_threshold: u32,
    cb_success_threshold: u32,
    cb_timeout_duration_secs: u64,
    cb_window_duration_secs: u64,
    disable_circuit_breaker: bool,
189
190
191
192
193
    health_failure_threshold: u32,
    health_success_threshold: u32,
    health_check_timeout_secs: u64,
    health_check_interval_secs: u64,
    health_check_endpoint: String,
194
    enable_igw: bool,
195
196
    queue_size: usize,
    queue_timeout_secs: u64,
197
    rate_limit_tokens_per_second: Option<i32>,
198
    connection_mode: core::ConnectionMode,
199
200
    model_path: Option<String>,
    tokenizer_path: Option<String>,
201
    chat_template: Option<String>,
202
203
204
205
    tokenizer_cache_enable_l0: bool,
    tokenizer_cache_l0_max_entries: usize,
    tokenizer_cache_enable_l1: bool,
    tokenizer_cache_l1_max_memory: usize,
206
207
    reasoning_parser: Option<String>,
    tool_call_parser: Option<String>,
208
209
210
    backend: BackendType,
    history_backend: HistoryBackendType,
    oracle_config: Option<PyOracleConfig>,
211
212
213
    client_cert_path: Option<String>,
    client_key_path: Option<String>,
    ca_cert_paths: Vec<String>,
214
215
}

216
impl Router {
217
    /// Determine connection mode from worker URLs
218
    fn determine_connection_mode(worker_urls: &[String]) -> core::ConnectionMode {
219
220
        for url in worker_urls {
            if url.starts_with("grpc://") || url.starts_with("grpcs://") {
221
                return core::ConnectionMode::Grpc { port: None };
222
223
            }
        }
224
        core::ConnectionMode::Http
225
226
    }

227
228
229
230
231
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

232
233
234
235
236
237
238
239
240
241
242
243
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
244
                    load_check_interval_secs: 5,
245
246
247
248
                },
            }
        };

249
250
251
252
        let mode = if self.enable_igw {
            RoutingMode::Regular {
                worker_urls: vec![],
            }
253
254
255
256
        } else if matches!(self.backend, BackendType::Openai) {
            RoutingMode::OpenAI {
                worker_urls: self.worker_urls.clone(),
            }
257
        } else if self.pd_disaggregation {
258
259
260
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
261
262
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
263
264
265
266
267
268
269
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

270
        let policy = convert_policy(&self.policy);
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294

        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

295
296
297
298
299
300
301
302
303
304
305
306
307
308
        let history_backend = match self.history_backend {
            HistoryBackendType::Memory => config::HistoryBackend::Memory,
            HistoryBackendType::None => config::HistoryBackend::None,
            HistoryBackendType::Oracle => config::HistoryBackend::Oracle,
        };

        let oracle = if matches!(self.history_backend, HistoryBackendType::Oracle) {
            self.oracle_config
                .as_ref()
                .map(|cfg| cfg.to_config_oracle())
        } else {
            None
        };

309
        config::RouterConfig::builder()
310
311
312
313
314
315
316
317
318
319
320
321
322
323
            .mode(mode)
            .policy(policy)
            .host(&self.host)
            .port(self.port)
            .connection_mode(self.connection_mode.clone())
            .max_payload_size(self.max_payload_size)
            .request_timeout_secs(self.request_timeout_secs)
            .worker_startup_timeout_secs(self.worker_startup_timeout_secs)
            .worker_startup_check_interval_secs(self.worker_startup_check_interval)
            .max_concurrent_requests(self.max_concurrent_requests)
            .queue_size(self.queue_size)
            .queue_timeout_secs(self.queue_timeout_secs)
            .cors_allowed_origins(self.cors_allowed_origins.clone())
            .retry_config(config::RetryConfig {
324
325
326
327
328
                max_retries: self.retry_max_retries,
                initial_backoff_ms: self.retry_initial_backoff_ms,
                max_backoff_ms: self.retry_max_backoff_ms,
                backoff_multiplier: self.retry_backoff_multiplier,
                jitter_factor: self.retry_jitter_factor,
329
330
            })
            .circuit_breaker_config(config::CircuitBreakerConfig {
331
332
333
334
                failure_threshold: self.cb_failure_threshold,
                success_threshold: self.cb_success_threshold,
                timeout_duration_secs: self.cb_timeout_duration_secs,
                window_duration_secs: self.cb_window_duration_secs,
335
336
            })
            .health_check_config(config::HealthCheckConfig {
337
338
339
340
341
                failure_threshold: self.health_failure_threshold,
                success_threshold: self.health_success_threshold,
                timeout_secs: self.health_check_timeout_secs,
                check_interval_secs: self.health_check_interval_secs,
                endpoint: self.health_check_endpoint.clone(),
342
343
            })
            .tokenizer_cache(config::TokenizerCacheConfig {
344
345
346
347
                enable_l0: self.tokenizer_cache_enable_l0,
                l0_max_entries: self.tokenizer_cache_l0_max_entries,
                enable_l1: self.tokenizer_cache_enable_l1,
                l1_max_memory: self.tokenizer_cache_l1_max_memory,
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
            })
            .history_backend(history_backend)
            .maybe_api_key(self.api_key.as_ref())
            .maybe_discovery(discovery)
            .maybe_metrics(metrics)
            .maybe_log_dir(self.log_dir.as_ref())
            .maybe_log_level(self.log_level.as_ref())
            .maybe_request_id_headers(self.request_id_headers.clone())
            .maybe_rate_limit_tokens_per_second(self.rate_limit_tokens_per_second)
            .maybe_model_path(self.model_path.as_ref())
            .maybe_tokenizer_path(self.tokenizer_path.as_ref())
            .maybe_chat_template(self.chat_template.as_ref())
            .maybe_oracle(oracle)
            .maybe_reasoning_parser(self.reasoning_parser.as_ref())
            .maybe_tool_call_parser(self.tool_call_parser.as_ref())
            .dp_aware(self.dp_aware)
            .retries(!self.disable_retries)
            .circuit_breaker(!self.disable_circuit_breaker)
366
367
368
369
370
371
372
            .igw(self.enable_igw)
            .maybe_client_cert_and_key(
                self.client_cert_path.as_ref(),
                self.client_key_path.as_ref(),
            )
            .add_ca_certificates(self.ca_cert_paths.clone())
            .build()
373
374
375
    }
}

376
377
378
#[pymethods]
impl Router {
    #[new]
379
380
381
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
382
        host = String::from("0.0.0.0"),
383
        port = 3001,
384
385
386
387
388
389
390
        worker_startup_timeout_secs = 600,
        worker_startup_check_interval = 30,
        cache_threshold = 0.3,
        balance_abs_threshold = 64,
        balance_rel_threshold = 1.5,
        eviction_interval_secs = 120,
        max_tree_size = 2usize.pow(26),
391
        max_payload_size = 512 * 1024 * 1024,
392
393
        dp_aware = false,
        api_key = None,
394
        log_dir = None,
395
        log_level = None,
396
397
398
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
399
        service_discovery_namespace = None,
400
401
402
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
403
        prometheus_port = None,
404
        prometheus_host = None,
405
406
407
        request_timeout_secs = 1800,
        request_id_headers = None,
        pd_disaggregation = false,
408
        prefill_urls = None,
409
410
        decode_urls = None,
        prefill_policy = None,
411
        decode_policy = None,
412
        max_concurrent_requests = -1,
413
        cors_allowed_origins = vec![],
414
415
416
417
418
        retry_max_retries = 5,
        retry_initial_backoff_ms = 50,
        retry_max_backoff_ms = 30_000,
        retry_backoff_multiplier = 1.5,
        retry_jitter_factor = 0.2,
419
        disable_retries = false,
420
421
422
423
        cb_failure_threshold = 10,
        cb_success_threshold = 3,
        cb_timeout_duration_secs = 60,
        cb_window_duration_secs = 120,
424
        disable_circuit_breaker = false,
425
426
427
428
429
        health_failure_threshold = 3,
        health_success_threshold = 2,
        health_check_timeout_secs = 5,
        health_check_interval_secs = 60,
        health_check_endpoint = String::from("/health"),
430
        enable_igw = false,
431
432
433
        queue_size = 100,
        queue_timeout_secs = 60,
        rate_limit_tokens_per_second = None,
434
435
        model_path = None,
        tokenizer_path = None,
436
        chat_template = None,
437
438
439
440
        tokenizer_cache_enable_l0 = false,
        tokenizer_cache_l0_max_entries = 10000,
        tokenizer_cache_enable_l1 = false,
        tokenizer_cache_l1_max_memory = 52428800,
441
442
        reasoning_parser = None,
        tool_call_parser = None,
443
444
445
        backend = BackendType::Sglang,
        history_backend = HistoryBackendType::Memory,
        oracle_config = None,
446
447
448
        client_cert_path = None,
        client_key_path = None,
        ca_cert_paths = vec![],
449
    ))]
450
    #[allow(clippy::too_many_arguments)]
451
452
453
454
455
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
456
        worker_startup_timeout_secs: u64,
457
        worker_startup_check_interval: u64,
458
        cache_threshold: f32,
459
460
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
461
462
        eviction_interval_secs: u64,
        max_tree_size: usize,
463
        max_payload_size: usize,
464
465
        dp_aware: bool,
        api_key: Option<String>,
466
        log_dir: Option<String>,
467
        log_level: Option<String>,
468
469
470
471
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
472
473
474
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
475
476
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
477
        request_timeout_secs: u64,
478
        request_id_headers: Option<Vec<String>>,
479
        pd_disaggregation: bool,
480
481
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
482
483
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
484
        max_concurrent_requests: i32,
485
        cors_allowed_origins: Vec<String>,
486
487
488
489
490
491
492
493
494
495
496
        retry_max_retries: u32,
        retry_initial_backoff_ms: u64,
        retry_max_backoff_ms: u64,
        retry_backoff_multiplier: f32,
        retry_jitter_factor: f32,
        disable_retries: bool,
        cb_failure_threshold: u32,
        cb_success_threshold: u32,
        cb_timeout_duration_secs: u64,
        cb_window_duration_secs: u64,
        disable_circuit_breaker: bool,
497
498
499
500
501
        health_failure_threshold: u32,
        health_success_threshold: u32,
        health_check_timeout_secs: u64,
        health_check_interval_secs: u64,
        health_check_endpoint: String,
502
        enable_igw: bool,
503
504
        queue_size: usize,
        queue_timeout_secs: u64,
505
        rate_limit_tokens_per_second: Option<i32>,
506
507
        model_path: Option<String>,
        tokenizer_path: Option<String>,
508
        chat_template: Option<String>,
509
510
511
512
        tokenizer_cache_enable_l0: bool,
        tokenizer_cache_l0_max_entries: usize,
        tokenizer_cache_enable_l1: bool,
        tokenizer_cache_l1_max_memory: usize,
513
514
        reasoning_parser: Option<String>,
        tool_call_parser: Option<String>,
515
516
517
        backend: BackendType,
        history_backend: HistoryBackendType,
        oracle_config: Option<PyOracleConfig>,
518
519
520
        client_cert_path: Option<String>,
        client_key_path: Option<String>,
        ca_cert_paths: Vec<String>,
521
    ) -> PyResult<Self> {
522
523
524
525
526
527
528
529
530
531
532
533
534
535
        let mut all_urls = worker_urls.clone();

        if let Some(ref prefill_urls) = prefill_urls {
            for (url, _) in prefill_urls {
                all_urls.push(url.clone());
            }
        }

        if let Some(ref decode_urls) = decode_urls {
            all_urls.extend(decode_urls.clone());
        }

        let connection_mode = Self::determine_connection_mode(&all_urls);

536
        Ok(Router {
537
538
539
            host,
            port,
            worker_urls,
540
            policy,
541
            worker_startup_timeout_secs,
542
            worker_startup_check_interval,
543
            cache_threshold,
544
545
            balance_abs_threshold,
            balance_rel_threshold,
546
547
            eviction_interval_secs,
            max_tree_size,
548
            max_payload_size,
549
550
            dp_aware,
            api_key,
551
            log_dir,
552
            log_level,
553
554
555
556
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
557
558
559
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
560
561
            prometheus_port,
            prometheus_host,
562
            request_timeout_secs,
563
            request_id_headers,
564
            pd_disaggregation,
565
566
            prefill_urls,
            decode_urls,
567
568
            prefill_policy,
            decode_policy,
569
570
            max_concurrent_requests,
            cors_allowed_origins,
571
572
573
574
575
576
577
578
579
580
581
            retry_max_retries,
            retry_initial_backoff_ms,
            retry_max_backoff_ms,
            retry_backoff_multiplier,
            retry_jitter_factor,
            disable_retries,
            cb_failure_threshold,
            cb_success_threshold,
            cb_timeout_duration_secs,
            cb_window_duration_secs,
            disable_circuit_breaker,
582
583
584
585
586
            health_failure_threshold,
            health_success_threshold,
            health_check_timeout_secs,
            health_check_interval_secs,
            health_check_endpoint,
587
            enable_igw,
588
589
590
            queue_size,
            queue_timeout_secs,
            rate_limit_tokens_per_second,
591
592
593
            connection_mode,
            model_path,
            tokenizer_path,
594
            chat_template,
595
596
597
598
            tokenizer_cache_enable_l0,
            tokenizer_cache_l0_max_entries,
            tokenizer_cache_enable_l1,
            tokenizer_cache_l1_max_memory,
599
600
            reasoning_parser,
            tool_call_parser,
601
602
603
            backend,
            history_backend,
            oracle_config,
604
605
606
            client_cert_path,
            client_key_path,
            ca_cert_paths,
607
        })
608
609
610
    }

    fn start(&self) -> PyResult<()> {
611
612
613
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
614

615
616
617
618
619
620
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
621

622
623
624
625
626
627
628
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
629
630
631
632
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
633
634
635
636
637
            })
        } else {
            None
        };

638
639
640
641
642
643
644
645
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

646
647
648
649
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        runtime.block_on(async move {
650
651
652
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
653
                router_config,
654
                max_payload_size: self.max_payload_size,
655
                log_dir: self.log_dir.clone(),
656
                log_level: self.log_level.clone(),
657
                service_discovery_config,
658
                prometheus_config,
659
                request_timeout_secs: self.request_timeout_secs,
660
                request_id_headers: self.request_id_headers.clone(),
661
662
            })
            .await
663
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
664
        })
665
666
667
668
    }
}

#[pymodule]
669
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
670
    m.add_class::<PolicyType>()?;
671
672
673
    m.add_class::<BackendType>()?;
    m.add_class::<HistoryBackendType>()?;
    m.add_class::<PyOracleConfig>()?;
674
675
    m.add_class::<Router>()?;
    Ok(())
676
}