lib.rs 13.9 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5
pub mod core;
6
pub mod metrics;
7
pub mod middleware;
8
pub mod openai_api_types;
9
10
pub mod policies;
pub mod routers;
11
pub mod server;
12
pub mod service_discovery;
13
pub mod tree;
14
use crate::metrics::PrometheusConfig;
15

16
#[pyclass(eq)]
17
#[derive(Clone, PartialEq, Debug)]
18
19
20
pub enum PolicyType {
    Random,
    RoundRobin,
21
    CacheAware,
22
    PowerOfTwo,
23
24
}

25
#[pyclass]
26
#[derive(Debug, Clone, PartialEq)]
27
28
29
30
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
31
    policy: PolicyType,
32
    worker_startup_timeout_secs: u64,
33
    worker_startup_check_interval: u64,
34
    cache_threshold: f32,
35
36
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
37
38
    eviction_interval_secs: u64,
    max_tree_size: usize,
39
    max_payload_size: usize,
40
41
    dp_aware: bool,
    api_key: Option<String>,
42
    log_dir: Option<String>,
43
    log_level: Option<String>,
44
45
46
47
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
48
49
50
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
51
52
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
53
    request_timeout_secs: u64,
54
    request_id_headers: Option<Vec<String>>,
55
    pd_disaggregation: bool,
56
57
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
58
59
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
60
61
    max_concurrent_requests: usize,
    cors_allowed_origins: Vec<String>,
62
63
64
65
66
67
68
69
70
71
72
73
74
    // Retry configuration
    retry_max_retries: u32,
    retry_initial_backoff_ms: u64,
    retry_max_backoff_ms: u64,
    retry_backoff_multiplier: f32,
    retry_jitter_factor: f32,
    disable_retries: bool,
    // Circuit breaker configuration
    cb_failure_threshold: u32,
    cb_success_threshold: u32,
    cb_timeout_duration_secs: u64,
    cb_window_duration_secs: u64,
    disable_circuit_breaker: bool,
75
76
}

77
78
79
80
81
82
83
impl Router {
    /// Convert PyO3 Router to RouterConfig
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        // Convert policy helper function
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
                    load_check_interval_secs: 5, // Default value
                },
            }
        };

102
103
104
105
106
        // Determine routing mode
        let mode = if self.pd_disaggregation {
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
107
108
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
109
110
111
112
113
114
115
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

116
117
        // Convert main policy
        let policy = convert_policy(&self.policy);
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152

        // Service discovery configuration
        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        // Metrics configuration
        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
153
154
            dp_aware: self.dp_aware,
            api_key: self.api_key.clone(),
155
156
157
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
158
            log_level: self.log_level.clone(),
159
            request_id_headers: self.request_id_headers.clone(),
160
161
            max_concurrent_requests: self.max_concurrent_requests,
            cors_allowed_origins: self.cors_allowed_origins.clone(),
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
            retry: config::RetryConfig {
                max_retries: self.retry_max_retries,
                initial_backoff_ms: self.retry_initial_backoff_ms,
                max_backoff_ms: self.retry_max_backoff_ms,
                backoff_multiplier: self.retry_backoff_multiplier,
                jitter_factor: self.retry_jitter_factor,
            },
            circuit_breaker: config::CircuitBreakerConfig {
                failure_threshold: self.cb_failure_threshold,
                success_threshold: self.cb_success_threshold,
                timeout_duration_secs: self.cb_timeout_duration_secs,
                window_duration_secs: self.cb_window_duration_secs,
            },
            disable_retries: false,
            disable_circuit_breaker: false,
177
178
179
180
        })
    }
}

