lib.rs 11 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5
pub mod core;
6
pub mod metrics;
7
pub mod middleware;
8
pub mod openai_api_types;
9
10
pub mod policies;
pub mod routers;
11
pub mod server;
12
pub mod service_discovery;
13
pub mod tree;
14
use crate::metrics::PrometheusConfig;
15

16
#[pyclass(eq)]
17
#[derive(Clone, PartialEq, Debug)]
18
19
20
pub enum PolicyType {
    Random,
    RoundRobin,
21
    CacheAware,
22
    PowerOfTwo, // Moved from PD-specific, now shared
23
24
}

25
#[pyclass]
26
#[derive(Debug, Clone, PartialEq)]
27
28
29
30
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
31
    policy: PolicyType,
32
    worker_startup_timeout_secs: u64,
33
    worker_startup_check_interval: u64,
34
    cache_threshold: f32,
35
36
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
37
38
    eviction_interval_secs: u64,
    max_tree_size: usize,
39
    max_payload_size: usize,
40
    log_dir: Option<String>,
41
    log_level: Option<String>,
42
43
44
45
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
46
47
48
49
    // PD service discovery fields
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
50
51
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
52
    request_timeout_secs: u64,
53
    request_id_headers: Option<Vec<String>>,
54
    // PD mode flag
55
56
    pd_disaggregation: bool,
    // PD-specific fields (only used when pd_disaggregation is true)
57
58
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
59
60
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
61
62
}

63
64
65
66
67
68
69
impl Router {
    /// Convert PyO3 Router to RouterConfig
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        // Convert policy helper function
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
                    load_check_interval_secs: 5, // Default value
                },
            }
        };

88
89
90
91
92
        // Determine routing mode
        let mode = if self.pd_disaggregation {
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
93
94
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
95
96
97
98
99
100
101
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

102
103
        // Convert main policy
        let policy = convert_policy(&self.policy);
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141

        // Service discovery configuration
        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        // Metrics configuration
        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
142
            log_level: self.log_level.clone(),
143
            request_id_headers: self.request_id_headers.clone(),
144
145
146
147
        })
    }
}

148
149
150
#[pymethods]
impl Router {
    #[new]
151
152
153
154
155
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
156
        worker_startup_timeout_secs = 300,
157
        worker_startup_check_interval = 10,
158
        cache_threshold = 0.50,
159
160
        balance_abs_threshold = 32,
        balance_rel_threshold = 1.0001,
161
        eviction_interval_secs = 60,
162
        max_tree_size = 2usize.pow(24),
163
        max_payload_size = 256 * 1024 * 1024,  // 256MB default for large batches
164
        log_dir = None,
165
        log_level = None,
166
167
168
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
169
        service_discovery_namespace = None,
170
171
172
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
173
        prometheus_port = None,
174
175
        prometheus_host = None,
        request_timeout_secs = 600,  // Add configurable request timeout
176
        request_id_headers = None,  // Custom request ID headers
177
        pd_disaggregation = false,  // New flag for PD mode
178
        prefill_urls = None,
179
180
181
        decode_urls = None,
        prefill_policy = None,
        decode_policy = None
182
183
184
185
186
187
    ))]
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
188
        worker_startup_timeout_secs: u64,
189
        worker_startup_check_interval: u64,
190
        cache_threshold: f32,
191
192
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
193
194
        eviction_interval_secs: u64,
        max_tree_size: usize,
195
        max_payload_size: usize,
196
        log_dir: Option<String>,
197
        log_level: Option<String>,
198
199
200
201
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
202
203
204
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
205
206
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
207
        request_timeout_secs: u64,
208
        request_id_headers: Option<Vec<String>>,
209
        pd_disaggregation: bool,
210
211
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
212
213
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
214
215
    ) -> PyResult<Self> {
        Ok(Router {
216
217
218
            host,
            port,
            worker_urls,
219
            policy,
220
            worker_startup_timeout_secs,
221
            worker_startup_check_interval,
222
            cache_threshold,
223
224
            balance_abs_threshold,
            balance_rel_threshold,
225
226
            eviction_interval_secs,
            max_tree_size,
227
            max_payload_size,
228
            log_dir,
229
            log_level,
230
231
232
233
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
234
235
236
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
237
238
            prometheus_port,
            prometheus_host,
239
            request_timeout_secs,
240
            request_id_headers,
241
            pd_disaggregation,
242
243
            prefill_urls,
            decode_urls,
244
245
            prefill_policy,
            decode_policy,
246
        })
247
248
249
    }

    fn start(&self) -> PyResult<()> {
250
251
252
253
        // Convert to RouterConfig and validate
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
254

255
256
257
258
259
260
261
        // Validate the configuration
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
262

263
264
265
266
267
268
269
270
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
271
272
273
274
275
                // PD mode configuration
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
276
277
278
279
280
            })
        } else {
            None
        };

281
282
283
284
285
286
287
288
289
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

290
291
292
293
294
295
        // Use tokio runtime instead of actix-web System for better compatibility
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        // Block on the async startup function
        runtime.block_on(async move {
296
297
298
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
299
                router_config,
300
                max_payload_size: self.max_payload_size,
301
                log_dir: self.log_dir.clone(),
302
                log_level: self.log_level.clone(),
303
                service_discovery_config,
304
                prometheus_config,
305
                request_timeout_secs: self.request_timeout_secs,
306
                request_id_headers: self.request_id_headers.clone(),
307
308
            })
            .await
309
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
310
        })
311
312
313
314
    }
}

#[pymodule]
315
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
316
    m.add_class::<PolicyType>()?;
317
318
    m.add_class::<Router>()?;
    Ok(())
319
}