lib.rs 10.3 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5
pub mod core;
6
7
8
pub mod openai_api_types;
pub mod pd_router;
pub mod pd_types;
9
pub mod prometheus;
10
pub mod request_adapter;
11
pub mod router;
12
pub mod server;
13
pub mod service_discovery;
14
pub mod tree;
15
use crate::prometheus::PrometheusConfig;
16

17
#[pyclass(eq)]
18
#[derive(Clone, PartialEq, Debug)]
19
20
21
pub enum PolicyType {
    Random,
    RoundRobin,
22
    CacheAware,
23
    PowerOfTwo, // Moved from PD-specific, now shared
24
25
}

26
#[pyclass]
27
#[derive(Debug, Clone, PartialEq)]
28
29
30
31
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
32
    policy: PolicyType,
33
    worker_startup_timeout_secs: u64,
34
    worker_startup_check_interval: u64,
35
    cache_threshold: f32,
36
37
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
38
39
    eviction_interval_secs: u64,
    max_tree_size: usize,
40
    max_payload_size: usize,
41
    log_dir: Option<String>,
42
    log_level: Option<String>,
43
44
45
46
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
47
48
49
50
    // PD service discovery fields
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
51
52
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
53
54
    request_timeout_secs: u64,
    // PD mode flag
55
56
    pd_disaggregation: bool,
    // PD-specific fields (only used when pd_disaggregation is true)
57
58
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
59
60
}

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
impl Router {
    /// Convert PyO3 Router to RouterConfig
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

        // Determine routing mode
        let mode = if self.pd_disaggregation {
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

        // Convert policy
        let policy = match self.policy {
            PolicyType::Random => ConfigPolicyConfig::Random,
            PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
            PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                cache_threshold: self.cache_threshold,
                balance_abs_threshold: self.balance_abs_threshold,
                balance_rel_threshold: self.balance_rel_threshold,
                eviction_interval_secs: self.eviction_interval_secs,
                max_tree_size: self.max_tree_size,
            },
            PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
                load_check_interval_secs: 5, // Default value
            },
        };

        // Service discovery configuration
        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        // Metrics configuration
        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
133
            log_level: self.log_level.clone(),
134
135
136
137
        })
    }
}

138
139
140
#[pymethods]
impl Router {
    #[new]
141
142
143
144
145
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
146
        worker_startup_timeout_secs = 300,
147
        worker_startup_check_interval = 10,
148
        cache_threshold = 0.50,
149
150
        balance_abs_threshold = 32,
        balance_rel_threshold = 1.0001,
151
        eviction_interval_secs = 60,
152
        max_tree_size = 2usize.pow(24),
153
        max_payload_size = 256 * 1024 * 1024,  // 256MB default for large batches
154
        log_dir = None,
155
        log_level = None,
156
157
158
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
159
        service_discovery_namespace = None,
160
161
162
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
163
        prometheus_port = None,
164
165
        prometheus_host = None,
        request_timeout_secs = 600,  // Add configurable request timeout
166
        pd_disaggregation = false,  // New flag for PD mode
167
168
        prefill_urls = None,
        decode_urls = None
169
170
171
172
173
174
    ))]
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
175
        worker_startup_timeout_secs: u64,
176
        worker_startup_check_interval: u64,
177
        cache_threshold: f32,
178
179
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
180
181
        eviction_interval_secs: u64,
        max_tree_size: usize,
182
        max_payload_size: usize,
183
        log_dir: Option<String>,
184
        log_level: Option<String>,
185
186
187
188
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
189
190
191
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
192
193
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
194
        request_timeout_secs: u64,
195
        pd_disaggregation: bool,
196
197
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
198
199
    ) -> PyResult<Self> {
        Ok(Router {
200
201
202
            host,
            port,
            worker_urls,
203
            policy,
204
            worker_startup_timeout_secs,
205
            worker_startup_check_interval,
206
            cache_threshold,
207
208
            balance_abs_threshold,
            balance_rel_threshold,
209
210
            eviction_interval_secs,
            max_tree_size,
211
            max_payload_size,
212
            log_dir,
213
            log_level,
214
215
216
217
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
218
219
220
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
221
222
            prometheus_port,
            prometheus_host,
223
            request_timeout_secs,
224
            pd_disaggregation,
225
226
            prefill_urls,
            decode_urls,
227
        })
228
229
230
    }

    fn start(&self) -> PyResult<()> {
231
232
233
234
        // Convert to RouterConfig and validate
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
235

236
237
238
239
240
241
242
        // Validate the configuration
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
243

244
245
246
247
        // Convert to internal policy config
        let policy_config = router_config
            .to_routing_policy_config()
            .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
248

249
250
251
252
253
254
255
256
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
257
258
259
260
261
                // PD mode configuration
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
262
263
264
265
266
            })
        } else {
            None
        };

267
268
269
270
271
272
273
274
275
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

276
277
278
279
280
281
        // Use tokio runtime instead of actix-web System for better compatibility
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        // Block on the async startup function
        runtime.block_on(async move {
282
283
284
285
286
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
                worker_urls: self.worker_urls.clone(),
                policy_config,
287
                max_payload_size: self.max_payload_size,
288
                log_dir: self.log_dir.clone(),
289
                log_level: self.log_level.clone(),
290
                service_discovery_config,
291
                prometheus_config,
292
                request_timeout_secs: self.request_timeout_secs,
293
294
            })
            .await
295
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
296
        })
297
298
299
300
    }
}

#[pymodule]
301
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
302
    m.add_class::<PolicyType>()?;
303
304
    m.add_class::<Router>()?;
    Ok(())
305
}