router.rs 34.6 KB
Newer Older
1
use crate::tree::Tree;
2
use ::metrics::{counter, gauge, histogram};
3
4
5
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse};
use bytes::Bytes;
Byron Hsu's avatar
Byron Hsu committed
6
use futures_util::{StreamExt, TryStreamExt};
7
use serde_json::Value;
8
use std::collections::HashMap;
9
use std::fmt::Debug;
10
use std::sync::atomic::AtomicUsize;
11
use std::sync::{Arc, Mutex, RwLock};
12
13
use std::thread;
use std::time::Duration;
14
use std::time::Instant;
15
use tokio;
16
use tracing::{debug, error, info, warn};
17

18
19
20
21
22
23
24
25
26
27
28
29
fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
    req.headers()
        .iter()
        .filter_map(|(name, value)| {
            value
                .to_str()
                .ok()
                .map(|v| (name.to_string(), v.to_string()))
        })
        .collect()
}

30
#[derive(Debug)]
31
32
pub enum Router {
    RoundRobin {
33
        worker_urls: Arc<RwLock<Vec<String>>>,
34
        current_index: AtomicUsize,
35
        timeout_secs: u64,
36
        interval_secs: u64,
37
38
    },
    Random {
39
        worker_urls: Arc<RwLock<Vec<String>>>,
40
        timeout_secs: u64,
41
        interval_secs: u64,
42
    },
43
44
    CacheAware {
        /*
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
            Cache-Aware Load Balancing Router

            This router combines two strategies to optimize both cache utilization and request distribution:

            1. Cache-Aware Routing (Approximate Tree)
            2. Load Balancing (Shortest Queue with Balance Thresholds)

            The router dynamically switches between these strategies based on load conditions:
            - Uses load balancing when the system is imbalanced
            - Uses cache-aware routing when the system is balanced

            A system is considered imbalanced if both conditions are met:
            1. (max - min) > abs_threshold
            2. max > rel_threshold * min

            Strategy Details:

            1. Cache-Aware Routing (Approximate Tree)
            -------------------------------------------
            This strategy maintains an approximate radix tree for each worker based on request history,
            eliminating the need for direct cache state queries. The tree stores raw text characters
            instead of token IDs to avoid tokenization overhead.

            Process:
            a. For each request, find the worker with the highest prefix match
            b. If match rate > cache_threshold:
            Route to the worker with highest match (likely has relevant data cached)
            c. If match rate ≤ cache_threshold:
            Route to the worker with smallest tree size (most available cache capacity)
            d. Background maintenance:
            Periodically evict least recently used leaf nodes to prevent memory overflow

            2. Load Balancing (Shortest Queue)
            -------------------------------------------
            This strategy tracks pending request counts per worker and routes new requests
            to the least busy worker when the system is detected to be imbalanced.

            Configuration Parameters:
            ------------------------
            1. cache_threshold: (float, 0.0 to 1.0)
            Minimum prefix match ratio to use highest-match routing.
            Below this threshold, routes to worker with most available cache space.

            2. balance_abs_threshold: (integer)
            Absolute difference threshold for load imbalance detection.
            System is potentially imbalanced if (max_load - min_load) > abs_threshold

            3. balance_rel_threshold: (float)
            Relative ratio threshold for load imbalance detection.
            System is potentially imbalanced if max_load > min_load * rel_threshold
            Used in conjunction with abs_threshold to determine final imbalance state.

            4. eviction_interval_secs: (integer)
            Interval between LRU eviction cycles for the approximate trees.

            5. max_tree_size: (integer)
            Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
            during the next eviction cycle.
103
        */
104
        worker_urls: Arc<RwLock<Vec<String>>>,
105
106
107
        tree: Arc<Mutex<Tree>>,
        running_queue: Arc<Mutex<HashMap<String, usize>>>,
        processed_queue: Arc<Mutex<HashMap<String, usize>>>,
108
        cache_threshold: f32,
109
110
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
111
        timeout_secs: u64,
112
        interval_secs: u64,
113
        _eviction_thread: Option<thread::JoinHandle<()>>,
114
115
116
    },
}

