router.rs 31.5 KB
Newer Older
1
use crate::tree::Tree;
2
3
4
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse};
use bytes::Bytes;
Byron Hsu's avatar
Byron Hsu committed
5
use futures_util::{StreamExt, TryStreamExt};
6
use log::{debug, error, info, warn};
7
use std::collections::HashMap;
8
use std::fmt::Debug;
9
use std::sync::atomic::AtomicUsize;
10
use std::sync::{Arc, Mutex, RwLock};
11
12
use std::thread;
use std::time::Duration;
13
use tokio;
14

15
16
17
18
19
20
21
22
23
24
25
26
fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
    req.headers()
        .iter()
        .filter_map(|(name, value)| {
            value
                .to_str()
                .ok()
                .map(|v| (name.to_string(), v.to_string()))
        })
        .collect()
}

27
#[derive(Debug)]
28
29
pub enum Router {
    RoundRobin {
30
        worker_urls: Arc<RwLock<Vec<String>>>,
31
        current_index: AtomicUsize,
32
        timeout_secs: u64,
33
        interval_secs: u64,
34
35
    },
    Random {
36
        worker_urls: Arc<RwLock<Vec<String>>>,
37
        timeout_secs: u64,
38
        interval_secs: u64,
39
    },
40
41
    CacheAware {
        /*
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
            Cache-Aware Load Balancing Router

            This router combines two strategies to optimize both cache utilization and request distribution:

            1. Cache-Aware Routing (Approximate Tree)
            2. Load Balancing (Shortest Queue with Balance Thresholds)

            The router dynamically switches between these strategies based on load conditions:
            - Uses load balancing when the system is imbalanced
            - Uses cache-aware routing when the system is balanced

            A system is considered imbalanced if both conditions are met:
            1. (max - min) > abs_threshold
            2. max > rel_threshold * min

            Strategy Details:

            1. Cache-Aware Routing (Approximate Tree)
            -------------------------------------------
            This strategy maintains an approximate radix tree for each worker based on request history,
            eliminating the need for direct cache state queries. The tree stores raw text characters
            instead of token IDs to avoid tokenization overhead.

            Process:
            a. For each request, find the worker with the highest prefix match
            b. If match rate > cache_threshold:
            Route to the worker with highest match (likely has relevant data cached)
            c. If match rate ≤ cache_threshold:
            Route to the worker with smallest tree size (most available cache capacity)
            d. Background maintenance:
            Periodically evict least recently used leaf nodes to prevent memory overflow

            2. Load Balancing (Shortest Queue)
            -------------------------------------------
            This strategy tracks pending request counts per worker and routes new requests
            to the least busy worker when the system is detected to be imbalanced.

            Configuration Parameters:
            ------------------------
            1. cache_threshold: (float, 0.0 to 1.0)
            Minimum prefix match ratio to use highest-match routing.
            Below this threshold, routes to worker with most available cache space.

            2. balance_abs_threshold: (integer)
            Absolute difference threshold for load imbalance detection.
            System is potentially imbalanced if (max_load - min_load) > abs_threshold

            3. balance_rel_threshold: (float)
            Relative ratio threshold for load imbalance detection.
            System is potentially imbalanced if max_load > min_load * rel_threshold
            Used in conjunction with abs_threshold to determine final imbalance state.

            4. eviction_interval_secs: (integer)
            Interval between LRU eviction cycles for the approximate trees.

            5. max_tree_size: (integer)
            Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
            during the next eviction cycle.
100
        */
101
        worker_urls: Arc<RwLock<Vec<String>>>,
102
103
104
        tree: Arc<Mutex<Tree>>,
        running_queue: Arc<Mutex<HashMap<String, usize>>>,
        processed_queue: Arc<Mutex<HashMap<String, usize>>>,
105
        cache_threshold: f32,
106
107
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
108
        timeout_secs: u64,
109
        interval_secs: u64,
110
        _eviction_thread: Option<thread::JoinHandle<()>>,
111
112
113
    },
}

114
#[derive(Debug, Clone)]
115
pub enum PolicyConfig {
116
117
    RandomConfig {
        timeout_secs: u64,
118
        interval_secs: u64,
119
120
121
    },
    RoundRobinConfig {
        timeout_secs: u64,
122
        interval_secs: u64,
123
    },
124
    CacheAwareConfig {
125
        cache_threshold: f32,
126
127
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
128
129
        eviction_interval_secs: u64,
        max_tree_size: usize,
130
        timeout_secs: u64,
131
        interval_secs: u64,
132
133
134
    },
}

