router.rs 29.3 KB
Newer Older
1
use crate::tree::Tree;
2
3
4
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse};
use bytes::Bytes;
Byron Hsu's avatar
Byron Hsu committed
5
use futures_util::{StreamExt, TryStreamExt};
6
use log::{debug, error, info, warn};
7
use std::collections::HashMap;
8
use std::fmt::Debug;
9
use std::sync::atomic::AtomicUsize;
10
use std::sync::{Arc, Mutex, RwLock};
11
12
use std::thread;
use std::time::Duration;
13
use tokio;
14
15

#[derive(Debug)]
16
17
pub enum Router {
    RoundRobin {
18
        worker_urls: Arc<RwLock<Vec<String>>>,
19
        current_index: AtomicUsize,
20
        timeout_secs: u64,
21
        interval_secs: u64,
22
23
    },
    Random {
24
        worker_urls: Arc<RwLock<Vec<String>>>,
25
        timeout_secs: u64,
26
        interval_secs: u64,
27
    },
28
29
    CacheAware {
        /*
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
            Cache-Aware Load Balancing Router

            This router combines two strategies to optimize both cache utilization and request distribution:

            1. Cache-Aware Routing (Approximate Tree)
            2. Load Balancing (Shortest Queue with Balance Thresholds)

            The router dynamically switches between these strategies based on load conditions:
            - Uses load balancing when the system is imbalanced
            - Uses cache-aware routing when the system is balanced

            A system is considered imbalanced if both conditions are met:
            1. (max - min) > abs_threshold
            2. max > rel_threshold * min

            Strategy Details:

            1. Cache-Aware Routing (Approximate Tree)
            -------------------------------------------
            This strategy maintains an approximate radix tree for each worker based on request history,
            eliminating the need for direct cache state queries. The tree stores raw text characters
            instead of token IDs to avoid tokenization overhead.

            Process:
            a. For each request, find the worker with the highest prefix match
            b. If match rate > cache_threshold:
            Route to the worker with highest match (likely has relevant data cached)
            c. If match rate ≤ cache_threshold:
            Route to the worker with smallest tree size (most available cache capacity)
            d. Background maintenance:
            Periodically evict least recently used leaf nodes to prevent memory overflow

            2. Load Balancing (Shortest Queue)
            -------------------------------------------
            This strategy tracks pending request counts per worker and routes new requests
            to the least busy worker when the system is detected to be imbalanced.

            Configuration Parameters:
            ------------------------
            1. cache_threshold: (float, 0.0 to 1.0)
            Minimum prefix match ratio to use highest-match routing.
            Below this threshold, routes to worker with most available cache space.

            2. balance_abs_threshold: (integer)
            Absolute difference threshold for load imbalance detection.
            System is potentially imbalanced if (max_load - min_load) > abs_threshold

            3. balance_rel_threshold: (float)
            Relative ratio threshold for load imbalance detection.
            System is potentially imbalanced if max_load > min_load * rel_threshold
            Used in conjunction with abs_threshold to determine final imbalance state.

            4. eviction_interval_secs: (integer)
            Interval between LRU eviction cycles for the approximate trees.

            5. max_tree_size: (integer)
            Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
            during the next eviction cycle.
88
        */
89
        worker_urls: Arc<RwLock<Vec<String>>>,
90
91
92
        tree: Arc<Mutex<Tree>>,
        running_queue: Arc<Mutex<HashMap<String, usize>>>,
        processed_queue: Arc<Mutex<HashMap<String, usize>>>,
93
        cache_threshold: f32,
94
95
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
96
        timeout_secs: u64,
97
        interval_secs: u64,
98
        _eviction_thread: Option<thread::JoinHandle<()>>,
99
100
101
    },
}

102
#[derive(Debug, Clone)]
103
pub enum PolicyConfig {
104
105
    RandomConfig {
        timeout_secs: u64,
106
        interval_secs: u64,
107
108
109
    },
    RoundRobinConfig {
        timeout_secs: u64,
110
        interval_secs: u64,
111
    },
112
    CacheAwareConfig {
113
        cache_threshold: f32,
114
115
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
116
117
        eviction_interval_secs: u64,
        max_tree_size: usize,
118
        timeout_secs: u64,
119
        interval_secs: u64,
120
121
122
    },
}