117
#[derive(Debug, Clone)]
118
pub enum PolicyConfig {
119
120
    RandomConfig {
        timeout_secs: u64,
121
        interval_secs: u64,
122
123
124
    },
    RoundRobinConfig {
        timeout_secs: u64,
125
        interval_secs: u64,
126
    },
127
    CacheAwareConfig {
128
        cache_threshold: f32,
129
130
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
131
132
        eviction_interval_secs: u64,
        max_tree_size: usize,
133
        timeout_secs: u64,
134
        interval_secs: u64,
135
136
137
    },
}

138
impl Router {
139
    pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Result<Self, String> {
140
141
142
        // Update active workers gauge
        gauge!("sgl_router_active_workers").set(worker_urls.len() as f64);

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        // Get timeout and interval from policy config
        let (timeout_secs, interval_secs) = match &policy_config {
            PolicyConfig::RandomConfig {
                timeout_secs,
                interval_secs,
            } => (*timeout_secs, *interval_secs),
            PolicyConfig::RoundRobinConfig {
                timeout_secs,
                interval_secs,
            } => (*timeout_secs, *interval_secs),
            PolicyConfig::CacheAwareConfig {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
158
159
        };

160
        // Wait until all workers are healthy
161
        Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
162
163
164

        // Create router based on policy...
        Ok(match policy_config {
165
166
167
168
            PolicyConfig::RandomConfig {
                timeout_secs,
                interval_secs,
            } => Router::Random {
169
                worker_urls: Arc::new(RwLock::new(worker_urls)),
170
                timeout_secs,
171
                interval_secs,
172
            },
173
174
175
176
            PolicyConfig::RoundRobinConfig {
                timeout_secs,
                interval_secs,
            } => Router::RoundRobin {
177
                worker_urls: Arc::new(RwLock::new(worker_urls)),
178
                current_index: std::sync::atomic::AtomicUsize::new(0),
179
                timeout_secs,
180
                interval_secs,
181
            },
182
            PolicyConfig::CacheAwareConfig {
183
                cache_threshold,
184
185
                balance_abs_threshold,
                balance_rel_threshold,
186
187
                eviction_interval_secs,
                max_tree_size,
188
                timeout_secs,
189
                interval_secs,
190
            } => {
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
                let mut running_queue = HashMap::new();
                for url in &worker_urls {
                    running_queue.insert(url.clone(), 0);
                }

                let mut processed_queue = HashMap::new();
                for url in &worker_urls {
                    processed_queue.insert(url.clone(), 0);
                }

                let tree = Arc::new(Mutex::new(Tree::new()));
                let running_queue = Arc::new(Mutex::new(running_queue));
                let processed_queue = Arc::new(Mutex::new(processed_queue));

                // Create background eviction thread
                let tree_clone = Arc::clone(&tree);
                let processed_queue_clone = Arc::clone(&processed_queue);
208
                let running_queue_clone = Arc::clone(&running_queue);
209
210
211
212
213
214
215
                let eviction_thread = thread::spawn(move || {
                    loop {
                        // Sleep for the specified interval
                        thread::sleep(Duration::from_secs(eviction_interval_secs));

                        let locked_tree_clone = tree_clone.lock().unwrap();
                        // Run eviction
216
                        locked_tree_clone.evict_tenant_by_size(max_tree_size);
217
218
219

                        // Print the process queue
                        let locked_processed_queue = processed_queue_clone.lock().unwrap();
220
                        info!("Processed Queue: {:?}", locked_processed_queue);
221
222
223

                        // Print the running queue
                        let locked_running_queue = running_queue_clone.lock().unwrap();
224
                        info!("Running Queue: {:?}", locked_running_queue);
225
226
                    }
                });
227
228

                for url in &worker_urls {
229
                    tree.lock().unwrap().insert(&"".to_string(), url);
230
231
                }

232
                Router::CacheAware {
233
                    worker_urls: Arc::new(RwLock::new(worker_urls)),
234
235
236
                    tree,
                    running_queue,
                    processed_queue,
237
                    cache_threshold,
238
239
                    balance_abs_threshold,
                    balance_rel_threshold,
240
                    timeout_secs,
241
                    interval_secs,
242
                    _eviction_thread: Some(eviction_thread),
243
244
                }
            }
245
        })
246
247
    }

248
249
250
251
252
253
254
255
256
    /// Get a reference to the worker URLs shared across threads
    pub fn get_worker_urls(&self) -> Arc<RwLock<Vec<String>>> {
        match self {
            Router::RoundRobin { worker_urls, .. } => Arc::clone(worker_urls),
            Router::Random { worker_urls, .. } => Arc::clone(worker_urls),
            Router::CacheAware { worker_urls, .. } => Arc::clone(worker_urls),
        }
    }

257
258
259
260
261
262
263
264
265
266
    fn wait_for_healthy_workers(
        worker_urls: &[String],
        timeout_secs: u64,
        interval_secs: u64,
    ) -> Result<(), String> {
        let start_time = std::time::Instant::now();
        let sync_client = reqwest::blocking::Client::new();

        loop {
            if start_time.elapsed() > Duration::from_secs(timeout_secs) {
267
                error!(
268
269
                    "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
                    timeout_secs, worker_urls
270
                );
271
                return Err(format!(
272
273
                    "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
                    timeout_secs, worker_urls
274
275
276
277
278
279
280
281
282
283
                ));
            }

            let mut all_healthy = true;
            let mut unhealthy_workers = Vec::new();

            for url in worker_urls {
                match sync_client.get(&format!("{}/health", url)).send() {
                    Ok(res) => {
                        if !res.status().is_success() {
Byron Hsu's avatar
Byron Hsu committed
284
285
                            let msg = format!(
                                "Worker heatlh check is pending with status {}",
286
287
                                res.status()
                            );
Byron Hsu's avatar
Byron Hsu committed
288
                            info!("{}", msg);
289
                            all_healthy = false;
Byron Hsu's avatar
Byron Hsu committed
290
                            unhealthy_workers.push((url, msg));
291
292
                        }
                    }
Byron Hsu's avatar
Byron Hsu committed
293
294
295
                    Err(_) => {
                        let msg = format!("Worker is not ready yet");
                        info!("{}", msg);
296
                        all_healthy = false;
Byron Hsu's avatar
Byron Hsu committed
297
                        unhealthy_workers.push((url, msg));
298
299
300
301
302
303
304
305
                    }
                }
            }

            if all_healthy {
                info!("All workers are healthy");
                return Ok(());
            } else {
Byron Hsu's avatar
Byron Hsu committed
306
                info!("Initializing workers:");
307
308
309
310
311
312
313
314
                for (url, reason) in &unhealthy_workers {
                    info!("  {} - {}", url, reason);
                }
                thread::sleep(Duration::from_secs(interval_secs));
            }
        }
    }

315
316
317
    fn select_first_worker(&self) -> Result<String, String> {
        match self {
            Router::RoundRobin { worker_urls, .. }
318
            | Router::Random { worker_urls, .. }
319
320
321
322
323
324
325
326
327
328
329
            | Router::CacheAware { worker_urls, .. } => {
                if worker_urls.read().unwrap().is_empty() {
                    Err("No workers are available".to_string())
                } else {
                    Ok(worker_urls.read().unwrap()[0].clone())
                }
            }
        }
    }

    async fn send_request(
330
331
        &self,
        client: &reqwest::Client,
332
        worker_url: &str,
333
        route: &str,
334
        req: &HttpRequest,
335
    ) -> HttpResponse {
336
        let start = Instant::now();
337
338
339
340
341
342
343
344
345
        let mut request_builder = client.get(format!("{}{}", worker_url, route));

        // Copy all headers from original request except for /health because it does not need authorization
        if route != "/health" {
            for (name, value) in copy_request_headers(req) {
                request_builder = request_builder.header(name, value);
            }
        }

346
        let response = match request_builder.send().await {
347
348
349
350
351
352
353
354
355
356
357
358
359
360
            Ok(res) => {
                let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
                    .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);

                match res.bytes().await {
                    Ok(body) => HttpResponse::build(status).body(body.to_vec()),
                    Err(e) => HttpResponse::InternalServerError()
                        .body(format!("Failed to read response body: {}", e)),
                }
            }
            Err(e) => HttpResponse::InternalServerError().body(format!(
                "Failed to send request to worker {}: {}",
                worker_url, e
            )),
361
362
363
364
365
366
367
368
369
370
371
372
373
        };

        // Record request metrics
        if route != "/health" {
            let duration = start.elapsed();
            counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1);
            histogram!("sgl_router_request_duration_seconds", "route" => route.to_string())
                .record(duration.as_secs_f64());

            if !response.status().is_success() {
                counter!("sgl_router_request_errors_total", "route" => route.to_string())
                    .increment(1);
            }
374
        }
375
        response
376
377
    }

