router.rs 8.32 KB
Newer Older
1
use crate::tree::RadixTree;
2
3
4
5
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse};
use bytes::Bytes;
use futures_util::TryStreamExt;
6
use std::collections::HashMap;
7
use std::fmt::Debug;
8
9
10
use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, Mutex};
use tokenizers::tokenizer::Tokenizer;
11
12

#[derive(Debug)]
13
14
15
pub enum Router {
    RoundRobin {
        worker_urls: Vec<String>,
16
        current_index: AtomicUsize,
17
18
19
20
    },
    Random {
        worker_urls: Vec<String>,
    },
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    ApproxTree {
        worker_urls: Vec<String>,
        // TODO: don't lock the whole tree
        url_to_tree: Arc<Mutex<HashMap<String, RadixTree>>>,
        tokenizer: Tokenizer,
        url_to_count: Arc<Mutex<HashMap<String, usize>>>,
        cache_threshold: f32,
    },
}

pub enum PolicyConfig {
    RandomConfig,
    RoundRobinConfig,
    ApproxTreeConfig {
        tokenizer_path: String,
        cache_threshold: f32,
    },
}

fn get_token_ids_from_request(body: &Bytes, tokenizer: &Tokenizer) -> Vec<u32> {
    // 1. convert body to json
    let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();
    // 2. get the text field
    let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
    // 3. tokenize the text field
    let tokens = tokenizer.encode(text, false).unwrap();

    tokens.get_ids().to_vec()
49
50
}

51
impl Router {
52
53
54
55
    pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Self {
        match policy_config {
            PolicyConfig::RandomConfig => Router::Random { worker_urls },
            PolicyConfig::RoundRobinConfig => Router::RoundRobin {
56
57
58
                worker_urls,
                current_index: std::sync::atomic::AtomicUsize::new(0),
            },
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
            PolicyConfig::ApproxTreeConfig {
                tokenizer_path,
                cache_threshold,
            } => {
                let mut url_to_tree = HashMap::new();
                let mut url_to_count = HashMap::new();

                for url in &worker_urls {
                    url_to_tree.insert(url.clone(), RadixTree::new());
                    url_to_count.insert(url.clone(), 0);
                }

                Router::ApproxTree {
                    worker_urls,
                    url_to_tree: Arc::new(Mutex::new(url_to_tree)),
                    // TODO: rust ::from_pretrained cannot load from local file, so use ::from_file to load local file
                    tokenizer: Tokenizer::from_file(tokenizer_path).unwrap(),
                    url_to_count: Arc::new(Mutex::new(url_to_count)),
                    cache_threshold,
                }
            }
80
81
82
        }
    }

83
84
    pub fn get_first(&self) -> Option<String> {
        match self {
85
86
87
            Router::RoundRobin { worker_urls, .. }
            | Router::Random { worker_urls }
            | Router::ApproxTree { worker_urls, .. } => {
88
89
90
91
92
93
                if worker_urls.is_empty() {
                    None
                } else {
                    Some(worker_urls[0].clone())
                }
            }
94
95
96
        }
    }

97
98
99
100
101
102
    pub async fn dispatch(
        &self,
        client: &reqwest::Client,
        req: HttpRequest,
        body: Bytes,
    ) -> HttpResponse {
103
104
105
106
107
        let mut input_ids: Vec<u32> = Vec::new();
        if let Router::ApproxTree { tokenizer, .. } = self {
            input_ids = get_token_ids_from_request(&body, tokenizer);
        }

108
109
110
111
112
        let worker_url = match self {
            Router::RoundRobin {
                worker_urls,
                current_index,
            } => {
113
                let idx = current_index
114
115
116
117
118
                    .fetch_update(
                        std::sync::atomic::Ordering::SeqCst,
                        std::sync::atomic::Ordering::SeqCst,
                        |x| Some((x + 1) % worker_urls.len()),
                    )
119
                    .unwrap();
120

121
                worker_urls[idx].clone()
122
            }
123

124
            Router::Random { worker_urls } => {
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
                worker_urls[rand::random::<usize>() % worker_urls.len()].clone()
            }

            Router::ApproxTree {
                worker_urls,
                url_to_tree,
                url_to_count,
                cache_threshold,
                ..
            } => {
                // TODO: pipeline the locks. Release one earlier.

                let mut max_matched_rate = 0.0;
                let mut max_matched_idx = 0;

                let locked_url_to_tree = url_to_tree.lock().unwrap();

                // 1. Find the highest matched worker
                for (i, url) in worker_urls.iter().enumerate() {
                    let tree = locked_url_to_tree.get(url).unwrap();
                    let matched = tree.prefix_match(&input_ids[..]).len();
                    let matched_rate = matched as f32 / input_ids.len() as f32;

                    if matched_rate > max_matched_rate {
                        max_matched_rate = matched_rate;
                        max_matched_idx = i;
                    }
                }

                // 2. If the rate is higher than the threshold, select the worker. If not, select the worker with the shortest queue
                if max_matched_rate > *cache_threshold {
                    worker_urls[max_matched_idx].clone()
                } else {
                    // pick the shortest queue from url_to_count
                    let locked_url_to_count = url_to_count.lock().unwrap();

                    let mut min_count = std::usize::MAX;
                    let mut min_count_id = 0;

                    for (i, url) in worker_urls.iter().enumerate() {
                        let count = locked_url_to_count.get(url).unwrap();
                        if *count < min_count {
                            min_count = *count;
                            min_count_id = i;
                        }
                    }

                    worker_urls[min_count_id].clone()
                }
174
175
            }
        };
176

177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        if let Router::ApproxTree {
            url_to_tree,
            url_to_count,
            ..
        } = self
        {
            // Insert input_ids to the tree
            let mut locked_url_to_tree = url_to_tree.lock().unwrap();
            let selected_tree = locked_url_to_tree.get_mut(&worker_url).unwrap();
            selected_tree.insert(&input_ids[..]);

            let mut locked_url_to_count = url_to_count.lock().unwrap();
            let count = locked_url_to_count.get_mut(&worker_url).unwrap();
            *count += 1;
        }

193
194
195
196
        // Check if client requested streaming
        let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
            .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
            .unwrap_or(false);
197

198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        let res = match client
            .post(format!("{}/generate", worker_url))
            .header(
                "Content-Type",
                req.headers()
                    .get("Content-Type")
                    .and_then(|h| h.to_str().ok())
                    .unwrap_or("application/json"),
            )
            .body(body.to_vec())
            .send()
            .await
        {
            Ok(res) => res,
            Err(_) => return HttpResponse::InternalServerError().finish(),
        };
214

215
216
        let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
            .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
217

218
        if !is_stream {
219
220
221
222
223
224
225
            // TODO: do the correction on the tree based on the cached input_ids
            if let Router::ApproxTree { url_to_count, .. } = self {
                let mut locked_url_to_count = url_to_count.lock().unwrap();
                let count = locked_url_to_count.get_mut(&worker_url).unwrap();
                *count -= 1;
            }

226
227
228
229
230
            match res.bytes().await {
                Ok(body) => HttpResponse::build(status).body(body.to_vec()),
                Err(_) => HttpResponse::InternalServerError().finish(),
            }
        } else {
231
            // TODO: do the correction on the tree based on the cached input_ids. The streaming might be tricker to handle
232
233
234
235
236
            HttpResponse::build(status)
                .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
                .streaming(res.bytes_stream().map_err(|_| {
                    actix_web::error::ErrorInternalServerError("Failed to read string")
                }))
237
238
        }
    }
239
}