router.rs 13.8 KB
Newer Older
1
use crate::tree::Tree;
2
3
4
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse};
use bytes::Bytes;
5
use futures_util::{Stream, StreamExt, TryStreamExt};
6
use std::collections::HashMap;
7
use std::fmt::Debug;
8
9
use std::hash::Hash;
use std::pin::Pin;
10
11
use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, Mutex};
12
13
use std::thread;
use std::time::Duration;
14
15

#[derive(Debug)]
16
17
18
pub enum Router {
    RoundRobin {
        worker_urls: Vec<String>,
19
        current_index: AtomicUsize,
20
21
22
23
    },
    Random {
        worker_urls: Vec<String>,
    },
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    CacheAware {
        /*
        Cache-Aware Load Balancing Router

        This router combines two strategies to optimize both cache utilization and request distribution:

        1. Cache-Aware Routing (Approximate Tree)
        2. Load Balancing (Shortest Queue)

        For each incoming request, the router chooses between these strategies:
        - With probability P: Uses cache-aware routing
        - With probability (1-P): Uses load balancing
        where P is configured via `cache_routing_prob`

        Strategy Details:

        1. Cache-Aware Routing (Approximate Tree)
        -------------------------------------------
        This strategy maintains an approximate radix tree for each worker based on request history,
        eliminating the need for direct cache state queries. The tree stores raw text characters
        instead of token IDs to avoid tokenization overhead.

        Process:
        a. For each request, find the worker with the highest prefix match
        b. If match rate > cache_threshold:
        Route to the worker with highest match (likely has relevant data cached)
        c. If match rate ≤ cache_threshold:
        Route to the worker with smallest tree size (most available cache capacity)
        d. Background maintenance:
        Periodically evict least recently used leaf nodes to prevent memory overflow

        2. Load Balancing (Shortest Queue)
        -------------------------------------------
        This strategy tracks pending request counts per worker and routes new requests
        to the least busy worker for optimal load distribution.

        Configuration Parameters:
        ------------------------
        1. cache_routing_prob: (float, 0.0 to 1.0)
        - 0.0: Exclusively use load balancing
        - 1.0: Exclusively use cache-aware routing
        - Between 0-1: Probability of using cache-aware routing vs load balancing

        2. cache_threshold: (float, 0.0 to 1.0)
        Minimum prefix match ratio to use highest-match routing.
        Below this threshold, routes to worker with most available cache space.

        3. eviction_interval_secs: (integer)
        Interval between LRU eviction cycles for the approximate trees.

        4. max_tree_size: (integer)
        Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
        during the next eviction cycle.
        */
78
        worker_urls: Vec<String>,
79
80
81
        tree: Arc<Mutex<Tree>>,
        running_queue: Arc<Mutex<HashMap<String, usize>>>,
        processed_queue: Arc<Mutex<HashMap<String, usize>>>,
82
        cache_threshold: f32,
83
84
        cache_routing_prob: f32,
        _eviction_thread: Option<thread::JoinHandle<()>>, // Store thread handle
85
86
87
    },
}

88
#[derive(Debug)]
89
90
91
pub enum PolicyConfig {
    RandomConfig,
    RoundRobinConfig,
92
    CacheAwareConfig {
93
        cache_threshold: f32,
94
95
96
        cache_routing_prob: f32,
        eviction_interval_secs: u64,
        max_tree_size: usize,
97
98
99
    },
}