378
379
380
381
382
383
    pub async fn route_to_first(
        &self,
        client: &reqwest::Client,
        route: &str,
        req: &HttpRequest,
    ) -> HttpResponse {
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
        const MAX_REQUEST_RETRIES: u32 = 3;
        const MAX_TOTAL_RETRIES: u32 = 6;
        let mut total_retries = 0;

        while total_retries < MAX_TOTAL_RETRIES {
            match self.select_first_worker() {
                Ok(worker_url) => {
                    let mut request_retries = 0;

                    // Try the same worker multiple times
                    while request_retries < MAX_REQUEST_RETRIES {
                        if total_retries >= 1 {
                            info!("Retrying request after {} failed attempts", total_retries);
                        }

399
                        let response = self.send_request(client, &worker_url, route, req).await;
400
401
402

                        if response.status().is_success() {
                            return response;
403
404
405
406
407
408
409
                        } else {
                            // if the worker is healthy, it means the request is bad, so return the error response
                            let health_response =
                                self.send_request(client, &worker_url, "/health", req).await;
                            if health_response.status().is_success() {
                                return response;
                            }
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
                        }

                        warn!(
                            "Request to {} failed (attempt {}/{})",
                            worker_url,
                            request_retries + 1,
                            MAX_REQUEST_RETRIES
                        );

                        request_retries += 1;
                        total_retries += 1;

                        if request_retries == MAX_REQUEST_RETRIES {
                            warn!("Removing failed worker: {}", worker_url);
                            self.remove_worker(&worker_url);
                            break;
                        }
                    }
                }
                Err(e) => return HttpResponse::InternalServerError().body(e),
            }
431
        }
432
433

        HttpResponse::InternalServerError().body("All retry attempts failed")
434
435
436
    }

    fn get_text_from_request(&self, body: &Bytes, route: &str) -> String {
437
438
439
440
441
442
        // Convert body to JSON
        let json: Value = match serde_json::from_slice(body) {
            Ok(j) => j,
            Err(_) => {
                warn!("Failed to parse JSON from request body.");
                return String::new();
443
            }
444
        };
445

446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
        match route {
            "/generate" => {
                // For /generate, always use the "text" field.
                match json.get("text").and_then(Value::as_str) {
                    Some(text) => text.to_string(),
                    None => {
                        warn!("No 'text' field found in request body for route /generate.");
                        String::new()
                    }
                }
            }
            "/v1/chat/completions" | "/v1/completions" => {
                // For these routes, try "messages", then "prompt", then "text".
                if let Some(messages) = json.get("messages") {
                    serde_json::to_string(messages).unwrap_or_default()
                } else if let Some(prompt) = json.get("prompt").and_then(Value::as_str) {
                    prompt.to_string()
                } else {
                    warn!("Failed to find 'messages', 'prompt' in request body.");
                    String::new()
                }
            }
            _ => {
                warn!("Unknown route: {} - defaulting to fallback string", route);
                String::new()
            }
        }
473
474
475
476
477
    }

    // TODO: return Result<String, String> instead of panicking
    fn select_generate_worker(&self, body: &Bytes, route: &str) -> String {
        let text = self.get_text_from_request(&body, route);
478

479
480
481
482
        let worker_url = match self {
            Router::RoundRobin {
                worker_urls,
                current_index,
483
                ..
484
            } => {
485
                let idx = current_index
486
487
488
                    .fetch_update(
                        std::sync::atomic::Ordering::SeqCst,
                        std::sync::atomic::Ordering::SeqCst,
489
                        |x| Some((x + 1) % worker_urls.read().unwrap().len()),
490
                    )
491
                    .unwrap();
492
                worker_urls.read().unwrap()[idx].clone()
493
            }
494

495
            Router::Random { worker_urls, .. } => worker_urls.read().unwrap()
496
497
                [rand::random::<usize>() % worker_urls.read().unwrap().len()]
            .clone(),
498

499
            Router::CacheAware {
500
                worker_urls,
501
502
503
                tree,
                running_queue,
                processed_queue,
504
                cache_threshold,
505
506
                balance_abs_threshold,
                balance_rel_threshold,
507
508
                ..
            } => {
509
                // TODO: delay scheduling if cache hit rate is high because it may cause imbalance. prioritize low hit rate ones
510

Byron Hsu's avatar
Byron Hsu committed
511
                let tree = tree.lock().unwrap();
512
                let mut running_queue = running_queue.lock().unwrap();
513

514
515
516
517
518
519
520
521
522
523
524
525
                // Get current load statistics
                let max_load = *running_queue.values().max().unwrap_or(&0);
                let min_load = *running_queue.values().min().unwrap_or(&0);

                // Load is considered imbalanced if:
                // 1. (max - min) > abs_threshold AND
                // 2. max > rel_threshold * min
                let is_imbalanced = max_load.saturating_sub(min_load) > *balance_abs_threshold
                    && (max_load as f32) > (min_load as f32 * balance_rel_threshold);

                let selected_url = if is_imbalanced {
                    // Log load balancing trigger and current queue state
526
                    info!(
527
528
529
530
531
532
                        "Load balancing triggered due to workload imbalance:\n\
                        Max load: {}, Min load: {}\n\
                        Current running queue: {:?}",
                        max_load, min_load, running_queue
                    );

533
534
535
536
                    counter!("sgl_router_load_balancing_events_total").increment(1);
                    gauge!("sgl_router_max_load").set(max_load as f64);
                    gauge!("sgl_router_min_load").set(min_load as f64);

537
538
539
540
541
                    // Use shortest queue routing when load is imbalanced
                    running_queue
                        .iter()
                        .min_by_key(|(_url, &count)| count)
                        .map(|(url, _)| url.clone())
542
                        .unwrap_or_else(|| worker_urls.read().unwrap()[0].clone())
543
544
                } else {
                    // Use cache-aware routing when load is balanced
545
546
547
                    let (matched_text, matched_worker) = tree.prefix_match(&text);
                    let matched_rate =
                        matched_text.chars().count() as f32 / text.chars().count() as f32;
548

549
                    if matched_rate > *cache_threshold {
550
                        counter!("sgl_router_cache_hits_total").increment(1);
551
552
                        matched_worker.to_string()
                    } else {
553
                        counter!("sgl_router_cache_misses_total").increment(1);
554
                        tree.get_smallest_tenant()
555
                    }
556
                };
557

558
559
                // Update queues and tree
                *running_queue.get_mut(&selected_url).unwrap() += 1;
560

561
562
563
564
565
                *processed_queue
                    .lock()
                    .unwrap()
                    .get_mut(&selected_url)
                    .unwrap() += 1;
566
567
568
569
570

                gauge!("sgl_router_running_requests", "worker" => selected_url.to_string())
                    .set(*running_queue.get(&selected_url).unwrap() as f64);
                counter!("sgl_router_processed_requests_total", "worker" => selected_url.to_string()).increment(1);

571
572
573
                tree.insert(&text, &selected_url);

                selected_url
574
575
            }
        };
576

577
578
579
580
581
582
        worker_url
    }

    async fn send_generate_request(
        &self,
        client: &reqwest::Client,
583
584
        req: &HttpRequest,
        body: &Bytes,
585
586
587
        route: &str,
        worker_url: &str,
    ) -> HttpResponse {
588
589
590
        let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
            .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
            .unwrap_or(false);
591

592
        let mut request_builder = client
593
            .post(format!("{}{}", worker_url, route))
594
595
596
597
598
599
600
601
            .body(body.to_vec());

        // Copy all headers from original request
        for (name, value) in copy_request_headers(req) {
            request_builder = request_builder.header(name, value);
        }

        let res = match request_builder.send().await {
602
603
604
            Ok(res) => res,
            Err(_) => return HttpResponse::InternalServerError().finish(),
        };
605

606
607
        let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
            .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
608

609
        if !is_stream {
610
611
            // For non-streaming requests, get response first
            let response = match res.bytes().await {
612
                Ok(body) => HttpResponse::build(status).body(body.to_vec()),
613
614
615
616
                Err(e) => {
                    let error_msg = format!("Failed to get response body: {}", e);
                    HttpResponse::InternalServerError().body(error_msg)
                }
617
618
619
620
621
            };

            // Then decrement running queue counter if using CacheAware
            if let Router::CacheAware { running_queue, .. } = self {
                if let Ok(mut queue) = running_queue.lock() {
622
                    if let Some(count) = queue.get_mut(worker_url) {
623
624
625
                        *count = count.saturating_sub(1);
                    }
                }
626
            }
627
628
629
630

            response
        } else if let Router::CacheAware { running_queue, .. } = self {
            let running_queue = Arc::clone(running_queue);
631
            let worker_url = worker_url.to_string();
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649

            HttpResponse::build(status)
                .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
                .streaming(
                    res.bytes_stream()
                        .map_err(|_| {
                            actix_web::error::ErrorInternalServerError("Failed to read stream")
                        })
                        .inspect(move |bytes| {
                            let bytes = bytes.as_ref().unwrap();
                            if bytes
                                .as_ref()
                                .windows(12)
                                .any(|window| window == b"data: [DONE]")
                            {
                                let mut locked_queue = running_queue.lock().unwrap();
                                let count = locked_queue.get_mut(&worker_url).unwrap();
                                *count = count.saturating_sub(1);
650
                                debug!("Streaming is done!!")
651
652
653
                            }
                        }),
                )
654
655
656
657
        } else {
            HttpResponse::build(status)
                .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
                .streaming(res.bytes_stream().map_err(|_| {
658
                    actix_web::error::ErrorInternalServerError("Failed to read stream")
659
                }))
660
661
        }
    }
