router.rs 32.1 KB
Newer Older
1
use crate::tree::Tree;
2
3
4
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse};
use bytes::Bytes;
Byron Hsu's avatar
Byron Hsu committed
5
use futures_util::{StreamExt, TryStreamExt};
6
use log::{debug, error, info, warn};
7
use serde_json::Value;
8
use std::collections::HashMap;
9
use std::fmt::Debug;
10
use std::sync::atomic::AtomicUsize;
11
use std::sync::{Arc, Mutex, RwLock};
12
13
use std::thread;
use std::time::Duration;
14
use tokio;
15

16
17
18
19
20
21
22
23
24
25
26
27
fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
    req.headers()
        .iter()
        .filter_map(|(name, value)| {
            value
                .to_str()
                .ok()
                .map(|v| (name.to_string(), v.to_string()))
        })
        .collect()
}

28
#[derive(Debug)]
29
30
pub enum Router {
    RoundRobin {
31
        worker_urls: Arc<RwLock<Vec<String>>>,
32
        current_index: AtomicUsize,
33
        timeout_secs: u64,
34
        interval_secs: u64,
35
36
    },
    Random {
37
        worker_urls: Arc<RwLock<Vec<String>>>,
38
        timeout_secs: u64,
39
        interval_secs: u64,
40
    },
41
42
    CacheAware {
        /*
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
            Cache-Aware Load Balancing Router

            This router combines two strategies to optimize both cache utilization and request distribution:

            1. Cache-Aware Routing (Approximate Tree)
            2. Load Balancing (Shortest Queue with Balance Thresholds)

            The router dynamically switches between these strategies based on load conditions:
            - Uses load balancing when the system is imbalanced
            - Uses cache-aware routing when the system is balanced

            A system is considered imbalanced if both conditions are met:
            1. (max - min) > abs_threshold
            2. max > rel_threshold * min

            Strategy Details:

            1. Cache-Aware Routing (Approximate Tree)
            -------------------------------------------
            This strategy maintains an approximate radix tree for each worker based on request history,
            eliminating the need for direct cache state queries. The tree stores raw text characters
            instead of token IDs to avoid tokenization overhead.

            Process:
            a. For each request, find the worker with the highest prefix match
            b. If match rate > cache_threshold:
            Route to the worker with highest match (likely has relevant data cached)
            c. If match rate ≤ cache_threshold:
            Route to the worker with smallest tree size (most available cache capacity)
            d. Background maintenance:
            Periodically evict least recently used leaf nodes to prevent memory overflow

            2. Load Balancing (Shortest Queue)
            -------------------------------------------
            This strategy tracks pending request counts per worker and routes new requests
            to the least busy worker when the system is detected to be imbalanced.

            Configuration Parameters:
            ------------------------
            1. cache_threshold: (float, 0.0 to 1.0)
            Minimum prefix match ratio to use highest-match routing.
            Below this threshold, routes to worker with most available cache space.

            2. balance_abs_threshold: (integer)
            Absolute difference threshold for load imbalance detection.
            System is potentially imbalanced if (max_load - min_load) > abs_threshold

            3. balance_rel_threshold: (float)
            Relative ratio threshold for load imbalance detection.
            System is potentially imbalanced if max_load > min_load * rel_threshold
            Used in conjunction with abs_threshold to determine final imbalance state.

            4. eviction_interval_secs: (integer)
            Interval between LRU eviction cycles for the approximate trees.

            5. max_tree_size: (integer)
            Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
            during the next eviction cycle.
101
        */
102
        worker_urls: Arc<RwLock<Vec<String>>>,
103
104
105
        tree: Arc<Mutex<Tree>>,
        running_queue: Arc<Mutex<HashMap<String, usize>>>,
        processed_queue: Arc<Mutex<HashMap<String, usize>>>,
106
        cache_threshold: f32,
107
108
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
109
        timeout_secs: u64,
110
        interval_secs: u64,
111
        _eviction_thread: Option<thread::JoinHandle<()>>,
112
113
114
    },
}

115
#[derive(Debug, Clone)]
116
pub enum PolicyConfig {
117
118
    RandomConfig {
        timeout_secs: u64,
119
        interval_secs: u64,
120
121
122
    },
    RoundRobinConfig {
        timeout_secs: u64,
123
        interval_secs: u64,
124
    },
125
    CacheAwareConfig {
126
        cache_threshold: f32,
127
128
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
129
130
        eviction_interval_secs: u64,
        max_tree_size: usize,
131
        timeout_secs: u64,
132
        interval_secs: u64,
133
134
135
    },
}

