router.rs 32.5 KB
Newer Older
1
use crate::tree::Tree;
2
3
4
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse};
use bytes::Bytes;
Byron Hsu's avatar
Byron Hsu committed
5
use futures_util::{StreamExt, TryStreamExt};
6
use serde_json::Value;
7
use std::collections::HashMap;
8
use std::fmt::Debug;
9
use std::sync::atomic::AtomicUsize;
10
use std::sync::{Arc, Mutex, RwLock};
11
12
use std::thread;
use std::time::Duration;
13
use tokio;
14
use tracing::{debug, error, info, warn};
15

16
17
18
19
20
21
22
23
24
25
26
27
fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
    req.headers()
        .iter()
        .filter_map(|(name, value)| {
            value
                .to_str()
                .ok()
                .map(|v| (name.to_string(), v.to_string()))
        })
        .collect()
}

28
#[derive(Debug)]
29
30
pub enum Router {
    RoundRobin {
31
        worker_urls: Arc<RwLock<Vec<String>>>,
32
        current_index: AtomicUsize,
33
        timeout_secs: u64,
34
        interval_secs: u64,
35
36
    },
    Random {
37
        worker_urls: Arc<RwLock<Vec<String>>>,
38
        timeout_secs: u64,
39
        interval_secs: u64,
40
    },
41
42
    CacheAware {
        /*
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
            Cache-Aware Load Balancing Router

            This router combines two strategies to optimize both cache utilization and request distribution:

            1. Cache-Aware Routing (Approximate Tree)
            2. Load Balancing (Shortest Queue with Balance Thresholds)

            The router dynamically switches between these strategies based on load conditions:
            - Uses load balancing when the system is imbalanced
            - Uses cache-aware routing when the system is balanced

            A system is considered imbalanced if both conditions are met:
            1. (max - min) > abs_threshold
            2. max > rel_threshold * min

            Strategy Details:

            1. Cache-Aware Routing (Approximate Tree)
            -------------------------------------------
            This strategy maintains an approximate radix tree for each worker based on request history,
            eliminating the need for direct cache state queries. The tree stores raw text characters
            instead of token IDs to avoid tokenization overhead.

            Process:
            a. For each request, find the worker with the highest prefix match
            b. If match rate > cache_threshold:
            Route to the worker with highest match (likely has relevant data cached)
            c. If match rate ≤ cache_threshold:
            Route to the worker with smallest tree size (most available cache capacity)
            d. Background maintenance:
            Periodically evict least recently used leaf nodes to prevent memory overflow

            2. Load Balancing (Shortest Queue)
            -------------------------------------------
            This strategy tracks pending request counts per worker and routes new requests
            to the least busy worker when the system is detected to be imbalanced.

            Configuration Parameters:
            ------------------------
            1. cache_threshold: (float, 0.0 to 1.0)
            Minimum prefix match ratio to use highest-match routing.
            Below this threshold, routes to worker with most available cache space.

            2. balance_abs_threshold: (integer)
            Absolute difference threshold for load imbalance detection.
            System is potentially imbalanced if (max_load - min_load) > abs_threshold

            3. balance_rel_threshold: (float)
            Relative ratio threshold for load imbalance detection.
            System is potentially imbalanced if max_load > min_load * rel_threshold
            Used in conjunction with abs_threshold to determine final imbalance state.

            4. eviction_interval_secs: (integer)
            Interval between LRU eviction cycles for the approximate trees.

            5. max_tree_size: (integer)
            Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
            during the next eviction cycle.
101
        */
102
        worker_urls: Arc<RwLock<Vec<String>>>,
103
104
105
        tree: Arc<Mutex<Tree>>,
        running_queue: Arc<Mutex<HashMap<String, usize>>>,
        processed_queue: Arc<Mutex<HashMap<String, usize>>>,
106
        cache_threshold: f32,
107
108
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
109
        timeout_secs: u64,
110
        interval_secs: u64,
111
        _eviction_thread: Option<thread::JoinHandle<()>>,
112
113
114
    },
}