135
impl Router {
136
    pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Result<Self, String> {
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        // Get timeout and interval from policy config
        let (timeout_secs, interval_secs) = match &policy_config {
            PolicyConfig::RandomConfig {
                timeout_secs,
                interval_secs,
            } => (*timeout_secs, *interval_secs),
            PolicyConfig::RoundRobinConfig {
                timeout_secs,
                interval_secs,
            } => (*timeout_secs, *interval_secs),
            PolicyConfig::CacheAwareConfig {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
152
153
        };

154
        // Wait until all workers are healthy
155
        Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
156
157
158

        // Create router based on policy...
        Ok(match policy_config {
159
160
161
162
            PolicyConfig::RandomConfig {
                timeout_secs,
                interval_secs,
            } => Router::Random {
163
                worker_urls: Arc::new(RwLock::new(worker_urls)),
164
                timeout_secs,
165
                interval_secs,
166
            },
167
168
169
170
            PolicyConfig::RoundRobinConfig {
                timeout_secs,
                interval_secs,
            } => Router::RoundRobin {
171
                worker_urls: Arc::new(RwLock::new(worker_urls)),
172
                current_index: std::sync::atomic::AtomicUsize::new(0),
173
                timeout_secs,
174
                interval_secs,
175
            },
176
            PolicyConfig::CacheAwareConfig {
177
                cache_threshold,
178
179
                balance_abs_threshold,
                balance_rel_threshold,
180
181
                eviction_interval_secs,
                max_tree_size,
182
                timeout_secs,
183
                interval_secs,
184
            } => {
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
                let mut running_queue = HashMap::new();
                for url in &worker_urls {
                    running_queue.insert(url.clone(), 0);
                }

                let mut processed_queue = HashMap::new();
                for url in &worker_urls {
                    processed_queue.insert(url.clone(), 0);
                }

                let tree = Arc::new(Mutex::new(Tree::new()));
                let running_queue = Arc::new(Mutex::new(running_queue));
                let processed_queue = Arc::new(Mutex::new(processed_queue));

                // Create background eviction thread
                let tree_clone = Arc::clone(&tree);
                let processed_queue_clone = Arc::clone(&processed_queue);
202
                let running_queue_clone = Arc::clone(&running_queue);
203
204
205
206
207
208
209
                let eviction_thread = thread::spawn(move || {
                    loop {
                        // Sleep for the specified interval
                        thread::sleep(Duration::from_secs(eviction_interval_secs));

                        let locked_tree_clone = tree_clone.lock().unwrap();
                        // Run eviction
210
                        locked_tree_clone.evict_tenant_by_size(max_tree_size);
211
212
213

                        // Print the process queue
                        let locked_processed_queue = processed_queue_clone.lock().unwrap();
214
                        info!("Processed Queue: {:?}", locked_processed_queue);
215
216
217

                        // Print the running queue
                        let locked_running_queue = running_queue_clone.lock().unwrap();
218
                        info!("Running Queue: {:?}", locked_running_queue);
219
220
                    }
                });
221
222

                for url in &worker_urls {
223
                    tree.lock().unwrap().insert(&"".to_string(), url);
224
225
                }

226
                Router::CacheAware {
227
                    worker_urls: Arc::new(RwLock::new(worker_urls)),
228
229
230
                    tree,
                    running_queue,
                    processed_queue,
231
                    cache_threshold,
232
233
                    balance_abs_threshold,
                    balance_rel_threshold,
234
                    timeout_secs,
235
                    interval_secs,
236
                    _eviction_thread: Some(eviction_thread),
237
238
                }
            }
239
        })
240
241
    }

242
243
244
245
246
247
248
249
250
251
    fn wait_for_healthy_workers(
        worker_urls: &[String],
        timeout_secs: u64,
        interval_secs: u64,
    ) -> Result<(), String> {
        let start_time = std::time::Instant::now();
        let sync_client = reqwest::blocking::Client::new();

        loop {
            if start_time.elapsed() > Duration::from_secs(timeout_secs) {
252
                error!(
253
254
                    "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
                    timeout_secs, worker_urls
255
                );
256
                return Err(format!(
257
258
                    "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
                    timeout_secs, worker_urls
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
                ));
            }

            let mut all_healthy = true;
            let mut unhealthy_workers = Vec::new();

            for url in worker_urls {
                match sync_client.get(&format!("{}/health", url)).send() {
                    Ok(res) => {
                        if !res.status().is_success() {
                            info!(
                                "Worker {} health check is pending with status: {}.",
                                url,
                                res.status()
                            );
                            all_healthy = false;
                            unhealthy_workers.push((url, format!("Status: {}", res.status())));
                        }
                    }
                    Err(e) => {
                        info!("Worker {} health check is pending with error: {}", url, e);
                        all_healthy = false;
                        unhealthy_workers.push((url, format!("Error: {}", e)));
                    }
                }
            }

            if all_healthy {
                info!("All workers are healthy");
                return Ok(());
            } else {
                info!("Unhealthy workers:");
                for (url, reason) in &unhealthy_workers {
                    info!("  {} - {}", url, reason);
                }
                thread::sleep(Duration::from_secs(interval_secs));
            }
        }
    }

299
300
301
    fn select_first_worker(&self) -> Result<String, String> {
        match self {
            Router::RoundRobin { worker_urls, .. }
302
            | Router::Random { worker_urls, .. }
303
304
305
306
307
308
309
310
311
312
313
            | Router::CacheAware { worker_urls, .. } => {
                if worker_urls.read().unwrap().is_empty() {
                    Err("No workers are available".to_string())
                } else {
                    Ok(worker_urls.read().unwrap()[0].clone())
                }
            }
        }
    }

    async fn send_request(
314
315
        &self,
        client: &reqwest::Client,
316
        worker_url: &str,
317
        route: &str,
318
        req: &HttpRequest,
319
    ) -> HttpResponse {
320
321
322
323
324
325
326
327
328
329
        let mut request_builder = client.get(format!("{}{}", worker_url, route));

        // Copy all headers from original request except for /health because it does not need authorization
        if route != "/health" {
            for (name, value) in copy_request_headers(req) {
                request_builder = request_builder.header(name, value);
            }
        }

        match request_builder.send().await {
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
            Ok(res) => {
                let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
                    .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);

                match res.bytes().await {
                    Ok(body) => HttpResponse::build(status).body(body.to_vec()),
                    Err(e) => HttpResponse::InternalServerError()
                        .body(format!("Failed to read response body: {}", e)),
                }
            }
            Err(e) => HttpResponse::InternalServerError().body(format!(
                "Failed to send request to worker {}: {}",
                worker_url, e
            )),
        }
    }

347
348
349
350
351
352
    pub async fn route_to_first(
        &self,
        client: &reqwest::Client,
        route: &str,
        req: &HttpRequest,
    ) -> HttpResponse {
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
        const MAX_REQUEST_RETRIES: u32 = 3;
        const MAX_TOTAL_RETRIES: u32 = 6;
        let mut total_retries = 0;

        while total_retries < MAX_TOTAL_RETRIES {
            match self.select_first_worker() {
                Ok(worker_url) => {
                    let mut request_retries = 0;

                    // Try the same worker multiple times
                    while request_retries < MAX_REQUEST_RETRIES {
                        if total_retries >= 1 {
                            info!("Retrying request after {} failed attempts", total_retries);
                        }

368
                        let response = self.send_request(client, &worker_url, route, req).await;
369
370
371

                        if response.status().is_success() {
                            return response;
372
373
374
375
376
377
378
                        } else {
                            // if the worker is healthy, it means the request is bad, so return the error response
                            let health_response =
                                self.send_request(client, &worker_url, "/health", req).await;
                            if health_response.status().is_success() {
                                return response;
                            }
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
                        }

                        warn!(
                            "Request to {} failed (attempt {}/{})",
                            worker_url,
                            request_retries + 1,
                            MAX_REQUEST_RETRIES
                        );

                        request_retries += 1;
                        total_retries += 1;

                        if request_retries == MAX_REQUEST_RETRIES {
                            warn!("Removing failed worker: {}", worker_url);
                            self.remove_worker(&worker_url);
                            break;
                        }
                    }
                }
                Err(e) => return HttpResponse::InternalServerError().body(e),
            }
400
        }
401
402

        HttpResponse::InternalServerError().body("All retry attempts failed")
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
    }

