lib.rs 22.9 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5

6
pub mod core;
7
pub mod data_connector;
8
#[cfg(feature = "grpc-client")]
9
pub mod grpc_client;
10
pub mod mcp;
11
pub mod metrics;
12
pub mod middleware;
13
pub mod policies;
14
pub mod protocols;
15
pub mod reasoning_parser;
16
pub mod routers;
17
pub mod server;
18
pub mod service_discovery;
19
pub mod tokenizer;
20
pub mod tool_parser;
21
pub mod tree;
22
use crate::metrics::PrometheusConfig;
23

24
#[pyclass(eq)]
25
#[derive(Clone, PartialEq, Debug)]
26
27
28
pub enum PolicyType {
    Random,
    RoundRobin,
29
    CacheAware,
30
    PowerOfTwo,
31
32
}

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#[pyclass(eq)]
#[derive(Clone, PartialEq, Debug)]
pub enum BackendType {
    Sglang,
    Openai,
}

#[pyclass(eq)]
#[derive(Clone, PartialEq, Debug)]
pub enum HistoryBackendType {
    Memory,
    None,
    Oracle,
}

#[pyclass]
#[derive(Clone, PartialEq)]
pub struct PyOracleConfig {
    #[pyo3(get, set)]
    pub wallet_path: Option<String>,
    #[pyo3(get, set)]
    pub connect_descriptor: Option<String>,
    #[pyo3(get, set)]
    pub username: Option<String>,
    #[pyo3(get, set)]
    pub password: Option<String>,
    #[pyo3(get, set)]
    pub pool_min: usize,
    #[pyo3(get, set)]
    pub pool_max: usize,
    #[pyo3(get, set)]
    pub pool_timeout_secs: u64,
}

impl std::fmt::Debug for PyOracleConfig {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("PyOracleConfig")
            .field("wallet_path", &self.wallet_path)
            .field("connect_descriptor", &"<redacted>")
            .field("username", &self.username)
            .field("password", &"<redacted>")
            .field("pool_min", &self.pool_min)
            .field("pool_max", &self.pool_max)
            .field("pool_timeout_secs", &self.pool_timeout_secs)
            .finish()
    }
}

#[pymethods]
impl PyOracleConfig {
    #[new]
    #[pyo3(signature = (
        password = None,
        username = None,
        connect_descriptor = None,
        wallet_path = None,
        pool_min = 1,
        pool_max = 16,
        pool_timeout_secs = 30,
    ))]
    fn new(
        password: Option<String>,
        username: Option<String>,
        connect_descriptor: Option<String>,
        wallet_path: Option<String>,
        pool_min: usize,
        pool_max: usize,
        pool_timeout_secs: u64,
    ) -> PyResult<Self> {
        if pool_min == 0 {
            return Err(pyo3::exceptions::PyValueError::new_err(
                "pool_min must be at least 1",
            ));
        }
        if pool_max < pool_min {
            return Err(pyo3::exceptions::PyValueError::new_err(
                "pool_max must be >= pool_min",
            ));
        }

        Ok(PyOracleConfig {
            wallet_path,
            connect_descriptor,
            username,
            password,
            pool_min,
            pool_max,
            pool_timeout_secs,
        })
    }
}

impl PyOracleConfig {
    fn to_config_oracle(&self) -> config::OracleConfig {
        // Simple conversion - validation happens later in validate_oracle()
        config::OracleConfig {
            wallet_path: self.wallet_path.clone(),
            connect_descriptor: self.connect_descriptor.clone().unwrap_or_default(),
            username: self.username.clone().unwrap_or_default(),
            password: self.password.clone().unwrap_or_default(),
            pool_min: self.pool_min,
            pool_max: self.pool_max,
            pool_timeout_secs: self.pool_timeout_secs,
        }
    }
}

140
#[pyclass]
141
#[derive(Debug, Clone, PartialEq)]
142
143
144
145
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
146
    policy: PolicyType,
147
    worker_startup_timeout_secs: u64,
148
    worker_startup_check_interval: u64,
