lib.rs 11.2 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5
pub mod core;
6
pub mod metrics;
7
pub mod middleware;
8
pub mod openai_api_types;
9
10
pub mod policies;
pub mod routers;
11
pub mod server;
12
pub mod service_discovery;
13
pub mod tree;
14
use crate::metrics::PrometheusConfig;
15

16
#[pyclass(eq)]
17
#[derive(Clone, PartialEq, Debug)]
18
19
20
pub enum PolicyType {
    Random,
    RoundRobin,
21
    CacheAware,
22
    PowerOfTwo, // Moved from PD-specific, now shared
23
24
}

25
#[pyclass]
26
#[derive(Debug, Clone, PartialEq)]
27
28
29
30
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
31
    policy: PolicyType,
32
    worker_startup_timeout_secs: u64,
33
    worker_startup_check_interval: u64,
34
    cache_threshold: f32,
35
36
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
37
38
    eviction_interval_secs: u64,
    max_tree_size: usize,
39
    max_payload_size: usize,
40
41
    dp_aware: bool,
    api_key: Option<String>,
42
    log_dir: Option<String>,
43
    log_level: Option<String>,
44
45
46
47
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
48
49
50
51
    // PD service discovery fields
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
52
53
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
54
    request_timeout_secs: u64,
55
    request_id_headers: Option<Vec<String>>,
56
    // PD mode flag
57
58
    pd_disaggregation: bool,
    // PD-specific fields (only used when pd_disaggregation is true)
59
60
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
61
62
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
63
64
}

65
66
67
68
69
70
71
impl Router {
    /// Convert PyO3 Router to RouterConfig
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        // Convert policy helper function
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
                    load_check_interval_secs: 5, // Default value
                },
            }
        };

90
91
92
93
94
        // Determine routing mode
        let mode = if self.pd_disaggregation {
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
95
96
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
97
98
99
100
101
102
103
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

104
105
        // Convert main policy
        let policy = convert_policy(&self.policy);
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

        // Service discovery configuration
        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        // Metrics configuration
        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
141
142
            dp_aware: self.dp_aware,
            api_key: self.api_key.clone(),
143
144
145
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
146
            log_level: self.log_level.clone(),
147
            request_id_headers: self.request_id_headers.clone(),
148
149
150
151
        })
    }
}

152
153
154
#[pymethods]
impl Router {
    #[new]
155
156
157
158
159
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
160
        worker_startup_timeout_secs = 300,
161
        worker_startup_check_interval = 10,
162
        cache_threshold = 0.50,
163
164
        balance_abs_threshold = 32,
        balance_rel_threshold = 1.0001,
165
        eviction_interval_secs = 60,
166
        max_tree_size = 2usize.pow(24),
167
        max_payload_size = 256 * 1024 * 1024,  // 256MB default for large batches
168
169
        dp_aware = false,
        api_key = None,
170
        log_dir = None,
171
        log_level = None,
172
173
174
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
175
        service_discovery_namespace = None,
176
177
178
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
179
        prometheus_port = None,
180
181
        prometheus_host = None,
        request_timeout_secs = 600,  // Add configurable request timeout
182
        request_id_headers = None,  // Custom request ID headers
183
        pd_disaggregation = false,  // New flag for PD mode
184
        prefill_urls = None,
185
186
187
        decode_urls = None,
        prefill_policy = None,
        decode_policy = None
188
189
190
191
192
193
    ))]
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
194
        worker_startup_timeout_secs: u64,
195
        worker_startup_check_interval: u64,
196
        cache_threshold: f32,
197
198
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
199
200
        eviction_interval_secs: u64,
        max_tree_size: usize,
201
        max_payload_size: usize,
202
203
        dp_aware: bool,
        api_key: Option<String>,
204
        log_dir: Option<String>,
205
        log_level: Option<String>,
206
207
208
209
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
210
211
212
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
213
214
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
215
        request_timeout_secs: u64,
216
        request_id_headers: Option<Vec<String>>,
217
        pd_disaggregation: bool,
218
219
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
220
221
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
222
223
    ) -> PyResult<Self> {
        Ok(Router {
224
225
226
            host,
            port,
            worker_urls,
227
            policy,
228
            worker_startup_timeout_secs,
229
            worker_startup_check_interval,
230
            cache_threshold,
231
232
            balance_abs_threshold,
            balance_rel_threshold,
233
234
            eviction_interval_secs,
            max_tree_size,
235
            max_payload_size,
236
237
            dp_aware,
            api_key,
238
            log_dir,
239
            log_level,
240
241
242
243
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
244
245
246
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
247
248
            prometheus_port,
            prometheus_host,
249
            request_timeout_secs,
250
            request_id_headers,
251
            pd_disaggregation,
252
253
            prefill_urls,
            decode_urls,
254
255
            prefill_policy,
            decode_policy,
256
        })
257
258
259
    }

    fn start(&self) -> PyResult<()> {
260
261
262
263
        // Convert to RouterConfig and validate
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
264

265
266
267
268
269
270
271
        // Validate the configuration
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
272

273
274
275
276
277
278
279
280
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
281
282
283
284
285
                // PD mode configuration
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
286
287
288
289
290
            })
        } else {
            None
        };

291
292
293
294
295
296
297
298
299
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

300
301
302
303
304
305
        // Use tokio runtime instead of actix-web System for better compatibility
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        // Block on the async startup function
        runtime.block_on(async move {
306
307
308
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
309
                router_config,
310
                max_payload_size: self.max_payload_size,
311
                log_dir: self.log_dir.clone(),
312
                log_level: self.log_level.clone(),
313
                service_discovery_config,
314
                prometheus_config,
315
                request_timeout_secs: self.request_timeout_secs,
316
                request_id_headers: self.request_id_headers.clone(),
317
318
            })
            .await
319
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
320
        })
321
322
323
324
    }
}

#[pymodule]
325
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
326
    m.add_class::<PolicyType>()?;
327
328
    m.add_class::<Router>()?;
    Ok(())
329
}