lib.rs 17.9 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5

6
pub mod core;
7
8
#[cfg(feature = "grpc-client")]
pub mod grpc;
9
pub mod mcp;
10
pub mod metrics;
11
pub mod middleware;
12
pub mod policies;
13
pub mod protocols;
14
pub mod reasoning_parser;
15
pub mod routers;
16
pub mod server;
17
pub mod service_discovery;
18
pub mod tokenizer;
19
pub mod tool_parser;
20
pub mod tree;
21
use crate::metrics::PrometheusConfig;
22

23
#[pyclass(eq)]
24
#[derive(Clone, PartialEq, Debug)]
25
26
27
pub enum PolicyType {
    Random,
    RoundRobin,
28
    CacheAware,
29
    PowerOfTwo,
30
31
}

32
#[pyclass]
33
#[derive(Debug, Clone, PartialEq)]
34
35
36
37
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
38
    policy: PolicyType,
39
    worker_startup_timeout_secs: u64,
40
    worker_startup_check_interval: u64,
41
    cache_threshold: f32,
42
43
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
44
45
    eviction_interval_secs: u64,
    max_tree_size: usize,
46
    max_payload_size: usize,
47
48
    dp_aware: bool,
    api_key: Option<String>,
49
    log_dir: Option<String>,
50
    log_level: Option<String>,
51
52
53
54
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
55
56
57
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
58
59
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
60
    request_timeout_secs: u64,
61
    request_id_headers: Option<Vec<String>>,
62
    pd_disaggregation: bool,
63
64
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
65
66
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
67
68
    max_concurrent_requests: usize,
    cors_allowed_origins: Vec<String>,
69
70
71
72
73
74
75
76
77
78
79
80
81
    // Retry configuration
    retry_max_retries: u32,
    retry_initial_backoff_ms: u64,
    retry_max_backoff_ms: u64,
    retry_backoff_multiplier: f32,
    retry_jitter_factor: f32,
    disable_retries: bool,
    // Circuit breaker configuration
    cb_failure_threshold: u32,
    cb_success_threshold: u32,
    cb_timeout_duration_secs: u64,
    cb_window_duration_secs: u64,
    disable_circuit_breaker: bool,
82
83
84
85
86
87
    // Health check configuration
    health_failure_threshold: u32,
    health_success_threshold: u32,
    health_check_timeout_secs: u64,
    health_check_interval_secs: u64,
    health_check_endpoint: String,
88
89
    // IGW (Inference Gateway) configuration
    enable_igw: bool,
90
91
92
    queue_size: usize,
    queue_timeout_secs: u64,
    rate_limit_tokens_per_second: Option<usize>,
93
94
95
96
97
98
    // Connection mode (determined from worker URLs)
    connection_mode: config::ConnectionMode,
    // Model path for tokenizer
    model_path: Option<String>,
    // Explicit tokenizer path
    tokenizer_path: Option<String>,
99
100
}

101
impl Router {
102
103
    /// Determine connection mode from worker URLs
    fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode {
104
        // Only consider it gRPC if explicitly specified with grpc:// or grpcs:// scheme
105
106
107
108
109
        for url in worker_urls {
            if url.starts_with("grpc://") || url.starts_with("grpcs://") {
                return config::ConnectionMode::Grpc;
            }
        }
110
        // Default to HTTP for all other cases (including http://, https://, or no scheme)
111
112
113
        config::ConnectionMode::Http
    }

114
115
116
117
118
119
    /// Convert PyO3 Router to RouterConfig
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        // Convert policy helper function
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
                    load_check_interval_secs: 5, // Default value
                },
            }
        };

138
        // Determine routing mode