662

663
664
665
    pub async fn route_generate_request(
        &self,
        client: &reqwest::Client,
666
667
        req: &HttpRequest,
        body: &Bytes,
668
669
        route: &str,
    ) -> HttpResponse {
670
        let start = Instant::now();
671
672
673
674
675
676
677
678
679
680
681
682
        const MAX_REQUEST_RETRIES: u32 = 3;
        const MAX_TOTAL_RETRIES: u32 = 6;
        let mut total_retries = 0;

        while total_retries < MAX_TOTAL_RETRIES {
            let worker_url = self.select_generate_worker(body, route);
            let mut request_retries = 0;

            // Try the same worker multiple times
            while request_retries < MAX_REQUEST_RETRIES {
                if total_retries >= 1 {
                    info!("Retrying request after {} failed attempts", total_retries);
683
                    counter!("sgl_router_retries_total", "route" => route.to_string()).increment(1);
684
                }
685

686
687
688
689
690
                let response = self
                    .send_generate_request(client, req, body, route, &worker_url)
                    .await;

                if response.status().is_success() {
691
692
                    let duration = start.elapsed();
                    histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string()).record(duration.as_secs_f64());
693
                    return response;
694
695
696
697
698
                } else {
                    // if the worker is healthy, it means the request is bad, so return the error response
                    let health_response =
                        self.send_request(client, &worker_url, "/health", req).await;
                    if health_response.status().is_success() {
699
700
                        counter!("sgl_router_request_errors_total", "route" => route.to_string())
                            .increment(1);
701
702
                        return response;
                    }
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
                }

                warn!(
                    "Generate request to {} failed (attempt {}/{})",
                    worker_url,
                    request_retries + 1,
                    MAX_REQUEST_RETRIES
                );

                request_retries += 1;
                total_retries += 1;

                if request_retries == MAX_REQUEST_RETRIES {
                    warn!("Removing failed worker: {}", worker_url);
                    self.remove_worker(&worker_url);
                    break;
                }
            }
        }

