lib.rs 15.6 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5
pub mod core;
6
7
#[cfg(feature = "grpc-client")]
pub mod grpc;
8
pub mod metrics;
9
pub mod middleware;
10
pub mod policies;
11
pub mod protocols;
12
pub mod reasoning_parser;
13
pub mod routers;
14
pub mod server;
15
pub mod service_discovery;
16
pub mod tokenizer;
17
pub mod tree;
18
use crate::metrics::PrometheusConfig;
19

20
#[pyclass(eq)]
21
#[derive(Clone, PartialEq, Debug)]
22
23
24
pub enum PolicyType {
    Random,
    RoundRobin,
25
    CacheAware,
26
    PowerOfTwo,
27
28
}

29
#[pyclass]
30
#[derive(Debug, Clone, PartialEq)]
31
32
33
34
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
35
    policy: PolicyType,
36
    worker_startup_timeout_secs: u64,
37
    worker_startup_check_interval: u64,
38
    cache_threshold: f32,
39
40
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
41
42
    eviction_interval_secs: u64,
    max_tree_size: usize,
43
    max_payload_size: usize,
44
45
    dp_aware: bool,
    api_key: Option<String>,
46
    log_dir: Option<String>,
47
    log_level: Option<String>,
48
49
50
51
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
52
53
54
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
55
56
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
57
    request_timeout_secs: u64,
58
    request_id_headers: Option<Vec<String>>,
59
    pd_disaggregation: bool,
60
61
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
62
63
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
64
65
    max_concurrent_requests: usize,
    cors_allowed_origins: Vec<String>,
66
67
68
69
70
71
72
73
74
75
76
77
78
    // Retry configuration
    retry_max_retries: u32,
    retry_initial_backoff_ms: u64,
    retry_max_backoff_ms: u64,
    retry_backoff_multiplier: f32,
    retry_jitter_factor: f32,
    disable_retries: bool,
    // Circuit breaker configuration
    cb_failure_threshold: u32,
    cb_success_threshold: u32,
    cb_timeout_duration_secs: u64,
    cb_window_duration_secs: u64,
    disable_circuit_breaker: bool,
79
80
81
82
83
84
    // Health check configuration
    health_failure_threshold: u32,
    health_success_threshold: u32,
    health_check_timeout_secs: u64,
    health_check_interval_secs: u64,
    health_check_endpoint: String,
85
86
    // IGW (Inference Gateway) configuration
    enable_igw: bool,
87
88
}

89
90
91
92
93
94
95
impl Router {
    /// Convert PyO3 Router to RouterConfig
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        // Convert policy helper function
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
                    load_check_interval_secs: 5, // Default value
                },
            }
        };

114
        // Determine routing mode
115
116
117
118
119
120
        let mode = if self.enable_igw {
            // IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
            RoutingMode::Regular {
                worker_urls: vec![],
            }
        } else if self.pd_disaggregation {
121
122
123
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
124
125
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
126
127
128
129
130
131
132
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

133
134
        // Convert main policy
        let policy = convert_policy(&self.policy);
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

        // Service discovery configuration
        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        // Metrics configuration
        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
170
171
            dp_aware: self.dp_aware,
            api_key: self.api_key.clone(),
172
173
174
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
175
            log_level: self.log_level.clone(),
176
            request_id_headers: self.request_id_headers.clone(),
177
178
            max_concurrent_requests: self.max_concurrent_requests,
            cors_allowed_origins: self.cors_allowed_origins.clone(),
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
            retry: config::RetryConfig {
                max_retries: self.retry_max_retries,
                initial_backoff_ms: self.retry_initial_backoff_ms,
                max_backoff_ms: self.retry_max_backoff_ms,
                backoff_multiplier: self.retry_backoff_multiplier,
                jitter_factor: self.retry_jitter_factor,
            },
            circuit_breaker: config::CircuitBreakerConfig {
                failure_threshold: self.cb_failure_threshold,
                success_threshold: self.cb_success_threshold,
                timeout_duration_secs: self.cb_timeout_duration_secs,
                window_duration_secs: self.cb_window_duration_secs,
            },
            disable_retries: false,
            disable_circuit_breaker: false,
194
195
196
197
198
199
200
            health_check: config::HealthCheckConfig {
                failure_threshold: self.health_failure_threshold,
                success_threshold: self.health_success_threshold,
                timeout_secs: self.health_check_timeout_secs,
                check_interval_secs: self.health_check_interval_secs,
                endpoint: self.health_check_endpoint.clone(),
            },
201
            enable_igw: self.enable_igw,
202
203
204
205
        })
    }
}