123
impl Router {
124
    pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Result<Self, String> {
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        // Get timeout and interval from policy config
        let (timeout_secs, interval_secs) = match &policy_config {
            PolicyConfig::RandomConfig {
                timeout_secs,
                interval_secs,
            } => (*timeout_secs, *interval_secs),
            PolicyConfig::RoundRobinConfig {
                timeout_secs,
                interval_secs,
            } => (*timeout_secs, *interval_secs),
            PolicyConfig::CacheAwareConfig {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
140
141
        };

142
        // Wait until all workers are healthy
143
        Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
144
145
146

        // Create router based on policy...
        Ok(match policy_config {
147
148
149
150
            PolicyConfig::RandomConfig {
                timeout_secs,
                interval_secs,
            } => Router::Random {
151
                worker_urls: Arc::new(RwLock::new(worker_urls)),
152
                timeout_secs,
153
                interval_secs,
154
            },
155
156
157
158
            PolicyConfig::RoundRobinConfig {
                timeout_secs,
                interval_secs,
            } => Router::RoundRobin {
159
                worker_urls: Arc::new(RwLock::new(worker_urls)),
160
                current_index: std::sync::atomic::AtomicUsize::new(0),
161
                timeout_secs,
162
                interval_secs,
163
            },
164
            PolicyConfig::CacheAwareConfig {
165
                cache_threshold,
166
167
                balance_abs_threshold,
                balance_rel_threshold,
168
169
                eviction_interval_secs,
                max_tree_size,
170
                timeout_secs,
171
                interval_secs,
172
            } => {
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
                let mut running_queue = HashMap::new();
                for url in &worker_urls {
                    running_queue.insert(url.clone(), 0);
                }

                let mut processed_queue = HashMap::new();
                for url in &worker_urls {
                    processed_queue.insert(url.clone(), 0);
                }

                let tree = Arc::new(Mutex::new(Tree::new()));
                let running_queue = Arc::new(Mutex::new(running_queue));
                let processed_queue = Arc::new(Mutex::new(processed_queue));

                // Create background eviction thread
                let tree_clone = Arc::clone(&tree);
                let processed_queue_clone = Arc::clone(&processed_queue);
190
                let running_queue_clone = Arc::clone(&running_queue);
191
192
193
194
195
196
197
                let eviction_thread = thread::spawn(move || {
                    loop {
                        // Sleep for the specified interval
                        thread::sleep(Duration::from_secs(eviction_interval_secs));

                        let locked_tree_clone = tree_clone.lock().unwrap();
                        // Run eviction
198
                        locked_tree_clone.evict_tenant_by_size(max_tree_size);
199
200
201

                        // Print the process queue
                        let locked_processed_queue = processed_queue_clone.lock().unwrap();
202
                        info!("Processed Queue: {:?}", locked_processed_queue);
203
204
205

                        // Print the running queue
                        let locked_running_queue = running_queue_clone.lock().unwrap();
206
                        info!("Running Queue: {:?}", locked_running_queue);
207
208
                    }
                });
209
210

                for url in &worker_urls {
211
                    tree.lock().unwrap().insert(&"".to_string(), url);
212
213
                }

214
                Router::CacheAware {
215
                    worker_urls: Arc::new(RwLock::new(worker_urls)),
216
217
218
                    tree,
                    running_queue,
                    processed_queue,
219
                    cache_threshold,
220
221
                    balance_abs_threshold,
                    balance_rel_threshold,
222
                    timeout_secs,
223
                    interval_secs,
224
                    _eviction_thread: Some(eviction_thread),
225
226
                }
            }
227
        })
228
229
    }

230
231
232
233
234
235
236
237
238
239
    fn wait_for_healthy_workers(
        worker_urls: &[String],
        timeout_secs: u64,
        interval_secs: u64,
    ) -> Result<(), String> {
        let start_time = std::time::Instant::now();
        let sync_client = reqwest::blocking::Client::new();

        loop {
            if start_time.elapsed() > Duration::from_secs(timeout_secs) {
240
241
242
243
                error!(
                    "Timeout {}s waiting for workers to become healthy",
                    timeout_secs
                );
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
                return Err(format!(
                    "Timeout {}s waiting for workers to become healthy",
                    timeout_secs
                ));
            }

            let mut all_healthy = true;
            let mut unhealthy_workers = Vec::new();

            for url in worker_urls {
                match sync_client.get(&format!("{}/health", url)).send() {
                    Ok(res) => {
                        if !res.status().is_success() {
                            info!(
                                "Worker {} health check is pending with status: {}.",
                                url,
                                res.status()
                            );
                            all_healthy = false;
                            unhealthy_workers.push((url, format!("Status: {}", res.status())));
                        }
                    }
                    Err(e) => {
                        info!("Worker {} health check is pending with error: {}", url, e);
                        all_healthy = false;
                        unhealthy_workers.push((url, format!("Error: {}", e)));
                    }
                }
            }

            if all_healthy {
                info!("All workers are healthy");
                return Ok(());
            } else {
                info!("Unhealthy workers:");
                for (url, reason) in &unhealthy_workers {
                    info!("  {} - {}", url, reason);
                }
                thread::sleep(Duration::from_secs(interval_secs));
            }
        }
    }

287
288
289
    fn select_first_worker(&self) -> Result<String, String> {
        match self {
            Router::RoundRobin { worker_urls, .. }
290
            | Router::Random { worker_urls, .. }
291
292
293
294
295
296
297
298
299
300
301
            | Router::CacheAware { worker_urls, .. } => {
                if worker_urls.read().unwrap().is_empty() {
                    Err("No workers are available".to_string())
                } else {
                    Ok(worker_urls.read().unwrap()[0].clone())
                }
            }
        }
    }

    async fn send_request(
302
303
        &self,
        client: &reqwest::Client,
304
        worker_url: &str,
305
        route: &str,
306
    ) -> HttpResponse {
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
        match client.get(format!("{}{}", worker_url, route)).send().await {
            Ok(res) => {
                let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
                    .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);

                match res.bytes().await {
                    Ok(body) => HttpResponse::build(status).body(body.to_vec()),
                    Err(e) => HttpResponse::InternalServerError()
                        .body(format!("Failed to read response body: {}", e)),
                }
            }
            Err(e) => HttpResponse::InternalServerError().body(format!(
                "Failed to send request to worker {}: {}",
                worker_url, e
            )),
        }
    }