139
140
141
142
143
144
        let mode = if self.enable_igw {
            // IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
            RoutingMode::Regular {
                worker_urls: vec![],
            }
        } else if self.pd_disaggregation {
145
146
147
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
148
149
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
150
151
152
153
154
155
156
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

157
158
        // Convert main policy
        let policy = convert_policy(&self.policy);
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189

        // Service discovery configuration
        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        // Metrics configuration
        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
190
            connection_mode: self.connection_mode.clone(),
191
192
193
194
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
195
196
            dp_aware: self.dp_aware,
            api_key: self.api_key.clone(),
197
198
199
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
200
            log_level: self.log_level.clone(),
201
            request_id_headers: self.request_id_headers.clone(),
202
            max_concurrent_requests: self.max_concurrent_requests,
203
204
205
            queue_size: self.queue_size,
            queue_timeout_secs: self.queue_timeout_secs,
            rate_limit_tokens_per_second: self.rate_limit_tokens_per_second,
206
            cors_allowed_origins: self.cors_allowed_origins.clone(),
207
208
209
210
211
212
213
214
215
216
217
218
219
            retry: config::RetryConfig {
                max_retries: self.retry_max_retries,
                initial_backoff_ms: self.retry_initial_backoff_ms,
                max_backoff_ms: self.retry_max_backoff_ms,
                backoff_multiplier: self.retry_backoff_multiplier,
                jitter_factor: self.retry_jitter_factor,
            },
            circuit_breaker: config::CircuitBreakerConfig {
                failure_threshold: self.cb_failure_threshold,
                success_threshold: self.cb_success_threshold,
                timeout_duration_secs: self.cb_timeout_duration_secs,
                window_duration_secs: self.cb_window_duration_secs,
            },
220
221
            disable_retries: self.disable_retries,
            disable_circuit_breaker: self.disable_circuit_breaker,
222
223
224
225
226
227
228
            health_check: config::HealthCheckConfig {
                failure_threshold: self.health_failure_threshold,
                success_threshold: self.health_success_threshold,
                timeout_secs: self.health_check_timeout_secs,
                check_interval_secs: self.health_check_interval_secs,
                endpoint: self.health_check_endpoint.clone(),
            },
229
            enable_igw: self.enable_igw,
230
231
            model_path: self.model_path.clone(),
            tokenizer_path: self.tokenizer_path.clone(),
232
233
234
235
        })
    }
}

236
237
238
#[pymethods]
impl Router {
    #[new]
239
240
241
242
243
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
244
245
246
247
248
249
250
251
        worker_startup_timeout_secs = 600,
        worker_startup_check_interval = 30,
        cache_threshold = 0.3,
        balance_abs_threshold = 64,
        balance_rel_threshold = 1.5,
        eviction_interval_secs = 120,
        max_tree_size = 2usize.pow(26),
        max_payload_size = 512 * 1024 * 1024,  // 512MB default for large batches
252
253
        dp_aware = false,
        api_key = None,
254
        log_dir = None,
255
        log_level = None,
256
257
258
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
259
        service_discovery_namespace = None,
260
261
262
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
263
        prometheus_port = None,
264
        prometheus_host = None,
265
        request_timeout_secs = 1800,  // Add configurable request timeout
266
        request_id_headers = None,  // Custom request ID headers
267
        pd_disaggregation = false,  // New flag for PD mode
268
        prefill_urls = None,
269
270
        decode_urls = None,
        prefill_policy = None,
271
        decode_policy = None,
272
        max_concurrent_requests = 256,
273
274
        cors_allowed_origins = vec![],
        // Retry defaults
275
276
277
278
279
        retry_max_retries = 5,
        retry_initial_backoff_ms = 50,
        retry_max_backoff_ms = 30_000,
        retry_backoff_multiplier = 1.5,
        retry_jitter_factor = 0.2,
280
281
        disable_retries = false,
        // Circuit breaker defaults
282
283
284
285
        cb_failure_threshold = 10,
        cb_success_threshold = 3,
        cb_timeout_duration_secs = 60,
        cb_window_duration_secs = 120,
286
        disable_circuit_breaker = false,
287
288
289
290
291
292
        // Health check defaults
        health_failure_threshold = 3,
        health_success_threshold = 2,
        health_check_timeout_secs = 5,
        health_check_interval_secs = 60,
        health_check_endpoint = String::from("/health"),
293
294
        // IGW defaults
        enable_igw = false,
295
296
297
        queue_size = 100,
        queue_timeout_secs = 60,
        rate_limit_tokens_per_second = None,
298
299
300
        // Tokenizer defaults
        model_path = None,
        tokenizer_path = None,
301
    ))]
302
    #[allow(clippy::too_many_arguments)]
303
304
305
306
307
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
308
        worker_startup_timeout_secs: u64,
309
        worker_startup_check_interval: u64,
310
        cache_threshold: f32,
311
312
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
313
314
        eviction_interval_secs: u64,
        max_tree_size: usize,
315
        max_payload_size: usize,
316
317
        dp_aware: bool,
        api_key: Option<String>,
318
        log_dir: Option<String>,
319
        log_level: Option<String>,
320
321
322
323
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
324
325
326
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
327
328
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
329
        request_timeout_secs: u64,