181
182
183
#[pymethods]
impl Router {
    #[new]
184
185
186
187
188
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
189
        worker_startup_timeout_secs = 300,
190
        worker_startup_check_interval = 10,
191
        cache_threshold = 0.50,
192
193
        balance_abs_threshold = 32,
        balance_rel_threshold = 1.0001,
194
        eviction_interval_secs = 60,
195
        max_tree_size = 2usize.pow(24),
196
        max_payload_size = 256 * 1024 * 1024,  // 256MB default for large batches
197
198
        dp_aware = false,
        api_key = None,
199
        log_dir = None,
200
        log_level = None,
201
202
203
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
204
        service_discovery_namespace = None,
205
206
207
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
208
        prometheus_port = None,
209
210
        prometheus_host = None,
        request_timeout_secs = 600,  // Add configurable request timeout
211
        request_id_headers = None,  // Custom request ID headers
212
        pd_disaggregation = false,  // New flag for PD mode
213
        prefill_urls = None,
214
215
        decode_urls = None,
        prefill_policy = None,
216
217
        decode_policy = None,
        max_concurrent_requests = 64,
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        cors_allowed_origins = vec![],
        // Retry defaults
        retry_max_retries = 3,
        retry_initial_backoff_ms = 100,
        retry_max_backoff_ms = 10_000,
        retry_backoff_multiplier = 2.0,
        retry_jitter_factor = 0.1,
        disable_retries = false,
        // Circuit breaker defaults
        cb_failure_threshold = 5,
        cb_success_threshold = 2,
        cb_timeout_duration_secs = 30,
        cb_window_duration_secs = 60,
        disable_circuit_breaker = false,
232
233
234
235
236
237
    ))]
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
238
        worker_startup_timeout_secs: u64,
239
        worker_startup_check_interval: u64,
240
        cache_threshold: f32,
241
242
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
243
244
        eviction_interval_secs: u64,
        max_tree_size: usize,
245
        max_payload_size: usize,
246
247
        dp_aware: bool,
        api_key: Option<String>,
248
        log_dir: Option<String>,
249
        log_level: Option<String>,
250
251
252
253
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
254
255
256
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
257
258
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
259
        request_timeout_secs: u64,
260
        request_id_headers: Option<Vec<String>>,
261
        pd_disaggregation: bool,
262
263
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
264
265
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
266
267
        max_concurrent_requests: usize,
        cors_allowed_origins: Vec<String>,
268
269
270
271
272
273
274
275
276
277
278
        retry_max_retries: u32,
        retry_initial_backoff_ms: u64,
        retry_max_backoff_ms: u64,
        retry_backoff_multiplier: f32,
        retry_jitter_factor: f32,
        disable_retries: bool,
        cb_failure_threshold: u32,
        cb_success_threshold: u32,
        cb_timeout_duration_secs: u64,
        cb_window_duration_secs: u64,
        disable_circuit_breaker: bool,
279
280
    ) -> PyResult<Self> {
        Ok(Router {
281
282
283
            host,
            port,
            worker_urls,
284
            policy,
285
            worker_startup_timeout_secs,
286
            worker_startup_check_interval,
287
            cache_threshold,
288
289
            balance_abs_threshold,
            balance_rel_threshold,
290
291
            eviction_interval_secs,
            max_tree_size,
292
            max_payload_size,
293
294
            dp_aware,
            api_key,
295
            log_dir,
296
            log_level,
297
298
299
300
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
301
302
303
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
304
305
            prometheus_port,
            prometheus_host,
306
            request_timeout_secs,
307
            request_id_headers,
308
            pd_disaggregation,
309
310
            prefill_urls,
            decode_urls,
311
312
            prefill_policy,
            decode_policy,
313
314
            max_concurrent_requests,
            cors_allowed_origins,
315
316
317
318
319
320
321
322
323
324
325
            retry_max_retries,
            retry_initial_backoff_ms,
            retry_max_backoff_ms,
            retry_backoff_multiplier,
            retry_jitter_factor,
            disable_retries,
            cb_failure_threshold,
            cb_success_threshold,
            cb_timeout_duration_secs,
            cb_window_duration_secs,
            disable_circuit_breaker,
326
        })
327
328
329
    }

    fn start(&self) -> PyResult<()> {
330
331
332
333
        // Convert to RouterConfig and validate
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
334

335
336
337
338
339
340
341
        // Validate the configuration
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
342

343
344
345
346
347
348
349
350
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
351
352
353
354
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
355
356
357
358
359
            })
        } else {
            None
        };

360
361
362
363
364
365
366
367
368
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

369
370
371
372
373
374
        // Use tokio runtime instead of actix-web System for better compatibility
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        // Block on the async startup function
        runtime.block_on(async move {
375
376
377
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
378
                router_config,
379
                max_payload_size: self.max_payload_size,
380
                log_dir: self.log_dir.clone(),
381
                log_level: self.log_level.clone(),
382
                service_discovery_config,
383
                prometheus_config,
384
                request_timeout_secs: self.request_timeout_secs,
385
                request_id_headers: self.request_id_headers.clone(),
386
387
            })
            .await
388
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
389
        })
390
391
392
393
    }
}

#[pymodule]
394
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
395
    m.add_class::<PolicyType>()?;
396
397
    m.add_class::<Router>()?;
    Ok(())
398
}