router.rs 13.1 KB
Newer Older
1
use crate::tree::Tree;
2
3
4
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse};
use bytes::Bytes;
5
use futures_util::{Stream, StreamExt, TryStreamExt};
6
use std::collections::HashMap;
7
use std::fmt::Debug;
8
9
use std::hash::Hash;
use std::pin::Pin;
10
11
use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, Mutex};
12
13
use std::thread;
use std::time::Duration;
14
15

#[derive(Debug)]
16
17
18
pub enum Router {
    RoundRobin {
        worker_urls: Vec<String>,
19
        current_index: AtomicUsize,
20
21
22
23
    },
    Random {
        worker_urls: Vec<String>,
    },
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    CacheAware {
        /*
        Cache-Aware Load Balancing Router

        This router combines two strategies to optimize both cache utilization and request distribution:

        1. Cache-Aware Routing (Approximate Tree)
        2. Load Balancing (Shortest Queue)

        For each incoming request, the router chooses between these strategies:
        - With probability P: Uses cache-aware routing
        - With probability (1-P): Uses load balancing
        where P is configured via `cache_routing_prob`

        Strategy Details:

        1. Cache-Aware Routing (Approximate Tree)
        -------------------------------------------
        This strategy maintains an approximate radix tree for each worker based on request history,
        eliminating the need for direct cache state queries. The tree stores raw text characters
        instead of token IDs to avoid tokenization overhead.

        Process:
        a. For each request, find the worker with the highest prefix match
        b. If match rate > cache_threshold:
        Route to the worker with highest match (likely has relevant data cached)
        c. If match rate ≤ cache_threshold:
        Route to the worker with smallest tree size (most available cache capacity)
        d. Background maintenance:
        Periodically evict least recently used leaf nodes to prevent memory overflow

        2. Load Balancing (Shortest Queue)
        -------------------------------------------
        This strategy tracks pending request counts per worker and routes new requests
        to the least busy worker for optimal load distribution.

        Configuration Parameters:
        ------------------------
        1. cache_routing_prob: (float, 0.0 to 1.0)
        - 0.0: Exclusively use load balancing
        - 1.0: Exclusively use cache-aware routing
        - Between 0-1: Probability of using cache-aware routing vs load balancing

        2. cache_threshold: (float, 0.0 to 1.0)
        Minimum prefix match ratio to use highest-match routing.
        Below this threshold, routes to worker with most available cache space.

        3. eviction_interval_secs: (integer)
        Interval between LRU eviction cycles for the approximate trees.

        4. max_tree_size: (integer)
        Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
        during the next eviction cycle.
        */
78
        worker_urls: Vec<String>,
79
80
81
        tree: Arc<Mutex<Tree>>,
        running_queue: Arc<Mutex<HashMap<String, usize>>>,
        processed_queue: Arc<Mutex<HashMap<String, usize>>>,
82
        cache_threshold: f32,
83
84
        cache_routing_prob: f32,
        _eviction_thread: Option<thread::JoinHandle<()>>, // Store thread handle
85
86
87
    },
}

88
#[derive(Debug)]
89
90
91
pub enum PolicyConfig {
    RandomConfig,
    RoundRobinConfig,
92
    CacheAwareConfig {
93
        cache_threshold: f32,
94
95
96
        cache_routing_prob: f32,
        eviction_interval_secs: u64,
        max_tree_size: usize,
97
98
99
    },
}

100
fn get_text_from_request(body: &Bytes) -> String {
101
102
103
104
    // 1. convert body to json
    let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();
    // 2. get the text field
    let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
105
    return text.to_string();
106
107
}

