lib.rs 3.85 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod logging;
3
pub mod router;
4
pub mod server;
5
pub mod tree;
6

7
#[pyclass(eq)]
8
#[derive(Clone, PartialEq, Debug)]
9
10
11
pub enum PolicyType {
    Random,
    RoundRobin,
12
    CacheAware,
13
14
}

15
#[pyclass]
16
#[derive(Debug, Clone, PartialEq)]
17
18
19
20
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
21
    policy: PolicyType,
22
    worker_startup_timeout_secs: u64,
23
    worker_startup_check_interval: u64,
24
    cache_threshold: f32,
25
26
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
27
28
    eviction_interval_secs: u64,
    max_tree_size: usize,
29
    max_payload_size: usize,
30
    verbose: bool,
31
    log_dir: Option<String>,
32
33
34
35
36
}

#[pymethods]
impl Router {
    #[new]
37
38
39
40
41
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
42
        worker_startup_timeout_secs = 300,
43
        worker_startup_check_interval = 10,
44
        cache_threshold = 0.50,
45
46
        balance_abs_threshold = 32,
        balance_rel_threshold = 1.0001,
47
        eviction_interval_secs = 60,
48
        max_tree_size = 2usize.pow(24),
49
        max_payload_size = 4 * 1024 * 1024,
50
51
        verbose = false,
        log_dir = None,
52
53
54
55
56
57
    ))]
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
58
        worker_startup_timeout_secs: u64,
59
        worker_startup_check_interval: u64,
60
        cache_threshold: f32,
61
62
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
63
64
        eviction_interval_secs: u64,
        max_tree_size: usize,
65
        max_payload_size: usize,
66
        verbose: bool,
67
        log_dir: Option<String>,
68
69
    ) -> PyResult<Self> {
        Ok(Router {
70
71
72
            host,
            port,
            worker_urls,
73
            policy,
74
            worker_startup_timeout_secs,
75
            worker_startup_check_interval,
76
            cache_threshold,
77
78
            balance_abs_threshold,
            balance_rel_threshold,
79
80
            eviction_interval_secs,
            max_tree_size,
81
            max_payload_size,
82
            verbose,
83
            log_dir,
84
        })
85
86
87
    }

    fn start(&self) -> PyResult<()> {
88
        let policy_config = match &self.policy {
89
90
            PolicyType::Random => router::PolicyConfig::RandomConfig {
                timeout_secs: self.worker_startup_timeout_secs,
91
                interval_secs: self.worker_startup_check_interval,
92
93
94
            },
            PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
                timeout_secs: self.worker_startup_timeout_secs,
95
                interval_secs: self.worker_startup_check_interval,
96
            },
97
            PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
98
                timeout_secs: self.worker_startup_timeout_secs,
99
                interval_secs: self.worker_startup_check_interval,
100
                cache_threshold: self.cache_threshold,
101
102
                balance_abs_threshold: self.balance_abs_threshold,
                balance_rel_threshold: self.balance_rel_threshold,
103
104
                eviction_interval_secs: self.eviction_interval_secs,
                max_tree_size: self.max_tree_size,
105
106
            },
        };
107
108

        actix_web::rt::System::new().block_on(async move {
109
110
111
112
113
114
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
                worker_urls: self.worker_urls.clone(),
                policy_config,
                verbose: self.verbose,
115
                max_payload_size: self.max_payload_size,
116
                log_dir: self.log_dir.clone(),
117
118
            })
            .await
119
120
121
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
            Ok(())
        })
122
123
124
125
    }
}

#[pymodule]
126
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
127
    m.add_class::<PolicyType>()?;
128
129
    m.add_class::<Router>()?;
    Ok(())
130
}