lib.rs 18.6 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5

6
pub mod core;
7
8
#[cfg(feature = "grpc-client")]
pub mod grpc;
9
pub mod mcp;
10
pub mod metrics;
11
pub mod middleware;
12
pub mod policies;
13
pub mod protocols;
14
pub mod reasoning_parser;
15
pub mod routers;
16
pub mod server;
17
pub mod service_discovery;
18
pub mod tokenizer;
19
pub mod tool_parser;
20
pub mod tree;
21
use crate::metrics::PrometheusConfig;
22

23
#[pyclass(eq)]
24
#[derive(Clone, PartialEq, Debug)]
25
26
27
pub enum PolicyType {
    Random,
    RoundRobin,
28
    CacheAware,
29
    PowerOfTwo,
30
31
}

32
#[pyclass]
33
#[derive(Debug, Clone, PartialEq)]
34
35
36
37
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
38
    policy: PolicyType,
39
    worker_startup_timeout_secs: u64,
40
    worker_startup_check_interval: u64,
41
    cache_threshold: f32,
42
43
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
44
45
    eviction_interval_secs: u64,
    max_tree_size: usize,
46
    max_payload_size: usize,
47
48
    dp_aware: bool,
    api_key: Option<String>,
49
    log_dir: Option<String>,
50
    log_level: Option<String>,
51
52
53
54
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
55
56
57
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
58
59
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
60
    request_timeout_secs: u64,
61
    request_id_headers: Option<Vec<String>>,
62
    pd_disaggregation: bool,
63
64
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
65
66
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
67
68
    max_concurrent_requests: usize,
    cors_allowed_origins: Vec<String>,
69
70
71
72
73
74
75
76
77
78
79
80
81
    // Retry configuration
    retry_max_retries: u32,
    retry_initial_backoff_ms: u64,
    retry_max_backoff_ms: u64,
    retry_backoff_multiplier: f32,
    retry_jitter_factor: f32,
    disable_retries: bool,
    // Circuit breaker configuration
    cb_failure_threshold: u32,
    cb_success_threshold: u32,
    cb_timeout_duration_secs: u64,
    cb_window_duration_secs: u64,
    disable_circuit_breaker: bool,
82
83
84
85
86
87
    // Health check configuration
    health_failure_threshold: u32,
    health_success_threshold: u32,
    health_check_timeout_secs: u64,
    health_check_interval_secs: u64,
    health_check_endpoint: String,
88
89
    // IGW (Inference Gateway) configuration
    enable_igw: bool,
90
91
92
    queue_size: usize,
    queue_timeout_secs: u64,
    rate_limit_tokens_per_second: Option<usize>,
93
94
95
96
97
98
    // Connection mode (determined from worker URLs)
    connection_mode: config::ConnectionMode,
    // Model path for tokenizer
    model_path: Option<String>,
    // Explicit tokenizer path
    tokenizer_path: Option<String>,
99
100
}

101
impl Router {
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    /// Determine connection mode from worker URLs
    fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode {
        // Check if any URL is a gRPC endpoint (starts with grpc:// or has port that commonly indicates gRPC)
        for url in worker_urls {
            if url.starts_with("grpc://") || url.starts_with("grpcs://") {
                return config::ConnectionMode::Grpc;
            }
            // Also check for common gRPC ports if the scheme isn't specified
            if let Ok(parsed_url) = url::Url::parse(url) {
                if let Some(port) = parsed_url.port() {
                    // Common gRPC ports
                    if port == 50051 || port == 9090 || ((50000..=50100).contains(&port)) {
                        return config::ConnectionMode::Grpc;
                    }
                }
            } else if url.contains(":50051") || url.contains(":9090") || url.contains(":5000") {
                // Fallback check for URLs that might not parse correctly
                return config::ConnectionMode::Grpc;
            }
        }
        // Default to HTTP
        config::ConnectionMode::Http
    }

126
127
128
129
130
131
    /// Convert PyO3 Router to RouterConfig
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        // Convert policy helper function
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
                    load_check_interval_secs: 5, // Default value
                },
            }
        };

150
        // Determine routing mode