    pub async fn route_to_first(&self, client: &reqwest::Client, route: &str) -> HttpResponse {
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
        const MAX_REQUEST_RETRIES: u32 = 3;
        const MAX_TOTAL_RETRIES: u32 = 6;
        let mut total_retries = 0;

        while total_retries < MAX_TOTAL_RETRIES {
            match self.select_first_worker() {
                Ok(worker_url) => {
                    let mut request_retries = 0;

                    // Try the same worker multiple times
                    while request_retries < MAX_REQUEST_RETRIES {
                        if total_retries >= 1 {
                            info!("Retrying request after {} failed attempts", total_retries);
                        }

                        let response = self.send_request(client, &worker_url, route).await;

                        if response.status().is_success() {
                            return response;
                        }

                        warn!(
                            "Request to {} failed (attempt {}/{})",
                            worker_url,
                            request_retries + 1,
                            MAX_REQUEST_RETRIES
                        );

                        request_retries += 1;
                        total_retries += 1;

                        if request_retries == MAX_REQUEST_RETRIES {
                            warn!("Removing failed worker: {}", worker_url);
                            self.remove_worker(&worker_url);
                            break;
                        }
                    }
                }
                Err(e) => return HttpResponse::InternalServerError().body(e),
            }
366
        }
367
368

        HttpResponse::InternalServerError().body("All retry attempts failed")
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    }

