lib.rs 3.64 KB
Newer Older
1
2
use pyo3::prelude::*;
pub mod router;
3
pub mod server;
4
pub mod tree;
5

6
7
8
9
10
#[pyclass(eq)]
#[derive(Clone, PartialEq)]
pub enum PolicyType {
    Random,
    RoundRobin,
11
    CacheAware,
12
13
}

14
15
16
17
18
#[pyclass]
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
19
    policy: PolicyType,
20
    worker_startup_timeout_secs: u64,
21
    worker_startup_check_interval: u64,
22
    cache_threshold: f32,
23
24
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
25
26
    eviction_interval_secs: u64,
    max_tree_size: usize,
27
    max_payload_size: usize,
28
    verbose: bool,
29
30
31
32
33
}

#[pymethods]
impl Router {
    #[new]
34
35
36
37
38
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
39
        worker_startup_timeout_secs = 300,
40
        worker_startup_check_interval = 10,
41
        cache_threshold = 0.50,
42
43
        balance_abs_threshold = 32,
        balance_rel_threshold = 1.0001,
44
        eviction_interval_secs = 60,
45
        max_tree_size = 2usize.pow(24),
46
        max_payload_size = 4 * 1024 * 1024,
47
        verbose = false
48
49
50
51
52
53
    ))]
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
54
        worker_startup_timeout_secs: u64,
55
        worker_startup_check_interval: u64,
56
        cache_threshold: f32,
57
58
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
59
60
        eviction_interval_secs: u64,
        max_tree_size: usize,
61
        max_payload_size: usize,
62
        verbose: bool,
63
64
    ) -> PyResult<Self> {
        Ok(Router {
65
66
67
            host,
            port,
            worker_urls,
68
            policy,
69
            worker_startup_timeout_secs,
70
            worker_startup_check_interval,
71
            cache_threshold,
72
73
            balance_abs_threshold,
            balance_rel_threshold,
74
75
            eviction_interval_secs,
            max_tree_size,
76
            max_payload_size,
77
            verbose,
78
        })
79
80
81
    }

    fn start(&self) -> PyResult<()> {
82
        let policy_config = match &self.policy {
83
84
            PolicyType::Random => router::PolicyConfig::RandomConfig {
                timeout_secs: self.worker_startup_timeout_secs,
85
                interval_secs: self.worker_startup_check_interval,
86
87
88
            },
            PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
                timeout_secs: self.worker_startup_timeout_secs,
89
                interval_secs: self.worker_startup_check_interval,
90
            },
91
            PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
92
                timeout_secs: self.worker_startup_timeout_secs,
93
                interval_secs: self.worker_startup_check_interval,
94
                cache_threshold: self.cache_threshold,
95
96
                balance_abs_threshold: self.balance_abs_threshold,
                balance_rel_threshold: self.balance_rel_threshold,
97
98
                eviction_interval_secs: self.eviction_interval_secs,
                max_tree_size: self.max_tree_size,
99
100
            },
        };
101
102

        actix_web::rt::System::new().block_on(async move {
103
104
105
106
107
108
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
                worker_urls: self.worker_urls.clone(),
                policy_config,
                verbose: self.verbose,
109
                max_payload_size: self.max_payload_size,
110
111
            })
            .await
112
113
114
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
            Ok(())
        })
115
116
117
118
    }
}

#[pymodule]
119
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
120
    m.add_class::<PolicyType>()?;
121
122
    m.add_class::<Router>()?;
    Ok(())
123
}