149
    cache_threshold: f32,
150
151
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
152
153
    eviction_interval_secs: u64,
    max_tree_size: usize,
154
    max_payload_size: usize,
155
156
    dp_aware: bool,
    api_key: Option<String>,
157
    log_dir: Option<String>,
158
    log_level: Option<String>,
159
160
161
162
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
163
164
165
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
166
167
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
168
    request_timeout_secs: u64,
169
    request_id_headers: Option<Vec<String>>,
170
    pd_disaggregation: bool,
171
172
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
173
174
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
175
    max_concurrent_requests: i32,
176
    cors_allowed_origins: Vec<String>,
177
178
179
180
181
182
183
184
185
186
187
    retry_max_retries: u32,
    retry_initial_backoff_ms: u64,
    retry_max_backoff_ms: u64,
    retry_backoff_multiplier: f32,
    retry_jitter_factor: f32,
    disable_retries: bool,
    cb_failure_threshold: u32,
    cb_success_threshold: u32,
    cb_timeout_duration_secs: u64,
    cb_window_duration_secs: u64,
    disable_circuit_breaker: bool,
188
189
190
191
192
    health_failure_threshold: u32,
    health_success_threshold: u32,
    health_check_timeout_secs: u64,
    health_check_interval_secs: u64,
    health_check_endpoint: String,
193
    enable_igw: bool,
194
195
    queue_size: usize,
    queue_timeout_secs: u64,
196
    rate_limit_tokens_per_second: Option<i32>,
197
    connection_mode: core::ConnectionMode,
198
199
    model_path: Option<String>,
    tokenizer_path: Option<String>,
200
    chat_template: Option<String>,
201
202
203
204
    tokenizer_cache_enable_l0: bool,
    tokenizer_cache_l0_max_entries: usize,
    tokenizer_cache_enable_l1: bool,
    tokenizer_cache_l1_max_memory: usize,
205
206
    reasoning_parser: Option<String>,
    tool_call_parser: Option<String>,
207
208
209
    backend: BackendType,
    history_backend: HistoryBackendType,
    oracle_config: Option<PyOracleConfig>,
210
211
212
    client_cert_path: Option<String>,
    client_key_path: Option<String>,
    ca_cert_paths: Vec<String>,
213
214
}

215
impl Router {
216
    /// Determine connection mode from worker URLs
217
    fn determine_connection_mode(worker_urls: &[String]) -> core::ConnectionMode {
218
219
        for url in worker_urls {
            if url.starts_with("grpc://") || url.starts_with("grpcs://") {
220
                return core::ConnectionMode::Grpc { port: None };
221
222
            }
        }
223
        core::ConnectionMode::Http
224
225
    }

226
227
228
229
230
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

231
232
233
234
235
236
237
238
239
240
241
242
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
243
                    load_check_interval_secs: 5,
244
245
246
247
                },
            }
        };

