lib.rs 11.5 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5
pub mod core;
6
pub mod metrics;
7
pub mod middleware;
8
pub mod openai_api_types;
9
10
pub mod policies;
pub mod routers;
11
pub mod server;
12
pub mod service_discovery;
13
pub mod tree;
14
use crate::metrics::PrometheusConfig;
15

16
#[pyclass(eq)]
17
#[derive(Clone, PartialEq, Debug)]
18
19
20
pub enum PolicyType {
    Random,
    RoundRobin,
21
    CacheAware,
22
    PowerOfTwo,
23
24
}

25
#[pyclass]
26
#[derive(Debug, Clone, PartialEq)]
27
28
29
30
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
31
    policy: PolicyType,
32
    worker_startup_timeout_secs: u64,
33
    worker_startup_check_interval: u64,
34
    cache_threshold: f32,
35
36
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
37
38
    eviction_interval_secs: u64,
    max_tree_size: usize,
39
    max_payload_size: usize,
40
41
    dp_aware: bool,
    api_key: Option<String>,
42
    log_dir: Option<String>,
43
    log_level: Option<String>,
44
45
46
47
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
48
49
50
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
51
52
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
53
    request_timeout_secs: u64,
54
    request_id_headers: Option<Vec<String>>,
55
    pd_disaggregation: bool,
56
57
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
58
59
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
60
61
    max_concurrent_requests: usize,
    cors_allowed_origins: Vec<String>,
62
63
}

64
65
66
67
68
69
70
impl Router {
    /// Convert PyO3 Router to RouterConfig
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        // Convert policy helper function
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
                    load_check_interval_secs: 5, // Default value
                },
            }
        };

89
90
91
92
93
        // Determine routing mode
        let mode = if self.pd_disaggregation {
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
94
95
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
96
97
98
99
100
101
102
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

103
104
        // Convert main policy
        let policy = convert_policy(&self.policy);
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139

        // Service discovery configuration
        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        // Metrics configuration
        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
140
141
            dp_aware: self.dp_aware,
            api_key: self.api_key.clone(),
142
143
144
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
145
            log_level: self.log_level.clone(),
146
            request_id_headers: self.request_id_headers.clone(),
147
148
            max_concurrent_requests: self.max_concurrent_requests,
            cors_allowed_origins: self.cors_allowed_origins.clone(),
149
            retry: config::RetryConfig::default(),
150
151
152
153
        })
    }
}

154
155
156
#[pymethods]
impl Router {
    #[new]
157
158
159
160
161
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
162
        worker_startup_timeout_secs = 300,
163
        worker_startup_check_interval = 10,
164
        cache_threshold = 0.50,
165
166
        balance_abs_threshold = 32,
        balance_rel_threshold = 1.0001,
167
        eviction_interval_secs = 60,
168
        max_tree_size = 2usize.pow(24),
169
        max_payload_size = 256 * 1024 * 1024,  // 256MB default for large batches
170
171
        dp_aware = false,
        api_key = None,
172
        log_dir = None,
173
        log_level = None,
174
175
176
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
177
        service_discovery_namespace = None,
178
179
180
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
181
        prometheus_port = None,
182
183
        prometheus_host = None,
        request_timeout_secs = 600,  // Add configurable request timeout
184
        request_id_headers = None,  // Custom request ID headers
185
        pd_disaggregation = false,  // New flag for PD mode
186
        prefill_urls = None,
187
188
        decode_urls = None,
        prefill_policy = None,
189
190
191
        decode_policy = None,
        max_concurrent_requests = 64,
        cors_allowed_origins = vec![]
192
193
194
195
196
197
    ))]
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
198
        worker_startup_timeout_secs: u64,
199
        worker_startup_check_interval: u64,
200
        cache_threshold: f32,
201
202
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
203
204
        eviction_interval_secs: u64,
        max_tree_size: usize,
205
        max_payload_size: usize,
206
207
        dp_aware: bool,
        api_key: Option<String>,
208
        log_dir: Option<String>,
209
        log_level: Option<String>,
210
211
212
213
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
214
215
216
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
217
218
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
219
        request_timeout_secs: u64,
220
        request_id_headers: Option<Vec<String>>,
221
        pd_disaggregation: bool,
222
223
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
224
225
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
226
227
        max_concurrent_requests: usize,
        cors_allowed_origins: Vec<String>,
228
229
    ) -> PyResult<Self> {
        Ok(Router {
230
231
232
            host,
            port,
            worker_urls,
233
            policy,
234
            worker_startup_timeout_secs,
235
            worker_startup_check_interval,
236
            cache_threshold,
237
238
            balance_abs_threshold,
            balance_rel_threshold,
239
240
            eviction_interval_secs,
            max_tree_size,
241
            max_payload_size,
242
243
            dp_aware,
            api_key,
244
            log_dir,
245
            log_level,
246
247
248
249
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
250
251
252
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
253
254
            prometheus_port,
            prometheus_host,
255
            request_timeout_secs,
256
            request_id_headers,
257
            pd_disaggregation,
258
259
            prefill_urls,
            decode_urls,
260
261
            prefill_policy,
            decode_policy,
262
263
            max_concurrent_requests,
            cors_allowed_origins,
264
        })
265
266
267
    }

    fn start(&self) -> PyResult<()> {
268
269
270
271
        // Convert to RouterConfig and validate
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
272

273
274
275
276
277
278
279
        // Validate the configuration
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
280

281
282
283
284
285
286
287
288
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
289
290
291
292
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
293
294
295
296
297
            })
        } else {
            None
        };

298
299
300
301
302
303
304
305
306
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

307
308
309
310
311
312
        // Use tokio runtime instead of actix-web System for better compatibility
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        // Block on the async startup function
        runtime.block_on(async move {
313
314
315
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
316
                router_config,
317
                max_payload_size: self.max_payload_size,
318
                log_dir: self.log_dir.clone(),
319
                log_level: self.log_level.clone(),
320
                service_discovery_config,
321
                prometheus_config,
322
                request_timeout_secs: self.request_timeout_secs,
323
                request_id_headers: self.request_id_headers.clone(),
324
325
            })
            .await
326
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
327
        })
328
329
330
331
    }
}

#[pymodule]
332
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
333
    m.add_class::<PolicyType>()?;
334
335
    m.add_class::<Router>()?;
    Ok(())
336
}