100
101
fn get_text_from_request(body: &Bytes, route: &str) -> String {
    // convert body to json
102
    let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();
103

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    if route == "generate" {
        // get the "text" field
        let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
        return text.to_string();
    } else if route == "v1/chat/completions" {
        // get the messages field as raw text
        if let Some(messages) = json.get("messages") {
            // Convert messages back to a string, preserving all JSON formatting
            return serde_json::to_string(messages).unwrap_or_default();
        }
    } else if route == "v1/completions" {
        let prompt = json.get("prompt").and_then(|t| t.as_str()).unwrap_or("");
        return prompt.to_string();
    }

    return "".to_string();
}
121
impl Router {
122
123
124
125
    pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Self {
        match policy_config {
            PolicyConfig::RandomConfig => Router::Random { worker_urls },
            PolicyConfig::RoundRobinConfig => Router::RoundRobin {
126
127
128
                worker_urls,
                current_index: std::sync::atomic::AtomicUsize::new(0),
            },
129
            PolicyConfig::CacheAwareConfig {
130
                cache_threshold,
131
132
133
                cache_routing_prob,
                eviction_interval_secs,
                max_tree_size,
134
            } => {
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
                let mut running_queue = HashMap::new();
                for url in &worker_urls {
                    running_queue.insert(url.clone(), 0);
                }

                let mut processed_queue = HashMap::new();
                for url in &worker_urls {
                    processed_queue.insert(url.clone(), 0);
                }

                let tree = Arc::new(Mutex::new(Tree::new()));
                let running_queue = Arc::new(Mutex::new(running_queue));
                let processed_queue = Arc::new(Mutex::new(processed_queue));

                // Create background eviction thread
                let tree_clone = Arc::clone(&tree);
                let processed_queue_clone = Arc::clone(&processed_queue);
                let eviction_thread = thread::spawn(move || {
                    loop {
                        // Sleep for the specified interval
                        thread::sleep(Duration::from_secs(eviction_interval_secs));

                        let locked_tree_clone = tree_clone.lock().unwrap();
                        // Run eviction
                        locked_tree_clone.evict_tenant_data(max_tree_size);

                        // Print the process queue
                        let locked_processed_queue = processed_queue_clone.lock().unwrap();
                        println!("Processed Queue: {:?}", locked_processed_queue);
                    }
                });
166
167

                for url in &worker_urls {
168
                    tree.lock().unwrap().insert(&"".to_string(), url);
169
170
                }

171
                Router::CacheAware {
172
                    worker_urls,
173
174
175
                    tree,
                    running_queue,
                    processed_queue,
176
                    cache_threshold,
177
178
                    cache_routing_prob,
                    _eviction_thread: Some(eviction_thread),
179
180
                }
            }
181
182
183
        }
    }

184
185
    pub fn get_first(&self) -> Option<String> {
        match self {
186
187
            Router::RoundRobin { worker_urls, .. }
            | Router::Random { worker_urls }
188
            | Router::CacheAware { worker_urls, .. } => {
189
190
191
192
193
194
                if worker_urls.is_empty() {
                    None
                } else {
                    Some(worker_urls[0].clone())
                }
            }
195
196
197
        }
    }

198
199
200
201
202
    pub async fn dispatch(
        &self,
        client: &reqwest::Client,
        req: HttpRequest,
        body: Bytes,
203
        route: &str,
204
    ) -> HttpResponse {
205
206
207
        let text = get_text_from_request(&body, route);
        // For Debug
        // println!("text: {:?}, route: {:?}", text, route);
208

209
210
211
212
213
        let worker_url = match self {
            Router::RoundRobin {
                worker_urls,
                current_index,
            } => {
214
                let idx = current_index
215
216
217
218
219
                    .fetch_update(
                        std::sync::atomic::Ordering::SeqCst,
                        std::sync::atomic::Ordering::SeqCst,
                        |x| Some((x + 1) % worker_urls.len()),
                    )
220
                    .unwrap();
221

222
                worker_urls[idx].clone()
223
            }
224

225
            Router::Random { worker_urls } => {
226
227
228
                worker_urls[rand::random::<usize>() % worker_urls.len()].clone()
            }

229
            Router::CacheAware {
230
                worker_urls,
231
232
233
                tree,
                running_queue,
                processed_queue,
234
                cache_threshold,
235
                cache_routing_prob,
236
237
                ..
            } => {
238
                // even though the tree is thread-safe, we still put a lock to ensure the whole op (tree read + queue read + tree write + queue write) is atomic to handle some edge cases (e.g. multiple requests with long prefix entering at the same time)
239

240
241
                let mut tree = tree.lock().unwrap();
                let mut running_queue = running_queue.lock().unwrap();
242

243
244
                // Generate a random float between 0 and 1 for probability check
                let sampled_p: f32 = rand::random();
245

246
247
248
249
250
                let selected_url = if sampled_p < *cache_routing_prob {
                    // Cache-aware routing logic
                    let (matched_text, matched_worker) = tree.prefix_match(&text);
                    let matched_rate =
                        matched_text.chars().count() as f32 / text.chars().count() as f32;
251

252
253
254
                    if matched_rate > *cache_threshold {
                        matched_worker.to_string()
                    } else {
255
256
257
258
259
260
                        // For Debug
                        // let m_map: HashMap<String, usize> = tree
                        //     .tenant_char_count
                        //     .iter()
                        //     .map(|entry| (entry.key().clone(), *entry.value()))
                        //     .collect();
261

262
                        // println!("map: {:?}, mmap: {:?}", tree.get_tenant_char_count(), m_map);
263
264

                        tree.get_smallest_tenant()
265
                    }
266
267
268
269
270
271
272
273
                } else {
                    // Shortest queue routing logic
                    running_queue
                        .iter()
                        .min_by_key(|(_url, &count)| count)
                        .map(|(url, _)| url.clone())
                        .unwrap_or_else(|| worker_urls[0].clone())
                };
274

275
276
277
278
279
280
281
282
283
284
285
286
287
                // Update running queue
                let count = running_queue.get_mut(&selected_url).unwrap();
                *count += 1;

                // Update processed queue
                let mut locked_processed_queue = processed_queue.lock().unwrap();
                let count = locked_processed_queue.get_mut(&selected_url).unwrap();
                *count += 1;

                // Update tree with the new request
                tree.insert(&text, &selected_url);

                selected_url
288
289
            }
        };
290

291
292
293
        let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
            .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
            .unwrap_or(false);
294

295
        let res = match client
296
            .post(format!("{}/{}", worker_url.clone(), route))
297
298
299
300
301
302
303
304
305
306
307
308
309
310
            .header(
                "Content-Type",
                req.headers()
                    .get("Content-Type")
                    .and_then(|h| h.to_str().ok())
                    .unwrap_or("application/json"),
            )
            .body(body.to_vec())
            .send()
            .await
        {
            Ok(res) => res,
            Err(_) => return HttpResponse::InternalServerError().finish(),
        };
311

312
313
        let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
            .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
314

315
        if !is_stream {
316
317
            // For non-streaming requests, get response first
            let response = match res.bytes().await {
318
319
                Ok(body) => HttpResponse::build(status).body(body.to_vec()),
                Err(_) => HttpResponse::InternalServerError().finish(),
320
321
322
323
324
325
326
327
328
            };

            // Then decrement running queue counter if using CacheAware
            if let Router::CacheAware { running_queue, .. } = self {
                if let Ok(mut queue) = running_queue.lock() {
                    if let Some(count) = queue.get_mut(&worker_url) {
                        *count = count.saturating_sub(1);
                    }
                }
329
            }
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357

            response
        } else if let Router::CacheAware { running_queue, .. } = self {
            let running_queue = Arc::clone(running_queue);
            let worker_url = worker_url.clone();

            HttpResponse::build(status)
                .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
                .streaming(
                    res.bytes_stream()
                        .map_err(|_| {
                            actix_web::error::ErrorInternalServerError("Failed to read stream")
                        })
                        .inspect(move |bytes| {
                            let bytes = bytes.as_ref().unwrap();
                            if bytes
                                .as_ref()
                                .windows(12)
                                .any(|window| window == b"data: [DONE]")
                            {
                                let mut locked_queue = running_queue.lock().unwrap();
                                let count = locked_queue.get_mut(&worker_url).unwrap();
                                *count = count.saturating_sub(1);
                                // print
                                // println!("streaming is done!!")
                            }
                        }),
                )
358
359
360
361
        } else {
            HttpResponse::build(status)
                .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
                .streaming(res.bytes_stream().map_err(|_| {
362
                    actix_web::error::ErrorInternalServerError("Failed to read stream")
363
                }))
364
365
        }
    }
366
}