lib.rs 10.3 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5
6
7
pub mod openai_api_types;
pub mod pd_router;
pub mod pd_types;
8
pub mod prometheus;
9
pub mod request_adapter;
10
pub mod router;
11
pub mod server;
12
pub mod service_discovery;
13
pub mod tree;
14
use crate::prometheus::PrometheusConfig;
15

16
#[pyclass(eq)]
17
#[derive(Clone, PartialEq, Debug)]
18
19
20
pub enum PolicyType {
    Random,
    RoundRobin,
21
    CacheAware,
22
    PowerOfTwo, // Moved from PD-specific, now shared
23
24
}

25
#[pyclass]
26
#[derive(Debug, Clone, PartialEq)]
27
28
29
30
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
31
    policy: PolicyType,
32
    worker_startup_timeout_secs: u64,
33
    worker_startup_check_interval: u64,
34
    cache_threshold: f32,
35
36
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
37
38
    eviction_interval_secs: u64,
    max_tree_size: usize,
39
    max_payload_size: usize,
40
    log_dir: Option<String>,
41
    log_level: Option<String>,
42
43
44
45
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
46
47
48
49
    // PD service discovery fields
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
50
51
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
52
53
    request_timeout_secs: u64,
    // PD mode flag
54
55
    pd_disaggregation: bool,
    // PD-specific fields (only used when pd_disaggregation is true)
56
57
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
58
59
}

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
impl Router {
    /// Convert PyO3 Router to RouterConfig
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

        // Determine routing mode
        let mode = if self.pd_disaggregation {
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

        // Convert policy
        let policy = match self.policy {
            PolicyType::Random => ConfigPolicyConfig::Random,
            PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
            PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                cache_threshold: self.cache_threshold,
                balance_abs_threshold: self.balance_abs_threshold,
                balance_rel_threshold: self.balance_rel_threshold,
                eviction_interval_secs: self.eviction_interval_secs,
                max_tree_size: self.max_tree_size,
            },
            PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
                load_check_interval_secs: 5, // Default value
            },
        };

        // Service discovery configuration
        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        // Metrics configuration
        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
132
            log_level: self.log_level.clone(),
133
134
135
136
        })
    }
}

137
138
139
#[pymethods]
impl Router {
    #[new]
140
141
142
143
144
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
145
        worker_startup_timeout_secs = 300,
146
        worker_startup_check_interval = 10,
147
        cache_threshold = 0.50,
148
149
        balance_abs_threshold = 32,
        balance_rel_threshold = 1.0001,
150
        eviction_interval_secs = 60,
151
        max_tree_size = 2usize.pow(24),
152
        max_payload_size = 256 * 1024 * 1024,  // 256MB default for large batches
153
        log_dir = None,
154
        log_level = None,
155
156
157
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
158
        service_discovery_namespace = None,
159
160
161
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
162
        prometheus_port = None,
163
164
        prometheus_host = None,
        request_timeout_secs = 600,  // Add configurable request timeout
165
        pd_disaggregation = false,  // New flag for PD mode
166
167
        prefill_urls = None,
        decode_urls = None
168
169
170
171
172
173
    ))]
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
174
        worker_startup_timeout_secs: u64,
175
        worker_startup_check_interval: u64,
176
        cache_threshold: f32,
177
178
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
179
180
        eviction_interval_secs: u64,
        max_tree_size: usize,
181
        max_payload_size: usize,
182
        log_dir: Option<String>,
183
        log_level: Option<String>,
184
185
186
187
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
188
189
190
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
191
192
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
193
        request_timeout_secs: u64,
194
        pd_disaggregation: bool,
195
196
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
197
198
    ) -> PyResult<Self> {
        Ok(Router {
199
200
201
            host,
            port,
            worker_urls,
202
            policy,
203
            worker_startup_timeout_secs,
204
            worker_startup_check_interval,
205
            cache_threshold,
206
207
            balance_abs_threshold,
            balance_rel_threshold,
208
209
            eviction_interval_secs,
            max_tree_size,
210
            max_payload_size,
211
            log_dir,
212
            log_level,
213
214
215
216
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
217
218
219
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
220
221
            prometheus_port,
            prometheus_host,
222
            request_timeout_secs,
223
            pd_disaggregation,
224
225
            prefill_urls,
            decode_urls,
226
        })
227
228
229
    }

    fn start(&self) -> PyResult<()> {
230
231
232
233
        // Convert to RouterConfig and validate
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
234

235
236
237
238
239
240
241
        // Validate the configuration
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
242

243
244
245
246
        // Convert to internal policy config
        let policy_config = router_config
            .to_routing_policy_config()
            .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
247

248
249
250
251
252
253
254
255
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
256
257
258
259
260
                // PD mode configuration
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
261
262
263
264
265
            })
        } else {
            None
        };

266
267
268
269
270
271
272
273
274
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

275
276
277
278
279
280
        // Use tokio runtime instead of actix-web System for better compatibility
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        // Block on the async startup function
        runtime.block_on(async move {
281
282
283
284
285
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
                worker_urls: self.worker_urls.clone(),
                policy_config,
286
                max_payload_size: self.max_payload_size,
287
                log_dir: self.log_dir.clone(),
288
                log_level: self.log_level.clone(),
289
                service_discovery_config,
290
                prometheus_config,
291
                request_timeout_secs: self.request_timeout_secs,
292
293
            })
            .await
294
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
295
        })
296
297
298
299
    }
}

#[pymodule]
300
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
301
    m.add_class::<PolicyType>()?;
302
303
    m.add_class::<Router>()?;
    Ok(())
304
}