    fn get_text_from_request(&self, body: &Bytes, route: &str) -> String {
        // convert body to json
        let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();

        if route == "generate" {
            // get the "text" field
            let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
            return text.to_string();
        } else if route == "v1/chat/completions" {
            // get the messages field as raw text
            if let Some(messages) = json.get("messages") {
                // Convert messages back to a string, preserving all JSON formatting
                return serde_json::to_string(messages).unwrap_or_default();
            }
        } else if route == "v1/completions" {
            let prompt = json.get("prompt").and_then(|t| t.as_str()).unwrap_or("");
            return prompt.to_string();
        }

        return "".to_string();
    }

    // TODO: return Result<String, String> instead of panicking
    fn select_generate_worker(&self, body: &Bytes, route: &str) -> String {
        let text = self.get_text_from_request(&body, route);
396

397
398
399
400
        let worker_url = match self {
            Router::RoundRobin {
                worker_urls,
                current_index,
401
                ..
402
            } => {
403
                let idx = current_index
404
405
406
                    .fetch_update(
                        std::sync::atomic::Ordering::SeqCst,
                        std::sync::atomic::Ordering::SeqCst,
407
                        |x| Some((x + 1) % worker_urls.read().unwrap().len()),
408
                    )
409
                    .unwrap();
410
                worker_urls.read().unwrap()[idx].clone()
411
            }
412

413
            Router::Random { worker_urls, .. } => worker_urls.read().unwrap()
414
415
                [rand::random::<usize>() % worker_urls.read().unwrap().len()]
            .clone(),
416

417
            Router::CacheAware {
418
                worker_urls,
419
420
421
                tree,
                running_queue,
                processed_queue,
422
                cache_threshold,
423
424
                balance_abs_threshold,
                balance_rel_threshold,
425
426
                ..
            } => {
427
                // TODO: delay scheduling if cache hit rate is high because it may cause imbalance. prioritize low hit rate ones
428

Byron Hsu's avatar
Byron Hsu committed
429
                let tree = tree.lock().unwrap();
430
                let mut running_queue = running_queue.lock().unwrap();
431

432
433
434
435
436
437
438
439
440
441
442
443
                // Get current load statistics
                let max_load = *running_queue.values().max().unwrap_or(&0);
                let min_load = *running_queue.values().min().unwrap_or(&0);

                // Load is considered imbalanced if:
                // 1. (max - min) > abs_threshold AND
                // 2. max > rel_threshold * min
                let is_imbalanced = max_load.saturating_sub(min_load) > *balance_abs_threshold
                    && (max_load as f32) > (min_load as f32 * balance_rel_threshold);

                let selected_url = if is_imbalanced {
                    // Log load balancing trigger and current queue state
444
                    info!(
445
446
447
448
449
450
451
452
453
454
455
                        "Load balancing triggered due to workload imbalance:\n\
                        Max load: {}, Min load: {}\n\
                        Current running queue: {:?}",
                        max_load, min_load, running_queue
                    );

                    // Use shortest queue routing when load is imbalanced
                    running_queue
                        .iter()
                        .min_by_key(|(_url, &count)| count)
                        .map(|(url, _)| url.clone())
456
                        .unwrap_or_else(|| worker_urls.read().unwrap()[0].clone())
457
458
                } else {
                    // Use cache-aware routing when load is balanced
459
460
461
                    let (matched_text, matched_worker) = tree.prefix_match(&text);
                    let matched_rate =
                        matched_text.chars().count() as f32 / text.chars().count() as f32;
462

463
464
465
466
                    if matched_rate > *cache_threshold {
                        matched_worker.to_string()
                    } else {
                        tree.get_smallest_tenant()
467
                    }
468
                };
469

470
471
                // Update queues and tree
                *running_queue.get_mut(&selected_url).unwrap() += 1;
472

473
474
475
476
477
                *processed_queue
                    .lock()
                    .unwrap()
                    .get_mut(&selected_url)
                    .unwrap() += 1;
478
479
480
                tree.insert(&text, &selected_url);

                selected_url
481
482
            }
        };
483

484
485
486
487
488
489
        worker_url
    }

