lib.rs 18.1 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5

6
pub mod core;
7
pub mod data_connector;
8
#[cfg(feature = "grpc-client")]
9
pub mod grpc_client;
10
pub mod mcp;
11
pub mod metrics;
12
pub mod middleware;
13
pub mod policies;
14
pub mod protocols;
15
pub mod reasoning_parser;
16
pub mod routers;
17
pub mod server;
18
pub mod service_discovery;
19
pub mod tokenizer;
20
pub mod tool_parser;
21
pub mod tree;
22
use crate::metrics::PrometheusConfig;
23

24
#[pyclass(eq)]
25
#[derive(Clone, PartialEq, Debug)]
26
27
28
pub enum PolicyType {
    Random,
    RoundRobin,
29
    CacheAware,
30
    PowerOfTwo,
31
32
}

33
#[pyclass]
34
#[derive(Debug, Clone, PartialEq)]
35
36
37
38
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
39
    policy: PolicyType,
40
    worker_startup_timeout_secs: u64,
41
    worker_startup_check_interval: u64,
42
    cache_threshold: f32,
43
44
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
45
46
    eviction_interval_secs: u64,
    max_tree_size: usize,
47
    max_payload_size: usize,
48
49
    dp_aware: bool,
    api_key: Option<String>,
50
    log_dir: Option<String>,
51
    log_level: Option<String>,
52
53
54
55
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
56
57
58
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
59
60
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
61
    request_timeout_secs: u64,
62
    request_id_headers: Option<Vec<String>>,
63
    pd_disaggregation: bool,
64
65
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
66
67
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
68
69
    max_concurrent_requests: usize,
    cors_allowed_origins: Vec<String>,
70
71
72
73
74
75
76
77
78
79
80
81
82
    // Retry configuration
    retry_max_retries: u32,
    retry_initial_backoff_ms: u64,
    retry_max_backoff_ms: u64,
    retry_backoff_multiplier: f32,
    retry_jitter_factor: f32,
    disable_retries: bool,
    // Circuit breaker configuration
    cb_failure_threshold: u32,
    cb_success_threshold: u32,
    cb_timeout_duration_secs: u64,
    cb_window_duration_secs: u64,
    disable_circuit_breaker: bool,
83
84
85
86
87
88
    // Health check configuration
    health_failure_threshold: u32,
    health_success_threshold: u32,
    health_check_timeout_secs: u64,
    health_check_interval_secs: u64,
    health_check_endpoint: String,
89
90
    // IGW (Inference Gateway) configuration
    enable_igw: bool,
91
92
93
    queue_size: usize,
    queue_timeout_secs: u64,
    rate_limit_tokens_per_second: Option<usize>,
94
95
96
97
98
99
    // Connection mode (determined from worker URLs)
    connection_mode: config::ConnectionMode,
    // Model path for tokenizer
    model_path: Option<String>,
    // Explicit tokenizer path
    tokenizer_path: Option<String>,
100
101
}

102
impl Router {
103
104
    /// Determine connection mode from worker URLs
    fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode {
105
        // Only consider it gRPC if explicitly specified with grpc:// or grpcs:// scheme
106
107
108
109
110
        for url in worker_urls {
            if url.starts_with("grpc://") || url.starts_with("grpcs://") {
                return config::ConnectionMode::Grpc;
            }
        }
111
        // Default to HTTP for all other cases (including http://, https://, or no scheme)
112
113
114
        config::ConnectionMode::Http
    }

115
116
117
118
119
120
    /// Convert PyO3 Router to RouterConfig
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        // Convert policy helper function
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
                    load_check_interval_secs: 5, // Default value
                },
            }
        };

139
        // Determine routing mode
