lib.rs 16.6 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod config;
3
pub mod logging;
4
use std::collections::HashMap;
5

6
pub mod core;
7
pub mod data_connector;
8
#[cfg(feature = "grpc-client")]
9
pub mod grpc_client;
10
pub mod mcp;
11
pub mod metrics;
12
pub mod middleware;
13
pub mod policies;
14
pub mod protocols;
15
pub mod reasoning_parser;
16
pub mod routers;
17
pub mod server;
18
pub mod service_discovery;
19
pub mod tokenizer;
20
pub mod tool_parser;
21
pub mod tree;
22
use crate::metrics::PrometheusConfig;
23

24
#[pyclass(eq)]
25
#[derive(Clone, PartialEq, Debug)]
26
27
28
pub enum PolicyType {
    Random,
    RoundRobin,
29
    CacheAware,
30
    PowerOfTwo,
31
32
}

33
#[pyclass]
34
#[derive(Debug, Clone, PartialEq)]
35
36
37
38
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
39
    policy: PolicyType,
40
    worker_startup_timeout_secs: u64,
41
    worker_startup_check_interval: u64,
42
    cache_threshold: f32,
43
44
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
45
46
    eviction_interval_secs: u64,
    max_tree_size: usize,
47
    max_payload_size: usize,
48
49
    dp_aware: bool,
    api_key: Option<String>,
50
    log_dir: Option<String>,
51
    log_level: Option<String>,
52
53
54
55
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
56
57
58
    prefill_selector: HashMap<String, String>,
    decode_selector: HashMap<String, String>,
    bootstrap_port_annotation: String,
59
60
    prometheus_port: Option<u16>,
    prometheus_host: Option<String>,
61
    request_timeout_secs: u64,
62
    request_id_headers: Option<Vec<String>>,
63
    pd_disaggregation: bool,
64
65
    prefill_urls: Option<Vec<(String, Option<u16>)>>,
    decode_urls: Option<Vec<String>>,
66
67
    prefill_policy: Option<PolicyType>,
    decode_policy: Option<PolicyType>,
68
69
    max_concurrent_requests: usize,
    cors_allowed_origins: Vec<String>,
70
71
72
73
74
75
76
77
78
79
80
    retry_max_retries: u32,
    retry_initial_backoff_ms: u64,
    retry_max_backoff_ms: u64,
    retry_backoff_multiplier: f32,
    retry_jitter_factor: f32,
    disable_retries: bool,
    cb_failure_threshold: u32,
    cb_success_threshold: u32,
    cb_timeout_duration_secs: u64,
    cb_window_duration_secs: u64,
    disable_circuit_breaker: bool,
81
82
83
84
85
    health_failure_threshold: u32,
    health_success_threshold: u32,
    health_check_timeout_secs: u64,
    health_check_interval_secs: u64,
    health_check_endpoint: String,
86
    enable_igw: bool,
87
88
89
    queue_size: usize,
    queue_timeout_secs: u64,
    rate_limit_tokens_per_second: Option<usize>,
90
91
92
    connection_mode: config::ConnectionMode,
    model_path: Option<String>,
    tokenizer_path: Option<String>,
93
94
}