    async fn send_generate_request(
        &self,
        client: &reqwest::Client,
490
491
        req: &HttpRequest,
        body: &Bytes,
492
493
494
        route: &str,
        worker_url: &str,
    ) -> HttpResponse {
495
496
497
        let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
            .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
            .unwrap_or(false);
498

499
        let res = match client
500
            .post(format!("{}{}", worker_url, route))
501
502
503
504
505
506
507
508
509
510
511
512
513
514
            .header(
                "Content-Type",
                req.headers()
                    .get("Content-Type")
                    .and_then(|h| h.to_str().ok())
                    .unwrap_or("application/json"),
            )
            .body(body.to_vec())
            .send()
            .await
        {
            Ok(res) => res,
            Err(_) => return HttpResponse::InternalServerError().finish(),
        };
515

516
517
        let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
            .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
518

519
        if !is_stream {
520
521
            // For non-streaming requests, get response first
            let response = match res.bytes().await {
522
                Ok(body) => HttpResponse::build(status).body(body.to_vec()),
523
524
525
526
                Err(e) => {
                    let error_msg = format!("Failed to get response body: {}", e);
                    HttpResponse::InternalServerError().body(error_msg)
                }
527
528
529
530
531
            };

            // Then decrement running queue counter if using CacheAware
            if let Router::CacheAware { running_queue, .. } = self {
                if let Ok(mut queue) = running_queue.lock() {
532
                    if let Some(count) = queue.get_mut(worker_url) {
533
534
535
                        *count = count.saturating_sub(1);
                    }
                }
536
            }
537
538
539
540

            response
        } else if let Router::CacheAware { running_queue, .. } = self {
            let running_queue = Arc::clone(running_queue);
541
            let worker_url = worker_url.to_string();
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559

            HttpResponse::build(status)
                .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
                .streaming(
                    res.bytes_stream()
                        .map_err(|_| {
                            actix_web::error::ErrorInternalServerError("Failed to read stream")
                        })
                        .inspect(move |bytes| {
                            let bytes = bytes.as_ref().unwrap();
                            if bytes
                                .as_ref()
                                .windows(12)
                                .any(|window| window == b"data: [DONE]")
                            {
                                let mut locked_queue = running_queue.lock().unwrap();
                                let count = locked_queue.get_mut(&worker_url).unwrap();
                                *count = count.saturating_sub(1);
560
                                debug!("Streaming is done!!")
561
562
563
                            }
                        }),
                )
564
565
566
567
        } else {
            HttpResponse::build(status)
                .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
                .streaming(res.bytes_stream().map_err(|_| {
568
                    actix_web::error::ErrorInternalServerError("Failed to read stream")
569
                }))
570
571
        }
    }
572

573
574
575
    pub async fn route_generate_request(
        &self,
        client: &reqwest::Client,
576
577
        req: &HttpRequest,
        body: &Bytes,
578
579
        route: &str,
    ) -> HttpResponse {
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
        const MAX_REQUEST_RETRIES: u32 = 3;
        const MAX_TOTAL_RETRIES: u32 = 6;
        let mut total_retries = 0;

        while total_retries < MAX_TOTAL_RETRIES {
            let worker_url = self.select_generate_worker(body, route);
            let mut request_retries = 0;

            // Try the same worker multiple times
            while request_retries < MAX_REQUEST_RETRIES {
                if total_retries >= 1 {
                    info!("Retrying request after {} failed attempts", total_retries);
                }
                let response = self
                    .send_generate_request(client, req, body, route, &worker_url)
                    .await;

                if response.status().is_success() {
                    return response;
                }

                warn!(
                    "Generate request to {} failed (attempt {}/{})",
                    worker_url,
                    request_retries + 1,
                    MAX_REQUEST_RETRIES
                );

                request_retries += 1;
                total_retries += 1;

                if request_retries == MAX_REQUEST_RETRIES {
                    warn!("Removing failed worker: {}", worker_url);
                    self.remove_worker(&worker_url);
                    break;
                }
            }
        }

        HttpResponse::InternalServerError().body("All retry attempts failed")
620
621
    }

622
    pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
        let (timeout_secs, interval_secs) = match self {
            Router::Random {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
            Router::RoundRobin {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
            Router::CacheAware {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
639
        };
640
641
642
643
644
645

        let start_time = std::time::Instant::now();
        let client = reqwest::Client::new();

        loop {
            if start_time.elapsed() > Duration::from_secs(timeout_secs) {
646
647
648
649
                error!(
                    "Timeout {}s waiting for worker {} to become healthy",
                    timeout_secs, worker_url
                );
650
                return Err(format!(
651
652
653
654
655
656
657
658
659
660
                    "Timeout {}s waiting for worker {} to become healthy",
                    timeout_secs, worker_url
                ));
            }

            match client.get(&format!("{}/health", worker_url)).send().await {
                Ok(res) => {
                    if res.status().is_success() {
                        match self {
                            Router::RoundRobin { worker_urls, .. }
661
                            | Router::Random { worker_urls, .. }
662
663
664
                            | Router::CacheAware { worker_urls, .. } => {
                                info!("Worker {} health check passed", worker_url);
                                let mut urls = worker_urls.write().unwrap();
665
                                if urls.contains(&worker_url.to_string()) {
666
                                    return Err(format!("Worker {} already exists", worker_url));
667
668
                                }
                                info!("Added worker: {}", worker_url);
669
                                urls.push(worker_url.to_string());
670
671
                            }
                        }
672
673
674
675
676
677
678
679
680
681

                        // If cache aware, initialize the queues for the new worker
                        if let Router::CacheAware {
                            running_queue,
                            processed_queue,
                            tree,
                            ..
                        } = self
                        {
                            // Add worker to running queue with initial count of 0
682
683
684
685
                            running_queue
                                .lock()
                                .unwrap()
                                .insert(worker_url.to_string(), 0);
686
687
688
689
690

                            // Add worker to processed queue with initial count of 0
                            processed_queue
                                .lock()
                                .unwrap()
691
                                .insert(worker_url.to_string(), 0);
692
693
694
695
696
697

                            // Add worker to tree
                            tree.lock().unwrap().insert(&"".to_string(), &worker_url);
                        }

                        return Ok(format!("Successfully added worker: {}", worker_url));
698
699
                    } else {
                        info!(
700
701
702
                            "Worker {} health check is pending with status: {}.",
                            worker_url,
                            res.status()
703
704
705
706
707
708
709
710
711
712
713
714
                        );
                        // if the url does not have http or https prefix, warn users
                        if !worker_url.starts_with("http://") && !worker_url.starts_with("https://")
                        {
                            warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
                        }

                        tokio::time::sleep(Duration::from_secs(interval_secs)).await;
                        continue;
                    }
                }
                Err(e) => {
715
716
717
718
                    info!(
                        "Worker {} health check is pending with error: {}",
                        worker_url, e
                    );
719
720
721
722
723
724
725
726
727

                    // if the url does not have http or https prefix, warn users
                    if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") {
                        warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
                    }

                    tokio::time::sleep(Duration::from_secs(interval_secs)).await;
                    continue;
                }
728
729
730
            }
        }
    }
731

732
    pub fn remove_worker(&self, worker_url: &str) {
733
734
        match self {
            Router::RoundRobin { worker_urls, .. }
735
            | Router::Random { worker_urls, .. }
736
737
            | Router::CacheAware { worker_urls, .. } => {
                let mut urls = worker_urls.write().unwrap();
738
739
740
741
742
743
744
                if let Some(index) = urls.iter().position(|url| url == &worker_url) {
                    urls.remove(index);
                    info!("Removed worker: {}", worker_url);
                } else {
                    warn!("Worker {} not found, skipping removal", worker_url);
                    return;
                }
745
746
747
748
            }
        }

        // if cache aware, remove the worker from the tree
749
750
751
752
753
754
755
        if let Router::CacheAware {
            tree,
            running_queue,
            processed_queue,
            ..
        } = self
        {
756
            tree.lock().unwrap().remove_tenant(&worker_url);
757
758
759
760
761
762
763
764
            running_queue
                .lock()
                .unwrap()
                .remove(&worker_url.to_string());
            processed_queue
                .lock()
                .unwrap()
                .remove(&worker_url.to_string());
765
766
767
768
            info!(
                "Removed worker from tree and cleaned up queues: {}",
                worker_url
            );
769
770
        }
    }
771
}