"ppdet/slim/quant.py" did not exist on "dcc7bf4f1a243d90d6c4f7c51551cea3f256325f"
router.rs 32.1 KB
Newer Older
1
use crate::tree::Tree;
2
3
4
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse};
use bytes::Bytes;
Byron Hsu's avatar
Byron Hsu committed
5
use futures_util::{StreamExt, TryStreamExt};
6
use log::{debug, error, info, warn};
7
use serde_json::Value;
8
use std::collections::HashMap;
9
use std::fmt::Debug;
10
use std::sync::atomic::AtomicUsize;
11
use std::sync::{Arc, Mutex, RwLock};
12
13
use std::thread;
use std::time::Duration;
14
use tokio;
15

16
17
18
19
20
21
22
23
24
25
26
27
fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
    req.headers()
        .iter()
        .filter_map(|(name, value)| {
            value
                .to_str()
                .ok()
                .map(|v| (name.to_string(), v.to_string()))
        })
        .collect()
}

28
#[derive(Debug)]
29
30
pub enum Router {
    RoundRobin {
31
        worker_urls: Arc<RwLock<Vec<String>>>,
32
        current_index: AtomicUsize,
33
        timeout_secs: u64,
34
        interval_secs: u64,
35
36
    },
    Random {
37
        worker_urls: Arc<RwLock<Vec<String>>>,
38
        timeout_secs: u64,
39
        interval_secs: u64,
40
    },
41
42
    CacheAware {
        /*
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
            Cache-Aware Load Balancing Router

            This router combines two strategies to optimize both cache utilization and request distribution:

            1. Cache-Aware Routing (Approximate Tree)
            2. Load Balancing (Shortest Queue with Balance Thresholds)

            The router dynamically switches between these strategies based on load conditions:
            - Uses load balancing when the system is imbalanced
            - Uses cache-aware routing when the system is balanced

            A system is considered imbalanced if both conditions are met:
            1. (max - min) > abs_threshold
            2. max > rel_threshold * min

            Strategy Details:

            1. Cache-Aware Routing (Approximate Tree)
            -------------------------------------------
            This strategy maintains an approximate radix tree for each worker based on request history,
            eliminating the need for direct cache state queries. The tree stores raw text characters
            instead of token IDs to avoid tokenization overhead.

            Process:
            a. For each request, find the worker with the highest prefix match
            b. If match rate > cache_threshold:
            Route to the worker with highest match (likely has relevant data cached)
            c. If match rate ≤ cache_threshold:
            Route to the worker with smallest tree size (most available cache capacity)
            d. Background maintenance:
            Periodically evict least recently used leaf nodes to prevent memory overflow

            2. Load Balancing (Shortest Queue)
            -------------------------------------------
            This strategy tracks pending request counts per worker and routes new requests
            to the least busy worker when the system is detected to be imbalanced.

            Configuration Parameters:
            ------------------------
            1. cache_threshold: (float, 0.0 to 1.0)
            Minimum prefix match ratio to use highest-match routing.
            Below this threshold, routes to worker with most available cache space.

            2. balance_abs_threshold: (integer)
            Absolute difference threshold for load imbalance detection.
            System is potentially imbalanced if (max_load - min_load) > abs_threshold

            3. balance_rel_threshold: (float)
            Relative ratio threshold for load imbalance detection.
            System is potentially imbalanced if max_load > min_load * rel_threshold
            Used in conjunction with abs_threshold to determine final imbalance state.

            4. eviction_interval_secs: (integer)
            Interval between LRU eviction cycles for the approximate trees.

            5. max_tree_size: (integer)
            Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
            during the next eviction cycle.
101
        */
102
        worker_urls: Arc<RwLock<Vec<String>>>,
103
104
105
        tree: Arc<Mutex<Tree>>,
        running_queue: Arc<Mutex<HashMap<String, usize>>>,
        processed_queue: Arc<Mutex<HashMap<String, usize>>>,
106
        cache_threshold: f32,
107
108
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
109
        timeout_secs: u64,
110
        interval_secs: u64,
111
        _eviction_thread: Option<thread::JoinHandle<()>>,
112
113
114
    },
}

115
#[derive(Debug, Clone)]
116
pub enum PolicyConfig {
117
118
    RandomConfig {
        timeout_secs: u64,
119
        interval_secs: u64,
120
121
122
    },
    RoundRobinConfig {
        timeout_secs: u64,
123
        interval_secs: u64,
124
    },
125
    CacheAwareConfig {
126
        cache_threshold: f32,
127
128
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
129
130
        eviction_interval_secs: u64,
        max_tree_size: usize,
131
        timeout_secs: u64,
132
        interval_secs: u64,
133
134
135
    },
}

