lib.rs 10.6 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5
pub mod core;
6
pub mod metrics;
7
pub mod openai_api_types;
8
9
pub mod policies;
pub mod routers;
10
pub mod server;
11
pub mod service_discovery;
12
pub mod tree;
13
use crate::metrics::PrometheusConfig;
14

15
#[pyclass(eq)]
16
#[derive(Clone, PartialEq, Debug)]
17
18
19
pub enum PolicyType {
    Random,
    RoundRobin,
20
    CacheAware,
21
    PowerOfTwo, // Moved from PD-specific, now shared
22
23
}

24
#[pyclass]
25
#[derive(Debug, Clone, PartialEq)]
26
27
28
29
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
30
    policy: PolicyType,
31
    worker_startup_timeout_secs: u64,
32
    worker_startup_check_interval: u64,
33
    cache_threshold: f32,
34
35
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
36
37
    eviction_interval_secs: u64,
    max_tree_size: usize,
38
    max_payload_size: usize,
39
    log_dir: Option<String>,
40
    log_level: Option<String>,
41
42
43
44
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
45
46
47
48
    // PD service discovery fields
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
49
50
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
51
52
    request_timeout_secs: u64,
    // PD mode flag
53
54
    pd_disaggregation: bool,
    // PD-specific fields (only used when pd_disaggregation is true)
55
56
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
57
58
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
59
60
}

61
62
63
64
65
66
67
impl Router {
    /// Convert PyO3 Router to RouterConfig
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        // Convert policy helper function
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
                    load_check_interval_secs: 5, // Default value
                },
            }
        };

86
87
88
89
90
        // Determine routing mode
        let mode = if self.pd_disaggregation {
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
91
92
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
93
94
95
96
97
98
99
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

100
101
        // Convert main policy
        let policy = convert_policy(&self.policy);
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139

        // Service discovery configuration
        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        // Metrics configuration
        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
140
            log_level: self.log_level.clone(),
141
142
143
144
        })
    }
}

145
146
147
#[pymethods]
impl Router {
    #[new]
148
149
150
151
152
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
153
        worker_startup_timeout_secs = 300,
154
        worker_startup_check_interval = 10,
155
        cache_threshold = 0.50,
156
157
        balance_abs_threshold = 32,
        balance_rel_threshold = 1.0001,
158
        eviction_interval_secs = 60,
159
        max_tree_size = 2usize.pow(24),
160
        max_payload_size = 256 * 1024 * 1024,  // 256MB default for large batches
161
        log_dir = None,
162
        log_level = None,
163
164
165
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
166
        service_discovery_namespace = None,
167
168
169
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
170
        prometheus_port = None,
171
172
        prometheus_host = None,
        request_timeout_secs = 600,  // Add configurable request timeout
173
        pd_disaggregation = false,  // New flag for PD mode
174
        prefill_urls = None,
175
176
177
        decode_urls = None,
        prefill_policy = None,
        decode_policy = None
178
179
180
181
182
183
    ))]
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
184
        worker_startup_timeout_secs: u64,
185
        worker_startup_check_interval: u64,
186
        cache_threshold: f32,
187
188
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
189
190
        eviction_interval_secs: u64,
        max_tree_size: usize,
191
        max_payload_size: usize,
192
        log_dir: Option<String>,
193
        log_level: Option<String>,
194
195
196
197
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
198
199
200
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
201
202
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
203
        request_timeout_secs: u64,
204
        pd_disaggregation: bool,
205
206
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
207
208
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
209
210
    ) -> PyResult<Self> {
        Ok(Router {
211
212
213
            host,
            port,
            worker_urls,
214
            policy,
215
            worker_startup_timeout_secs,
216
            worker_startup_check_interval,
217
            cache_threshold,
218
219
            balance_abs_threshold,
            balance_rel_threshold,
220
221
            eviction_interval_secs,
            max_tree_size,
222
            max_payload_size,
223
            log_dir,
224
            log_level,
225
226
227
228
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
229
230
231
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
232
233
            prometheus_port,
            prometheus_host,
234
            request_timeout_secs,
235
            pd_disaggregation,
236
237
            prefill_urls,
            decode_urls,
238
239
            prefill_policy,
            decode_policy,
240
        })
241
242
243
    }

    fn start(&self) -> PyResult<()> {
244
245
246
247
        // Convert to RouterConfig and validate
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
248

249
250
251
252
253
254
255
        // Validate the configuration
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
256

257
258
259
260
261
262
263
264
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
265
266
267
268
269
                // PD mode configuration
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
270
271
272
273
274
            })
        } else {
            None
        };

275
276
277
278
279
280
281
282
283
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

284
285
286
287
288
289
        // Use tokio runtime instead of actix-web System for better compatibility
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        // Block on the async startup function
        runtime.block_on(async move {
290
291
292
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
293
                router_config,
294
                max_payload_size: self.max_payload_size,
295
                log_dir: self.log_dir.clone(),
296
                log_level: self.log_level.clone(),
297
                service_discovery_config,
298
                prometheus_config,
299
                request_timeout_secs: self.request_timeout_secs,
300
301
            })
            .await
302
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
303
        })
304
305
306
307
    }
}

#[pymodule]
308
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
309
    m.add_class::<PolicyType>()?;
310
311
    m.add_class::<Router>()?;
    Ok(())
312
}