lib.rs 2.5 KB
Newer Older
1
// Python Binding
2
3
use pyo3::prelude::*;
pub mod router;
4
5
mod server;
pub mod tree;
6

7
8
9
10
11
12
13
14
#[pyclass(eq)]
#[derive(Clone, PartialEq)]
pub enum PolicyType {
    Random,
    RoundRobin,
    ApproxTree,
}

15
16
17
18
19
#[pyclass]
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
20
21
22
    policy: PolicyType,
    tokenizer_path: Option<String>,
    cache_threshold: Option<f32>,
23
24
25
26
27
}

#[pymethods]
impl Router {
    #[new]
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
        tokenizer_path = None,
        cache_threshold = Some(0.50)
    ))]
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
        tokenizer_path: Option<String>,
        cache_threshold: Option<f32>,
    ) -> PyResult<Self> {
        // Validate required parameters for approx_tree policy
        if matches!(policy, PolicyType::ApproxTree) {
            if tokenizer_path.is_none() {
                return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
                    "tokenizer_path is required for approx_tree policy",
                ));
            }
        }

        Ok(Router {
54
55
56
            host,
            port,
            worker_urls,
57
            policy,
58
59
60
            tokenizer_path,
            cache_threshold,
        })
61
62
63
64
65
66
    }

    fn start(&self) -> PyResult<()> {
        let host = self.host.clone();
        let port = self.port;
        let worker_urls = self.worker_urls.clone();
67
68
69
70
71
72
73
74
75
76
77
78
79
80

        let policy_config = match &self.policy {
            PolicyType::Random => router::PolicyConfig::RandomConfig,
            PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig,
            PolicyType::ApproxTree => router::PolicyConfig::ApproxTreeConfig {
                tokenizer_path: self
                    .tokenizer_path
                    .clone()
                    .expect("tokenizer_path is required for approx_tree policy"),
                cache_threshold: self
                    .cache_threshold
                    .expect("cache_threshold is required for approx_tree policy"),
            },
        };
81
82

        actix_web::rt::System::new().block_on(async move {
83
            server::startup(host, port, worker_urls, policy_config)
84
85
                .await
                .unwrap();
86
87
88
89
90
91
92
93
        });

        Ok(())
    }
}

#[pymodule]
fn sglang_router(m: &Bound<'_, PyModule>) -> PyResult<()> {
94
    m.add_class::<PolicyType>()?;
95
96
    m.add_class::<Router>()?;
    Ok(())
97
}