lib.rs 18 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5

6
pub mod core;
7
pub mod data_connector;
8
9
#[cfg(feature = "grpc-client")]
pub mod grpc;
10
pub mod mcp;
11
pub mod metrics;
12
pub mod middleware;
13
pub mod policies;
14
pub mod protocols;
15
pub mod reasoning_parser;
16
pub mod routers;
17
pub mod server;
18
pub mod service_discovery;
19
pub mod tokenizer;
20
pub mod tool_parser;
21
pub mod tree;
22
use crate::metrics::PrometheusConfig;
23

24
#[pyclass(eq)]
25
#[derive(Clone, PartialEq, Debug)]
26
27
28
pub enum PolicyType {
    Random,
    RoundRobin,
29
    CacheAware,
30
    PowerOfTwo,
31
32
}

33
#[pyclass]
34
#[derive(Debug, Clone, PartialEq)]
35
36
37
38
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
39
    policy: PolicyType,
40
    worker_startup_timeout_secs: u64,
41
    worker_startup_check_interval: u64,
42
    cache_threshold: f32,
43
44
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
45
46
    eviction_interval_secs: u64,
    max_tree_size: usize,
47
    max_payload_size: usize,
48
49
    dp_aware: bool,
    api_key: Option<String>,
50
    log_dir: Option<String>,
51
    log_level: Option<String>,
52
53
54
55
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
56
57
58
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
59
60
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
61
    request_timeout_secs: u64,
62
    request_id_headers: Option<Vec<String>>,
63
    pd_disaggregation: bool,
64
65
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
66
67
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
68
69
    max_concurrent_requests: usize,
    cors_allowed_origins: Vec<String>,
70
71
72
73
74
75
76
77
78
79
80
81
82
    // Retry configuration
    retry_max_retries: u32,
    retry_initial_backoff_ms: u64,
    retry_max_backoff_ms: u64,
    retry_backoff_multiplier: f32,
    retry_jitter_factor: f32,
    disable_retries: bool,
    // Circuit breaker configuration
    cb_failure_threshold: u32,
    cb_success_threshold: u32,
    cb_timeout_duration_secs: u64,
    cb_window_duration_secs: u64,
    disable_circuit_breaker: bool,
83
84
85
86
87
88
    // Health check configuration
    health_failure_threshold: u32,
    health_success_threshold: u32,
    health_check_timeout_secs: u64,
    health_check_interval_secs: u64,
    health_check_endpoint: String,
89
90
    // IGW (Inference Gateway) configuration
    enable_igw: bool,
91
92
93
    queue_size: usize,
    queue_timeout_secs: u64,
    rate_limit_tokens_per_second: Option<usize>,
94
95
96
97
98
99
    // Connection mode (determined from worker URLs)
    connection_mode: config::ConnectionMode,
    // Model path for tokenizer
    model_path: Option<String>,
    // Explicit tokenizer path
    tokenizer_path: Option<String>,
100
101
}

102
impl Router {
103
104
    /// Determine connection mode from worker URLs
    fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode {
105
        // Only consider it gRPC if explicitly specified with grpc:// or grpcs:// scheme
106
107
108
109
110
        for url in worker_urls {
            if url.starts_with("grpc://") || url.starts_with("grpcs://") {
                return config::ConnectionMode::Grpc;
            }
        }
111
        // Default to HTTP for all other cases (including http://, https://, or no scheme)
112
113
114
        config::ConnectionMode::Http
    }

115
116
117
118
119
120
    /// Convert PyO3 Router to RouterConfig
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        // Convert policy helper function
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
                    load_check_interval_secs: 5, // Default value
                },
            }
        };

139
        // Determine routing mode