    fn get_text_from_request(&self, body: &Bytes, route: &str) -> String {
        // convert body to json
        let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();

        if route == "generate" {
            // get the "text" field
            let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
            return text.to_string();
        } else if route == "v1/chat/completions" {
            // get the messages field as raw text
            if let Some(messages) = json.get("messages") {
                // Convert messages back to a string, preserving all JSON formatting
                return serde_json::to_string(messages).unwrap_or_default();
            }
        } else if route == "v1/completions" {
            let prompt = json.get("prompt").and_then(|t| t.as_str()).unwrap_or("");
            return prompt.to_string();
        }

        return "".to_string();
    }

    // TODO: return Result<String, String> instead of panicking
    fn select_generate_worker(&self, body: &Bytes, route: &str) -> String {
        let text = self.get_text_from_request(&body, route);
430

431
432
433
434
        let worker_url = match self {
            Router::RoundRobin {
                worker_urls,
                current_index,
435
                ..
436
            } => {
437
                let idx = current_index
438
439
440
                    .fetch_update(
                        std::sync::atomic::Ordering::SeqCst,
                        std::sync::atomic::Ordering::SeqCst,
441
                        |x| Some((x + 1) % worker_urls.read().unwrap().len()),
442
                    )
443
                    .unwrap();
444
                worker_urls.read().unwrap()[idx].clone()
445
            }
446

447
            Router::Random { worker_urls, .. } => worker_urls.read().unwrap()
448
449
                [rand::random::<usize>() % worker_urls.read().unwrap().len()]
            .clone(),
450

451
            Router::CacheAware {
452
                worker_urls,
453
454
455
                tree,
                running_queue,
                processed_queue,
456
                cache_threshold,
457
458
                balance_abs_threshold,
                balance_rel_threshold,
459
460
                ..
            } => {
461
                // TODO: delay scheduling if cache hit rate is high because it may cause imbalance. prioritize low hit rate ones
462

Byron Hsu's avatar
Byron Hsu committed
463
                let tree = tree.lock().unwrap();
464
                let mut running_queue = running_queue.lock().unwrap();
465

466
467
468
469
470
471
472
473
474
475
476
477
                // Get current load statistics
                let max_load = *running_queue.values().max().unwrap_or(&0);
                let min_load = *running_queue.values().min().unwrap_or(&0);

                // Load is considered imbalanced if:
                // 1. (max - min) > abs_threshold AND
                // 2. max > rel_threshold * min
                let is_imbalanced = max_load.saturating_sub(min_load) > *balance_abs_threshold
                    && (max_load as f32) > (min_load as f32 * balance_rel_threshold);

                let selected_url = if is_imbalanced {
                    // Log load balancing trigger and current queue state
478
                    info!(
479
480
481
482
483
484
485
486
487
488
489
                        "Load balancing triggered due to workload imbalance:\n\
                        Max load: {}, Min load: {}\n\
                        Current running queue: {:?}",
                        max_load, min_load, running_queue
                    );

                    // Use shortest queue routing when load is imbalanced
                    running_queue
                        .iter()
                        .min_by_key(|(_url, &count)| count)
                        .map(|(url, _)| url.clone())
490
                        .unwrap_or_else(|| worker_urls.read().unwrap()[0].clone())
491
492
                } else {
                    // Use cache-aware routing when load is balanced
493
494
495
                    let (matched_text, matched_worker) = tree.prefix_match(&text);
                    let matched_rate =
                        matched_text.chars().count() as f32 / text.chars().count() as f32;
496

497
498
499
500
                    if matched_rate > *cache_threshold {
                        matched_worker.to_string()
                    } else {
                        tree.get_smallest_tenant()
501
                    }
502
                };
503

504
505
                // Update queues and tree
                *running_queue.get_mut(&selected_url).unwrap() += 1;
506

507
508
509
510
511
                *processed_queue
                    .lock()
                    .unwrap()
                    .get_mut(&selected_url)
                    .unwrap() += 1;
512
513
514
                tree.insert(&text, &selected_url);

                selected_url
515
516
            }
        };
517

518
519
520
521
522
523
        worker_url
    }