95
impl Router {
96
97
98
99
100
101
102
103
104
105
    /// Determine connection mode from worker URLs
    fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode {
        for url in worker_urls {
            if url.starts_with("grpc://") || url.starts_with("grpcs://") {
                return config::ConnectionMode::Grpc;
            }
        }
        config::ConnectionMode::Http
    }

106
107
108
109
110
    pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
        use config::{
            DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode,
        };

111
112
113
114
115
116
117
118
119
120
121
122
        let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig {
            match policy {
                PolicyType::Random => ConfigPolicyConfig::Random,
                PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin,
                PolicyType::CacheAware => ConfigPolicyConfig::CacheAware {
                    cache_threshold: self.cache_threshold,
                    balance_abs_threshold: self.balance_abs_threshold,
                    balance_rel_threshold: self.balance_rel_threshold,
                    eviction_interval_secs: self.eviction_interval_secs,
                    max_tree_size: self.max_tree_size,
                },
                PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo {
123
                    load_check_interval_secs: 5,
124
125
126
127
                },
            }
        };

128
129
130
131
132
        let mode = if self.enable_igw {
            RoutingMode::Regular {
                worker_urls: vec![],
            }
        } else if self.pd_disaggregation {
133
134
135
            RoutingMode::PrefillDecode {
                prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
                decode_urls: self.decode_urls.clone().unwrap_or_default(),
136
137
                prefill_policy: self.prefill_policy.as_ref().map(convert_policy),
                decode_policy: self.decode_policy.as_ref().map(convert_policy),
138
139
140
141
142
143
144
            }
        } else {
            RoutingMode::Regular {
                worker_urls: self.worker_urls.clone(),
            }
        };

145
        let policy = convert_policy(&self.policy);
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174

        let discovery = if self.service_discovery {
            Some(DiscoveryConfig {
                enabled: true,
                namespace: self.service_discovery_namespace.clone(),
                port: self.service_discovery_port,
                check_interval_secs: 60,
                selector: self.selector.clone(),
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
            })
        } else {
            None
        };

        let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) {
            (Some(port), Some(host)) => Some(MetricsConfig {
                port,
                host: host.clone(),
            }),
            _ => None,
        };

        Ok(config::RouterConfig {
            mode,
            policy,
            host: self.host.clone(),
            port: self.port,
175
            connection_mode: self.connection_mode.clone(),
176
177
178
179
            max_payload_size: self.max_payload_size,
            request_timeout_secs: self.request_timeout_secs,
            worker_startup_timeout_secs: self.worker_startup_timeout_secs,
            worker_startup_check_interval_secs: self.worker_startup_check_interval,
180
181
            dp_aware: self.dp_aware,
            api_key: self.api_key.clone(),
182
183
184
            discovery,
            metrics,
            log_dir: self.log_dir.clone(),
185
            log_level: self.log_level.clone(),
186
            request_id_headers: self.request_id_headers.clone(),
187
            max_concurrent_requests: self.max_concurrent_requests,
188
189
190
            queue_size: self.queue_size,
            queue_timeout_secs: self.queue_timeout_secs,
            rate_limit_tokens_per_second: self.rate_limit_tokens_per_second,
191
            cors_allowed_origins: self.cors_allowed_origins.clone(),
192
193
194
195
196
197
198
199
200
201
202
203
204
            retry: config::RetryConfig {
                max_retries: self.retry_max_retries,
                initial_backoff_ms: self.retry_initial_backoff_ms,
                max_backoff_ms: self.retry_max_backoff_ms,
                backoff_multiplier: self.retry_backoff_multiplier,
                jitter_factor: self.retry_jitter_factor,
            },
            circuit_breaker: config::CircuitBreakerConfig {
                failure_threshold: self.cb_failure_threshold,
                success_threshold: self.cb_success_threshold,
                timeout_duration_secs: self.cb_timeout_duration_secs,
                window_duration_secs: self.cb_window_duration_secs,
            },
205
206
            disable_retries: self.disable_retries,
            disable_circuit_breaker: self.disable_circuit_breaker,
207
208
209
210
211
212
213
            health_check: config::HealthCheckConfig {
                failure_threshold: self.health_failure_threshold,
                success_threshold: self.health_success_threshold,
                timeout_secs: self.health_check_timeout_secs,
                check_interval_secs: self.health_check_interval_secs,
                endpoint: self.health_check_endpoint.clone(),
            },
214
            enable_igw: self.enable_igw,
215
216
            model_path: self.model_path.clone(),
            tokenizer_path: self.tokenizer_path.clone(),
217
            history_backend: config::HistoryBackend::Memory,
218
            oracle: None,
219
220
221
222
        })
    }
}