248
249
250
251
        let mode = if self.enable_igw {
            RoutingMode::Regular {
                worker_urls: vec![],
            }
252
253
254
255
        } else if matches!(self.backend, BackendType::Openai) {
            RoutingMode::OpenAI {
                worker_urls: self.worker_urls.clone(),
            }
256
        } else if self.pd_disaggregation {
257
258
259
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
260
261
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
262
263
264
265
266
267
268
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

269
        let policy = convert_policy(&self.policy);
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293

        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

294
295
296
297
298
299
300
301
302
303
304
305
306
307
        let history_backend = match self.history_backend {
            HistoryBackendType::Memory => config::HistoryBackend::Memory,
            HistoryBackendType::None => config::HistoryBackend::None,
            HistoryBackendType::Oracle => config::HistoryBackend::Oracle,
        };

        let oracle = if matches!(self.history_backend, HistoryBackendType::Oracle) {
            self.oracle_config
                .as_ref()
                .map(|cfg| cfg.to_config_oracle())
        } else {
            None
        };

308
        config::RouterConfig::builder()
309
310
311
312
313
314
315
316
317
318
319
320
321
322
            .mode(mode)
            .policy(policy)
            .host(&self.host)
            .port(self.port)
            .connection_mode(self.connection_mode.clone())
            .max_payload_size(self.max_payload_size)
            .request_timeout_secs(self.request_timeout_secs)
            .worker_startup_timeout_secs(self.worker_startup_timeout_secs)
            .worker_startup_check_interval_secs(self.worker_startup_check_interval)
            .max_concurrent_requests(self.max_concurrent_requests)
            .queue_size(self.queue_size)
            .queue_timeout_secs(self.queue_timeout_secs)
            .cors_allowed_origins(self.cors_allowed_origins.clone())
            .retry_config(config::RetryConfig {
323
324
325
326
327
                max_retries: self.retry_max_retries,
                initial_backoff_ms: self.retry_initial_backoff_ms,
                max_backoff_ms: self.retry_max_backoff_ms,
                backoff_multiplier: self.retry_backoff_multiplier,
                jitter_factor: self.retry_jitter_factor,
328
329
            })
            .circuit_breaker_config(config::CircuitBreakerConfig {
330
331
332
333
                failure_threshold: self.cb_failure_threshold,
                success_threshold: self.cb_success_threshold,
                timeout_duration_secs: self.cb_timeout_duration_secs,
                window_duration_secs: self.cb_window_duration_secs,
334
335
            })
            .health_check_config(config::HealthCheckConfig {
336
337
338
339
340
                failure_threshold: self.health_failure_threshold,
                success_threshold: self.health_success_threshold,
                timeout_secs: self.health_check_timeout_secs,
                check_interval_secs: self.health_check_interval_secs,
                endpoint: self.health_check_endpoint.clone(),
341
342
            })
            .tokenizer_cache(config::TokenizerCacheConfig {
343
344
345
346
                enable_l0: self.tokenizer_cache_enable_l0,
                l0_max_entries: self.tokenizer_cache_l0_max_entries,
                enable_l1: self.tokenizer_cache_enable_l1,
                l1_max_memory: self.tokenizer_cache_l1_max_memory,
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
            })
            .history_backend(history_backend)
            .maybe_api_key(self.api_key.as_ref())
            .maybe_discovery(discovery)
            .maybe_metrics(metrics)
            .maybe_log_dir(self.log_dir.as_ref())
            .maybe_log_level(self.log_level.as_ref())
            .maybe_request_id_headers(self.request_id_headers.clone())
            .maybe_rate_limit_tokens_per_second(self.rate_limit_tokens_per_second)
            .maybe_model_path(self.model_path.as_ref())
            .maybe_tokenizer_path(self.tokenizer_path.as_ref())
            .maybe_chat_template(self.chat_template.as_ref())
            .maybe_oracle(oracle)
            .maybe_reasoning_parser(self.reasoning_parser.as_ref())
            .maybe_tool_call_parser(self.tool_call_parser.as_ref())
            .dp_aware(self.dp_aware)
            .retries(!self.disable_retries)
            .circuit_breaker(!self.disable_circuit_breaker)
365
366
367
368
369
370
371
            .igw(self.enable_igw)
            .maybe_client_cert_and_key(
                self.client_cert_path.as_ref(),
                self.client_key_path.as_ref(),
            )
            .add_ca_certificates(self.ca_cert_paths.clone())
            .build()
372
373
374
    }
}

