lib.rs 17 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5

6
pub mod core;
7
pub mod data_connector;
8
#[cfg(feature = "grpc-client")]
9
pub mod grpc_client;
10
pub mod mcp;
11
pub mod metrics;
12
pub mod middleware;
13
pub mod policies;
14
pub mod protocols;
15
pub mod reasoning_parser;
16
pub mod routers;
17
pub mod server;
18
pub mod service_discovery;
19
pub mod tokenizer;
20
pub mod tool_parser;
21
pub mod tree;
22
use crate::metrics::PrometheusConfig;
23

24
#[pyclass(eq)]
25
#[derive(Clone, PartialEq, Debug)]
26
27
28
pub enum PolicyType {
    Random,
    RoundRobin,
29
    CacheAware,
30
    PowerOfTwo,
31
32
}

33
#[pyclass]
34
#[derive(Debug, Clone, PartialEq)]
35
36
37
38
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
39
    policy: PolicyType,
40
    worker_startup_timeout_secs: u64,
41
    worker_startup_check_interval: u64,
42
    cache_threshold: f32,
43
44
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
45
46
    eviction_interval_secs: u64,
    max_tree_size: usize,
47
    max_payload_size: usize,
48
49
    dp_aware: bool,
    api_key: Option<String>,
50
    log_dir: Option<String>,
51
    log_level: Option<String>,
52
53
54
55
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
56
57
58
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
59
60
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
61
    request_timeout_secs: u64,
62
    request_id_headers: Option<Vec<String>>,
63
    pd_disaggregation: bool,
64
65
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
66
67
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
68
69
    max_concurrent_requests: usize,
    cors_allowed_origins: Vec<String>,
70
71
72
73
74
75
76
77
78
79
80
    retry_max_retries: u32,
    retry_initial_backoff_ms: u64,
    retry_max_backoff_ms: u64,
    retry_backoff_multiplier: f32,
    retry_jitter_factor: f32,
    disable_retries: bool,
    cb_failure_threshold: u32,
    cb_success_threshold: u32,
    cb_timeout_duration_secs: u64,
    cb_window_duration_secs: u64,
    disable_circuit_breaker: bool,
81
82
83
84
85
    health_failure_threshold: u32,
    health_success_threshold: u32,
    health_check_timeout_secs: u64,
    health_check_interval_secs: u64,
    health_check_endpoint: String,
86
    enable_igw: bool,
87
88
89
    queue_size: usize,
    queue_timeout_secs: u64,
    rate_limit_tokens_per_second: Option<usize>,
90
91
92
    connection_mode: config::ConnectionMode,
    model_path: Option<String>,
    tokenizer_path: Option<String>,
93
94
    reasoning_parser: Option<String>,
    tool_call_parser: Option<String>,
95
96
}

97
impl Router {
98
99
100
101
102
103
104
105
106
107
    /// Determine connection mode from worker URLs
    fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode {
        for url in worker_urls {
            if url.starts_with("grpc://") || url.starts_with("grpcs://") {
                return config::ConnectionMode::Grpc;
            }
        }
        config::ConnectionMode::Http
    }

108
109
110
111
112
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

113
114
115
116
117
118
119
120
121
122
123
124
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
125
                    load_check_interval_secs: 5,
126
127
128
129
                },
            }
        };