223
224
225
#[pymethods]
impl Router {
    #[new]
226
227
228
229
230
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
231
232
233
234
235
236
237
        worker_startup_timeout_secs = 600,
        worker_startup_check_interval = 30,
        cache_threshold = 0.3,
        balance_abs_threshold = 64,
        balance_rel_threshold = 1.5,
        eviction_interval_secs = 120,
        max_tree_size = 2usize.pow(26),
238
        max_payload_size = 512 * 1024 * 1024,
239
240
        dp_aware = false,
        api_key = None,
241
        log_dir = None,
242
        log_level = None,
243
244
245
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
246
        service_discovery_namespace = None,
247
248
249
        prefill_selector = HashMap::new(),
        decode_selector = HashMap::new(),
        bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
250
        prometheus_port = None,
251
        prometheus_host = None,
252
253
254
        request_timeout_secs = 1800,
        request_id_headers = None,
        pd_disaggregation = false,
255
        prefill_urls = None,
256
257
        decode_urls = None,
        prefill_policy = None,
258
        decode_policy = None,
259
        max_concurrent_requests = 256,
260
        cors_allowed_origins = vec![],
261
262
263
264
265
        retry_max_retries = 5,
        retry_initial_backoff_ms = 50,
        retry_max_backoff_ms = 30_000,
        retry_backoff_multiplier = 1.5,
        retry_jitter_factor = 0.2,
266
        disable_retries = false,
267
268
269
270
        cb_failure_threshold = 10,
        cb_success_threshold = 3,
        cb_timeout_duration_secs = 60,
        cb_window_duration_secs = 120,
271
        disable_circuit_breaker = false,
272
273
274
275
276
        health_failure_threshold = 3,
        health_success_threshold = 2,
        health_check_timeout_secs = 5,
        health_check_interval_secs = 60,
        health_check_endpoint = String::from("/health"),
277
        enable_igw = false,
278
279
280
        queue_size = 100,
        queue_timeout_secs = 60,
        rate_limit_tokens_per_second = None,
281
282
        model_path = None,
        tokenizer_path = None,
283
    ))]
284
    #[allow(clippy::too_many_arguments)]
285
286
287
288
289
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
290
        worker_startup_timeout_secs: u64,
291
        worker_startup_check_interval: u64,
292
        cache_threshold: f32,
293
294
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
295
296
        eviction_interval_secs: u64,
        max_tree_size: usize,
297
        max_payload_size: usize,
298
299
        dp_aware: bool,
        api_key: Option<String>,
300
        log_dir: Option<String>,
301
        log_level: Option<String>,
302
303
304
305
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
306
307
308
        prefill_selector: HashMap<String, String>,
        decode_selector: HashMap<String, String>,
        bootstrap_port_annotation: String,
309
310
        prometheus_port: Option<u16>,
        prometheus_host: Option<String>,
311
        request_timeout_secs: u64,
312
        request_id_headers: Option<Vec<String>>,
313
        pd_disaggregation: bool,
314
315
        prefill_urls: Option<Vec<(String, Option<u16>)>>,
        decode_urls: Option<Vec<String>>,
316
317
        prefill_policy: Option<PolicyType>,
        decode_policy: Option<PolicyType>,
318
319
        max_concurrent_requests: usize,
        cors_allowed_origins: Vec<String>,
320
321
322
323
324
325
326
327
328
329
330
        retry_max_retries: u32,
        retry_initial_backoff_ms: u64,
        retry_max_backoff_ms: u64,
        retry_backoff_multiplier: f32,
        retry_jitter_factor: f32,
        disable_retries: bool,
        cb_failure_threshold: u32,
        cb_success_threshold: u32,
        cb_timeout_duration_secs: u64,
        cb_window_duration_secs: u64,
        disable_circuit_breaker: bool,
331
332
333
334
335
        health_failure_threshold: u32,
        health_success_threshold: u32,
        health_check_timeout_secs: u64,
        health_check_interval_secs: u64,
        health_check_endpoint: String,
336
        enable_igw: bool,
