"profiler/include/profile_layernorm_impl.hpp" did not exist on "e08d68d25d4406864c7f4eb8c389b4247da79713"
lib.rs 5.03 KB
Newer Older
1
use pyo3::prelude::*;
2
pub mod logging;
3
use std::collections::HashMap;
4
pub mod router;
5
pub mod server;
6
pub mod service_discovery;
7
pub mod tree;
8

9
#[pyclass(eq)]
10
#[derive(Clone, PartialEq, Debug)]
11
12
13
pub enum PolicyType {
    Random,
    RoundRobin,
14
    CacheAware,
15
16
}

17
#[pyclass]
18
#[derive(Debug, Clone, PartialEq)]
19
20
21
22
struct Router {
    host: String,
    port: u16,
    worker_urls: Vec<String>,
23
    policy: PolicyType,
24
    worker_startup_timeout_secs: u64,
25
    worker_startup_check_interval: u64,
26
    cache_threshold: f32,
27
28
    balance_abs_threshold: usize,
    balance_rel_threshold: f32,
29
30
    eviction_interval_secs: u64,
    max_tree_size: usize,
31
    max_payload_size: usize,
32
    verbose: bool,
33
    log_dir: Option<String>,
34
35
36
37
    service_discovery: bool,
    selector: HashMap<String, String>,
    service_discovery_port: u16,
    service_discovery_namespace: Option<String>,
38
39
40
41
42
}

#[pymethods]
impl Router {
    #[new]
43
44
45
46
47
    #[pyo3(signature = (
        worker_urls,
        policy = PolicyType::RoundRobin,
        host = String::from("127.0.0.1"),
        port = 3001,
48
        worker_startup_timeout_secs = 300,
49
        worker_startup_check_interval = 10,
50
        cache_threshold = 0.50,
51
52
        balance_abs_threshold = 32,
        balance_rel_threshold = 1.0001,
53
        eviction_interval_secs = 60,
54
        max_tree_size = 2usize.pow(24),
55
        max_payload_size = 4 * 1024 * 1024,
56
57
        verbose = false,
        log_dir = None,
58
59
60
61
        service_discovery = false,
        selector = HashMap::new(),
        service_discovery_port = 80,
        service_discovery_namespace = None
62
63
64
65
66
67
    ))]
    fn new(
        worker_urls: Vec<String>,
        policy: PolicyType,
        host: String,
        port: u16,
68
        worker_startup_timeout_secs: u64,
69
        worker_startup_check_interval: u64,
70
        cache_threshold: f32,
71
72
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
73
74
        eviction_interval_secs: u64,
        max_tree_size: usize,
75
        max_payload_size: usize,
76
        verbose: bool,
77
        log_dir: Option<String>,
78
79
80
81
        service_discovery: bool,
        selector: HashMap<String, String>,
        service_discovery_port: u16,
        service_discovery_namespace: Option<String>,
82
83
    ) -> PyResult<Self> {
        Ok(Router {
84
85
86
            host,
            port,
            worker_urls,
87
            policy,
88
            worker_startup_timeout_secs,
89
            worker_startup_check_interval,
90
            cache_threshold,
91
92
            balance_abs_threshold,
            balance_rel_threshold,
93
94
            eviction_interval_secs,
            max_tree_size,
95
            max_payload_size,
96
            verbose,
97
            log_dir,
98
99
100
101
            service_discovery,
            selector,
            service_discovery_port,
            service_discovery_namespace,
102
        })
103
104
105
    }

    fn start(&self) -> PyResult<()> {
106
        let policy_config = match &self.policy {
107
108
            PolicyType::Random => router::PolicyConfig::RandomConfig {
                timeout_secs: self.worker_startup_timeout_secs,
109
                interval_secs: self.worker_startup_check_interval,
110
111
112
            },
            PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
                timeout_secs: self.worker_startup_timeout_secs,
113
                interval_secs: self.worker_startup_check_interval,
114
            },
115
            PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
116
                timeout_secs: self.worker_startup_timeout_secs,
117
                interval_secs: self.worker_startup_check_interval,
118
                cache_threshold: self.cache_threshold,
119
120
                balance_abs_threshold: self.balance_abs_threshold,
                balance_rel_threshold: self.balance_rel_threshold,
121
122
                eviction_interval_secs: self.eviction_interval_secs,
                max_tree_size: self.max_tree_size,
123
124
            },
        };
125

126
127
128
129
130
131
132
133
134
135
136
137
138
        // Create service discovery config if enabled
        let service_discovery_config = if self.service_discovery {
            Some(service_discovery::ServiceDiscoveryConfig {
                enabled: true,
                selector: self.selector.clone(),
                check_interval: std::time::Duration::from_secs(60),
                port: self.service_discovery_port,
                namespace: self.service_discovery_namespace.clone(),
            })
        } else {
            None
        };

139
        actix_web::rt::System::new().block_on(async move {
140
141
142
143
144
145
            server::startup(server::ServerConfig {
                host: self.host.clone(),
                port: self.port,
                worker_urls: self.worker_urls.clone(),
                policy_config,
                verbose: self.verbose,
146
                max_payload_size: self.max_payload_size,
147
                log_dir: self.log_dir.clone(),
148
                service_discovery_config,
149
150
            })
            .await
151
152
153
            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
            Ok(())
        })
154
155
156
157
    }
}

#[pymodule]
158
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
159
    m.add_class::<PolicyType>()?;
160
161
    m.add_class::<Router>()?;
    Ok(())
162
}