lib.rs 17.2 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5

6
pub mod core;
7
pub mod data_connector;
8
#[cfg(feature = "grpc-client")]
9
pub mod grpc_client;
10
pub mod mcp;
11
pub mod metrics;
12
pub mod middleware;
13
pub mod policies;
14
pub mod protocols;
15
pub mod reasoning_parser;
16
pub mod routers;
17
pub mod server;
18
pub mod service_discovery;
19
pub mod tokenizer;
20
pub mod tool_parser;
21
pub mod tree;
22
use crate::metrics::PrometheusConfig;
23

24
#[pyclass(eq)]
25
#[derive(Clone, PartialEq, Debug)]
26
27
28
pub enum PolicyType {
    Random,
    RoundRobin,
29
    CacheAware,
30
    PowerOfTwo,
31
32
}

33
#[pyclass]
34
#[derive(Debug, Clone, PartialEq)]
35
36
37
38
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
39
    policy: PolicyType,
40
    worker_startup_timeout_secs: u64,
41
    worker_startup_check_interval: u64,
42
    cache_threshold: f32,
43
44
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
45
46
    eviction_interval_secs: u64,
    max_tree_size: usize,
47
    max_payload_size: usize,
48
49
    dp_aware: bool,
    api_key: Option<String>,
50
    log_dir: Option<String>,
51
    log_level: Option<String>,
52
53
54
55
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
56
57
58
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
59
60
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
61
    request_timeout_secs: u64,
62
    request_id_headers: Option<Vec<String>>,
63
    pd_disaggregation: bool,
64
65
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
66
67
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
68
    max_concurrent_requests: i32,
69
    cors_allowed_origins: Vec<String>,
70
71
72
73
74
75
76
77
78
79
80
    retry_max_retries: u32,
    retry_initial_backoff_ms: u64,
    retry_max_backoff_ms: u64,
    retry_backoff_multiplier: f32,
    retry_jitter_factor: f32,
    disable_retries: bool,
    cb_failure_threshold: u32,
    cb_success_threshold: u32,
    cb_timeout_duration_secs: u64,
    cb_window_duration_secs: u64,
    disable_circuit_breaker: bool,
81
82
83
84
85
    health_failure_threshold: u32,
    health_success_threshold: u32,
    health_check_timeout_secs: u64,
    health_check_interval_secs: u64,
    health_check_endpoint: String,
86
    enable_igw: bool,
87
88
    queue_size: usize,
    queue_timeout_secs: u64,
89
    rate_limit_tokens_per_second: Option<i32>,
90
91
92
    connection_mode: config::ConnectionMode,
    model_path: Option<String>,
    tokenizer_path: Option<String>,
93
    chat_template: Option<String>,
94
95
    reasoning_parser: Option<String>,
    tool_call_parser: Option<String>,
96
97
}

98
impl Router {
99
100
101
102
103
104
105
106
107
108
    /// Determine connection mode from worker URLs
    fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode {
        for url in worker_urls {
            if url.starts_with("grpc://") || url.starts_with("grpcs://") {
                return config::ConnectionMode::Grpc;
            }
        }
        config::ConnectionMode::Http
    }

109
110
111
112
113
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

114
115
116
117
118
119
120
121
122
123
124
125
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
126
                    load_check_interval_secs: 5,
127
128
129
130
                },
            }
        };