337
338
339
        queue_size: usize,
        queue_timeout_secs: u64,
        rate_limit_tokens_per_second: Option<usize>,
340
341
        model_path: Option<String>,
        tokenizer_path: Option<String>,
342
    ) -> PyResult<Self> {
343
344
345
346
347
348
349
350
351
352
353
354
355
356
        let mut all_urls = worker_urls.clone();

        if let Some(ref prefill_urls) = prefill_urls {
            for (url, _) in prefill_urls {
                all_urls.push(url.clone());
            }
        }

        if let Some(ref decode_urls) = decode_urls {
            all_urls.extend(decode_urls.clone());
        }

        let connection_mode = Self::determine_connection_mode(&all_urls);

357
        Ok(Router {
358
359
360
            host,
            port,
            worker_urls,
361
            policy,
362
            worker_startup_timeout_secs,
363
            worker_startup_check_interval,
364
            cache_threshold,
365
366
            balance_abs_threshold,
            balance_rel_threshold,
367
368
            eviction_interval_secs,
            max_tree_size,
369
            max_payload_size,
370
371
            dp_aware,
            api_key,
372
            log_dir,
373
            log_level,
374
375
376
377
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
378
379
380
            prefill_selector,
            decode_selector,
            bootstrap_port_annotation,
381
382
            prometheus_port,
            prometheus_host,
383
            request_timeout_secs,
384
            request_id_headers,
385
            pd_disaggregation,
386
387
            prefill_urls,
            decode_urls,
388
389
            prefill_policy,
            decode_policy,
390
391
            max_concurrent_requests,
            cors_allowed_origins,
392
393
394
395
396
397
398
399
400
401
402
            retry_max_retries,
            retry_initial_backoff_ms,
            retry_max_backoff_ms,
            retry_backoff_multiplier,
            retry_jitter_factor,
            disable_retries,
            cb_failure_threshold,
            cb_success_threshold,
            cb_timeout_duration_secs,
            cb_window_duration_secs,
            disable_circuit_breaker,
403
404
405
406
407
            health_failure_threshold,
            health_success_threshold,
            health_check_timeout_secs,
            health_check_interval_secs,
            health_check_endpoint,
408
            enable_igw,
409
410
411
            queue_size,
            queue_timeout_secs,
            rate_limit_tokens_per_second,
412
413
414
            connection_mode,
            model_path,
            tokenizer_path,
415
        })
416
417
418
    }

    fn start(&self) -> PyResult<()> {
419
420
421
        let router_config = self.to_router_config().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e))
        })?;
422

423
424
425
426
427
428
        router_config.validate().map_err(|e| {
            pyo3::exceptions::PyValueError::new_err(format!(
                "Configuration validation failed: {}",
                e
            ))
        })?;
429

430
431
432
433
434
435
436
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
437
438
439
440
                pd_mode: self.pd_disaggregation,
                prefill_selector: self.prefill_selector.clone(),
                decode_selector: self.decode_selector.clone(),
                bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
441
442
443
444
445
            })
        } else {
            None
        };

446
447
448
449
450
451
452
453
        let prometheus_config = Some(PrometheusConfig {
            port: self.prometheus_port.unwrap_or(29000),
            host: self
                .prometheus_host
                .clone()
                .unwrap_or_else(|| "127.0.0.1".to_string()),
        });

454
455
456
457
        let runtime = tokio::runtime::Runtime::new()
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

        runtime.block_on(async move {
458
459
460
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
461
                router_config,
462
                max_payload_size: self.max_payload_size,
463
                log_dir: self.log_dir.clone(),
464
                log_level: self.log_level.clone(),
465
                service_discovery_config,
466
                prometheus_config,
467
                request_timeout_secs: self.request_timeout_secs,
468
                request_id_headers: self.request_id_headers.clone(),
469
470
            })
            .await
471
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
472
        })
473
474
475
476
    }
}

#[pymodule]
477
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
478
    m.add_class::<PolicyType>()?;
479
480
    m.add_class::<Router>()?;
    Ok(())
481
}