375
376
377
#[pymethods]
impl Router {
    #[new]
378
379
380
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
381
        host = String::from("0.0.0.0"),
382
        port = 3001,
383
384
385
386
387
388
389
        worker_startup_timeout_secs = 600,
        worker_startup_check_interval = 30,
        cache_threshold = 0.3,
        balance_abs_threshold = 64,
        balance_rel_threshold = 1.5,
        eviction_interval_secs = 120,
        max_tree_size = 2usize.pow(26),
390
        max_payload_size = 512 * 1024 * 1024,
391
392
        dp_aware = false,
        api_key = None,
393
        log_dir = None,
394
        log_level = None,
395
396
397
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
398
        service_discovery_namespace = None,
399
400
401
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
402
        prometheus_port = None,
403
        prometheus_host = None,
404
405
406
        request_timeout_secs = 1800,
        request_id_headers = None,
        pd_disaggregation = false,
407
        prefill_urls = None,
408
409
        decode_urls = None,
        prefill_policy = None,
410
        decode_policy = None,
411
        max_concurrent_requests = -1,
412
        cors_allowed_origins = vec![],
413
414
415
416
417
        retry_max_retries = 5,
        retry_initial_backoff_ms = 50,
        retry_max_backoff_ms = 30_000,
        retry_backoff_multiplier = 1.5,
        retry_jitter_factor = 0.2,
418
        disable_retries = false,
419
420
421
422
        cb_failure_threshold = 10,
        cb_success_threshold = 3,
        cb_timeout_duration_secs = 60,
        cb_window_duration_secs = 120,
423
        disable_circuit_breaker = false,
424
425
426
427
428
        health_failure_threshold = 3,
        health_success_threshold = 2,
        health_check_timeout_secs = 5,
        health_check_interval_secs = 60,
        health_check_endpoint = String::from("/health"),
429
        enable_igw = false,
430
431
432
        queue_size = 100,
        queue_timeout_secs = 60,
        rate_limit_tokens_per_second = None,
433
434
        model_path = None,
        tokenizer_path = None,
435
        chat_template = None,
436
437
438
439
        tokenizer_cache_enable_l0 = false,
        tokenizer_cache_l0_max_entries = 10000,
        tokenizer_cache_enable_l1 = false,
        tokenizer_cache_l1_max_memory = 52428800,
440
441
        reasoning_parser = None,
        tool_call_parser = None,
442
443
444
        backend = BackendType::Sglang,
        history_backend = HistoryBackendType::Memory,
        oracle_config = None,
445
446
447
        client_cert_path = None,
        client_key_path = None,
        ca_cert_paths = vec![],
448
    ))]
449
    #[allow(clippy::too_many_arguments)]
450
451
452
453
454
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
455
        worker_startup_timeout_secs: u64,
456
        worker_startup_check_interval: u64,
457
        cache_threshold: f32,
458
459
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
460
461
        eviction_interval_secs: u64,
        max_tree_size: usize,
462
        max_payload_size: usize,
463
464
        dp_aware: bool,
        api_key: Option<String>,
465
        log_dir: Option<String>,
466
        log_level: Option<String>,
467
468
469
470
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
471
472
473
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
474
475
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
476
        request_timeout_secs: u64,
477
        request_id_headers: Option<Vec<String>>,
478
        pd_disaggregation: bool,
479
480
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
481
482
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
483
        max_concurrent_requests: i32,
484
        cors_allowed_origins: Vec<String>,
485
486
487
488
489
490
491
492
493
494
495
        retry_max_retries: u32,
        retry_initial_backoff_ms: u64,
        retry_max_backoff_ms: u64,
        retry_backoff_multiplier: f32,
        retry_jitter_factor: f32,
        disable_retries: bool,
        cb_failure_threshold: u32,
        cb_success_threshold: u32,
        cb_timeout_duration_secs: u64,
        cb_window_duration_secs: u64,
        disable_circuit_breaker: bool,
496
497
498
499
500
        health_failure_threshold: u32,
        health_success_threshold: u32,
        health_check_timeout_secs: u64,
        health_check_interval_secs: u64,
        health_check_endpoint: String,
501
        enable_igw: bool,
502
503
        queue_size: usize,
        queue_timeout_secs: u64,
504
        rate_limit_tokens_per_second: Option<i32>,
505
506
        model_path: Option<String>,
        tokenizer_path: Option<String>,
507
        chat_template: Option<String>,
508
509
510
511
        tokenizer_cache_enable_l0: bool,
        tokenizer_cache_l0_max_entries: usize,
        tokenizer_cache_enable_l1: bool,
        tokenizer_cache_l1_max_memory: usize,
