lib.rs 5.7 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod logging;
3
use std::collections::HashMap;
4
pub mod prometheus;
5
pub mod router;
6
pub mod server;
7
pub mod service_discovery;
8
pub mod tree;
9
use crate::prometheus::PrometheusConfig;
10

11
#[pyclass(eq)]
12
#[derive(Clone, PartialEq, Debug)]
13
14
15
pub enum PolicyType {
    Random,
    RoundRobin,
16
    CacheAware,
17
18
}

19
#[pyclass]
20
#[derive(Debug, Clone, PartialEq)]
21
22
23
24
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
25
    policy: PolicyType,
26
    worker_startup_timeout_secs: u64,
27
    worker_startup_check_interval: u64,
28
    cache_threshold: f32,
29
30
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
31
32
    eviction_interval_secs: u64,
    max_tree_size: usize,
33
    max_payload_size: usize,
34
    verbose: bool,
35
    log_dir: Option<String>,
36
37
38
39
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
40
41
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
42
43
44
45
46
}

#[pymethods]
impl Router {
    #[new]
47
48
49
50
51
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
52
        worker_startup_timeout_secs = 300,
53
        worker_startup_check_interval = 10,
54
        cache_threshold = 0.50,
55
56
        balance_abs_threshold = 32,
        balance_rel_threshold = 1.0001,
57
        eviction_interval_secs = 60,
58
        max_tree_size = 2usize.pow(24),
59
        max_payload_size = 4 * 1024 * 1024,
60
61
        verbose = false,
        log_dir = None,
62
63
64
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
65
66
67
        service_discovery_namespace = None,
        prometheus_port = None,
        prometheus_host = None
68
69
70
71
72
73
    ))]
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
74
        worker_startup_timeout_secs: u64,
75
        worker_startup_check_interval: u64,
76
        cache_threshold: f32,
77
78
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
79
80
        eviction_interval_secs: u64,
        max_tree_size: usize,
81
        max_payload_size: usize,
82
        verbose: bool,
83
        log_dir: Option<String>,
84
85
86
87
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
88
89
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
90
91
    ) -> PyResult<Self> {
        Ok(Router {
92
93
94
            host,
            port,
            worker_urls,
95
            policy,
96
            worker_startup_timeout_secs,
97
            worker_startup_check_interval,
98
            cache_threshold,
99
100
            balance_abs_threshold,
            balance_rel_threshold,
101
102
            eviction_interval_secs,
            max_tree_size,
103
            max_payload_size,
104
            verbose,
105
            log_dir,
106
107
108
109
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
110
111
            prometheus_port,
            prometheus_host,
112
        })
113
114
115
    }

    fn start(&self) -> PyResult<()> {
116
        let policy_config = match &self.policy {
117
118
            PolicyType::Random => router::PolicyConfig::RandomConfig {
                timeout_secs: self.worker_startup_timeout_secs,
119
                interval_secs: self.worker_startup_check_interval,
120
121
122
            },
            PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
                timeout_secs: self.worker_startup_timeout_secs,
123
                interval_secs: self.worker_startup_check_interval,
124
            },
125
            PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
126
                timeout_secs: self.worker_startup_timeout_secs,
127
                interval_secs: self.worker_startup_check_interval,
128
                cache_threshold: self.cache_threshold,
129
130
                balance_abs_threshold: self.balance_abs_threshold,
                balance_rel_threshold: self.balance_rel_threshold,
131
132
                eviction_interval_secs: self.eviction_interval_secs,
                max_tree_size: self.max_tree_size,
133
134
            },
        };
135

136
137
138
139
140
141
142
143
144
145
146
147
148
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
            })
        } else {
            None
        };

149
150
151
152
153
154
155
156
157
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

158
        actix_web::rt::System::new().block_on(async move {
159
160
161
162
163
164
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
                worker_urls: self.worker_urls.clone(),
                policy_config,
                verbose: self.verbose,
165
                max_payload_size: self.max_payload_size,
166
                log_dir: self.log_dir.clone(),
167
                service_discovery_config,
168
                prometheus_config,
169
170
            })
            .await
171
172
173
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
            Ok(())
        })
174
175
176
177
    }
}

#[pymodule]
178
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
179
    m.add_class::<PolicyType>()?;
180
181
    m.add_class::<Router>()?;
    Ok(())
182
}