151
152
153
154
155
156
        let mode = if self.enable_igw {
            // IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
            RoutingMode::Regular {
                worker_urls: vec![],
            }
        } else if self.pd_disaggregation {
157
158
159
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
160
161
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
162
163
164
165
166
167
168
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

169
170
        // Convert main policy
        let policy = convert_policy(&self.policy);
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201

        // Service discovery configuration
        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        // Metrics configuration
        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
202
            connection_mode: self.connection_mode.clone(),
203
204
205
206
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
207
208
            dp_aware: self.dp_aware,
            api_key: self.api_key.clone(),
209
210
211
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
212
            log_level: self.log_level.clone(),
213
            request_id_headers: self.request_id_headers.clone(),
214
            max_concurrent_requests: self.max_concurrent_requests,
215
216
217
            queue_size: self.queue_size,
            queue_timeout_secs: self.queue_timeout_secs,
            rate_limit_tokens_per_second: self.rate_limit_tokens_per_second,
218
            cors_allowed_origins: self.cors_allowed_origins.clone(),
219
220
221
222
223
224
225
226
227
228
229
230
231
            retry: config::RetryConfig {
                max_retries: self.retry_max_retries,
                initial_backoff_ms: self.retry_initial_backoff_ms,
                max_backoff_ms: self.retry_max_backoff_ms,
                backoff_multiplier: self.retry_backoff_multiplier,
                jitter_factor: self.retry_jitter_factor,
            },
            circuit_breaker: config::CircuitBreakerConfig {
                failure_threshold: self.cb_failure_threshold,
                success_threshold: self.cb_success_threshold,
                timeout_duration_secs: self.cb_timeout_duration_secs,
                window_duration_secs: self.cb_window_duration_secs,
            },
232
233
            disable_retries: self.disable_retries,
            disable_circuit_breaker: self.disable_circuit_breaker,
234
235
236
237
238
239
240
            health_check: config::HealthCheckConfig {
                failure_threshold: self.health_failure_threshold,
                success_threshold: self.health_success_threshold,
                timeout_secs: self.health_check_timeout_secs,
                check_interval_secs: self.health_check_interval_secs,
                endpoint: self.health_check_endpoint.clone(),
            },
241
            enable_igw: self.enable_igw,
242
243
            model_path: self.model_path.clone(),
            tokenizer_path: self.tokenizer_path.clone(),
244
245
246
247
        })
    }
}

248
249
250
#[pymethods]
impl Router {
    #[new]
251
252
253
254
255
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
256
257
258
259
260
261
262
263
        worker_startup_timeout_secs = 600,
        worker_startup_check_interval = 30,
        cache_threshold = 0.3,
        balance_abs_threshold = 64,
        balance_rel_threshold = 1.5,
        eviction_interval_secs = 120,
        max_tree_size = 2usize.pow(26),
        max_payload_size = 512 * 1024 * 1024,  // 512MB default for large batches
264
265
        dp_aware = false,
        api_key = None,
266
        log_dir = None,
267
        log_level = None,
268
269
270
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
271
        service_discovery_namespace = None,
272
273
274
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
275
        prometheus_port = None,
276
        prometheus_host = None,
277
        request_timeout_secs = 1800,  // Add configurable request timeout
278
        request_id_headers = None,  // Custom request ID headers
279
        pd_disaggregation = false,  // New flag for PD mode
280
        prefill_urls = None,
281
282
        decode_urls = None,
        prefill_policy = None,
283
        decode_policy = None,
284
        max_concurrent_requests = 256,
285
286
        cors_allowed_origins = vec![],
        // Retry defaults
287
288
289
290
291
        retry_max_retries = 5,
        retry_initial_backoff_ms = 50,
        retry_max_backoff_ms = 30_000,
        retry_backoff_multiplier = 1.5,
        retry_jitter_factor = 0.2,
292
293
        disable_retries = false,
        // Circuit breaker defaults
294
295
296
297
        cb_failure_threshold = 10,
        cb_success_threshold = 3,
        cb_timeout_duration_secs = 60,
        cb_window_duration_secs = 120,
298
        disable_circuit_breaker = false,
299
300
301
302
303
304
        // Health check defaults
        health_failure_threshold = 3,
        health_success_threshold = 2,
        health_check_timeout_secs = 5,
        health_check_interval_secs = 60,
        health_check_endpoint = String::from("/health"),
305
306
        // IGW defaults
        enable_igw = false,
307
308
309
        queue_size = 100,
        queue_timeout_secs = 60,
        rate_limit_tokens_per_second = None,
310
311
312
        // Tokenizer defaults
        model_path = None,
        tokenizer_path = None,
313
    ))]
314
    #[allow(clippy::too_many_arguments)]
315
316
317
318
319
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
320
        worker_startup_timeout_secs: u64,
321
        worker_startup_check_interval: u64,
322
        cache_threshold: f32,
323
324
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
325
326
        eviction_interval_secs: u64,
        max_tree_size: usize,
327
        max_payload_size: usize,
328
329
        dp_aware: bool,
        api_key: Option<String>,
330
        log_dir: Option<String>,