330
        request_id_headers: Option<Vec<String>>,
331
        pd_disaggregation: bool,
332
333
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
334
335
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
336
337
        max_concurrent_requests: usize,
        cors_allowed_origins: Vec<String>,
338
339
340
341
342
343
344
345
346
347
348
        retry_max_retries: u32,
        retry_initial_backoff_ms: u64,
        retry_max_backoff_ms: u64,
        retry_backoff_multiplier: f32,
        retry_jitter_factor: f32,
        disable_retries: bool,
        cb_failure_threshold: u32,
        cb_success_threshold: u32,
        cb_timeout_duration_secs: u64,
        cb_window_duration_secs: u64,
        disable_circuit_breaker: bool,
349
350
351
352
353
        health_failure_threshold: u32,
        health_success_threshold: u32,
        health_check_timeout_secs: u64,
        health_check_interval_secs: u64,
        health_check_endpoint: String,
354
        enable_igw: bool,
355
356
357
        queue_size: usize,
        queue_timeout_secs: u64,
        rate_limit_tokens_per_second: Option<usize>,
358
359
        model_path: Option<String>,
        tokenizer_path: Option<String>,
360
    ) -> PyResult<Self> {
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
        // Determine connection mode from worker URLs
        let mut all_urls = worker_urls.clone();

        // Add prefill URLs if in PD mode
        if let Some(ref prefill_urls) = prefill_urls {
            for (url, _) in prefill_urls {
                all_urls.push(url.clone());
            }
        }

        // Add decode URLs if in PD mode
        if let Some(ref decode_urls) = decode_urls {
            all_urls.extend(decode_urls.clone());
        }

        let connection_mode = Self::determine_connection_mode(&all_urls);

378
        Ok(Router {
379
380
381
            host,
            port,
            worker_urls,
382
            policy,
383
            worker_startup_timeout_secs,
384
            worker_startup_check_interval,
385
            cache_threshold,
386
387
            balance_abs_threshold,
            balance_rel_threshold,
388
389
            eviction_interval_secs,
            max_tree_size,
390
            max_payload_size,
391
392
            dp_aware,
            api_key,
393
            log_dir,
394
            log_level,
395
396
397
398
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
399
400
401
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
402
403
            prometheus_port,
            prometheus_host,
404
            request_timeout_secs,
405
            request_id_headers,
406
            pd_disaggregation,
407
408
            prefill_urls,
            decode_urls,
409
410
            prefill_policy,
            decode_policy,
411
412
            max_concurrent_requests,
            cors_allowed_origins,
413
414
415
416
417
418
419
420
421
422
423
            retry_max_retries,
            retry_initial_backoff_ms,
            retry_max_backoff_ms,
            retry_backoff_multiplier,
            retry_jitter_factor,
            disable_retries,
            cb_failure_threshold,
            cb_success_threshold,
            cb_timeout_duration_secs,
            cb_window_duration_secs,
            disable_circuit_breaker,
424
425
426
427
428
            health_failure_threshold,
            health_success_threshold,
            health_check_timeout_secs,
            health_check_interval_secs,
            health_check_endpoint,
429
            enable_igw,
430
431
432
            queue_size,
            queue_timeout_secs,
            rate_limit_tokens_per_second,
433
434
435
            connection_mode,
            model_path,
            tokenizer_path,
436
        })
437
438
439
    }

    fn start(&self) -> PyResult<()> {
440
441
442
443
        // Convert to RouterConfig and validate
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
444

445
446
447
448
449
450
451
        // Validate the configuration
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
452

453
454
455
456
457
458
459
460
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
461
462
463
464
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
465
466
467
468
469
            })
        } else {
            None
        };

470
471
472
473
474
475
476
477
478
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

479
480
481
482
483
484
        // Use tokio runtime instead of actix-web System for better compatibility
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        // Block on the async startup function
        runtime.block_on(async move {
485
486
487
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
488
                router_config,
489
                max_payload_size: self.max_payload_size,
490
                log_dir: self.log_dir.clone(),
491
                log_level: self.log_level.clone(),
492
                service_discovery_config,
493
                prometheus_config,
494
                request_timeout_secs: self.request_timeout_secs,
495
                request_id_headers: self.request_id_headers.clone(),
496
497
            })
            .await
498
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
499
        })
500
501
502
503
    }
}

#[pymodule]
504
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
505
    m.add_class::<PolicyType>()?;
506
507
    m.add_class::<Router>()?;
    Ok(())
508
}