115
#[derive(Debug, Clone)]
116
pub enum PolicyConfig {
117
118
    RandomConfig {
        timeout_secs: u64,
119
        interval_secs: u64,
120
121
122
    },
    RoundRobinConfig {
        timeout_secs: u64,
123
        interval_secs: u64,
124
    },
125
    CacheAwareConfig {
126
        cache_threshold: f32,
127
128
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
129
130
        eviction_interval_secs: u64,
        max_tree_size: usize,
131
        timeout_secs: u64,
132
        interval_secs: u64,
133
134
135
    },
}

136
impl Router {
137
    pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Result<Self, String> {
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        // Get timeout and interval from policy config
        let (timeout_secs, interval_secs) = match &policy_config {
            PolicyConfig::RandomConfig {
                timeout_secs,
                interval_secs,
            } => (*timeout_secs, *interval_secs),
            PolicyConfig::RoundRobinConfig {
                timeout_secs,
                interval_secs,
            } => (*timeout_secs, *interval_secs),
            PolicyConfig::CacheAwareConfig {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
153
154
        };

155
        // Wait until all workers are healthy
156
        Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
157
158
159

        // Create router based on policy...
        Ok(match policy_config {
160
161
162
163
            PolicyConfig::RandomConfig {
                timeout_secs,
                interval_secs,
            } => Router::Random {
164
                worker_urls: Arc::new(RwLock::new(worker_urls)),
165
                timeout_secs,
166
                interval_secs,
167
            },
168
169
170
171
            PolicyConfig::RoundRobinConfig {
                timeout_secs,
                interval_secs,
            } => Router::RoundRobin {
172
                worker_urls: Arc::new(RwLock::new(worker_urls)),
173
                current_index: std::sync::atomic::AtomicUsize::new(0),
174
                timeout_secs,
175
                interval_secs,
176
            },
177
            PolicyConfig::CacheAwareConfig {
178
                cache_threshold,
179
180
                balance_abs_threshold,
                balance_rel_threshold,
181
182
                eviction_interval_secs,
                max_tree_size,
183
                timeout_secs,
184
                interval_secs,
185
            } => {
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
                let mut running_queue = HashMap::new();
                for url in &worker_urls {
                    running_queue.insert(url.clone(), 0);
                }

                let mut processed_queue = HashMap::new();
                for url in &worker_urls {
                    processed_queue.insert(url.clone(), 0);
                }

                let tree = Arc::new(Mutex::new(Tree::new()));
                let running_queue = Arc::new(Mutex::new(running_queue));
                let processed_queue = Arc::new(Mutex::new(processed_queue));

                // Create background eviction thread
                let tree_clone = Arc::clone(&tree);
                let processed_queue_clone = Arc::clone(&processed_queue);
203
                let running_queue_clone = Arc::clone(&running_queue);
204
205
206
207
208
209
210
                let eviction_thread = thread::spawn(move || {
                    loop {
                        // Sleep for the specified interval
                        thread::sleep(Duration::from_secs(eviction_interval_secs));

                        let locked_tree_clone = tree_clone.lock().unwrap();
                        // Run eviction
211
                        locked_tree_clone.evict_tenant_by_size(max_tree_size);
212
213
214

                        // Print the process queue
                        let locked_processed_queue = processed_queue_clone.lock().unwrap();
215
                        info!("Processed Queue: {:?}", locked_processed_queue);
216
217
218

                        // Print the running queue
                        let locked_running_queue = running_queue_clone.lock().unwrap();
219
                        info!("Running Queue: {:?}", locked_running_queue);
220
221
                    }
                });
222
223

                for url in &worker_urls {
224
                    tree.lock().unwrap().insert(&"".to_string(), url);
225
226
                }

227
                Router::CacheAware {
228
                    worker_urls: Arc::new(RwLock::new(worker_urls)),
229
230
231
                    tree,
                    running_queue,
                    processed_queue,
232
                    cache_threshold,
233
234
                    balance_abs_threshold,
                    balance_rel_threshold,
235
                    timeout_secs,
236
                    interval_secs,
237
                    _eviction_thread: Some(eviction_thread),
238
239
                }
            }
240
        })
241
242
    }

243
244
245
246
247
248
249
250
251
    /// Get a reference to the worker URLs shared across threads
    pub fn get_worker_urls(&self) -> Arc<RwLock<Vec<String>>> {
        match self {
            Router::RoundRobin { worker_urls, .. } => Arc::clone(worker_urls),
            Router::Random { worker_urls, .. } => Arc::clone(worker_urls),
            Router::CacheAware { worker_urls, .. } => Arc::clone(worker_urls),
        }
    }

252
253
254
255
256
257
258
259
260
261
    fn wait_for_healthy_workers(
        worker_urls: &[String],
        timeout_secs: u64,
        interval_secs: u64,
    ) -> Result<(), String> {
        let start_time = std::time::Instant::now();
        let sync_client = reqwest::blocking::Client::new();

        loop {
            if start_time.elapsed() > Duration::from_secs(timeout_secs) {
262
                error!(
263
264
                    "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
                    timeout_secs, worker_urls
265
                );
266
                return Err(format!(
267
268
                    "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
                    timeout_secs, worker_urls
269
270
271
272
273
274
275
276
277
278
                ));
            }

            let mut all_healthy = true;
            let mut unhealthy_workers = Vec::new();

            for url in worker_urls {
                match sync_client.get(&format!("{}/health", url)).send() {
                    Ok(res) => {
                        if !res.status().is_success() {
Byron Hsu's avatar
Byron Hsu committed
279
280
                            let msg = format!(
                                "Worker heatlh check is pending with status {}",
281
282
                                res.status()
                            );
Byron Hsu's avatar
Byron Hsu committed
283
                            info!("{}", msg);
284
                            all_healthy = false;
Byron Hsu's avatar
Byron Hsu committed
285
                            unhealthy_workers.push((url, msg));
286
287
                        }
                    }
Byron Hsu's avatar
Byron Hsu committed
288
289
290
                    Err(_) => {
                        let msg = format!("Worker is not ready yet");
                        info!("{}", msg);
291
                        all_healthy = false;
Byron Hsu's avatar
Byron Hsu committed
292
                        unhealthy_workers.push((url, msg));
293
294
295
296
297
298
299
300
                    }
                }
            }

            if all_healthy {
                info!("All workers are healthy");
                return Ok(());
            } else {
Byron Hsu's avatar
Byron Hsu committed
301
                info!("Initializing workers:");
302
303
304
305
306
307
308
309
                for (url, reason) in &unhealthy_workers {
                    info!("  {} - {}", url, reason);
                }
                thread::sleep(Duration::from_secs(interval_secs));
            }
        }
    }

310
311
312
    fn select_first_worker(&self) -> Result<String, String> {
        match self {
            Router::RoundRobin { worker_urls, .. }
313
            | Router::Random { worker_urls, .. }
314
315
316
317
318
319
320
321
322
323
324
            | Router::CacheAware { worker_urls, .. } => {
                if worker_urls.read().unwrap().is_empty() {
                    Err("No workers are available".to_string())
                } else {
                    Ok(worker_urls.read().unwrap()[0].clone())
                }
            }
        }
    }

    async fn send_request(
325
326
        &self,
        client: &reqwest::Client,
327
        worker_url: &str,
328
        route: &str,
329
        req: &HttpRequest,
330
    ) -> HttpResponse {
331
332
333
334
335
336
337
338
339
340
        let mut request_builder = client.get(format!("{}{}", worker_url, route));

        // Copy all headers from original request except for /health because it does not need authorization
        if route != "/health" {
            for (name, value) in copy_request_headers(req) {
                request_builder = request_builder.header(name, value);
            }
        }

        match request_builder.send().await {
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
            Ok(res) => {
                let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
                    .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);

                match res.bytes().await {
                    Ok(body) => HttpResponse::build(status).body(body.to_vec()),
                    Err(e) => HttpResponse::InternalServerError()
                        .body(format!("Failed to read response body: {}", e)),
                }
            }
            Err(e) => HttpResponse::InternalServerError().body(format!(
                "Failed to send request to worker {}: {}",
                worker_url, e
            )),
        }
    }

358
359
360
361
362
363
    pub async fn route_to_first(
        &self,
        client: &reqwest::Client,
        route: &str,
        req: &HttpRequest,
    ) -> HttpResponse {
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
        const MAX_REQUEST_RETRIES: u32 = 3;
        const MAX_TOTAL_RETRIES: u32 = 6;
        let mut total_retries = 0;

        while total_retries < MAX_TOTAL_RETRIES {
            match self.select_first_worker() {
                Ok(worker_url) => {
                    let mut request_retries = 0;

                    // Try the same worker multiple times
                    while request_retries < MAX_REQUEST_RETRIES {
                        if total_retries >= 1 {
                            info!("Retrying request after {} failed attempts", total_retries);
                        }

379
                        let response = self.send_request(client, &worker_url, route, req).await;
380
381
382

                        if response.status().is_success() {
                            return response;
383
384
385
386
387
388
389
                        } else {
                            // if the worker is healthy, it means the request is bad, so return the error response
                            let health_response =
                                self.send_request(client, &worker_url, "/health", req).await;
                            if health_response.status().is_success() {
                                return response;
                            }
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
                        }

                        warn!(
                            "Request to {} failed (attempt {}/{})",
                            worker_url,
                            request_retries + 1,
                            MAX_REQUEST_RETRIES
                        );

                        request_retries += 1;
                        total_retries += 1;

                        if request_retries == MAX_REQUEST_RETRIES {
                            warn!("Removing failed worker: {}", worker_url);
                            self.remove_worker(&worker_url);
                            break;
                        }
                    }
                }
                Err(e) => return HttpResponse::InternalServerError().body(e),
            }
411
        }
412
413

        HttpResponse::InternalServerError().body("All retry attempts failed")
414
415
416
    }