136
impl Router {
137
    pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Result<Self, String> {
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        // Get timeout and interval from policy config
        let (timeout_secs, interval_secs) = match &policy_config {
            PolicyConfig::RandomConfig {
                timeout_secs,
                interval_secs,
            } => (*timeout_secs, *interval_secs),
            PolicyConfig::RoundRobinConfig {
                timeout_secs,
                interval_secs,
            } => (*timeout_secs, *interval_secs),
            PolicyConfig::CacheAwareConfig {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
153
154
        };

155
        // Wait until all workers are healthy
156
        Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
157
158
159

        // Create router based on policy...
        Ok(match policy_config {
160
161
162
163
            PolicyConfig::RandomConfig {
                timeout_secs,
                interval_secs,
            } => Router::Random {
164
                worker_urls: Arc::new(RwLock::new(worker_urls)),
165
                timeout_secs,
166
                interval_secs,
167
            },
168
169
170
171
            PolicyConfig::RoundRobinConfig {
                timeout_secs,
                interval_secs,
            } => Router::RoundRobin {
172
                worker_urls: Arc::new(RwLock::new(worker_urls)),
173
                current_index: std::sync::atomic::AtomicUsize::new(0),
174
                timeout_secs,
175
                interval_secs,
176
            },
177
            PolicyConfig::CacheAwareConfig {
178
                cache_threshold,
179
180
                balance_abs_threshold,
                balance_rel_threshold,
181
182
                eviction_interval_secs,
                max_tree_size,
183
                timeout_secs,
184
                interval_secs,
185
            } => {
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
                let mut running_queue = HashMap::new();
                for url in &worker_urls {
                    running_queue.insert(url.clone(), 0);
                }

                let mut processed_queue = HashMap::new();
                for url in &worker_urls {
                    processed_queue.insert(url.clone(), 0);
                }

                let tree = Arc::new(Mutex::new(Tree::new()));
                let running_queue = Arc::new(Mutex::new(running_queue));
                let processed_queue = Arc::new(Mutex::new(processed_queue));

                // Create background eviction thread
                let tree_clone = Arc::clone(&tree);
                let processed_queue_clone = Arc::clone(&processed_queue);
203
                let running_queue_clone = Arc::clone(&running_queue);
204
205
206
207
208
209
210
                let eviction_thread = thread::spawn(move || {
                    loop {
                        // Sleep for the specified interval
                        thread::sleep(Duration::from_secs(eviction_interval_secs));

                        let locked_tree_clone = tree_clone.lock().unwrap();
                        // Run eviction
211
                        locked_tree_clone.evict_tenant_by_size(max_tree_size);
212
213
214

                        // Print the process queue
                        let locked_processed_queue = processed_queue_clone.lock().unwrap();
215
                        info!("Processed Queue: {:?}", locked_processed_queue);
216
217
218

                        // Print the running queue
                        let locked_running_queue = running_queue_clone.lock().unwrap();
219
                        info!("Running Queue: {:?}", locked_running_queue);
220
221
                    }
                });
222
223

                for url in &worker_urls {
224
                    tree.lock().unwrap().insert(&"".to_string(), url);
225
226
                }

227
                Router::CacheAware {
228
                    worker_urls: Arc::new(RwLock::new(worker_urls)),
229
230
231
                    tree,
                    running_queue,
                    processed_queue,
232
                    cache_threshold,
233
234
                    balance_abs_threshold,
                    balance_rel_threshold,
235
                    timeout_secs,
236
                    interval_secs,
237
                    _eviction_thread: Some(eviction_thread),
238
239
                }
            }
240
        })
241
242
    }

243
244
245
246
247
248
249
250
251
252
    fn wait_for_healthy_workers(
        worker_urls: &[String],
        timeout_secs: u64,
        interval_secs: u64,
    ) -> Result<(), String> {
        let start_time = std::time::Instant::now();
        let sync_client = reqwest::blocking::Client::new();

        loop {
            if start_time.elapsed() > Duration::from_secs(timeout_secs) {
253
                error!(
254
255
                    "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
                    timeout_secs, worker_urls
256
                );
257
                return Err(format!(
258
259
                    "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
                    timeout_secs, worker_urls
260
261
262
263
264
265
266
267
268
269
                ));
            }

            let mut all_healthy = true;
            let mut unhealthy_workers = Vec::new();

            for url in worker_urls {
                match sync_client.get(&format!("{}/health", url)).send() {
                    Ok(res) => {
                        if !res.status().is_success() {
Byron Hsu's avatar
Byron Hsu committed
270
271
                            let msg = format!(
                                "Worker heatlh check is pending with status {}",
272
273
                                res.status()
                            );
Byron Hsu's avatar
Byron Hsu committed
274
                            info!("{}", msg);
275
                            all_healthy = false;
Byron Hsu's avatar
Byron Hsu committed
276
                            unhealthy_workers.push((url, msg));
277
278
                        }
                    }
Byron Hsu's avatar
Byron Hsu committed
279
280
281
                    Err(_) => {
                        let msg = format!("Worker is not ready yet");
                        info!("{}", msg);
282
                        all_healthy = false;
Byron Hsu's avatar
Byron Hsu committed
283
                        unhealthy_workers.push((url, msg));
284
285
286
287
288
289
290
291
                    }
                }
            }

            if all_healthy {
                info!("All workers are healthy");
                return Ok(());
            } else {
Byron Hsu's avatar
Byron Hsu committed
292
                info!("Initializing workers:");
293
294
295
296
297
298
299
300
                for (url, reason) in &unhealthy_workers {
                    info!("  {} - {}", url, reason);
                }
                thread::sleep(Duration::from_secs(interval_secs));
            }
        }
    }

301
302
303
    fn select_first_worker(&self) -> Result<String, String> {
        match self {
            Router::RoundRobin { worker_urls, .. }
304
            | Router::Random { worker_urls, .. }
305
306
307
308
309
310
311
312
313
314
315
            | Router::CacheAware { worker_urls, .. } => {
                if worker_urls.read().unwrap().is_empty() {
                    Err("No workers are available".to_string())
                } else {
                    Ok(worker_urls.read().unwrap()[0].clone())
                }
            }
        }
    }

    async fn send_request(
316
317
        &self,
        client: &reqwest::Client,
318
        worker_url: &str,
319
        route: &str,
320
        req: &HttpRequest,
321
    ) -> HttpResponse {
322
323
324
325
326
327
328
329
330
331
        let mut request_builder = client.get(format!("{}{}", worker_url, route));

        // Copy all headers from original request except for /health because it does not need authorization
        if route != "/health" {
            for (name, value) in copy_request_headers(req) {
                request_builder = request_builder.header(name, value);
            }
        }

        match request_builder.send().await {
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
            Ok(res) => {
                let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
                    .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);

                match res.bytes().await {
                    Ok(body) => HttpResponse::build(status).body(body.to_vec()),
                    Err(e) => HttpResponse::InternalServerError()
                        .body(format!("Failed to read response body: {}", e)),
                }
            }
            Err(e) => HttpResponse::InternalServerError().body(format!(
                "Failed to send request to worker {}: {}",
                worker_url, e
            )),
        }
    }

349
350
351
352
353
354
    pub async fn route_to_first(
        &self,
        client: &reqwest::Client,
        route: &str,
        req: &HttpRequest,
    ) -> HttpResponse {
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
        const MAX_REQUEST_RETRIES: u32 = 3;
        const MAX_TOTAL_RETRIES: u32 = 6;
        let mut total_retries = 0;

        while total_retries < MAX_TOTAL_RETRIES {
            match self.select_first_worker() {
                Ok(worker_url) => {
                    let mut request_retries = 0;

                    // Try the same worker multiple times
                    while request_retries < MAX_REQUEST_RETRIES {
                        if total_retries >= 1 {
                            info!("Retrying request after {} failed attempts", total_retries);
                        }

370
                        let response = self.send_request(client, &worker_url, route, req).await;
371
372
373

                        if response.status().is_success() {
                            return response;
374
375
376
377
378
379
380
                        } else {
                            // if the worker is healthy, it means the request is bad, so return the error response
                            let health_response =
                                self.send_request(client, &worker_url, "/health", req).await;
                            if health_response.status().is_success() {
                                return response;
                            }
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
                        }

                        warn!(
                            "Request to {} failed (attempt {}/{})",
                            worker_url,
                            request_retries + 1,
                            MAX_REQUEST_RETRIES
                        );

                        request_retries += 1;
                        total_retries += 1;

                        if request_retries == MAX_REQUEST_RETRIES {
                            warn!("Removing failed worker: {}", worker_url);
                            self.remove_worker(&worker_url);
                            break;
                        }
                    }
                }
                Err(e) => return HttpResponse::InternalServerError().body(e),
            }
402
        }
403
404

        HttpResponse::InternalServerError().body("All retry attempts failed")
405
406
407
    }

