lib.rs 2.27 KB
Newer Older
1
2
use pyo3::prelude::*;
pub mod router;
3
pub mod server;
4
pub mod tree;
5

6
7
8
9
10
#[pyclass(eq)]
#[derive(Clone, PartialEq)]
pub enum PolicyType {
    Random,
    RoundRobin,
11
    CacheAware,
12
13
}

14
15
16
17
18
#[pyclass]
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
19
    policy: PolicyType,
20
21
22
23
    cache_threshold: f32,
    cache_routing_prob: f32,
    eviction_interval_secs: u64,
    max_tree_size: usize,
24
25
26
27
28
}

#[pymethods]
impl Router {
    #[new]
29
30
31
32
33
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
34
35
36
37
        cache_threshold = 0.50,
        cache_routing_prob = 1.0,
        eviction_interval_secs = 60,
        max_tree_size = 2usize.pow(24)
38
39
40
41
42
43
    ))]
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
44
45
46
47
        cache_threshold: f32,
        cache_routing_prob: f32,
        eviction_interval_secs: u64,
        max_tree_size: usize,
48
49
    ) -> PyResult<Self> {
        Ok(Router {
50
51
52
            host,
            port,
            worker_urls,
53
            policy,
54
            cache_threshold,
55
56
57
            cache_routing_prob,
            eviction_interval_secs,
            max_tree_size,
58
        })
59
60
61
62
63
64
    }

    fn start(&self) -> PyResult<()> {
        let host = self.host.clone();
        let port = self.port;
        let worker_urls = self.worker_urls.clone();
65
66
67
68

        let policy_config = match &self.policy {
            PolicyType::Random => router::PolicyConfig::RandomConfig,
            PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig,
69
70
71
72
73
            PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
                cache_threshold: self.cache_threshold,
                cache_routing_prob: self.cache_routing_prob,
                eviction_interval_secs: self.eviction_interval_secs,
                max_tree_size: self.max_tree_size,
74
75
            },
        };
76
77

        actix_web::rt::System::new().block_on(async move {
78
            server::startup(host, port, worker_urls, policy_config)
79
80
                .await
                .unwrap();
81
82
83
84
85
86
87
        });

        Ok(())
    }
}

#[pymodule]
88
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
89
    m.add_class::<PolicyType>()?;
90
91
    m.add_class::<Router>()?;
    Ok(())
92
}