140
141
142
143
144
145
        let mode = if self.enable_igw {
            // IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
            RoutingMode::Regular {
                worker_urls: vec![],
            }
        } else if self.pd_disaggregation {
146
147
148
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
149
150
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
151
152
153
154
155
156
157
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

158
159
        // Convert main policy
        let policy = convert_policy(&self.policy);
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190

        // Service discovery configuration
        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        // Metrics configuration
        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
191
            connection_mode: self.connection_mode.clone(),
192
193
194
195
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
196
197
            dp_aware: self.dp_aware,
            api_key: self.api_key.clone(),
198
199
200
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
201
            log_level: self.log_level.clone(),
202
            request_id_headers: self.request_id_headers.clone(),
203
            max_concurrent_requests: self.max_concurrent_requests,
204
205
206
            queue_size: self.queue_size,
            queue_timeout_secs: self.queue_timeout_secs,
            rate_limit_tokens_per_second: self.rate_limit_tokens_per_second,
207
            cors_allowed_origins: self.cors_allowed_origins.clone(),
208
209
210
211
212
213
214
215
216
217
218
219
220
            retry: config::RetryConfig {
                max_retries: self.retry_max_retries,
                initial_backoff_ms: self.retry_initial_backoff_ms,
                max_backoff_ms: self.retry_max_backoff_ms,
                backoff_multiplier: self.retry_backoff_multiplier,
                jitter_factor: self.retry_jitter_factor,
            },
            circuit_breaker: config::CircuitBreakerConfig {
                failure_threshold: self.cb_failure_threshold,
                success_threshold: self.cb_success_threshold,
                timeout_duration_secs: self.cb_timeout_duration_secs,
                window_duration_secs: self.cb_window_duration_secs,
            },
221
222
            disable_retries: self.disable_retries,
            disable_circuit_breaker: self.disable_circuit_breaker,
223
224
225
226
227
228
229
            health_check: config::HealthCheckConfig {
                failure_threshold: self.health_failure_threshold,
                success_threshold: self.health_success_threshold,
                timeout_secs: self.health_check_timeout_secs,
                check_interval_secs: self.health_check_interval_secs,
                endpoint: self.health_check_endpoint.clone(),
            },
230
            enable_igw: self.enable_igw,
231
232
            model_path: self.model_path.clone(),
            tokenizer_path: self.tokenizer_path.clone(),
233
            history_backend: config::HistoryBackend::Memory,
234
235
236
237
        })
    }
}

238
239
240
#[pymethods]
impl Router {
    #[new]
241
242
243
244
245
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
246
247
248
249
250
251
252
253
        worker_startup_timeout_secs = 600,
        worker_startup_check_interval = 30,
        cache_threshold = 0.3,
        balance_abs_threshold = 64,
        balance_rel_threshold = 1.5,
        eviction_interval_secs = 120,
        max_tree_size = 2usize.pow(26),
        max_payload_size = 512 * 1024 * 1024,  // 512MB default for large batches
254
255
        dp_aware = false,
        api_key = None,
256
        log_dir = None,
257
        log_level = None,
258
259
260
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
261
        service_discovery_namespace = None,
262
263
264
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
265
        prometheus_port = None,
266
        prometheus_host = None,
267
        request_timeout_secs = 1800,  // Add configurable request timeout
268
        request_id_headers = None,  // Custom request ID headers
269
        pd_disaggregation = false,  // New flag for PD mode
270
        prefill_urls = None,
271
272
        decode_urls = None,
        prefill_policy = None,
273
        decode_policy = None,
274
        max_concurrent_requests = 256,
275
276
        cors_allowed_origins = vec![],
        // Retry defaults
277
278
279
280
281
        retry_max_retries = 5,
        retry_initial_backoff_ms = 50,
        retry_max_backoff_ms = 30_000,
        retry_backoff_multiplier = 1.5,
        retry_jitter_factor = 0.2,
282
283
        disable_retries = false,
        // Circuit breaker defaults
284
285
286
287
        cb_failure_threshold = 10,
        cb_success_threshold = 3,
        cb_timeout_duration_secs = 60,
        cb_window_duration_secs = 120,
288
        disable_circuit_breaker = false,
289
290
291
292
293
294
        // Health check defaults
        health_failure_threshold = 3,
        health_success_threshold = 2,
        health_check_timeout_secs = 5,
        health_check_interval_secs = 60,
        health_check_endpoint = String::from("/health"),
295
296
        // IGW defaults
        enable_igw = false,
297
298
299
        queue_size = 100,
        queue_timeout_secs = 60,
        rate_limit_tokens_per_second = None,
300
301
302
        // Tokenizer defaults
        model_path = None,
        tokenizer_path = None,
303
    ))]
304
    #[allow(clippy::too_many_arguments)]
305
306
307
308
309
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
310
        worker_startup_timeout_secs: u64,
311
        worker_startup_check_interval: u64,
312
        cache_threshold: f32,
313
314
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
315
316
        eviction_interval_secs: u64,
        max_tree_size: usize,
317
        max_payload_size: usize,
318
319
        dp_aware: bool,
        api_key: Option<String>,
320
        log_dir: Option<String>,
321
        log_level: Option<String>,
322
323
324
325
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
326
327
328
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
329
330
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
331
        request_timeout_secs: u64,