136
impl Router {
137
    pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Result<Self, String> {
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        // Get timeout and interval from policy config
        let (timeout_secs, interval_secs) = match &policy_config {
            PolicyConfig::RandomConfig {
                timeout_secs,
                interval_secs,
            } => (*timeout_secs, *interval_secs),
            PolicyConfig::RoundRobinConfig {
                timeout_secs,
                interval_secs,
            } => (*timeout_secs, *interval_secs),
            PolicyConfig::CacheAwareConfig {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
153
154
        };

155
        // Wait until all workers are healthy
156
        Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
157
158
159

        // Create router based on policy...
        Ok(match policy_config {
160
161
162
163
            PolicyConfig::RandomConfig {
                timeout_secs,
                interval_secs,
            } => Router::Random {
164
                worker_urls: Arc::new(RwLock::new(worker_urls)),
165
                timeout_secs,
166
                interval_secs,
167
            },
168
169
170
171
            PolicyConfig::RoundRobinConfig {
                timeout_secs,
                interval_secs,
            } => Router::RoundRobin {
172
                worker_urls: Arc::new(RwLock::new(worker_urls)),
173
                current_index: std::sync::atomic::AtomicUsize::new(0),
174
                timeout_secs,
175
                interval_secs,
176
            },
177
            PolicyConfig::CacheAwareConfig {
178
                cache_threshold,
179
180
                balance_abs_threshold,
                balance_rel_threshold,
181
182
                eviction_interval_secs,
                max_tree_size,
183
                timeout_secs,
184
                interval_secs,
185
            } => {
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
                let mut running_queue = HashMap::new();
                for url in &worker_urls {
                    running_queue.insert(url.clone(), 0);
                }

                let mut processed_queue = HashMap::new();
                for url in &worker_urls {
                    processed_queue.insert(url.clone(), 0);
                }

                let tree = Arc::new(Mutex::new(Tree::new()));
                let running_queue = Arc::new(Mutex::new(running_queue));
                let processed_queue = Arc::new(Mutex::new(processed_queue));

                // Create background eviction thread
                let tree_clone = Arc::clone(&tree);
                let processed_queue_clone = Arc::clone(&processed_queue);
203
                let running_queue_clone = Arc::clone(&running_queue);
204
205
206
207
208
209
210
                let eviction_thread = thread::spawn(move || {
                    loop {
                        // Sleep for the specified interval
                        thread::sleep(Duration::from_secs(eviction_interval_secs));

                        let locked_tree_clone = tree_clone.lock().unwrap();
                        // Run eviction
211
                        locked_tree_clone.evict_tenant_by_size(max_tree_size);
212
213
214

                        // Print the process queue
                        let locked_processed_queue = processed_queue_clone.lock().unwrap();
215
                        info!("Processed Queue: {:?}", locked_processed_queue);
216
217
218

                        // Print the running queue
                        let locked_running_queue = running_queue_clone.lock().unwrap();
219
                        info!("Running Queue: {:?}", locked_running_queue);
220
221
                    }
                });
222
223

                for url in &worker_urls {
224
                    tree.lock().unwrap().insert(&"".to_string(), url);
225
226
                }

227
                Router::CacheAware {
228
                    worker_urls: Arc::new(RwLock::new(worker_urls)),
229
230
231
                    tree,
                    running_queue,
                    processed_queue,
232
                    cache_threshold,
233
234
                    balance_abs_threshold,
                    balance_rel_threshold,
235
                    timeout_secs,
236
                    interval_secs,
237
                    _eviction_thread: Some(eviction_thread),
238
239
                }
            }
240
        })
241
242
    }

243
244
245
246
247
248
249
250
251
252
    fn wait_for_healthy_workers(
        worker_urls: &[String],
        timeout_secs: u64,
        interval_secs: u64,
    ) -> Result<(), String> {
        let start_time = std::time::Instant::now();
        let sync_client = reqwest::blocking::Client::new();

        loop {
            if start_time.elapsed() > Duration::from_secs(timeout_secs) {
253
                error!(
254
255
                    "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
                    timeout_secs, worker_urls
256
                );
257
                return Err(format!(
258
259
                    "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
                    timeout_secs, worker_urls
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
                ));
            }

            let mut all_healthy = true;
            let mut unhealthy_workers = Vec::new();

            for url in worker_urls {
                match sync_client.get(&format!("{}/health", url)).send() {
                    Ok(res) => {
                        if !res.status().is_success() {
                            info!(
                                "Worker {} health check is pending with status: {}.",
                                url,
                                res.status()
                            );
                            all_healthy = false;
                            unhealthy_workers.push((url, format!("Status: {}", res.status())));
                        }
                    }
                    Err(e) => {
                        info!("Worker {} health check is pending with error: {}", url, e);
                        all_healthy = false;
                        unhealthy_workers.push((url, format!("Error: {}", e)));
                    }
                }
            }

            if all_healthy {
                info!("All workers are healthy");
                return Ok(());
            } else {
                info!("Unhealthy workers:");
                for (url, reason) in &unhealthy_workers {
                    info!("  {} - {}", url, reason);
                }
                thread::sleep(Duration::from_secs(interval_secs));
            }
        }
    }

300
301
302
    fn select_first_worker(&self) -> Result<String, String> {
        match self {
            Router::RoundRobin { worker_urls, .. }
303
            | Router::Random { worker_urls, .. }
304
305
306
307
308
309
310
311
312
313
314
            | Router::CacheAware { worker_urls, .. } => {
                if worker_urls.read().unwrap().is_empty() {
                    Err("No workers are available".to_string())
                } else {
                    Ok(worker_urls.read().unwrap()[0].clone())
                }
            }
        }
    }

    async fn send_request(
315
316
        &self,
        client: &reqwest::Client,
317
        worker_url: &str,
318
        route: &str,
319
        req: &HttpRequest,
320
    ) -> HttpResponse {
321
322
323
324
325
326
327
328
329
330
        let mut request_builder = client.get(format!("{}{}", worker_url, route));

        // Copy all headers from original request except for /health because it does not need authorization
        if route != "/health" {
            for (name, value) in copy_request_headers(req) {
                request_builder = request_builder.header(name, value);
            }
        }

        match request_builder.send().await {
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
            Ok(res) => {
                let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
                    .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);

                match res.bytes().await {
                    Ok(body) => HttpResponse::build(status).body(body.to_vec()),
                    Err(e) => HttpResponse::InternalServerError()
                        .body(format!("Failed to read response body: {}", e)),
                }
            }
            Err(e) => HttpResponse::InternalServerError().body(format!(
                "Failed to send request to worker {}: {}",
                worker_url, e
            )),
        }
    }

348
349
350
351
352
353
    pub async fn route_to_first(
        &self,
        client: &reqwest::Client,
        route: &str,
        req: &HttpRequest,
    ) -> HttpResponse {
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
        const MAX_REQUEST_RETRIES: u32 = 3;
        const MAX_TOTAL_RETRIES: u32 = 6;
        let mut total_retries = 0;

        while total_retries < MAX_TOTAL_RETRIES {
            match self.select_first_worker() {
                Ok(worker_url) => {
                    let mut request_retries = 0;

                    // Try the same worker multiple times
                    while request_retries < MAX_REQUEST_RETRIES {
                        if total_retries >= 1 {
                            info!("Retrying request after {} failed attempts", total_retries);
                        }

369
                        let response = self.send_request(client, &worker_url, route, req).await;
370
371
372

                        if response.status().is_success() {
                            return response;
373
374
375
376
377
378
379
                        } else {
                            // if the worker is healthy, it means the request is bad, so return the error response
                            let health_response =
                                self.send_request(client, &worker_url, "/health", req).await;
                            if health_response.status().is_success() {
                                return response;
                            }
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
                        }

                        warn!(
                            "Request to {} failed (attempt {}/{})",
                            worker_url,
                            request_retries + 1,
                            MAX_REQUEST_RETRIES
                        );

                        request_retries += 1;
                        total_retries += 1;

                        if request_retries == MAX_REQUEST_RETRIES {
                            warn!("Removing failed worker: {}", worker_url);
                            self.remove_worker(&worker_url);
                            break;
                        }
                    }
                }
                Err(e) => return HttpResponse::InternalServerError().body(e),
            }
401
        }
402
403

        HttpResponse::InternalServerError().body("All retry attempts failed")
404
405
406
    }