140
141
142
143
144
145
        let mode = if self.enable_igw {
            // IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
            RoutingMode::Regular {
                worker_urls: vec![],
            }
        } else if self.pd_disaggregation {
146
147
148
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
149
150
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
151
152
153
154
155
156
157
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

158
159
        // Convert main policy
        let policy = convert_policy(&self.policy);
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190

        // Service discovery configuration
        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        // Metrics configuration
        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
191
            connection_mode: self.connection_mode.clone(),
192
193
194
195
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
196
197
            dp_aware: self.dp_aware,
            api_key: self.api_key.clone(),
198
199
200
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
201
            log_level: self.log_level.clone(),
202
            request_id_headers: self.request_id_headers.clone(),
203
            max_concurrent_requests: self.max_concurrent_requests,
204
205
206
            queue_size: self.queue_size,
            queue_timeout_secs: self.queue_timeout_secs,
            rate_limit_tokens_per_second: self.rate_limit_tokens_per_second,
207
            cors_allowed_origins: self.cors_allowed_origins.clone(),
208
209
210
211
212
213
214
215
216
217
218
219
220
            retry: config::RetryConfig {
                max_retries: self.retry_max_retries,
                initial_backoff_ms: self.retry_initial_backoff_ms,
                max_backoff_ms: self.retry_max_backoff_ms,
                backoff_multiplier: self.retry_backoff_multiplier,
                jitter_factor: self.retry_jitter_factor,
            },
            circuit_breaker: config::CircuitBreakerConfig {
                failure_threshold: self.cb_failure_threshold,
                success_threshold: self.cb_success_threshold,
                timeout_duration_secs: self.cb_timeout_duration_secs,
                window_duration_secs: self.cb_window_duration_secs,
            },
221
222
            disable_retries: self.disable_retries,
            disable_circuit_breaker: self.disable_circuit_breaker,
223
224
225
226
227
228
229
            health_check: config::HealthCheckConfig {
                failure_threshold: self.health_failure_threshold,
                success_threshold: self.health_success_threshold,
                timeout_secs: self.health_check_timeout_secs,
                check_interval_secs: self.health_check_interval_secs,
                endpoint: self.health_check_endpoint.clone(),
            },
230
            enable_igw: self.enable_igw,
231
232
            model_path: self.model_path.clone(),
            tokenizer_path: self.tokenizer_path.clone(),
233
            history_backend: config::HistoryBackend::Memory,
234
            oracle: None,
235
236
237
238
        })
    }
}

239
240
241
#[pymethods]
impl Router {
    #[new]
242
243
244
245
246
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
247
248
249
250
251
252
253
254
        worker_startup_timeout_secs = 600,
        worker_startup_check_interval = 30,
        cache_threshold = 0.3,
        balance_abs_threshold = 64,
        balance_rel_threshold = 1.5,
        eviction_interval_secs = 120,
        max_tree_size = 2usize.pow(26),
        max_payload_size = 512 * 1024 * 1024,  // 512MB default for large batches
255
256
        dp_aware = false,
        api_key = None,
257
        log_dir = None,
258
        log_level = None,
259
260
261
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
262
        service_discovery_namespace = None,
263
264
265
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
266
        prometheus_port = None,
267
        prometheus_host = None,
268
        request_timeout_secs = 1800,  // Add configurable request timeout
269
        request_id_headers = None,  // Custom request ID headers
270
        pd_disaggregation = false,  // New flag for PD mode
271
        prefill_urls = None,
272
273
        decode_urls = None,
        prefill_policy = None,
274
        decode_policy = None,
275
        max_concurrent_requests = 256,
276
277
        cors_allowed_origins = vec![],
        // Retry defaults
278
279
280
281
282
        retry_max_retries = 5,
        retry_initial_backoff_ms = 50,
        retry_max_backoff_ms = 30_000,
        retry_backoff_multiplier = 1.5,
        retry_jitter_factor = 0.2,
283
284
        disable_retries = false,
        // Circuit breaker defaults
285
286
287
288
        cb_failure_threshold = 10,
        cb_success_threshold = 3,
        cb_timeout_duration_secs = 60,
        cb_window_duration_secs = 120,
289
        disable_circuit_breaker = false,
290
291
292
293
294
295
        // Health check defaults
        health_failure_threshold = 3,
        health_success_threshold = 2,
        health_check_timeout_secs = 5,
        health_check_interval_secs = 60,
        health_check_endpoint = String::from("/health"),
296
297
        // IGW defaults
        enable_igw = false,
298
299
300
        queue_size = 100,
        queue_timeout_secs = 60,
        rate_limit_tokens_per_second = None,
301
302
303
        // Tokenizer defaults
        model_path = None,
        tokenizer_path = None,
304
    ))]
305
    #[allow(clippy::too_many_arguments)]
306
307
308
309
310
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
311
        worker_startup_timeout_secs: u64,
312
        worker_startup_check_interval: u64,
313
        cache_threshold: f32,
314
315
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
316
317
        eviction_interval_secs: u64,
        max_tree_size: usize,
318
        max_payload_size: usize,
319
320
        dp_aware: bool,
        api_key: Option<String>,
321
        log_dir: Option<String>,
322
        log_level: Option<String>,
323
324
325
326
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
327
328
329
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
330
331
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
332
        request_timeout_secs: u64,