332
        request_id_headers: Option<Vec<String>>,
333
        pd_disaggregation: bool,
334
335
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
336
337
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
338
339
        max_concurrent_requests: usize,
        cors_allowed_origins: Vec<String>,
340
341
342
343
344
345
346
347
348
349
350
        retry_max_retries: u32,
        retry_initial_backoff_ms: u64,
        retry_max_backoff_ms: u64,
        retry_backoff_multiplier: f32,
        retry_jitter_factor: f32,
        disable_retries: bool,
        cb_failure_threshold: u32,
        cb_success_threshold: u32,
        cb_timeout_duration_secs: u64,
        cb_window_duration_secs: u64,
        disable_circuit_breaker: bool,
351
352
353
354
355
        health_failure_threshold: u32,
        health_success_threshold: u32,
        health_check_timeout_secs: u64,
        health_check_interval_secs: u64,
        health_check_endpoint: String,
356
        enable_igw: bool,
357
358
359
        queue_size: usize,
        queue_timeout_secs: u64,
        rate_limit_tokens_per_second: Option<usize>,
360
361
        model_path: Option<String>,
        tokenizer_path: Option<String>,
362
    ) -> PyResult<Self> {
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
        // Determine connection mode from worker URLs
        let mut all_urls = worker_urls.clone();

        // Add prefill URLs if in PD mode
        if let Some(ref prefill_urls) = prefill_urls {
            for (url, _) in prefill_urls {
                all_urls.push(url.clone());
            }
        }

        // Add decode URLs if in PD mode
        if let Some(ref decode_urls) = decode_urls {
            all_urls.extend(decode_urls.clone());
        }

        let connection_mode = Self::determine_connection_mode(&all_urls);

380
        Ok(Router {
381
382
383
            host,
            port,
            worker_urls,
384
            policy,
385
            worker_startup_timeout_secs,
386
            worker_startup_check_interval,
387
            cache_threshold,
388
389
            balance_abs_threshold,
            balance_rel_threshold,
390
391
            eviction_interval_secs,
            max_tree_size,
392
            max_payload_size,
393
394
            dp_aware,
            api_key,
395
            log_dir,
396
            log_level,
397
398
399
400
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
401
402
403
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
404
405
            prometheus_port,
            prometheus_host,
406
            request_timeout_secs,
407
            request_id_headers,
408
            pd_disaggregation,
409
410
            prefill_urls,
            decode_urls,
411
412
            prefill_policy,
            decode_policy,
413
414
            max_concurrent_requests,
            cors_allowed_origins,
415
416
417
418
419
420
421
422
423
424
425
            retry_max_retries,
            retry_initial_backoff_ms,
            retry_max_backoff_ms,
            retry_backoff_multiplier,
            retry_jitter_factor,
            disable_retries,
            cb_failure_threshold,
            cb_success_threshold,
            cb_timeout_duration_secs,
            cb_window_duration_secs,
            disable_circuit_breaker,
426
427
428
429
430
            health_failure_threshold,
            health_success_threshold,
            health_check_timeout_secs,
            health_check_interval_secs,
            health_check_endpoint,
431
            enable_igw,
432
433
434
            queue_size,
            queue_timeout_secs,
            rate_limit_tokens_per_second,
435
436
437
            connection_mode,
            model_path,
            tokenizer_path,
438
        })
439
440
441
    }

    fn start(&self) -> PyResult<()> {
442
443
444
445
        // Convert to RouterConfig and validate
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
446

447
448
449
450
451
452
453
        // Validate the configuration
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
454

455
456
457
458
459
460
461
462
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
463
464
465
466
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
467
468
469
470
471
            })
        } else {
            None
        };

472
473
474
475
476
477
478
479
480
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

481
482
483
484
485
486
        // Use tokio runtime instead of actix-web System for better compatibility
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        // Block on the async startup function
        runtime.block_on(async move {
487
488
489
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
490
                router_config,
491
                max_payload_size: self.max_payload_size,
492
                log_dir: self.log_dir.clone(),
493
                log_level: self.log_level.clone(),
494
                service_discovery_config,
495
                prometheus_config,
496
                request_timeout_secs: self.request_timeout_secs,
497
                request_id_headers: self.request_id_headers.clone(),
498
499
            })
            .await
500
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
501
        })
502
503
504
505
    }
}

#[pymodule]
506
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
507
    m.add_class::<PolicyType>()?;
508
509
    m.add_class::<Router>()?;
    Ok(())
510
}