    fn get_text_from_request(&self, body: &Bytes, route: &str) -> String {
407
408
409
410
411
412
        // Convert body to JSON
        let json: Value = match serde_json::from_slice(body) {
            Ok(j) => j,
            Err(_) => {
                warn!("Failed to parse JSON from request body.");
                return String::new();
413
            }
414
        };
415

416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
        match route {
            "/generate" => {
                // For /generate, always use the "text" field.
                match json.get("text").and_then(Value::as_str) {
                    Some(text) => text.to_string(),
                    None => {
                        warn!("No 'text' field found in request body for route /generate.");
                        String::new()
                    }
                }
            }
            "/v1/chat/completions" | "/v1/completions" => {
                // For these routes, try "messages", then "prompt", then "text".
                if let Some(messages) = json.get("messages") {
                    serde_json::to_string(messages).unwrap_or_default()
                } else if let Some(prompt) = json.get("prompt").and_then(Value::as_str) {
                    prompt.to_string()
                } else {
                    warn!("Failed to find 'messages', 'prompt' in request body.");
                    String::new()
                }
            }
            _ => {
                warn!("Unknown route: {} - defaulting to fallback string", route);
                String::new()
            }
        }
443
444
445
446
447
    }

    // TODO: return Result<String, String> instead of panicking
    fn select_generate_worker(&self, body: &Bytes, route: &str) -> String {
        let text = self.get_text_from_request(&body, route);
448

449
450
451
452
        let worker_url = match self {
            Router::RoundRobin {
                worker_urls,
                current_index,
453
                ..
454
            } => {
455
                let idx = current_index
456
457
458
                    .fetch_update(
                        std::sync::atomic::Ordering::SeqCst,
                        std::sync::atomic::Ordering::SeqCst,
459
                        |x| Some((x + 1) % worker_urls.read().unwrap().len()),
460
                    )
461
                    .unwrap();
462
                worker_urls.read().unwrap()[idx].clone()
463
            }
464

465
            Router::Random { worker_urls, .. } => worker_urls.read().unwrap()
466
467
                [rand::random::<usize>() % worker_urls.read().unwrap().len()]
            .clone(),
468

469
            Router::CacheAware {
470
                worker_urls,
471
472
473
                tree,
                running_queue,
                processed_queue,
474
                cache_threshold,
475
476
                balance_abs_threshold,
                balance_rel_threshold,
477
478
                ..
            } => {
479
                // TODO: delay scheduling if cache hit rate is high because it may cause imbalance. prioritize low hit rate ones
480

Byron Hsu's avatar
Byron Hsu committed
481
                let tree = tree.lock().unwrap();
482
                let mut running_queue = running_queue.lock().unwrap();
483

484
485
486
487
488
489
490
491
492
493
494
495
                // Get current load statistics
                let max_load = *running_queue.values().max().unwrap_or(&0);
                let min_load = *running_queue.values().min().unwrap_or(&0);

                // Load is considered imbalanced if:
                // 1. (max - min) > abs_threshold AND
                // 2. max > rel_threshold * min
                let is_imbalanced = max_load.saturating_sub(min_load) > *balance_abs_threshold
                    && (max_load as f32) > (min_load as f32 * balance_rel_threshold);

                let selected_url = if is_imbalanced {
                    // Log load balancing trigger and current queue state
496
                    info!(
497
498
499
500
501
502
503
504
505
506
507
                        "Load balancing triggered due to workload imbalance:\n\
                        Max load: {}, Min load: {}\n\
                        Current running queue: {:?}",
                        max_load, min_load, running_queue
                    );

                    // Use shortest queue routing when load is imbalanced
                    running_queue
                        .iter()
                        .min_by_key(|(_url, &count)| count)
                        .map(|(url, _)| url.clone())
508
                        .unwrap_or_else(|| worker_urls.read().unwrap()[0].clone())
509
510
                } else {
                    // Use cache-aware routing when load is balanced
511
512
513
                    let (matched_text, matched_worker) = tree.prefix_match(&text);
                    let matched_rate =
                        matched_text.chars().count() as f32 / text.chars().count() as f32;
514

515
516
517
518
                    if matched_rate > *cache_threshold {
                        matched_worker.to_string()
                    } else {
                        tree.get_smallest_tenant()
519
                    }
520
                };
521

522
523
                // Update queues and tree
                *running_queue.get_mut(&selected_url).unwrap() += 1;
524

525
526
527
528
529
                *processed_queue
                    .lock()
                    .unwrap()
                    .get_mut(&selected_url)
                    .unwrap() += 1;
530
531
532
                tree.insert(&text, &selected_url);

                selected_url
533
534
            }
        };
535

536
537
538
539
540
541
        worker_url
    }