206
207
208
#[pymethods]
impl Router {
    #[new]
209
210
211
212
213
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
214
215
216
217
218
219
220
221
        worker_startup_timeout_secs = 600,
        worker_startup_check_interval = 30,
        cache_threshold = 0.3,
        balance_abs_threshold = 64,
        balance_rel_threshold = 1.5,
        eviction_interval_secs = 120,
        max_tree_size = 2usize.pow(26),
        max_payload_size = 512 * 1024 * 1024,  // 512MB default for large batches
222
223
        dp_aware = false,
        api_key = None,
224
        log_dir = None,
225
        log_level = None,
226
227
228
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
229
        service_discovery_namespace = None,
230
231
232
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
233
        prometheus_port = None,
234
        prometheus_host = None,
235
        request_timeout_secs = 1800,  // Add configurable request timeout
236
        request_id_headers = None,  // Custom request ID headers
237
        pd_disaggregation = false,  // New flag for PD mode
238
        prefill_urls = None,
239
240
        decode_urls = None,
        prefill_policy = None,
241
        decode_policy = None,
242
        max_concurrent_requests = 256,
243
244
        cors_allowed_origins = vec![],
        // Retry defaults
245
246
247
248
249
        retry_max_retries = 5,
        retry_initial_backoff_ms = 50,
        retry_max_backoff_ms = 30_000,
        retry_backoff_multiplier = 1.5,
        retry_jitter_factor = 0.2,
250
251
        disable_retries = false,
        // Circuit breaker defaults
252
253
254
255
        cb_failure_threshold = 10,
        cb_success_threshold = 3,
        cb_timeout_duration_secs = 60,
        cb_window_duration_secs = 120,
256
        disable_circuit_breaker = false,
257
258
259
260
261
262
        // Health check defaults
        health_failure_threshold = 3,
        health_success_threshold = 2,
        health_check_timeout_secs = 5,
        health_check_interval_secs = 60,
        health_check_endpoint = String::from("/health"),
263
264
        // IGW defaults
        enable_igw = false,
265
    ))]
266
    #[allow(clippy::too_many_arguments)]
267
268
269
270
271
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
272
        worker_startup_timeout_secs: u64,
273
        worker_startup_check_interval: u64,
274
        cache_threshold: f32,
275
276
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
277
278
        eviction_interval_secs: u64,
        max_tree_size: usize,
279
        max_payload_size: usize,
280
281
        dp_aware: bool,
        api_key: Option<String>,
282
        log_dir: Option<String>,
283
        log_level: Option<String>,
284
285
286
287
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
288
289
290
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
291
292
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
293
        request_timeout_secs: u64,
294
        request_id_headers: Option<Vec<String>>,
295
        pd_disaggregation: bool,
296
297
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
298
299
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
300
301
        max_concurrent_requests: usize,
        cors_allowed_origins: Vec<String>,
302
303
304
305
306
307
308
309
310
311
312
        retry_max_retries: u32,
        retry_initial_backoff_ms: u64,
        retry_max_backoff_ms: u64,
        retry_backoff_multiplier: f32,
        retry_jitter_factor: f32,
        disable_retries: bool,
        cb_failure_threshold: u32,
        cb_success_threshold: u32,
        cb_timeout_duration_secs: u64,
        cb_window_duration_secs: u64,
        disable_circuit_breaker: bool,
313
314
315
316
317
        health_failure_threshold: u32,
        health_success_threshold: u32,
        health_check_timeout_secs: u64,
        health_check_interval_secs: u64,
        health_check_endpoint: String,
318
        enable_igw: bool,
319
320
    ) -> PyResult<Self> {
        Ok(Router {
321
322
323
            host,
            port,
            worker_urls,
324
            policy,
325
            worker_startup_timeout_secs,
326
            worker_startup_check_interval,
327
            cache_threshold,
328
329
            balance_abs_threshold,
            balance_rel_threshold,
330
331
            eviction_interval_secs,
            max_tree_size,
332
            max_payload_size,
333
334
            dp_aware,
            api_key,
335
            log_dir,
336
            log_level,
337
338
339
340
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
341
342
343
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
344
345
            prometheus_port,
            prometheus_host,
346
            request_timeout_secs,
347
            request_id_headers,
348
            pd_disaggregation,
349
350
            prefill_urls,
            decode_urls,
351
352
            prefill_policy,
            decode_policy,
353
354
            max_concurrent_requests,
            cors_allowed_origins,
355
356
357
358
359
360
361
362
363
364
365
            retry_max_retries,
            retry_initial_backoff_ms,
            retry_max_backoff_ms,
            retry_backoff_multiplier,
            retry_jitter_factor,
            disable_retries,
            cb_failure_threshold,
            cb_success_threshold,
            cb_timeout_duration_secs,
            cb_window_duration_secs,
            disable_circuit_breaker,
366
367
368
369
370
            health_failure_threshold,
            health_success_threshold,
            health_check_timeout_secs,
            health_check_interval_secs,
            health_check_endpoint,
371
            enable_igw,
372
        })
373
374
375
    }

    fn start(&self) -> PyResult<()> {
376
377
378
379
        // Convert to RouterConfig and validate
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
380

381
382
383
384
385
386
387
        // Validate the configuration
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
388

389
390
391
392
393
394
395
396
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
397
398
399
400
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
401
402
403
404
405
            })
        } else {
            None
        };

406
407
408
409
410
411
412
413
414
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

415
416
417
418
419
420
        // Use tokio runtime instead of actix-web System for better compatibility
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        // Block on the async startup function
        runtime.block_on(async move {
421
422
423
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
424
                router_config,
425
                max_payload_size: self.max_payload_size,
426
                log_dir: self.log_dir.clone(),
427
                log_level: self.log_level.clone(),
428
                service_discovery_config,
429
                prometheus_config,
430
                request_timeout_secs: self.request_timeout_secs,
431
                request_id_headers: self.request_id_headers.clone(),
432
433
            })
            .await
434
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
435
        })
436
437
438
439
    }
}

#[pymodule]
440
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
441
    m.add_class::<PolicyType>()?;
442
443
    m.add_class::<Router>()?;
    Ok(())
444
}