512
513
        reasoning_parser: Option<String>,
        tool_call_parser: Option<String>,
514
515
516
        backend: BackendType,
        history_backend: HistoryBackendType,
        oracle_config: Option<PyOracleConfig>,
517
518
519
        client_cert_path: Option<String>,
        client_key_path: Option<String>,
        ca_cert_paths: Vec<String>,
520
    ) -> PyResult<Self> {
521
522
523
524
525
526
527
528
529
530
531
532
533
534
        let mut all_urls = worker_urls.clone();

        if let Some(ref prefill_urls) = prefill_urls {
            for (url, _) in prefill_urls {
                all_urls.push(url.clone());
            }
        }

        if let Some(ref decode_urls) = decode_urls {
            all_urls.extend(decode_urls.clone());
        }

        let connection_mode = Self::determine_connection_mode(&all_urls);

535
        Ok(Router {
536
537
538
            host,
            port,
            worker_urls,
539
            policy,
540
            worker_startup_timeout_secs,
541
            worker_startup_check_interval,
542
            cache_threshold,
543
544
            balance_abs_threshold,
            balance_rel_threshold,
545
546
            eviction_interval_secs,
            max_tree_size,
547
            max_payload_size,
548
549
            dp_aware,
            api_key,
550
            log_dir,
551
            log_level,
552
553
554
555
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
556
557
558
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
559
560
            prometheus_port,
            prometheus_host,
561
            request_timeout_secs,
562
            request_id_headers,
563
            pd_disaggregation,
564
565
            prefill_urls,
            decode_urls,
566
567
            prefill_policy,
            decode_policy,
568
569
            max_concurrent_requests,
            cors_allowed_origins,
570
571
572
573
574
575
576
577
578
579
580
            retry_max_retries,
            retry_initial_backoff_ms,
            retry_max_backoff_ms,
            retry_backoff_multiplier,
            retry_jitter_factor,
            disable_retries,
            cb_failure_threshold,
            cb_success_threshold,
            cb_timeout_duration_secs,
            cb_window_duration_secs,
            disable_circuit_breaker,
581
582
583
584
585
            health_failure_threshold,
            health_success_threshold,
            health_check_timeout_secs,
            health_check_interval_secs,
            health_check_endpoint,
586
            enable_igw,
587
588
589
            queue_size,
            queue_timeout_secs,
            rate_limit_tokens_per_second,
590
591
592
            connection_mode,
            model_path,
            tokenizer_path,
593
            chat_template,
594
595
596
597
            tokenizer_cache_enable_l0,
            tokenizer_cache_l0_max_entries,
            tokenizer_cache_enable_l1,
            tokenizer_cache_l1_max_memory,
598
599
            reasoning_parser,
            tool_call_parser,
600
601
602
            backend,
            history_backend,
            oracle_config,
603
604
605
            client_cert_path,
            client_key_path,
            ca_cert_paths,
606
        })
607
608
609
    }

    fn start(&self) -> PyResult<()> {
610
611
612
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
613

614
615
616
617
618
619
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
620

621
622
623
624
625
626
627
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
628
629
630
631
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
632
633
634
635
636
            })
        } else {
            None
        };

637
638
639
640
641
642
643
644
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

645
646
647
648
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        runtime.block_on(async move {
649
650
651
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
652
                router_config,
653
                max_payload_size: self.max_payload_size,
654
                log_dir: self.log_dir.clone(),
655
                log_level: self.log_level.clone(),
656
                service_discovery_config,
657
                prometheus_config,
658
                request_timeout_secs: self.request_timeout_secs,
659
                request_id_headers: self.request_id_headers.clone(),
660
661
            })
            .await
662
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
663
        })
664
665
666
667
    }
}

#[pymodule]
668
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
669
    m.add_class::<PolicyType>()?;
670
671
672
    m.add_class::<BackendType>()?;
    m.add_class::<HistoryBackendType>()?;
    m.add_class::<PyOracleConfig>()?;
673
674
    m.add_class::<Router>()?;
    Ok(())
675
}