    fn get_text_from_request(&self, body: &Bytes, route: &str) -> String {
417
418
419
420
421
422
        // Convert body to JSON
        let json: Value = match serde_json::from_slice(body) {
            Ok(j) => j,
            Err(_) => {
                warn!("Failed to parse JSON from request body.");
                return String::new();
423
            }
424
        };
425

426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
        match route {
            "/generate" => {
                // For /generate, always use the "text" field.
                match json.get("text").and_then(Value::as_str) {
                    Some(text) => text.to_string(),
                    None => {
                        warn!("No 'text' field found in request body for route /generate.");
                        String::new()
                    }
                }
            }
            "/v1/chat/completions" | "/v1/completions" => {
                // For these routes, try "messages", then "prompt", then "text".
                if let Some(messages) = json.get("messages") {
                    serde_json::to_string(messages).unwrap_or_default()
                } else if let Some(prompt) = json.get("prompt").and_then(Value::as_str) {
                    prompt.to_string()
                } else {
                    warn!("Failed to find 'messages', 'prompt' in request body.");
                    String::new()
                }
            }
            _ => {
                warn!("Unknown route: {} - defaulting to fallback string", route);
                String::new()
            }
        }
453
454
455
456
457
    }

    // TODO: return Result<String, String> instead of panicking
    fn select_generate_worker(&self, body: &Bytes, route: &str) -> String {
        let text = self.get_text_from_request(&body, route);
458

459
460
461
462
        let worker_url = match self {
            Router::RoundRobin {
                worker_urls,
                current_index,
463
                ..
464
            } => {
465
                let idx = current_index
466
467
468
                    .fetch_update(
                        std::sync::atomic::Ordering::SeqCst,
                        std::sync::atomic::Ordering::SeqCst,
469
                        |x| Some((x + 1) % worker_urls.read().unwrap().len()),
470
                    )
471
                    .unwrap();
472
                worker_urls.read().unwrap()[idx].clone()
473
            }
474

475
            Router::Random { worker_urls, .. } => worker_urls.read().unwrap()
476
477
                [rand::random::<usize>() % worker_urls.read().unwrap().len()]
            .clone(),
478

479
            Router::CacheAware {
480
                worker_urls,
481
482
483
                tree,
                running_queue,
                processed_queue,
484
                cache_threshold,
485
486
                balance_abs_threshold,
                balance_rel_threshold,
487
488
                ..
            } => {
489
                // TODO: delay scheduling if cache hit rate is high because it may cause imbalance. prioritize low hit rate ones
490

Byron Hsu's avatar
Byron Hsu committed
491
                let tree = tree.lock().unwrap();
492
                let mut running_queue = running_queue.lock().unwrap();
493

494
495
496
497
498
499
500
501
502
503
504
505
                // Get current load statistics
                let max_load = *running_queue.values().max().unwrap_or(&0);
                let min_load = *running_queue.values().min().unwrap_or(&0);

                // Load is considered imbalanced if:
                // 1. (max - min) > abs_threshold AND
                // 2. max > rel_threshold * min
                let is_imbalanced = max_load.saturating_sub(min_load) > *balance_abs_threshold
                    && (max_load as f32) > (min_load as f32 * balance_rel_threshold);

                let selected_url = if is_imbalanced {
                    // Log load balancing trigger and current queue state
506
                    info!(
507
508
509
510
511
512
513
514
515
516
517
                        "Load balancing triggered due to workload imbalance:\n\
                        Max load: {}, Min load: {}\n\
                        Current running queue: {:?}",
                        max_load, min_load, running_queue
                    );

                    // Use shortest queue routing when load is imbalanced
                    running_queue
                        .iter()
                        .min_by_key(|(_url, &count)| count)
                        .map(|(url, _)| url.clone())
518
                        .unwrap_or_else(|| worker_urls.read().unwrap()[0].clone())
519
520
                } else {
                    // Use cache-aware routing when load is balanced
521
522
523
                    let (matched_text, matched_worker) = tree.prefix_match(&text);
                    let matched_rate =
                        matched_text.chars().count() as f32 / text.chars().count() as f32;
524

525
526
527
528
                    if matched_rate > *cache_threshold {
                        matched_worker.to_string()
                    } else {
                        tree.get_smallest_tenant()
529
                    }
530
                };
531

532
533
                // Update queues and tree
                *running_queue.get_mut(&selected_url).unwrap() += 1;
534

535
536
537
538
539
                *processed_queue
                    .lock()
                    .unwrap()
                    .get_mut(&selected_url)
                    .unwrap() += 1;
540
541
542
                tree.insert(&text, &selected_url);

                selected_url
543
544
            }
        };
545

546
547
548
549
550
551
        worker_url
    }