    fn get_text_from_request(&self, body: &Bytes, route: &str) -> String {
408
409
410
411
412
413
        // Convert body to JSON
        let json: Value = match serde_json::from_slice(body) {
            Ok(j) => j,
            Err(_) => {
                warn!("Failed to parse JSON from request body.");
                return String::new();
414
            }
415
        };
416

417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
        match route {
            "/generate" => {
                // For /generate, always use the "text" field.
                match json.get("text").and_then(Value::as_str) {
                    Some(text) => text.to_string(),
                    None => {
                        warn!("No 'text' field found in request body for route /generate.");
                        String::new()
                    }
                }
            }
            "/v1/chat/completions" | "/v1/completions" => {
                // For these routes, try "messages", then "prompt", then "text".
                if let Some(messages) = json.get("messages") {
                    serde_json::to_string(messages).unwrap_or_default()
                } else if let Some(prompt) = json.get("prompt").and_then(Value::as_str) {
                    prompt.to_string()
                } else {
                    warn!("Failed to find 'messages', 'prompt' in request body.");
                    String::new()
                }
            }
            _ => {
                warn!("Unknown route: {} - defaulting to fallback string", route);
                String::new()
            }
        }
444
445
446
447
448
    }

    // TODO: return Result<String, String> instead of panicking
    fn select_generate_worker(&self, body: &Bytes, route: &str) -> String {
        let text = self.get_text_from_request(&body, route);
449

450
451
452
453
        let worker_url = match self {
            Router::RoundRobin {
                worker_urls,
                current_index,
454
                ..
455
            } => {
456
                let idx = current_index
457
458
459
                    .fetch_update(
                        std::sync::atomic::Ordering::SeqCst,
                        std::sync::atomic::Ordering::SeqCst,
460
                        |x| Some((x + 1) % worker_urls.read().unwrap().len()),
461
                    )
462
                    .unwrap();
463
                worker_urls.read().unwrap()[idx].clone()
464
            }
465

466
            Router::Random { worker_urls, .. } => worker_urls.read().unwrap()
467
468
                [rand::random::<usize>() % worker_urls.read().unwrap().len()]
            .clone(),
469

470
            Router::CacheAware {
471
                worker_urls,
472
473
474
                tree,
                running_queue,
                processed_queue,
475
                cache_threshold,
476
477
                balance_abs_threshold,
                balance_rel_threshold,
478
479
                ..
            } => {
480
                // TODO: delay scheduling if cache hit rate is high because it may cause imbalance. prioritize low hit rate ones
481

Byron Hsu's avatar
Byron Hsu committed
482
                let tree = tree.lock().unwrap();
483
                let mut running_queue = running_queue.lock().unwrap();
484

485
486
487
488
489
490
491
492
493
494
495
496
                // Get current load statistics
                let max_load = *running_queue.values().max().unwrap_or(&0);
                let min_load = *running_queue.values().min().unwrap_or(&0);

                // Load is considered imbalanced if:
                // 1. (max - min) > abs_threshold AND
                // 2. max > rel_threshold * min
                let is_imbalanced = max_load.saturating_sub(min_load) > *balance_abs_threshold
                    && (max_load as f32) > (min_load as f32 * balance_rel_threshold);

                let selected_url = if is_imbalanced {
                    // Log load balancing trigger and current queue state
497
                    info!(
498
499
500
501
502
503
504
505
506
507
508
                        "Load balancing triggered due to workload imbalance:\n\
                        Max load: {}, Min load: {}\n\
                        Current running queue: {:?}",
                        max_load, min_load, running_queue
                    );

                    // Use shortest queue routing when load is imbalanced
                    running_queue
                        .iter()
                        .min_by_key(|(_url, &count)| count)
                        .map(|(url, _)| url.clone())
509
                        .unwrap_or_else(|| worker_urls.read().unwrap()[0].clone())
510
511
                } else {
                    // Use cache-aware routing when load is balanced
512
513
514
                    let (matched_text, matched_worker) = tree.prefix_match(&text);
                    let matched_rate =
                        matched_text.chars().count() as f32 / text.chars().count() as f32;
515

516
517
518
519
                    if matched_rate > *cache_threshold {
                        matched_worker.to_string()
                    } else {
                        tree.get_smallest_tenant()
520
                    }
521
                };
522

523
524
                // Update queues and tree
                *running_queue.get_mut(&selected_url).unwrap() += 1;
525

526
527
528
529
530
                *processed_queue
                    .lock()
                    .unwrap()
                    .get_mut(&selected_url)
                    .unwrap() += 1;
531
532
533
                tree.insert(&text, &selected_url);

                selected_url
534
535
            }
        };
536

537
538
539
540
541
542
        worker_url
    }

