lib.rs 9.99 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5
pub mod core;
6
pub mod openai_api_types;
7
pub mod policies;
8
pub mod prometheus;
9
pub mod routers;
10
pub mod server;
11
pub mod service_discovery;
12
pub mod tree;
13
use crate::prometheus::PrometheusConfig;
14

15
#[pyclass(eq)]
16
#[derive(Clone, PartialEq, Debug)]
17
18
19
pub enum PolicyType {
    Random,
    RoundRobin,
20
    CacheAware,
21
    PowerOfTwo, // Moved from PD-specific, now shared
22
23
}

24
#[pyclass]
25
#[derive(Debug, Clone, PartialEq)]
26
27
28
29
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
30
    policy: PolicyType,
31
    worker_startup_timeout_secs: u64,
32
    worker_startup_check_interval: u64,
33
    cache_threshold: f32,
34
35
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
36
37
    eviction_interval_secs: u64,
    max_tree_size: usize,
38
    max_payload_size: usize,
39
    log_dir: Option<String>,
40
    log_level: Option<String>,
41
42
43
44
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
45
46
47
48
    // PD service discovery fields
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
49
50
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
51
52
    request_timeout_secs: u64,
    // PD mode flag
53
54
    pd_disaggregation: bool,
    // PD-specific fields (only used when pd_disaggregation is true)
55
56
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
57
58
}

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
impl Router {
    /// Convert PyO3 Router to RouterConfig
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

        // Determine routing mode
        let mode = if self.pd_disaggregation {
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

        // Convert policy
        let policy = match self.policy {
            PolicyType::Random => ConfigPolicyConfig::Random,
            PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
            PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                cache_threshold: self.cache_threshold,
                balance_abs_threshold: self.balance_abs_threshold,
                balance_rel_threshold: self.balance_rel_threshold,
                eviction_interval_secs: self.eviction_interval_secs,
                max_tree_size: self.max_tree_size,
            },
            PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
                load_check_interval_secs: 5, // Default value
            },
        };

        // Service discovery configuration
        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        // Metrics configuration
        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
131
            log_level: self.log_level.clone(),
132
133
134
135
        })
    }
}

136
137
138
#[pymethods]
impl Router {
    #[new]
139
140
141
142
143
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
144
        worker_startup_timeout_secs = 300,
145
        worker_startup_check_interval = 10,
146
        cache_threshold = 0.50,
147
148
        balance_abs_threshold = 32,
        balance_rel_threshold = 1.0001,
149
        eviction_interval_secs = 60,
150
        max_tree_size = 2usize.pow(24),
151
        max_payload_size = 256 * 1024 * 1024,  // 256MB default for large batches
152
        log_dir = None,
153
        log_level = None,
154
155
156
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
157
        service_discovery_namespace = None,
158
159
160
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
161
        prometheus_port = None,
162
163
        prometheus_host = None,
        request_timeout_secs = 600,  // Add configurable request timeout
164
        pd_disaggregation = false,  // New flag for PD mode
165
166
        prefill_urls = None,
        decode_urls = None
167
168
169
170
171
172
    ))]
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
173
        worker_startup_timeout_secs: u64,
174
        worker_startup_check_interval: u64,
175
        cache_threshold: f32,
176
177
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
178
179
        eviction_interval_secs: u64,
        max_tree_size: usize,
180
        max_payload_size: usize,
181
        log_dir: Option<String>,
182
        log_level: Option<String>,
183
184
185
186
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
187
188
189
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
190
191
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
192
        request_timeout_secs: u64,
193
        pd_disaggregation: bool,
194
195
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
196
197
    ) -> PyResult<Self> {
        Ok(Router {
198
199
200
            host,
            port,
            worker_urls,
201
            policy,
202
            worker_startup_timeout_secs,
203
            worker_startup_check_interval,
204
            cache_threshold,
205
206
            balance_abs_threshold,
            balance_rel_threshold,
207
208
            eviction_interval_secs,
            max_tree_size,
209
            max_payload_size,
210
            log_dir,
211
            log_level,
212
213
214
215
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
216
217
218
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
219
220
            prometheus_port,
            prometheus_host,
221
            request_timeout_secs,
222
            pd_disaggregation,
223
224
            prefill_urls,
            decode_urls,
225
        })
226
227
228
    }

    fn start(&self) -> PyResult<()> {
229
230
231
232
        // Convert to RouterConfig and validate
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
233

234
235
236
237
238
239
240
        // Validate the configuration
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
241

242
243
244
245
246
247
248
249
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
250
251
252
253
254
                // PD mode configuration
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
255
256
257
258
259
            })
        } else {
            None
        };

260
261
262
263
264
265
266
267
268
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

269
270
271
272
273
274
        // Use tokio runtime instead of actix-web System for better compatibility
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        // Block on the async startup function
        runtime.block_on(async move {
275
276
277
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
278
                router_config,
279
                max_payload_size: self.max_payload_size,
280
                log_dir: self.log_dir.clone(),
281
                log_level: self.log_level.clone(),
282
                service_discovery_config,
283
                prometheus_config,
284
                request_timeout_secs: self.request_timeout_secs,
285
286
            })
            .await
287
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
288
        })
289
290
291
292
    }
}

#[pymodule]
293
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
294
    m.add_class::<PolicyType>()?;
295
296
    m.add_class::<Router>()?;
    Ok(())
297
}