    async fn send_generate_request(
        &self,
        client: &reqwest::Client,
524
525
        req: &HttpRequest,
        body: &Bytes,
526
527
528
        route: &str,
        worker_url: &str,
    ) -> HttpResponse {
529
530
531
        let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
            .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
            .unwrap_or(false);
532

533
        let mut request_builder = client
534
            .post(format!("{}{}", worker_url, route))
535
536
537
538
539
540
541
542
            .body(body.to_vec());

        // Copy all headers from original request
        for (name, value) in copy_request_headers(req) {
            request_builder = request_builder.header(name, value);
        }

        let res = match request_builder.send().await {
543
544
545
            Ok(res) => res,
            Err(_) => return HttpResponse::InternalServerError().finish(),
        };
546

547
548
        let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
            .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
549

550
        if !is_stream {
551
552
            // For non-streaming requests, get response first
            let response = match res.bytes().await {
553
                Ok(body) => HttpResponse::build(status).body(body.to_vec()),
554
555
556
557
                Err(e) => {
                    let error_msg = format!("Failed to get response body: {}", e);
                    HttpResponse::InternalServerError().body(error_msg)
                }
558
559
560
561
562
            };

            // Then decrement running queue counter if using CacheAware
            if let Router::CacheAware { running_queue, .. } = self {
                if let Ok(mut queue) = running_queue.lock() {
563
                    if let Some(count) = queue.get_mut(worker_url) {
564
565
566
                        *count = count.saturating_sub(1);
                    }
                }
567
            }
568
569
570
571

            response
        } else if let Router::CacheAware { running_queue, .. } = self {
            let running_queue = Arc::clone(running_queue);
572
            let worker_url = worker_url.to_string();
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590

            HttpResponse::build(status)
                .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
                .streaming(
                    res.bytes_stream()
                        .map_err(|_| {
                            actix_web::error::ErrorInternalServerError("Failed to read stream")
                        })
                        .inspect(move |bytes| {
                            let bytes = bytes.as_ref().unwrap();
                            if bytes
                                .as_ref()
                                .windows(12)
                                .any(|window| window == b"data: [DONE]")
                            {
                                let mut locked_queue = running_queue.lock().unwrap();
                                let count = locked_queue.get_mut(&worker_url).unwrap();
                                *count = count.saturating_sub(1);
591
                                debug!("Streaming is done!!")
592
593
594
                            }
                        }),
                )
595
596
597
598
        } else {
            HttpResponse::build(status)
                .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
                .streaming(res.bytes_stream().map_err(|_| {
599
                    actix_web::error::ErrorInternalServerError("Failed to read stream")
600
                }))
601
602
        }
    }
603

604
605
606
    pub async fn route_generate_request(
        &self,
        client: &reqwest::Client,
607
608
        req: &HttpRequest,
        body: &Bytes,
609
610
        route: &str,
    ) -> HttpResponse {
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
        const MAX_REQUEST_RETRIES: u32 = 3;
        const MAX_TOTAL_RETRIES: u32 = 6;
        let mut total_retries = 0;

        while total_retries < MAX_TOTAL_RETRIES {
            let worker_url = self.select_generate_worker(body, route);
            let mut request_retries = 0;

            // Try the same worker multiple times
            while request_retries < MAX_REQUEST_RETRIES {
                if total_retries >= 1 {
                    info!("Retrying request after {} failed attempts", total_retries);
                }
                let response = self
                    .send_generate_request(client, req, body, route, &worker_url)
                    .await;

                if response.status().is_success() {
                    return response;
630
631
632
633
634
635
636
                } else {
                    // if the worker is healthy, it means the request is bad, so return the error response
                    let health_response =
                        self.send_request(client, &worker_url, "/health", req).await;
                    if health_response.status().is_success() {
                        return response;
                    }
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
                }

                warn!(
                    "Generate request to {} failed (attempt {}/{})",
                    worker_url,
                    request_retries + 1,
                    MAX_REQUEST_RETRIES
                );

                request_retries += 1;
                total_retries += 1;

                if request_retries == MAX_REQUEST_RETRIES {
                    warn!("Removing failed worker: {}", worker_url);
                    self.remove_worker(&worker_url);
                    break;
                }
            }
        }

        HttpResponse::InternalServerError().body("All retry attempts failed")
658
659
    }

660
    pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
        let (timeout_secs, interval_secs) = match self {
            Router::Random {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
            Router::RoundRobin {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
            Router::CacheAware {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
677
        };
678
679
680
681
682
683

        let start_time = std::time::Instant::now();
        let client = reqwest::Client::new();

        loop {
            if start_time.elapsed() > Duration::from_secs(timeout_secs) {
684
                error!(
685
                    "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
686
687
                    timeout_secs, worker_url
                );
688
                return Err(format!(
689
                    "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
690
691
692
693
694
695
696
697
698
                    timeout_secs, worker_url
                ));
            }

            match client.get(&format!("{}/health", worker_url)).send().await {
                Ok(res) => {
                    if res.status().is_success() {
                        match self {
                            Router::RoundRobin { worker_urls, .. }
699
                            | Router::Random { worker_urls, .. }
700
701
702
                            | Router::CacheAware { worker_urls, .. } => {
                                info!("Worker {} health check passed", worker_url);
                                let mut urls = worker_urls.write().unwrap();
703
                                if urls.contains(&worker_url.to_string()) {
704
                                    return Err(format!("Worker {} already exists", worker_url));
705
706
                                }
                                info!("Added worker: {}", worker_url);
707
                                urls.push(worker_url.to_string());
708
709
                            }
                        }
710
711
712
713
714
715
716
717
718
719

                        // If cache aware, initialize the queues for the new worker
                        if let Router::CacheAware {
                            running_queue,
                            processed_queue,
                            tree,
                            ..
                        } = self
                        {
                            // Add worker to running queue with initial count of 0
720
721
722
723
                            running_queue
                                .lock()
                                .unwrap()
                                .insert(worker_url.to_string(), 0);
724
725
726
727
728

                            // Add worker to processed queue with initial count of 0
                            processed_queue
                                .lock()
                                .unwrap()
729
                                .insert(worker_url.to_string(), 0);
730
731
732
733
734
735

                            // Add worker to tree
                            tree.lock().unwrap().insert(&"".to_string(), &worker_url);
                        }

                        return Ok(format!("Successfully added worker: {}", worker_url));
736
737
                    } else {
                        info!(
738
739
740
                            "Worker {} health check is pending with status: {}.",
                            worker_url,
                            res.status()
741
742
743
744
745
746
747
748
749
750
751
752
                        );
                        // if the url does not have http or https prefix, warn users
                        if !worker_url.starts_with("http://") && !worker_url.starts_with("https://")
                        {
                            warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
                        }

                        tokio::time::sleep(Duration::from_secs(interval_secs)).await;
                        continue;
                    }
                }
                Err(e) => {
753
754
755
756
                    info!(
                        "Worker {} health check is pending with error: {}",
                        worker_url, e
                    );
757
758
759
760
761
762
763
764
765

                    // if the url does not have http or https prefix, warn users
                    if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") {
                        warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
                    }

                    tokio::time::sleep(Duration::from_secs(interval_secs)).await;
                    continue;
                }
766
767
768
            }
        }
    }
769

770
    pub fn remove_worker(&self, worker_url: &str) {
771
772
        match self {
            Router::RoundRobin { worker_urls, .. }
773
            | Router::Random { worker_urls, .. }
774
775
            | Router::CacheAware { worker_urls, .. } => {
                let mut urls = worker_urls.write().unwrap();
776
777
778
779
780
781
782
                if let Some(index) = urls.iter().position(|url| url == &worker_url) {
                    urls.remove(index);
                    info!("Removed worker: {}", worker_url);
                } else {
                    warn!("Worker {} not found, skipping removal", worker_url);
                    return;
                }
783
784
785
786
            }
        }

        // if cache aware, remove the worker from the tree
787
788
789
790
791
792
793
        if let Router::CacheAware {
            tree,
            running_queue,
            processed_queue,
            ..
        } = self
        {
794
            tree.lock().unwrap().remove_tenant(&worker_url);
795
796
797
798
799
800
801
802
            running_queue
                .lock()
                .unwrap()
                .remove(&worker_url.to_string());
            processed_queue
                .lock()
                .unwrap()
                .remove(&worker_url.to_string());
803
804
805
806
            info!(
                "Removed worker from tree and cleaned up queues: {}",
                worker_url
            );
807
808
        }
    }
809
}