333
        request_id_headers: Option<Vec<String>>,
334
        pd_disaggregation: bool,
335
336
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
337
338
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
339
340
        max_concurrent_requests: usize,
        cors_allowed_origins: Vec<String>,
341
342
343
344
345
346
347
348
349
350
351
        retry_max_retries: u32,
        retry_initial_backoff_ms: u64,
        retry_max_backoff_ms: u64,
        retry_backoff_multiplier: f32,
        retry_jitter_factor: f32,
        disable_retries: bool,
        cb_failure_threshold: u32,
        cb_success_threshold: u32,
        cb_timeout_duration_secs: u64,
        cb_window_duration_secs: u64,
        disable_circuit_breaker: bool,
352
353
354
355
356
        health_failure_threshold: u32,
        health_success_threshold: u32,
        health_check_timeout_secs: u64,
        health_check_interval_secs: u64,
        health_check_endpoint: String,
357
        enable_igw: bool,
358
359
360
        queue_size: usize,
        queue_timeout_secs: u64,
        rate_limit_tokens_per_second: Option<usize>,
361
362
        model_path: Option<String>,
        tokenizer_path: Option<String>,
363
    ) -> PyResult<Self> {
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
        // Determine connection mode from worker URLs
        let mut all_urls = worker_urls.clone();

        // Add prefill URLs if in PD mode
        if let Some(ref prefill_urls) = prefill_urls {
            for (url, _) in prefill_urls {
                all_urls.push(url.clone());
            }
        }

        // Add decode URLs if in PD mode
        if let Some(ref decode_urls) = decode_urls {
            all_urls.extend(decode_urls.clone());
        }

        let connection_mode = Self::determine_connection_mode(&all_urls);

381
        Ok(Router {
382
383
384
            host,
            port,
            worker_urls,
385
            policy,
386
            worker_startup_timeout_secs,
387
            worker_startup_check_interval,
388
            cache_threshold,
389
390
            balance_abs_threshold,
            balance_rel_threshold,
391
392
            eviction_interval_secs,
            max_tree_size,
393
            max_payload_size,
394
395
            dp_aware,
            api_key,
396
            log_dir,
397
            log_level,
398
399
400
401
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
402
403
404
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
405
406
            prometheus_port,
            prometheus_host,
407
            request_timeout_secs,
408
            request_id_headers,
409
            pd_disaggregation,
410
411
            prefill_urls,
            decode_urls,
412
413
            prefill_policy,
            decode_policy,
414
415
            max_concurrent_requests,
            cors_allowed_origins,
416
417
418
419
420
421
422
423
424
425
426
            retry_max_retries,
            retry_initial_backoff_ms,
            retry_max_backoff_ms,
            retry_backoff_multiplier,
            retry_jitter_factor,
            disable_retries,
            cb_failure_threshold,
            cb_success_threshold,
            cb_timeout_duration_secs,
            cb_window_duration_secs,
            disable_circuit_breaker,
427
428
429
430
431
            health_failure_threshold,
            health_success_threshold,
            health_check_timeout_secs,
            health_check_interval_secs,
            health_check_endpoint,
432
            enable_igw,
433
434
435
            queue_size,
            queue_timeout_secs,
            rate_limit_tokens_per_second,
436
437
438
            connection_mode,
            model_path,
            tokenizer_path,
439
        })
440
441
442
    }

    fn start(&self) -> PyResult<()> {
443
444
445
446
        // Convert to RouterConfig and validate
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
447

448
449
450
451
452
453
454
        // Validate the configuration
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
455

456
457
458
459
460
461
462
463
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
464
465
466
467
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
468
469
470
471
472
            })
        } else {
            None
        };

473
474
475
476
477
478
479
480
481
        // Create Prometheus config if enabled
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

482
483
484
485
486
487
        // Use tokio runtime instead of actix-web System for better compatibility
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        // Block on the async startup function
        runtime.block_on(async move {
488
489
490
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
491
                router_config,
492
                max_payload_size: self.max_payload_size,
493
                log_dir: self.log_dir.clone(),
494
                log_level: self.log_level.clone(),
495
                service_discovery_config,
496
                prometheus_config,
497
                request_timeout_secs: self.request_timeout_secs,
498
                request_id_headers: self.request_id_headers.clone(),
499
500
            })
            .await
501
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
502
        })
503
504
505
506
    }
}

#[pymodule]
507
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
508
    m.add_class::<PolicyType>()?;
509
510
    m.add_class::<Router>()?;
    Ok(())
511
}