108
impl Router {
109
110
111
112
    pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Self {
        match policy_config {
            PolicyConfig::RandomConfig => Router::Random { worker_urls },
            PolicyConfig::RoundRobinConfig => Router::RoundRobin {
113
114
115
                worker_urls,
                current_index: std::sync::atomic::AtomicUsize::new(0),
            },
116
            PolicyConfig::CacheAwareConfig {
117
                cache_threshold,
118
119
120
                cache_routing_prob,
                eviction_interval_secs,
                max_tree_size,
121
            } => {
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
                let mut running_queue = HashMap::new();
                for url in &worker_urls {
                    running_queue.insert(url.clone(), 0);
                }

                let mut processed_queue = HashMap::new();
                for url in &worker_urls {
                    processed_queue.insert(url.clone(), 0);
                }

                let tree = Arc::new(Mutex::new(Tree::new()));
                let running_queue = Arc::new(Mutex::new(running_queue));
                let processed_queue = Arc::new(Mutex::new(processed_queue));

                // Create background eviction thread
                let tree_clone = Arc::clone(&tree);
                let processed_queue_clone = Arc::clone(&processed_queue);
                let eviction_thread = thread::spawn(move || {
                    loop {
                        // Sleep for the specified interval
                        thread::sleep(Duration::from_secs(eviction_interval_secs));

                        let locked_tree_clone = tree_clone.lock().unwrap();
                        // Run eviction
                        locked_tree_clone.evict_tenant_data(max_tree_size);

                        // Print the process queue
                        let locked_processed_queue = processed_queue_clone.lock().unwrap();
                        println!("Processed Queue: {:?}", locked_processed_queue);
                    }
                });
153
154

                for url in &worker_urls {
155
                    tree.lock().unwrap().insert(&"".to_string(), url);
156
157
                }

158
                Router::CacheAware {
159
                    worker_urls,
160
161
162
                    tree,
                    running_queue,
                    processed_queue,
163
                    cache_threshold,
164
165
                    cache_routing_prob,
                    _eviction_thread: Some(eviction_thread),
166
167
                }
            }
168
169
170
        }
    }

171
172
    pub fn get_first(&self) -> Option<String> {
        match self {
173
174
            Router::RoundRobin { worker_urls, .. }
            | Router::Random { worker_urls }
175
            | Router::CacheAware { worker_urls, .. } => {
176
177
178
179
180
181
                if worker_urls.is_empty() {
                    None
                } else {
                    Some(worker_urls[0].clone())
                }
            }
182
183
184
        }
    }

185
186
187
188
189
190
    pub async fn dispatch(
        &self,
        client: &reqwest::Client,
        req: HttpRequest,
        body: Bytes,
    ) -> HttpResponse {
191
        let text = get_text_from_request(&body);
192

193
194
195
196
197
        let worker_url = match self {
            Router::RoundRobin {
                worker_urls,
                current_index,
            } => {
198
                let idx = current_index
199
200
201
202
203
                    .fetch_update(
                        std::sync::atomic::Ordering::SeqCst,
                        std::sync::atomic::Ordering::SeqCst,
                        |x| Some((x + 1) % worker_urls.len()),
                    )
204
                    .unwrap();
205

206
                worker_urls[idx].clone()
207
            }
208

209
            Router::Random { worker_urls } => {
210
211
212
                worker_urls[rand::random::<usize>() % worker_urls.len()].clone()
            }

213
            Router::CacheAware {
214
                worker_urls,
215
216
217
                tree,
                running_queue,
                processed_queue,
218
                cache_threshold,
219
                cache_routing_prob,
220
221
                ..
            } => {
222
                // even though the tree is thread-safe, we still put a lock to ensure the whole op (tree read + queue read + tree write + queue write) is atomic to handle some edge cases (e.g. multiple requests with long prefix entering at the same time)
223

224
225
                let mut tree = tree.lock().unwrap();
                let mut running_queue = running_queue.lock().unwrap();
226

227
228
                // Generate a random float between 0 and 1 for probability check
                let sampled_p: f32 = rand::random();
229

230
231
232
233
234
                let selected_url = if sampled_p < *cache_routing_prob {
                    // Cache-aware routing logic
                    let (matched_text, matched_worker) = tree.prefix_match(&text);
                    let matched_rate =
                        matched_text.chars().count() as f32 / text.chars().count() as f32;
235

236
237
238
239
240
241
242
243
                    if matched_rate > *cache_threshold {
                        matched_worker.to_string()
                    } else {
                        let m_map: HashMap<String, usize> = tree
                            .tenant_char_count
                            .iter()
                            .map(|entry| (entry.key().clone(), *entry.value()))
                            .collect();
244

245
246
247
                        println!("map: {:?}, mmap: {:?}", tree.get_tenant_char_count(), m_map);

                        tree.get_smallest_tenant()
248
                    }
249
250
251
252
253
254
255
256
                } else {
                    // Shortest queue routing logic
                    running_queue
                        .iter()
                        .min_by_key(|(_url, &count)| count)
                        .map(|(url, _)| url.clone())
                        .unwrap_or_else(|| worker_urls[0].clone())
                };
257

258
259
260
261
262
263
264
265
266
267
268
269
270
                // Update running queue
                let count = running_queue.get_mut(&selected_url).unwrap();
                *count += 1;

                // Update processed queue
                let mut locked_processed_queue = processed_queue.lock().unwrap();
                let count = locked_processed_queue.get_mut(&selected_url).unwrap();
                *count += 1;

                // Update tree with the new request
                tree.insert(&text, &selected_url);

                selected_url
271
272
            }
        };
273

274
275
276
        let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
            .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
            .unwrap_or(false);
277

278
        let res = match client
279
            .post(format!("{}/generate", worker_url.clone()))
280
281
282
283
284
285
286
287
288
289
290
291
292
293
            .header(
                "Content-Type",
                req.headers()
                    .get("Content-Type")
                    .and_then(|h| h.to_str().ok())
                    .unwrap_or("application/json"),
            )
            .body(body.to_vec())
            .send()
            .await
        {
            Ok(res) => res,
            Err(_) => return HttpResponse::InternalServerError().finish(),
        };
294

295
296
        let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
            .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
297

298
        if !is_stream {
299
300
            // For non-streaming requests, get response first
            let response = match res.bytes().await {
301
302
                Ok(body) => HttpResponse::build(status).body(body.to_vec()),
                Err(_) => HttpResponse::InternalServerError().finish(),
303
304
305
306
307
308
309
310
311
            };

            // Then decrement running queue counter if using CacheAware
            if let Router::CacheAware { running_queue, .. } = self {
                if let Ok(mut queue) = running_queue.lock() {
                    if let Some(count) = queue.get_mut(&worker_url) {
                        *count = count.saturating_sub(1);
                    }
                }
312
            }
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340

            response
        } else if let Router::CacheAware { running_queue, .. } = self {
            let running_queue = Arc::clone(running_queue);
            let worker_url = worker_url.clone();

            HttpResponse::build(status)
                .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
                .streaming(
                    res.bytes_stream()
                        .map_err(|_| {
                            actix_web::error::ErrorInternalServerError("Failed to read stream")
                        })
                        .inspect(move |bytes| {
                            let bytes = bytes.as_ref().unwrap();
                            if bytes
                                .as_ref()
                                .windows(12)
                                .any(|window| window == b"data: [DONE]")
                            {
                                let mut locked_queue = running_queue.lock().unwrap();
                                let count = locked_queue.get_mut(&worker_url).unwrap();
                                *count = count.saturating_sub(1);
                                // print
                                // println!("streaming is done!!")
                            }
                        }),
                )
341
342
343
344
        } else {
            HttpResponse::build(status)
                .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
                .streaming(res.bytes_stream().map_err(|_| {
345
                    actix_web::error::ErrorInternalServerError("Failed to read stream")
346
                }))
347
348
        }
    }
349
}