331
        log_level: Option<String>,
332
333
334
335
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
336
337
338
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
339
340
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
341
        request_timeout_secs: u64,
342
        request_id_headers: Option<Vec<String>>,
343
        pd_disaggregation: bool,
344
345
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
346
347
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
348
349
        max_concurrent_requests: usize,
        cors_allowed_origins: Vec<String>,
350
351
352
353
354
355
356
357
358
359
360
        retry_max_retries: u32,
        retry_initial_backoff_ms: u64,
        retry_max_backoff_ms: u64,
        retry_backoff_multiplier: f32,
        retry_jitter_factor: f32,
        disable_retries: bool,
        cb_failure_threshold: u32,
        cb_success_threshold: u32,
        cb_timeout_duration_secs: u64,
        cb_window_duration_secs: u64,
        disable_circuit_breaker: bool,
361
362
363
364
365
        health_failure_threshold: u32,
        health_success_threshold: u32,
        health_check_timeout_secs: u64,
        health_check_interval_secs: u64,
        health_check_endpoint: String,
366
        enable_igw: bool,
367
368
369
        queue_size: usize,
        queue_timeout_secs: u64,
        rate_limit_tokens_per_second: Option<usize>,
370
371
        model_path: Option<String>,
        tokenizer_path: Option<String>,
372
    ) -> PyResult<Self> {
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
        // Determine connection mode from worker URLs
        let mut all_urls = worker_urls.clone();

        // Add prefill URLs if in PD mode
        if let Some(ref prefill_urls) = prefill_urls {
            for (url, _) in prefill_urls {
                all_urls.push(url.clone());
            }
        }

        // Add decode URLs if in PD mode
        if let Some(ref decode_urls) = decode_urls {
            all_urls.extend(decode_urls.clone());
        }

        let connection_mode = Self::determine_connection_mode(&all_urls);

390
        Ok(Router {
391
392
393
            host,
            port,
            worker_urls,
394
            policy,
395
            worker_startup_timeout_secs,
396
            worker_startup_check_interval,
397
            cache_threshold,
398
399
            balance_abs_threshold,
            balance_rel_threshold,
400
401
            eviction_interval_secs,
            max_tree_size,
402
            max_payload_size,
403
404
            dp_aware,
            api_key,
405
            log_dir,
406
            log_level,
407
408
409
410
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
411
412
413
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
414
415
            prometheus_port,
            prometheus_host,
416
            request_timeout_secs,
417
            request_id_headers,
418
            pd_disaggregation,
419
420
            prefill_urls,
            decode_urls,
421
422
            prefill_policy,
            decode_policy,
423
424
            max_concurrent_requests,
            cors_allowed_origins,
425
426
427
428
429
430
431
432
433
434
435
            retry_max_retries,
            retry_initial_backoff_ms,
            retry_max_backoff_ms,
            retry_backoff_multiplier,
            retry_jitter_factor,
            disable_retries,
            cb_failure_threshold,
            cb_success_threshold,
            cb_timeout_duration_secs,
            cb_window_duration_secs,
            disable_circuit_breaker,
436
437
438
439
440
            health_failure_threshold,
            health_success_threshold,
            health_check_timeout_secs,
            health_check_interval_secs,
            health_check_endpoint,
441
            enable_igw,
442
443
444
            queue_size,
            queue_timeout_secs,
            rate_limit_tokens_per_second,
445
446
447
            connection_mode,
            model_path,
            tokenizer_path,
448
        })
449
450
451
    }

    fn start(&self) -> PyResult<()> {
452
453
454
455
        // Convert to RouterConfig and validate
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
456

457
458
459
460
461
462
463
        // Validate the configuration
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
464

465
466
467
468
469
470
471
472
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
473
474
475
476
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
477
478
479
480
481
            })
        } else {
            None
        };

482
483
484
485
486
487
488
489
490
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

491
492
493
494
495
496
        // Use tokio runtime instead of actix-web System for better compatibility
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        // Block on the async startup function
        runtime.block_on(async move {
497
498
499
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
500
                router_config,
501
                max_payload_size: self.max_payload_size,
502
                log_dir: self.log_dir.clone(),
503
                log_level: self.log_level.clone(),
504
                service_discovery_config,
505
                prometheus_config,
506
                request_timeout_secs: self.request_timeout_secs,
507
                request_id_headers: self.request_id_headers.clone(),
508
509
            })
            .await
510
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
511
        })
512
513
514
515
    }
}

#[pymodule]
516
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
517
    m.add_class::<PolicyType>()?;
518
519
    m.add_class::<Router>()?;
    Ok(())
520
}