131
132
133
134
135
        let mode = if self.enable_igw {
            RoutingMode::Regular {
                worker_urls: vec![],
            }
        } else if self.pd_disaggregation {
136
137
138
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
139
140
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
141
142
143
144
145
146
147
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

148
        let policy = convert_policy(&self.policy);
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
178
            connection_mode: self.connection_mode.clone(),
179
180
181
182
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
183
184
            dp_aware: self.dp_aware,
            api_key: self.api_key.clone(),
185
186
187
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
188
            log_level: self.log_level.clone(),
189
            request_id_headers: self.request_id_headers.clone(),
190
            max_concurrent_requests: self.max_concurrent_requests,
191
192
193
            queue_size: self.queue_size,
            queue_timeout_secs: self.queue_timeout_secs,
            rate_limit_tokens_per_second: self.rate_limit_tokens_per_second,
194
            cors_allowed_origins: self.cors_allowed_origins.clone(),
195
196
197
198
199
200
201
202
203
204
205
206
207
            retry: config::RetryConfig {
                max_retries: self.retry_max_retries,
                initial_backoff_ms: self.retry_initial_backoff_ms,
                max_backoff_ms: self.retry_max_backoff_ms,
                backoff_multiplier: self.retry_backoff_multiplier,
                jitter_factor: self.retry_jitter_factor,
            },
            circuit_breaker: config::CircuitBreakerConfig {
                failure_threshold: self.cb_failure_threshold,
                success_threshold: self.cb_success_threshold,
                timeout_duration_secs: self.cb_timeout_duration_secs,
                window_duration_secs: self.cb_window_duration_secs,
            },
208
209
            disable_retries: self.disable_retries,
            disable_circuit_breaker: self.disable_circuit_breaker,
210
211
212
213
214
215
216
            health_check: config::HealthCheckConfig {
                failure_threshold: self.health_failure_threshold,
                success_threshold: self.health_success_threshold,
                timeout_secs: self.health_check_timeout_secs,
                check_interval_secs: self.health_check_interval_secs,
                endpoint: self.health_check_endpoint.clone(),
            },
217
            enable_igw: self.enable_igw,
218
219
            model_path: self.model_path.clone(),
            tokenizer_path: self.tokenizer_path.clone(),
220
            chat_template: self.chat_template.clone(),
221
            history_backend: config::HistoryBackend::Memory,
222
            oracle: None,
223
224
            reasoning_parser: self.reasoning_parser.clone(),
            tool_call_parser: self.tool_call_parser.clone(),
225
226
227
228
        })
    }
}

229
230
231
#[pymethods]
impl Router {
    #[new]
232
233
234
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
235
        host = String::from("0.0.0.0"),
236
        port = 3001,
237
238
239
240
241
242
243
        worker_startup_timeout_secs = 600,
        worker_startup_check_interval = 30,
        cache_threshold = 0.3,
        balance_abs_threshold = 64,
        balance_rel_threshold = 1.5,
        eviction_interval_secs = 120,
        max_tree_size = 2usize.pow(26),
244
        max_payload_size = 512 * 1024 * 1024,
245
246
        dp_aware = false,
        api_key = None,
247
        log_dir = None,
248
        log_level = None,
249
250
251
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
252
        service_discovery_namespace = None,
253
254
255
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
256
        prometheus_port = None,
257
        prometheus_host = None,
258
259
260
        request_timeout_secs = 1800,
        request_id_headers = None,
        pd_disaggregation = false,
261
        prefill_urls = None,
262
263
        decode_urls = None,
        prefill_policy = None,
264
        decode_policy = None,
265
        max_concurrent_requests = -1,
266
        cors_allowed_origins = vec![],
267
268
269
270
271
        retry_max_retries = 5,
        retry_initial_backoff_ms = 50,
        retry_max_backoff_ms = 30_000,
        retry_backoff_multiplier = 1.5,
        retry_jitter_factor = 0.2,
272
        disable_retries = false,
273
274
275
276
        cb_failure_threshold = 10,
        cb_success_threshold = 3,
        cb_timeout_duration_secs = 60,
        cb_window_duration_secs = 120,
277
        disable_circuit_breaker = false,
278
279
280
281
282
        health_failure_threshold = 3,
        health_success_threshold = 2,
        health_check_timeout_secs = 5,
        health_check_interval_secs = 60,
        health_check_endpoint = String::from("/health"),
283
        enable_igw = false,
284
285
286
        queue_size = 100,
        queue_timeout_secs = 60,
        rate_limit_tokens_per_second = None,
287
288
        model_path = None,
        tokenizer_path = None,
289
        chat_template = None,
290
291
        reasoning_parser = None,
        tool_call_parser = None,
292
    ))]
293
    #[allow(clippy::too_many_arguments)]
294
295
296
297
298
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
299
        worker_startup_timeout_secs: u64,
300
        worker_startup_check_interval: u64,
301
        cache_threshold: f32,
302
303
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
304
305
        eviction_interval_secs: u64,
        max_tree_size: usize,
306
        max_payload_size: usize,
307
308
        dp_aware: bool,
        api_key: Option<String>,
309
        log_dir: Option<String>,
310
        log_level: Option<String>,
311
312
313
314
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
315
316
317
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
318
319
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
320
        request_timeout_secs: u64,
321
        request_id_headers: Option<Vec<String>>,
322
        pd_disaggregation: bool,