130
131
132
133
134
        let mode = if self.enable_igw {
            RoutingMode::Regular {
                worker_urls: vec![],
            }
        } else if self.pd_disaggregation {
135
136
137
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
138
139
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
140
141
142
143
144
145
146
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

147
        let policy = convert_policy(&self.policy);
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176

        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
177
            connection_mode: self.connection_mode.clone(),
178
179
180
181
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
182
183
            dp_aware: self.dp_aware,
            api_key: self.api_key.clone(),
184
185
186
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
187
            log_level: self.log_level.clone(),
188
            request_id_headers: self.request_id_headers.clone(),
189
            max_concurrent_requests: self.max_concurrent_requests,
190
191
192
            queue_size: self.queue_size,
            queue_timeout_secs: self.queue_timeout_secs,
            rate_limit_tokens_per_second: self.rate_limit_tokens_per_second,
193
            cors_allowed_origins: self.cors_allowed_origins.clone(),
194
195
196
197
198
199
200
201
202
203
204
205
206
            retry: config::RetryConfig {
                max_retries: self.retry_max_retries,
                initial_backoff_ms: self.retry_initial_backoff_ms,
                max_backoff_ms: self.retry_max_backoff_ms,
                backoff_multiplier: self.retry_backoff_multiplier,
                jitter_factor: self.retry_jitter_factor,
            },
            circuit_breaker: config::CircuitBreakerConfig {
                failure_threshold: self.cb_failure_threshold,
                success_threshold: self.cb_success_threshold,
                timeout_duration_secs: self.cb_timeout_duration_secs,
                window_duration_secs: self.cb_window_duration_secs,
            },
207
208
            disable_retries: self.disable_retries,
            disable_circuit_breaker: self.disable_circuit_breaker,
209
210
211
212
213
214
215
            health_check: config::HealthCheckConfig {
                failure_threshold: self.health_failure_threshold,
                success_threshold: self.health_success_threshold,
                timeout_secs: self.health_check_timeout_secs,
                check_interval_secs: self.health_check_interval_secs,
                endpoint: self.health_check_endpoint.clone(),
            },
216
            enable_igw: self.enable_igw,
217
218
            model_path: self.model_path.clone(),
            tokenizer_path: self.tokenizer_path.clone(),
219
            history_backend: config::HistoryBackend::Memory,
220
            oracle: None,
221
222
            reasoning_parser: self.reasoning_parser.clone(),
            tool_call_parser: self.tool_call_parser.clone(),
223
224
225
226
        })
    }
}

227
228
229
#[pymethods]
impl Router {
    #[new]
230
231
232
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
233
        host = String::from("0.0.0.0"),
234
        port = 3001,
235
236
237
238
239
240
241
        worker_startup_timeout_secs = 600,
        worker_startup_check_interval = 30,
        cache_threshold = 0.3,
        balance_abs_threshold = 64,
        balance_rel_threshold = 1.5,
        eviction_interval_secs = 120,
        max_tree_size = 2usize.pow(26),
242
        max_payload_size = 512 * 1024 * 1024,
243
244
        dp_aware = false,
        api_key = None,
245
        log_dir = None,
246
        log_level = None,
247
248
249
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
250
        service_discovery_namespace = None,
251
252
253
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
254
        prometheus_port = None,
255
        prometheus_host = None,
256
257
258
        request_timeout_secs = 1800,
        request_id_headers = None,
        pd_disaggregation = false,
259
        prefill_urls = None,
260
261
        decode_urls = None,
        prefill_policy = None,
262
        decode_policy = None,
263
        max_concurrent_requests = 256,
264
        cors_allowed_origins = vec![],
265
266
267
268
269
        retry_max_retries = 5,
        retry_initial_backoff_ms = 50,
        retry_max_backoff_ms = 30_000,
        retry_backoff_multiplier = 1.5,
        retry_jitter_factor = 0.2,
270
        disable_retries = false,
271
272
273
274
        cb_failure_threshold = 10,
        cb_success_threshold = 3,
        cb_timeout_duration_secs = 60,
        cb_window_duration_secs = 120,
275
        disable_circuit_breaker = false,
276
277
278
279
280
        health_failure_threshold = 3,
        health_success_threshold = 2,
        health_check_timeout_secs = 5,
        health_check_interval_secs = 60,
        health_check_endpoint = String::from("/health"),
281
        enable_igw = false,
282
283
284
        queue_size = 100,
        queue_timeout_secs = 60,
        rate_limit_tokens_per_second = None,
285
286
        model_path = None,
        tokenizer_path = None,
287
288
        reasoning_parser = None,
        tool_call_parser = None,
289
    ))]
290
    #[allow(clippy::too_many_arguments)]
291
292
293
294
295
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
296
        worker_startup_timeout_secs: u64,
297
        worker_startup_check_interval: u64,
298
        cache_threshold: f32,
299
300
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
301
302
        eviction_interval_secs: u64,
        max_tree_size: usize,
303
        max_payload_size: usize,
304
305
        dp_aware: bool,
        api_key: Option<String>,
306
        log_dir: Option<String>,
307
        log_level: Option<String>,
308
309
310
311
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
312
313
314
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
315
316
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
317
        request_timeout_secs: u64,
318
        request_id_headers: Option<Vec<String>>,
319
        pd_disaggregation: bool,