    async fn send_generate_request(
        &self,
        client: &reqwest::Client,
543
544
        req: &HttpRequest,
        body: &Bytes,
545
546
547
        route: &str,
        worker_url: &str,
    ) -> HttpResponse {
548
549
550
        let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
            .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
            .unwrap_or(false);
551

552
        let mut request_builder = client
553
            .post(format!("{}{}", worker_url, route))
554
555
556
557
558
559
560
561
            .body(body.to_vec());

        // Copy all headers from original request
        for (name, value) in copy_request_headers(req) {
            request_builder = request_builder.header(name, value);
        }

        let res = match request_builder.send().await {
562
563
564
            Ok(res) => res,
            Err(_) => return HttpResponse::InternalServerError().finish(),
        };
565

566
567
        let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
            .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
568

569
        if !is_stream {
570
571
            // For non-streaming requests, get response first
            let response = match res.bytes().await {
572
                Ok(body) => HttpResponse::build(status).body(body.to_vec()),
573
574
575
576
                Err(e) => {
                    let error_msg = format!("Failed to get response body: {}", e);
                    HttpResponse::InternalServerError().body(error_msg)
                }
577
578
579
580
581
            };

            // Then decrement running queue counter if using CacheAware
            if let Router::CacheAware { running_queue, .. } = self {
                if let Ok(mut queue) = running_queue.lock() {
582
                    if let Some(count) = queue.get_mut(worker_url) {
583
584
585
                        *count = count.saturating_sub(1);
                    }
                }
586
            }
587
588
589
590

            response
        } else if let Router::CacheAware { running_queue, .. } = self {
            let running_queue = Arc::clone(running_queue);
591
            let worker_url = worker_url.to_string();
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609

            HttpResponse::build(status)
                .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
                .streaming(
                    res.bytes_stream()
                        .map_err(|_| {
                            actix_web::error::ErrorInternalServerError("Failed to read stream")
                        })
                        .inspect(move |bytes| {
                            let bytes = bytes.as_ref().unwrap();
                            if bytes
                                .as_ref()
                                .windows(12)
                                .any(|window| window == b"data: [DONE]")
                            {
                                let mut locked_queue = running_queue.lock().unwrap();
                                let count = locked_queue.get_mut(&worker_url).unwrap();
                                *count = count.saturating_sub(1);
610
                                debug!("Streaming is done!!")
611
612
613
                            }
                        }),
                )
614
615
616
617
        } else {
            HttpResponse::build(status)
                .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
                .streaming(res.bytes_stream().map_err(|_| {
618
                    actix_web::error::ErrorInternalServerError("Failed to read stream")
619
                }))
620
621
        }
    }
622

623
624
625
    pub async fn route_generate_request(
        &self,
        client: &reqwest::Client,
626
627
        req: &HttpRequest,
        body: &Bytes,
628
629
        route: &str,
    ) -> HttpResponse {
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
        const MAX_REQUEST_RETRIES: u32 = 3;
        const MAX_TOTAL_RETRIES: u32 = 6;
        let mut total_retries = 0;

        while total_retries < MAX_TOTAL_RETRIES {
            let worker_url = self.select_generate_worker(body, route);
            let mut request_retries = 0;

            // Try the same worker multiple times
            while request_retries < MAX_REQUEST_RETRIES {
                if total_retries >= 1 {
                    info!("Retrying request after {} failed attempts", total_retries);
                }
                let response = self
                    .send_generate_request(client, req, body, route, &worker_url)
                    .await;

                if response.status().is_success() {
                    return response;
649
650
651
652
653
654
655
                } else {
                    // if the worker is healthy, it means the request is bad, so return the error response
                    let health_response =
                        self.send_request(client, &worker_url, "/health", req).await;
                    if health_response.status().is_success() {
                        return response;
                    }
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
                }

                warn!(
                    "Generate request to {} failed (attempt {}/{})",
                    worker_url,
                    request_retries + 1,
                    MAX_REQUEST_RETRIES
                );

                request_retries += 1;
                total_retries += 1;

                if request_retries == MAX_REQUEST_RETRIES {
                    warn!("Removing failed worker: {}", worker_url);
                    self.remove_worker(&worker_url);
                    break;
                }
            }
        }

        HttpResponse::InternalServerError().body("All retry attempts failed")
677
678
    }

679
    pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
        let (timeout_secs, interval_secs) = match self {
            Router::Random {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
            Router::RoundRobin {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
            Router::CacheAware {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
696
        };
697
698
699
700
701
702

        let start_time = std::time::Instant::now();
        let client = reqwest::Client::new();

        loop {
            if start_time.elapsed() > Duration::from_secs(timeout_secs) {
703
                error!(
704
                    "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
705
706
                    timeout_secs, worker_url
                );
707
                return Err(format!(
708
                    "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
709
710
711
712
713
714
715
716
717
                    timeout_secs, worker_url
                ));
            }

            match client.get(&format!("{}/health", worker_url)).send().await {
                Ok(res) => {
                    if res.status().is_success() {
                        match self {
                            Router::RoundRobin { worker_urls, .. }
718
                            | Router::Random { worker_urls, .. }
719
720
721
                            | Router::CacheAware { worker_urls, .. } => {
                                info!("Worker {} health check passed", worker_url);
                                let mut urls = worker_urls.write().unwrap();
722
                                if urls.contains(&worker_url.to_string()) {
723
                                    return Err(format!("Worker {} already exists", worker_url));
724
725
                                }
                                info!("Added worker: {}", worker_url);
726
                                urls.push(worker_url.to_string());
727
728
                            }
                        }
729
730
731
732
733
734
735
736
737
738

                        // If cache aware, initialize the queues for the new worker
                        if let Router::CacheAware {
                            running_queue,
                            processed_queue,
                            tree,
                            ..
                        } = self
                        {
                            // Add worker to running queue with initial count of 0
739
740
741
742
                            running_queue
                                .lock()
                                .unwrap()
                                .insert(worker_url.to_string(), 0);
743
744
745
746
747

                            // Add worker to processed queue with initial count of 0
                            processed_queue
                                .lock()
                                .unwrap()
748
                                .insert(worker_url.to_string(), 0);
749
750
751
752
753
754

                            // Add worker to tree
                            tree.lock().unwrap().insert(&"".to_string(), &worker_url);
                        }

                        return Ok(format!("Successfully added worker: {}", worker_url));
755
756
                    } else {
                        info!(
757
758
759
                            "Worker {} health check is pending with status: {}.",
                            worker_url,
                            res.status()
760
761
762
763
764
765
766
767
768
769
770
771
                        );
                        // if the url does not have http or https prefix, warn users
                        if !worker_url.starts_with("http://") && !worker_url.starts_with("https://")
                        {
                            warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
                        }

                        tokio::time::sleep(Duration::from_secs(interval_secs)).await;
                        continue;
                    }
                }
                Err(e) => {
772
773
774
775
                    info!(
                        "Worker {} health check is pending with error: {}",
                        worker_url, e
                    );
776
777
778
779
780
781
782
783
784

                    // if the url does not have http or https prefix, warn users
                    if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") {
                        warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
                    }

                    tokio::time::sleep(Duration::from_secs(interval_secs)).await;
                    continue;
                }
785
786
787
            }
        }
    }
788

789
    pub fn remove_worker(&self, worker_url: &str) {
790
791
        match self {
            Router::RoundRobin { worker_urls, .. }
792
            | Router::Random { worker_urls, .. }
793
794
            | Router::CacheAware { worker_urls, .. } => {
                let mut urls = worker_urls.write().unwrap();
795
796
797
798
799
800
801
                if let Some(index) = urls.iter().position(|url| url == &worker_url) {
                    urls.remove(index);
                    info!("Removed worker: {}", worker_url);
                } else {
                    warn!("Worker {} not found, skipping removal", worker_url);
                    return;
                }
802
803
804
805
            }
        }

        // if cache aware, remove the worker from the tree
806
807
808
809
810
811
812
        if let Router::CacheAware {
            tree,
            running_queue,
            processed_queue,
            ..
        } = self
        {
813
            tree.lock().unwrap().remove_tenant(&worker_url);
814
815
816
817
818
819
820
821
            running_queue
                .lock()
                .unwrap()
                .remove(&worker_url.to_string());
            processed_queue
                .lock()
                .unwrap()
                .remove(&worker_url.to_string());
822
823
824
825
            info!(
                "Removed worker from tree and cleaned up queues: {}",
                worker_url
            );
826
827
        }
    }
828
}