lib.rs 8.66 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod logging;
3
use std::collections::HashMap;
4
5
6
pub mod openai_api_types;
pub mod pd_router;
pub mod pd_types;
7
pub mod prometheus;
8
pub mod request_adapter;
9
pub mod router;
10
pub mod server;
11
pub mod service_discovery;
12
pub mod tree;
13
use crate::prometheus::PrometheusConfig;
14

15
#[pyclass(eq)]
16
#[derive(Clone, PartialEq, Debug)]
17
18
19
pub enum PolicyType {
    Random,
    RoundRobin,
20
    CacheAware,
21
    PowerOfTwo, // Moved from PD-specific, now shared
22
23
}

24
#[pyclass]
25
#[derive(Debug, Clone, PartialEq)]
26
27
28
29
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
30
    policy: PolicyType,
31
    worker_startup_timeout_secs: u64,
32
    worker_startup_check_interval: u64,
33
    cache_threshold: f32,
34
35
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
36
37
    eviction_interval_secs: u64,
    max_tree_size: usize,
38
    max_payload_size: usize,
39
    verbose: bool,
40
    log_dir: Option<String>,
41
42
43
44
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
45
46
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
47
48
49
50
51
52
    request_timeout_secs: u64,
    // PD mode flag
    pd_disaggregated: bool,
    // PD-specific fields (only used when pd_disaggregated is true)
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
53
54
55
56
57
}

#[pymethods]
impl Router {
    #[new]
58
59
60
61
62
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
63
        worker_startup_timeout_secs = 300,
64
        worker_startup_check_interval = 10,
65
        cache_threshold = 0.50,
66
67
        balance_abs_threshold = 32,
        balance_rel_threshold = 1.0001,
68
        eviction_interval_secs = 60,
69
        max_tree_size = 2usize.pow(24),
70
        max_payload_size = 256 * 1024 * 1024,  // 256MB default for large batches
71
72
        verbose = false,
        log_dir = None,
73
74
75
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
76
77
        service_discovery_namespace = None,
        prometheus_port = None,
78
79
80
81
82
        prometheus_host = None,
        request_timeout_secs = 600,  // Add configurable request timeout
        pd_disaggregated = false,  // New flag for PD mode
        prefill_urls = None,
        decode_urls = None
83
84
85
86
87
88
    ))]
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
89
        worker_startup_timeout_secs: u64,
90
        worker_startup_check_interval: u64,
91
        cache_threshold: f32,
92
93
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
94
95
        eviction_interval_secs: u64,
        max_tree_size: usize,
96
        max_payload_size: usize,
97
        verbose: bool,
98
        log_dir: Option<String>,
99
100
101
102
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
103
104
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
105
106
107
108
        request_timeout_secs: u64,
        pd_disaggregated: bool,
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
109
110
    ) -> PyResult<Self> {
        Ok(Router {
111
112
113
            host,
            port,
            worker_urls,
114
            policy,
115
            worker_startup_timeout_secs,
116
            worker_startup_check_interval,
117
            cache_threshold,
118
119
            balance_abs_threshold,
            balance_rel_threshold,
120
121
            eviction_interval_secs,
            max_tree_size,
122
            max_payload_size,
123
            verbose,
124
            log_dir,
125
126
127
128
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
129
130
            prometheus_port,
            prometheus_host,
131
132
133
134
            request_timeout_secs,
            pd_disaggregated,
            prefill_urls,
            decode_urls,
135
        })
136
137
138
    }

    fn start(&self) -> PyResult<()> {
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        let policy_config = if self.pd_disaggregated {
            // PD mode - map PolicyType to PDSelectionPolicy
            let pd_selection_policy = match &self.policy {
                PolicyType::Random => pd_types::PDSelectionPolicy::Random,
                PolicyType::PowerOfTwo => pd_types::PDSelectionPolicy::PowerOfTwo,
                PolicyType::CacheAware => pd_types::PDSelectionPolicy::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                },
                PolicyType::RoundRobin => {
                    return Err(pyo3::exceptions::PyValueError::new_err(
                        "RoundRobin policy is not supported in PD disaggregated mode",
                    ));
                }
            };

            let prefill_urls = self.prefill_urls.as_ref().ok_or_else(|| {
                pyo3::exceptions::PyValueError::new_err(
                    "PD disaggregated mode requires prefill_urls",
                )
            })?;
            let decode_urls = self.decode_urls.as_ref().ok_or_else(|| {
                pyo3::exceptions::PyValueError::new_err(
                    "PD disaggregated mode requires decode_urls",
                )
            })?;

            router::PolicyConfig::PrefillDecodeConfig {
                selection_policy: pd_selection_policy,
                prefill_urls: prefill_urls.clone(),
                decode_urls: decode_urls.clone(),
171
                timeout_secs: self.worker_startup_timeout_secs,
172
                interval_secs: self.worker_startup_check_interval,
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
            }
        } else {
            // Regular mode
            match &self.policy {
                PolicyType::Random => router::PolicyConfig::RandomConfig {
                    timeout_secs: self.worker_startup_timeout_secs,
                    interval_secs: self.worker_startup_check_interval,
                },
                PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
                    timeout_secs: self.worker_startup_timeout_secs,
                    interval_secs: self.worker_startup_check_interval,
                },
                PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
                    timeout_secs: self.worker_startup_timeout_secs,
                    interval_secs: self.worker_startup_check_interval,
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => {
                    return Err(pyo3::exceptions::PyValueError::new_err(
                        "PowerOfTwo policy is only supported in PD disaggregated mode",
                    ));
                }
            }
200
        };
201

202
203
204
205
206
207
208
209
210
211
212
213
214
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
            })
        } else {
            None
        };

215
216
217
218
219
220
221
222
223
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

224
        actix_web::rt::System::new().block_on(async move {
225
226
227
228
229
230
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
                worker_urls: self.worker_urls.clone(),
                policy_config,
                verbose: self.verbose,
231
                max_payload_size: self.max_payload_size,
232
                log_dir: self.log_dir.clone(),
233
                service_discovery_config,
234
                prometheus_config,
235
                request_timeout_secs: self.request_timeout_secs,
236
237
            })
            .await
238
239
240
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
            Ok(())
        })
241
242
243
244
    }
}

#[pymodule]
245
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
246
    m.add_class::<PolicyType>()?;
247
248
    m.add_class::<Router>()?;
    Ok(())
249
}