320
321
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
322
323
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
324
325
        max_concurrent_requests: usize,
        cors_allowed_origins: Vec<String>,
326
327
328
329
330
331
332
333
334
335
336
        retry_max_retries: u32,
        retry_initial_backoff_ms: u64,
        retry_max_backoff_ms: u64,
        retry_backoff_multiplier: f32,
        retry_jitter_factor: f32,
        disable_retries: bool,
        cb_failure_threshold: u32,
        cb_success_threshold: u32,
        cb_timeout_duration_secs: u64,
        cb_window_duration_secs: u64,
        disable_circuit_breaker: bool,
337
338
339
340
341
        health_failure_threshold: u32,
        health_success_threshold: u32,
        health_check_timeout_secs: u64,
        health_check_interval_secs: u64,
        health_check_endpoint: String,
342
        enable_igw: bool,
343
344
345
        queue_size: usize,
        queue_timeout_secs: u64,
        rate_limit_tokens_per_second: Option<usize>,
346
347
        model_path: Option<String>,
        tokenizer_path: Option<String>,
348
349
        reasoning_parser: Option<String>,
        tool_call_parser: Option<String>,
350
    ) -> PyResult<Self> {
351
352
353
354
355
356
357
358
359
360
361
362
363
364
        let mut all_urls = worker_urls.clone();

        if let Some(ref prefill_urls) = prefill_urls {
            for (url, _) in prefill_urls {
                all_urls.push(url.clone());
            }
        }

        if let Some(ref decode_urls) = decode_urls {
            all_urls.extend(decode_urls.clone());
        }

        let connection_mode = Self::determine_connection_mode(&all_urls);

365
        Ok(Router {
366
367
368
            host,
            port,
            worker_urls,
369
            policy,
370
            worker_startup_timeout_secs,
371
            worker_startup_check_interval,
372
            cache_threshold,
373
374
            balance_abs_threshold,
            balance_rel_threshold,
375
376
            eviction_interval_secs,
            max_tree_size,
377
            max_payload_size,
378
379
            dp_aware,
            api_key,
380
            log_dir,
381
            log_level,
382
383
384
385
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
386
387
388
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
389
390
            prometheus_port,
            prometheus_host,
391
            request_timeout_secs,
392
            request_id_headers,
393
            pd_disaggregation,
394
395
            prefill_urls,
            decode_urls,
396
397
            prefill_policy,
            decode_policy,
398
399
            max_concurrent_requests,
            cors_allowed_origins,
400
401
402
403
404
405
406
407
408
409
410
            retry_max_retries,
            retry_initial_backoff_ms,
            retry_max_backoff_ms,
            retry_backoff_multiplier,
            retry_jitter_factor,
            disable_retries,
            cb_failure_threshold,
            cb_success_threshold,
            cb_timeout_duration_secs,
            cb_window_duration_secs,
            disable_circuit_breaker,
411
412
413
414
415
            health_failure_threshold,
            health_success_threshold,
            health_check_timeout_secs,
            health_check_interval_secs,
            health_check_endpoint,
416
            enable_igw,
417
418
419
            queue_size,
            queue_timeout_secs,
            rate_limit_tokens_per_second,
420
421
422
            connection_mode,
            model_path,
            tokenizer_path,
423
424
            reasoning_parser,
            tool_call_parser,
425
        })
426
427
428
    }

    fn start(&self) -> PyResult<()> {
429
430
431
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
432

433
434
435
436
437
438
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
439

440
441
442
443
444
445
446
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
447
448
449
450
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
451
452
453
454
455
            })
        } else {
            None
        };

456
457
458
459
460
461
462
463
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

464
465
466
467
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        runtime.block_on(async move {
468
469
470
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
471
                router_config,
472
                max_payload_size: self.max_payload_size,
473
                log_dir: self.log_dir.clone(),
474
                log_level: self.log_level.clone(),
475
                service_discovery_config,
476
                prometheus_config,
477
                request_timeout_secs: self.request_timeout_secs,
478
                request_id_headers: self.request_id_headers.clone(),
479
480
            })
            .await
481
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
482
        })
483
484
485
486
    }
}

#[pymodule]
487
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
488
    m.add_class::<PolicyType>()?;
489
490
    m.add_class::<Router>()?;
    Ok(())
491
}