    async fn send_generate_request(
        &self,
        client: &reqwest::Client,
552
553
        req: &HttpRequest,
        body: &Bytes,
554
555
556
        route: &str,
        worker_url: &str,
    ) -> HttpResponse {
557
558
559
        let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
            .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
            .unwrap_or(false);
560

561
        let mut request_builder = client
562
            .post(format!("{}{}", worker_url, route))
563
564
565
566
567
568
569
570
            .body(body.to_vec());

        // Copy all headers from original request
        for (name, value) in copy_request_headers(req) {
            request_builder = request_builder.header(name, value);
        }

        let res = match request_builder.send().await {
571
572
573
            Ok(res) => res,
            Err(_) => return HttpResponse::InternalServerError().finish(),
        };
574

575
576
        let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
            .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
577

578
        if !is_stream {
579
580
            // For non-streaming requests, get response first
            let response = match res.bytes().await {
581
                Ok(body) => HttpResponse::build(status).body(body.to_vec()),
582
583
584
585
                Err(e) => {
                    let error_msg = format!("Failed to get response body: {}", e);
                    HttpResponse::InternalServerError().body(error_msg)
                }
586
587
588
589
590
            };

            // Then decrement running queue counter if using CacheAware
            if let Router::CacheAware { running_queue, .. } = self {
                if let Ok(mut queue) = running_queue.lock() {
591
                    if let Some(count) = queue.get_mut(worker_url) {
592
593
594
                        *count = count.saturating_sub(1);
                    }
                }
595
            }
596
597
598
599

            response
        } else if let Router::CacheAware { running_queue, .. } = self {
            let running_queue = Arc::clone(running_queue);
600
            let worker_url = worker_url.to_string();
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618

            HttpResponse::build(status)
                .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
                .streaming(
                    res.bytes_stream()
                        .map_err(|_| {
                            actix_web::error::ErrorInternalServerError("Failed to read stream")
                        })
                        .inspect(move |bytes| {
                            let bytes = bytes.as_ref().unwrap();
                            if bytes
                                .as_ref()
                                .windows(12)
                                .any(|window| window == b"data: [DONE]")
                            {
                                let mut locked_queue = running_queue.lock().unwrap();
                                let count = locked_queue.get_mut(&worker_url).unwrap();
                                *count = count.saturating_sub(1);
619
                                debug!("Streaming is done!!")
620
621
622
                            }
                        }),
                )
623
624
625
626
        } else {
            HttpResponse::build(status)
                .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
                .streaming(res.bytes_stream().map_err(|_| {
627
                    actix_web::error::ErrorInternalServerError("Failed to read stream")
628
                }))
629
630
        }
    }
631

632
633
634
    pub async fn route_generate_request(
        &self,
        client: &reqwest::Client,
635
636
        req: &HttpRequest,
        body: &Bytes,
637
638
        route: &str,
    ) -> HttpResponse {
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
        const MAX_REQUEST_RETRIES: u32 = 3;
        const MAX_TOTAL_RETRIES: u32 = 6;
        let mut total_retries = 0;

        while total_retries < MAX_TOTAL_RETRIES {
            let worker_url = self.select_generate_worker(body, route);
            let mut request_retries = 0;

            // Try the same worker multiple times
            while request_retries < MAX_REQUEST_RETRIES {
                if total_retries >= 1 {
                    info!("Retrying request after {} failed attempts", total_retries);
                }
                let response = self
                    .send_generate_request(client, req, body, route, &worker_url)
                    .await;

                if response.status().is_success() {
                    return response;
658
659
660
661
662
663
664
                } else {
                    // if the worker is healthy, it means the request is bad, so return the error response
                    let health_response =
                        self.send_request(client, &worker_url, "/health", req).await;
                    if health_response.status().is_success() {
                        return response;
                    }
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
                }

                warn!(
                    "Generate request to {} failed (attempt {}/{})",
                    worker_url,
                    request_retries + 1,
                    MAX_REQUEST_RETRIES
                );

                request_retries += 1;
                total_retries += 1;

                if request_retries == MAX_REQUEST_RETRIES {
                    warn!("Removing failed worker: {}", worker_url);
                    self.remove_worker(&worker_url);
                    break;
                }
            }
        }

        HttpResponse::InternalServerError().body("All retry attempts failed")
686
687
    }

688
    pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
        let (timeout_secs, interval_secs) = match self {
            Router::Random {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
            Router::RoundRobin {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
            Router::CacheAware {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
705
        };
706
707
708
709
710
711

        let start_time = std::time::Instant::now();
        let client = reqwest::Client::new();

        loop {
            if start_time.elapsed() > Duration::from_secs(timeout_secs) {
712
                error!(
713
                    "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
714
715
                    timeout_secs, worker_url
                );
716
                return Err(format!(
717
                    "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
718
719
720
721
722
723
724
725
726
                    timeout_secs, worker_url
                ));
            }

            match client.get(&format!("{}/health", worker_url)).send().await {
                Ok(res) => {
                    if res.status().is_success() {
                        match self {
                            Router::RoundRobin { worker_urls, .. }
727
                            | Router::Random { worker_urls, .. }
728
729
730
                            | Router::CacheAware { worker_urls, .. } => {
                                info!("Worker {} health check passed", worker_url);
                                let mut urls = worker_urls.write().unwrap();
731
                                if urls.contains(&worker_url.to_string()) {
732
                                    return Err(format!("Worker {} already exists", worker_url));
733
734
                                }
                                info!("Added worker: {}", worker_url);
735
                                urls.push(worker_url.to_string());
736
737
                            }
                        }
738
739
740
741
742
743
744
745
746
747

                        // If cache aware, initialize the queues for the new worker
                        if let Router::CacheAware {
                            running_queue,
                            processed_queue,
                            tree,
                            ..
                        } = self
                        {
                            // Add worker to running queue with initial count of 0
748
749
750
751
                            running_queue
                                .lock()
                                .unwrap()
                                .insert(worker_url.to_string(), 0);
752
753
754
755
756

                            // Add worker to processed queue with initial count of 0
                            processed_queue
                                .lock()
                                .unwrap()
757
                                .insert(worker_url.to_string(), 0);
758
759
760
761
762
763

                            // Add worker to tree
                            tree.lock().unwrap().insert(&"".to_string(), &worker_url);
                        }

                        return Ok(format!("Successfully added worker: {}", worker_url));
764
765
                    } else {
                        info!(
766
767
768
                            "Worker {} health check is pending with status: {}.",
                            worker_url,
                            res.status()
769
770
771
772
773
774
775
776
777
778
779
780
                        );
                        // if the url does not have http or https prefix, warn users
                        if !worker_url.starts_with("http://") && !worker_url.starts_with("https://")
                        {
                            warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
                        }

                        tokio::time::sleep(Duration::from_secs(interval_secs)).await;
                        continue;
                    }
                }
                Err(e) => {
781
782
783
784
                    info!(
                        "Worker {} health check is pending with error: {}",
                        worker_url, e
                    );
785
786
787
788
789
790
791
792
793

                    // if the url does not have http or https prefix, warn users
                    if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") {
                        warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
                    }

                    tokio::time::sleep(Duration::from_secs(interval_secs)).await;
                    continue;
                }
794
795
796
            }
        }
    }
797

798
    pub fn remove_worker(&self, worker_url: &str) {
799
800
        match self {
            Router::RoundRobin { worker_urls, .. }
801
            | Router::Random { worker_urls, .. }
802
803
            | Router::CacheAware { worker_urls, .. } => {
                let mut urls = worker_urls.write().unwrap();
804
805
806
807
808
809
810
                if let Some(index) = urls.iter().position(|url| url == &worker_url) {
                    urls.remove(index);
                    info!("Removed worker: {}", worker_url);
                } else {
                    warn!("Worker {} not found, skipping removal", worker_url);
                    return;
                }
811
812
813
814
            }
        }

        // if cache aware, remove the worker from the tree
815
816
817
818
819
820
821
        if let Router::CacheAware {
            tree,
            running_queue,
            processed_queue,
            ..
        } = self
        {
822
            tree.lock().unwrap().remove_tenant(&worker_url);
823
824
825
826
827
828
829
830
            running_queue
                .lock()
                .unwrap()
                .remove(&worker_url.to_string());
            processed_queue
                .lock()
                .unwrap()
                .remove(&worker_url.to_string());
831
832
833
834
            info!(
                "Removed worker from tree and cleaned up queues: {}",
                worker_url
            );
835
836
        }
    }
837
}