    async fn send_generate_request(
        &self,
        client: &reqwest::Client,
542
543
        req: &HttpRequest,
        body: &Bytes,
544
545
546
        route: &str,
        worker_url: &str,
    ) -> HttpResponse {
547
548
549
        let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
            .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
            .unwrap_or(false);
550

551
        let mut request_builder = client
552
            .post(format!("{}{}", worker_url, route))
553
554
555
556
557
558
559
560
            .body(body.to_vec());

        // Copy all headers from original request
        for (name, value) in copy_request_headers(req) {
            request_builder = request_builder.header(name, value);
        }

        let res = match request_builder.send().await {
561
562
563
            Ok(res) => res,
            Err(_) => return HttpResponse::InternalServerError().finish(),
        };
564

565
566
        let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
            .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
567

568
        if !is_stream {
569
570
            // For non-streaming requests, get response first
            let response = match res.bytes().await {
571
                Ok(body) => HttpResponse::build(status).body(body.to_vec()),
572
573
574
575
                Err(e) => {
                    let error_msg = format!("Failed to get response body: {}", e);
                    HttpResponse::InternalServerError().body(error_msg)
                }
576
577
578
579
580
            };

            // Then decrement running queue counter if using CacheAware
            if let Router::CacheAware { running_queue, .. } = self {
                if let Ok(mut queue) = running_queue.lock() {
581
                    if let Some(count) = queue.get_mut(worker_url) {
582
583
584
                        *count = count.saturating_sub(1);
                    }
                }
585
            }
586
587
588
589

            response
        } else if let Router::CacheAware { running_queue, .. } = self {
            let running_queue = Arc::clone(running_queue);
590
            let worker_url = worker_url.to_string();
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608

            HttpResponse::build(status)
                .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
                .streaming(
                    res.bytes_stream()
                        .map_err(|_| {
                            actix_web::error::ErrorInternalServerError("Failed to read stream")
                        })
                        .inspect(move |bytes| {
                            let bytes = bytes.as_ref().unwrap();
                            if bytes
                                .as_ref()
                                .windows(12)
                                .any(|window| window == b"data: [DONE]")
                            {
                                let mut locked_queue = running_queue.lock().unwrap();
                                let count = locked_queue.get_mut(&worker_url).unwrap();
                                *count = count.saturating_sub(1);
609
                                debug!("Streaming is done!!")
610
611
612
                            }
                        }),
                )
613
614
615
616
        } else {
            HttpResponse::build(status)
                .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
                .streaming(res.bytes_stream().map_err(|_| {
617
                    actix_web::error::ErrorInternalServerError("Failed to read stream")
618
                }))
619
620
        }
    }
621

622
623
624
    pub async fn route_generate_request(
        &self,
        client: &reqwest::Client,
625
626
        req: &HttpRequest,
        body: &Bytes,
627
628
        route: &str,
    ) -> HttpResponse {
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
        const MAX_REQUEST_RETRIES: u32 = 3;
        const MAX_TOTAL_RETRIES: u32 = 6;
        let mut total_retries = 0;

        while total_retries < MAX_TOTAL_RETRIES {
            let worker_url = self.select_generate_worker(body, route);
            let mut request_retries = 0;

            // Try the same worker multiple times
            while request_retries < MAX_REQUEST_RETRIES {
                if total_retries >= 1 {
                    info!("Retrying request after {} failed attempts", total_retries);
                }
                let response = self
                    .send_generate_request(client, req, body, route, &worker_url)
                    .await;

                if response.status().is_success() {
                    return response;
648
649
650
651
652
653
654
                } else {
                    // if the worker is healthy, it means the request is bad, so return the error response
                    let health_response =
                        self.send_request(client, &worker_url, "/health", req).await;
                    if health_response.status().is_success() {
                        return response;
                    }
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
                }

                warn!(
                    "Generate request to {} failed (attempt {}/{})",
                    worker_url,
                    request_retries + 1,
                    MAX_REQUEST_RETRIES
                );

                request_retries += 1;
                total_retries += 1;

                if request_retries == MAX_REQUEST_RETRIES {
                    warn!("Removing failed worker: {}", worker_url);
                    self.remove_worker(&worker_url);
                    break;
                }
            }
        }

        HttpResponse::InternalServerError().body("All retry attempts failed")
676
677
    }

678
    pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
        let (timeout_secs, interval_secs) = match self {
            Router::Random {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
            Router::RoundRobin {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
            Router::CacheAware {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
695
        };
696
697
698
699
700
701

        let start_time = std::time::Instant::now();
        let client = reqwest::Client::new();

        loop {
            if start_time.elapsed() > Duration::from_secs(timeout_secs) {
702
                error!(
703
                    "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
704
705
                    timeout_secs, worker_url
                );
706
                return Err(format!(
707
                    "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
708
709
710
711
712
713
714
715
716
                    timeout_secs, worker_url
                ));
            }

            match client.get(&format!("{}/health", worker_url)).send().await {
                Ok(res) => {
                    if res.status().is_success() {
                        match self {
                            Router::RoundRobin { worker_urls, .. }
717
                            | Router::Random { worker_urls, .. }
718
719
720
                            | Router::CacheAware { worker_urls, .. } => {
                                info!("Worker {} health check passed", worker_url);
                                let mut urls = worker_urls.write().unwrap();
721
                                if urls.contains(&worker_url.to_string()) {
722
                                    return Err(format!("Worker {} already exists", worker_url));
723
724
                                }
                                info!("Added worker: {}", worker_url);
725
                                urls.push(worker_url.to_string());
726
727
                            }
                        }
728
729
730
731
732
733
734
735
736
737

                        // If cache aware, initialize the queues for the new worker
                        if let Router::CacheAware {
                            running_queue,
                            processed_queue,
                            tree,
                            ..
                        } = self
                        {
                            // Add worker to running queue with initial count of 0
738
739
740
741
                            running_queue
                                .lock()
                                .unwrap()
                                .insert(worker_url.to_string(), 0);
742
743
744
745
746

                            // Add worker to processed queue with initial count of 0
                            processed_queue
                                .lock()
                                .unwrap()
747
                                .insert(worker_url.to_string(), 0);
748
749
750
751
752
753

                            // Add worker to tree
                            tree.lock().unwrap().insert(&"".to_string(), &worker_url);
                        }

                        return Ok(format!("Successfully added worker: {}", worker_url));
754
755
                    } else {
                        info!(
756
757
758
                            "Worker {} health check is pending with status: {}.",
                            worker_url,
                            res.status()
759
760
761
762
763
764
765
766
767
768
769
770
                        );
                        // if the url does not have http or https prefix, warn users
                        if !worker_url.starts_with("http://") && !worker_url.starts_with("https://")
                        {
                            warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
                        }

                        tokio::time::sleep(Duration::from_secs(interval_secs)).await;
                        continue;
                    }
                }
                Err(e) => {
771
772
773
774
                    info!(
                        "Worker {} health check is pending with error: {}",
                        worker_url, e
                    );
775
776
777
778
779
780
781
782
783

                    // if the url does not have http or https prefix, warn users
                    if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") {
                        warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
                    }

                    tokio::time::sleep(Duration::from_secs(interval_secs)).await;
                    continue;
                }
784
785
786
            }
        }
    }
787

788
    pub fn remove_worker(&self, worker_url: &str) {
789
790
        match self {
            Router::RoundRobin { worker_urls, .. }
791
            | Router::Random { worker_urls, .. }
792
793
            | Router::CacheAware { worker_urls, .. } => {
                let mut urls = worker_urls.write().unwrap();
794
795
796
797
798
799
800
                if let Some(index) = urls.iter().position(|url| url == &worker_url) {
                    urls.remove(index);
                    info!("Removed worker: {}", worker_url);
                } else {
                    warn!("Worker {} not found, skipping removal", worker_url);
                    return;
                }
801
802
803
804
            }
        }

        // if cache aware, remove the worker from the tree
805
806
807
808
809
810
811
        if let Router::CacheAware {
            tree,
            running_queue,
            processed_queue,
            ..
        } = self
        {
812
            tree.lock().unwrap().remove_tenant(&worker_url);
813
814
815
816
817
818
819
820
            running_queue
                .lock()
                .unwrap()
                .remove(&worker_url.to_string());
            processed_queue
                .lock()
                .unwrap()
                .remove(&worker_url.to_string());
821
822
823
824
            info!(
                "Removed worker from tree and cleaned up queues: {}",
                worker_url
            );
825
826
        }
    }
827
}