723
        counter!("sgl_router_request_errors_total", "route" => route.to_string()).increment(1);
724
        HttpResponse::InternalServerError().body("All retry attempts failed")
725
726
    }

727
    pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
        let (timeout_secs, interval_secs) = match self {
            Router::Random {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
            Router::RoundRobin {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
            Router::CacheAware {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
744
        };
745
746
747
748
749
750

        let start_time = std::time::Instant::now();
        let client = reqwest::Client::new();

        loop {
            if start_time.elapsed() > Duration::from_secs(timeout_secs) {
751
                error!(
752
                    "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
753
754
                    timeout_secs, worker_url
                );
755
                return Err(format!(
756
                    "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
757
758
759
760
761
762
763
764
765
                    timeout_secs, worker_url
                ));
            }

            match client.get(&format!("{}/health", worker_url)).send().await {
                Ok(res) => {
                    if res.status().is_success() {
                        match self {
                            Router::RoundRobin { worker_urls, .. }
766
                            | Router::Random { worker_urls, .. }
767
768
769
                            | Router::CacheAware { worker_urls, .. } => {
                                info!("Worker {} health check passed", worker_url);
                                let mut urls = worker_urls.write().unwrap();
770
                                if urls.contains(&worker_url.to_string()) {
771
                                    return Err(format!("Worker {} already exists", worker_url));
772
773
                                }
                                info!("Added worker: {}", worker_url);
774
                                urls.push(worker_url.to_string());
775
                                gauge!("sgl_router_active_workers").set(urls.len() as f64);
776
777
                            }
                        }
778
779
780
781
782
783
784
785
786
787

                        // If cache aware, initialize the queues for the new worker
                        if let Router::CacheAware {
                            running_queue,
                            processed_queue,
                            tree,
                            ..
                        } = self
                        {
                            // Add worker to running queue with initial count of 0
788
789
790
791
                            running_queue
                                .lock()
                                .unwrap()
                                .insert(worker_url.to_string(), 0);
792
793
794
795
796

                            // Add worker to processed queue with initial count of 0
                            processed_queue
                                .lock()
                                .unwrap()
797
                                .insert(worker_url.to_string(), 0);
798
799
800
801
802
803

                            // Add worker to tree
                            tree.lock().unwrap().insert(&"".to_string(), &worker_url);
                        }

                        return Ok(format!("Successfully added worker: {}", worker_url));
804
805
                    } else {
                        info!(
806
807
808
                            "Worker {} health check is pending with status: {}.",
                            worker_url,
                            res.status()
809
810
811
812
813
814
815
816
817
818
819
820
                        );
                        // if the url does not have http or https prefix, warn users
                        if !worker_url.starts_with("http://") && !worker_url.starts_with("https://")
                        {
                            warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
                        }

                        tokio::time::sleep(Duration::from_secs(interval_secs)).await;
                        continue;
                    }
                }
                Err(e) => {
821
822
823
824
                    info!(
                        "Worker {} health check is pending with error: {}",
                        worker_url, e
                    );
825
826
827
828
829
830
831
832
833

                    // if the url does not have http or https prefix, warn users
                    if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") {
                        warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
                    }

                    tokio::time::sleep(Duration::from_secs(interval_secs)).await;
                    continue;
                }
834
835
836
            }
        }
    }
837

838
    pub fn remove_worker(&self, worker_url: &str) {
839
840
        match self {
            Router::RoundRobin { worker_urls, .. }
841
            | Router::Random { worker_urls, .. }
842
843
            | Router::CacheAware { worker_urls, .. } => {
                let mut urls = worker_urls.write().unwrap();
844
845
846
                if let Some(index) = urls.iter().position(|url| url == &worker_url) {
                    urls.remove(index);
                    info!("Removed worker: {}", worker_url);
847
                    gauge!("sgl_router_active_workers").set(urls.len() as f64);
848
849
850
851
                } else {
                    warn!("Worker {} not found, skipping removal", worker_url);
                    return;
                }
852
853
854
855
            }
        }

        // if cache aware, remove the worker from the tree
856
857
858
859
860
861
862
        if let Router::CacheAware {
            tree,
            running_queue,
            processed_queue,
            ..
        } = self
        {
863
            tree.lock().unwrap().remove_tenant(&worker_url);
864
865
866
867
868
869
870
871
            running_queue
                .lock()
                .unwrap()
                .remove(&worker_url.to_string());
            processed_queue
                .lock()
                .unwrap()
                .remove(&worker_url.to_string());
872
873
874
875
            info!(
                "Removed worker from tree and cleaned up queues: {}",
                worker_url
            );
876
877
        }
    }
878
}