323
324
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
325
326
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
327
        max_concurrent_requests: i32,
328
        cors_allowed_origins: Vec<String>,
329
330
331
332
333
334
335
336
337
338
339
        retry_max_retries: u32,
        retry_initial_backoff_ms: u64,
        retry_max_backoff_ms: u64,
        retry_backoff_multiplier: f32,
        retry_jitter_factor: f32,
        disable_retries: bool,
        cb_failure_threshold: u32,
        cb_success_threshold: u32,
        cb_timeout_duration_secs: u64,
        cb_window_duration_secs: u64,
        disable_circuit_breaker: bool,
340
341
342
343
344
        health_failure_threshold: u32,
        health_success_threshold: u32,
        health_check_timeout_secs: u64,
        health_check_interval_secs: u64,
        health_check_endpoint: String,
345
        enable_igw: bool,
346
347
        queue_size: usize,
        queue_timeout_secs: u64,
348
        rate_limit_tokens_per_second: Option<i32>,
349
350
        model_path: Option<String>,
        tokenizer_path: Option<String>,
351
        chat_template: Option<String>,
352
353
        reasoning_parser: Option<String>,
        tool_call_parser: Option<String>,
354
    ) -> PyResult<Self> {
355
356
357
358
359
360
361
362
363
364
365
366
367
368
        let mut all_urls = worker_urls.clone();

        if let Some(ref prefill_urls) = prefill_urls {
            for (url, _) in prefill_urls {
                all_urls.push(url.clone());
            }
        }

        if let Some(ref decode_urls) = decode_urls {
            all_urls.extend(decode_urls.clone());
        }

        let connection_mode = Self::determine_connection_mode(&all_urls);

369
        Ok(Router {
370
371
372
            host,
            port,
            worker_urls,
373
            policy,
374
            worker_startup_timeout_secs,
375
            worker_startup_check_interval,
376
            cache_threshold,
377
378
            balance_abs_threshold,
            balance_rel_threshold,
379
380
            eviction_interval_secs,
            max_tree_size,
381
            max_payload_size,
382
383
            dp_aware,
            api_key,
384
            log_dir,
385
            log_level,
386
387
388
389
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
390
391
392
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
393
394
            prometheus_port,
            prometheus_host,
395
            request_timeout_secs,
396
            request_id_headers,
397
            pd_disaggregation,
398
399
            prefill_urls,
            decode_urls,
400
401
            prefill_policy,
            decode_policy,
402
403
            max_concurrent_requests,
            cors_allowed_origins,
404
405
406
407
408
409
410
411
412
413
414
            retry_max_retries,
            retry_initial_backoff_ms,
            retry_max_backoff_ms,
            retry_backoff_multiplier,
            retry_jitter_factor,
            disable_retries,
            cb_failure_threshold,
            cb_success_threshold,
            cb_timeout_duration_secs,
            cb_window_duration_secs,
            disable_circuit_breaker,
415
416
417
418
419
            health_failure_threshold,
            health_success_threshold,
            health_check_timeout_secs,
            health_check_interval_secs,
            health_check_endpoint,
420
            enable_igw,
421
422
423
            queue_size,
            queue_timeout_secs,
            rate_limit_tokens_per_second,
424
425
426
            connection_mode,
            model_path,
            tokenizer_path,
427
            chat_template,
428
429
            reasoning_parser,
            tool_call_parser,
430
        })
431
432
433
    }

    fn start(&self) -> PyResult<()> {
434
435
436
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
437

438
439
440
441
442
443
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
444

445
446
447
448
449
450
451
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
452
453
454
455
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
456
457
458
459
460
            })
        } else {
            None
        };

461
462
463
464
465
466
467
468
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

469
470
471
472
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        runtime.block_on(async move {
473
474
475
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
476
                router_config,
477
                max_payload_size: self.max_payload_size,
478
                log_dir: self.log_dir.clone(),
479
                log_level: self.log_level.clone(),
480
                service_discovery_config,
481
                prometheus_config,
482
                request_timeout_secs: self.request_timeout_secs,
483
                request_id_headers: self.request_id_headers.clone(),
484
485
            })
            .await
486
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
487
        })
488
489
490
491
    }
}

#[pymodule]
492
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
493
    m.add_class::<PolicyType>()?;
494
495
    m.add_class::<Router>()?;
    Ok(())
496
}