router.rs 53.8 KB
Newer Older
1
use crate::core::{HealthChecker, Worker, WorkerFactory};
2
3
use crate::pd_router::PDRouter;
use crate::pd_types::PDSelectionPolicy;
4
use crate::tree::Tree;
5
use ::metrics::{counter, gauge, histogram};
6
7
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse};
Byron Hsu's avatar
Byron Hsu committed
8
use futures_util::{StreamExt, TryStreamExt};
9
use std::fmt::Debug;
10
use std::sync::atomic::AtomicUsize;
11
use std::sync::{Arc, Mutex, RwLock};
12
13
use std::thread;
use std::time::Duration;
14
use std::time::Instant;
15
use tokio;
16
use tracing::{debug, error, info, warn};
17

18
pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
19
20
21
22
23
24
25
26
27
28
29
    req.headers()
        .iter()
        .filter_map(|(name, value)| {
            value
                .to_str()
                .ok()
                .map(|v| (name.to_string(), v.to_string()))
        })
        .collect()
}

30
#[derive(Debug)]
31
32
pub enum Router {
    RoundRobin {
33
        workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
34
        current_index: AtomicUsize,
35
        timeout_secs: u64,
36
        interval_secs: u64,
37
        _health_checker: Option<HealthChecker>,
38
39
    },
    Random {
40
        workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
41
        timeout_secs: u64,
42
        interval_secs: u64,
43
        _health_checker: Option<HealthChecker>,
44
    },
45
46
47
    PrefillDecode {
        pd_router: Arc<PDRouter>,
    },
48
49
    CacheAware {
        /*
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
            Cache-Aware Load Balancing Router

            This router combines two strategies to optimize both cache utilization and request distribution:

            1. Cache-Aware Routing (Approximate Tree)
            2. Load Balancing (Shortest Queue with Balance Thresholds)

            The router dynamically switches between these strategies based on load conditions:
            - Uses load balancing when the system is imbalanced
            - Uses cache-aware routing when the system is balanced

            A system is considered imbalanced if both conditions are met:
            1. (max - min) > abs_threshold
            2. max > rel_threshold * min

            Strategy Details:

            1. Cache-Aware Routing (Approximate Tree)
            -------------------------------------------
            This strategy maintains an approximate radix tree for each worker based on request history,
            eliminating the need for direct cache state queries. The tree stores raw text characters
            instead of token IDs to avoid tokenization overhead.

            Process:
            a. For each request, find the worker with the highest prefix match
            b. If match rate > cache_threshold:
            Route to the worker with highest match (likely has relevant data cached)
            c. If match rate ≤ cache_threshold:
            Route to the worker with smallest tree size (most available cache capacity)
            d. Background maintenance:
            Periodically evict least recently used leaf nodes to prevent memory overflow

            2. Load Balancing (Shortest Queue)
            -------------------------------------------
            This strategy tracks pending request counts per worker and routes new requests
            to the least busy worker when the system is detected to be imbalanced.

            Configuration Parameters:
            ------------------------
            1. cache_threshold: (float, 0.0 to 1.0)
            Minimum prefix match ratio to use highest-match routing.
            Below this threshold, routes to worker with most available cache space.

            2. balance_abs_threshold: (integer)
            Absolute difference threshold for load imbalance detection.
            System is potentially imbalanced if (max_load - min_load) > abs_threshold

            3. balance_rel_threshold: (float)
            Relative ratio threshold for load imbalance detection.
            System is potentially imbalanced if max_load > min_load * rel_threshold
            Used in conjunction with abs_threshold to determine final imbalance state.

            4. eviction_interval_secs: (integer)
            Interval between LRU eviction cycles for the approximate trees.

            5. max_tree_size: (integer)
            Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
            during the next eviction cycle.
108
        */
109
        workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
110
        tree: Arc<Mutex<Tree>>,
111
        cache_threshold: f32,
112
113
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
114
        timeout_secs: u64,
115
        interval_secs: u64,
116
        _eviction_thread: Option<thread::JoinHandle<()>>,
117
        _health_checker: Option<HealthChecker>,
118
119
120
    },
}

121
#[derive(Debug, Clone)]
122
pub enum PolicyConfig {
123
124
    RandomConfig {
        timeout_secs: u64,
125
        interval_secs: u64,
126
127
128
    },
    RoundRobinConfig {
        timeout_secs: u64,
129
        interval_secs: u64,
130
    },
131
    CacheAwareConfig {
132
        cache_threshold: f32,
133
134
        balance_abs_threshold: usize,
        balance_rel_threshold: f32,
135
136
        eviction_interval_secs: u64,
        max_tree_size: usize,
137
        timeout_secs: u64,
138
        interval_secs: u64,
139
    },
140
141
142
143
144
145
146
    PrefillDecodeConfig {
        selection_policy: PDSelectionPolicy,
        prefill_urls: Vec<(String, Option<u16>)>, // (url, bootstrap_port)
        decode_urls: Vec<String>,
        timeout_secs: u64,
        interval_secs: u64,
    },
147
148
}

149
impl Router {
150
    pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Result<Self, String> {
151
152
153
        // Update active workers gauge
        gauge!("sgl_router_active_workers").set(worker_urls.len() as f64);

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        // Get timeout and interval from policy config
        let (timeout_secs, interval_secs) = match &policy_config {
            PolicyConfig::RandomConfig {
                timeout_secs,
                interval_secs,
            } => (*timeout_secs, *interval_secs),
            PolicyConfig::RoundRobinConfig {
                timeout_secs,
                interval_secs,
            } => (*timeout_secs, *interval_secs),
            PolicyConfig::CacheAwareConfig {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
169
170
171
172
173
            PolicyConfig::PrefillDecodeConfig {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
174
175
        };

176
177
178
179
180
181
182
183
        // For PrefillDecode, we need to handle workers differently
        match &policy_config {
            PolicyConfig::PrefillDecodeConfig { .. } => {
                // PD mode doesn't use the worker_urls parameter
                // We'll validate PD workers separately
            }
            _ => {
                // Wait until all workers are healthy for regular modes
184
185
186
187
188
189
190
191
192
                let worker_urls = worker_urls.clone();
                std::thread::spawn(move || {
                    Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)
                })
                .join()
                .map_err(|e| {
                    error!("Health-check thread panicked: {:?}", e);
                    format!("Health-check thread panicked: {e:?}")
                })??;
193
194
            }
        }
195

196
197
198
199
200
201
        // Create Worker trait objects from URLs
        let workers: Vec<Box<dyn Worker>> = worker_urls
            .iter()
            .map(|url| WorkerFactory::create_regular(url.clone()))
            .collect();

202
203
        // Create router based on policy...
        Ok(match policy_config {
204
205
206
            PolicyConfig::RandomConfig {
                timeout_secs,
                interval_secs,
207
208
209
210
211
212
213
214
215
216
217
            } => {
                let workers = Arc::new(RwLock::new(workers));
                let health_checker =
                    crate::core::start_health_checker(Arc::clone(&workers), interval_secs);
                Router::Random {
                    workers,
                    timeout_secs,
                    interval_secs,
                    _health_checker: Some(health_checker),
                }
            }
218
219
220
            PolicyConfig::RoundRobinConfig {
                timeout_secs,
                interval_secs,
221
222
223
224
225
226
227
228
229
230
231
232
            } => {
                let workers = Arc::new(RwLock::new(workers));
                let health_checker =
                    crate::core::start_health_checker(Arc::clone(&workers), interval_secs);
                Router::RoundRobin {
                    workers,
                    current_index: std::sync::atomic::AtomicUsize::new(0),
                    timeout_secs,
                    interval_secs,
                    _health_checker: Some(health_checker),
                }
            }
233
            PolicyConfig::CacheAwareConfig {
234
                cache_threshold,
235
236
                balance_abs_threshold,
                balance_rel_threshold,
237
238
                eviction_interval_secs,
                max_tree_size,
239
                timeout_secs,
240
                interval_secs,
241
            } => {
242
243
244
245
                let tree = Arc::new(Mutex::new(Tree::new()));

                // Create background eviction thread
                let tree_clone = Arc::clone(&tree);
246
247
                let workers = Arc::new(RwLock::new(workers));
                let workers_clone = Arc::clone(&workers);
248
249
250
251
252
253
254
                let eviction_thread = thread::spawn(move || {
                    loop {
                        // Sleep for the specified interval
                        thread::sleep(Duration::from_secs(eviction_interval_secs));

                        let locked_tree_clone = tree_clone.lock().unwrap();
                        // Run eviction
255
                        locked_tree_clone.evict_tenant_by_size(max_tree_size);
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
                        drop(locked_tree_clone);

                        // Log worker loads and processed requests
                        let workers_guard = workers_clone.read().unwrap();
                        let loads: Vec<(String, usize)> = workers_guard
                            .iter()
                            .map(|w| (w.url().to_string(), w.load()))
                            .collect();
                        info!("Worker loads: {:?}", loads);

                        let processed: Vec<(String, usize)> = workers_guard
                            .iter()
                            .map(|w| (w.url().to_string(), w.processed_requests()))
                            .collect();
                        info!("Processed requests: {:?}", processed);
271
272
                    }
                });
273

274
275
                for worker in workers.read().unwrap().iter() {
                    tree.lock().unwrap().insert("", worker.url());
276
277
                }

278
279
280
                let health_checker =
                    crate::core::start_health_checker(Arc::clone(&workers), interval_secs);

281
                Router::CacheAware {
282
                    workers,
283
                    tree,
284
                    cache_threshold,
285
286
                    balance_abs_threshold,
                    balance_rel_threshold,
287
                    timeout_secs,
288
                    interval_secs,
289
                    _eviction_thread: Some(eviction_thread),
290
                    _health_checker: Some(health_checker),
291
292
                }
            }
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
            PolicyConfig::PrefillDecodeConfig {
                selection_policy,
                prefill_urls,
                decode_urls,
                timeout_secs,
                interval_secs,
            } => {
                // Create PDRouter instance
                let pd_router = PDRouter::new(
                    prefill_urls,
                    decode_urls,
                    selection_policy,
                    timeout_secs,
                    interval_secs,
                )?;

                Router::PrefillDecode {
                    pd_router: Arc::new(pd_router),
                }
            }
313
        })
314
315
    }

316
317
    /// Get the current list of worker URLs
    pub fn get_worker_urls(&self) -> Vec<String> {
318
        match self {
319
320
321
322
323
324
325
326
327
            Router::RoundRobin { workers, .. }
            | Router::Random { workers, .. }
            | Router::CacheAware { workers, .. } => workers
                .read()
                .unwrap()
                .iter()
                .map(|w| w.url().to_string())
                .collect(),
            Router::PrefillDecode { .. } => Vec::new(),
328
329
330
        }
    }

331
    pub fn wait_for_healthy_workers(
332
333
334
335
336
        worker_urls: &[String],
        timeout_secs: u64,
        interval_secs: u64,
    ) -> Result<(), String> {
        let start_time = std::time::Instant::now();
337
338
339
340
        let sync_client = reqwest::blocking::Client::builder()
            .timeout(Duration::from_secs(timeout_secs))
            .build()
            .map_err(|e| format!("Failed to create HTTP client: {}", e))?;
341
342
343

        loop {
            if start_time.elapsed() > Duration::from_secs(timeout_secs) {
344
                error!(
345
346
                    "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
                    timeout_secs, worker_urls
347
                );
348
                return Err(format!(
349
350
                    "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
                    timeout_secs, worker_urls
351
352
353
354
355
356
357
358
359
360
                ));
            }

            let mut all_healthy = true;
            let mut unhealthy_workers = Vec::new();

            for url in worker_urls {
                match sync_client.get(&format!("{}/health", url)).send() {
                    Ok(res) => {
                        if !res.status().is_success() {
Byron Hsu's avatar
Byron Hsu committed
361
362
                            let msg = format!(
                                "Worker heatlh check is pending with status {}",
363
364
                                res.status()
                            );
Byron Hsu's avatar
Byron Hsu committed
365
                            info!("{}", msg);
366
                            all_healthy = false;
Byron Hsu's avatar
Byron Hsu committed
367
                            unhealthy_workers.push((url, msg));
368
369
                        }
                    }
Byron Hsu's avatar
Byron Hsu committed
370
371
372
                    Err(_) => {
                        let msg = format!("Worker is not ready yet");
                        info!("{}", msg);
373
                        all_healthy = false;
Byron Hsu's avatar
Byron Hsu committed
374
                        unhealthy_workers.push((url, msg));
375
376
377
378
379
380
381
382
                    }
                }
            }

            if all_healthy {
                info!("All workers are healthy");
                return Ok(());
            } else {
Byron Hsu's avatar
Byron Hsu committed
383
                info!("Initializing workers:");
384
385
386
387
388
389
390
391
                for (url, reason) in &unhealthy_workers {
                    info!("  {} - {}", url, reason);
                }
                thread::sleep(Duration::from_secs(interval_secs));
            }
        }
    }

392
393
    fn select_first_worker(&self) -> Result<String, String> {
        match self {
394
395
396
397
398
            Router::RoundRobin { workers, .. }
            | Router::Random { workers, .. }
            | Router::CacheAware { workers, .. } => {
                let workers_guard = workers.read().unwrap();
                if workers_guard.is_empty() {
399
400
                    Err("No workers are available".to_string())
                } else {
401
                    Ok(workers_guard[0].url().to_string())
402
403
                }
            }
404
405
406
407
            Router::PrefillDecode { .. } => {
                // For PD mode, we don't need this method as routing is handled by PDRouter
                Err("PrefillDecode mode doesn't use select_first_worker".to_string())
            }
408
409
410
        }
    }

411
    pub async fn send_request(
412
413
        &self,
        client: &reqwest::Client,
414
        worker_url: &str,
415
        route: &str,
416
        req: &HttpRequest,
417
    ) -> HttpResponse {
418
        let start = Instant::now();
419
420
421
422
423
        let mut request_builder = client.get(format!("{}{}", worker_url, route));

        // Copy all headers from original request except for /health because it does not need authorization
        if route != "/health" {
            for (name, value) in copy_request_headers(req) {
424
425
426
427
428
                // Skip Content-Type and Content-Length as .json() sets them
                if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
                {
                    request_builder = request_builder.header(name, value);
                }
429
430
431
            }
        }

432
        let response = match request_builder.send().await {
433
434
435
436
437
438
439
440
441
442
443
444
445
446
            Ok(res) => {
                let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
                    .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);

                match res.bytes().await {
                    Ok(body) => HttpResponse::build(status).body(body.to_vec()),
                    Err(e) => HttpResponse::InternalServerError()
                        .body(format!("Failed to read response body: {}", e)),
                }
            }
            Err(e) => HttpResponse::InternalServerError().body(format!(
                "Failed to send request to worker {}: {}",
                worker_url, e
            )),
447
448
449
450
451
452
453
454
455
456
457
458
459
        };

        // Record request metrics
        if route != "/health" {
            let duration = start.elapsed();
            counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1);
            histogram!("sgl_router_request_duration_seconds", "route" => route.to_string())
                .record(duration.as_secs_f64());

            if !response.status().is_success() {
                counter!("sgl_router_request_errors_total", "route" => route.to_string())
                    .increment(1);
            }
460
        }
461
        response
462
463
    }

464
465
466
467
468
469
    pub async fn route_to_first(
        &self,
        client: &reqwest::Client,
        route: &str,
        req: &HttpRequest,
    ) -> HttpResponse {
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
        const MAX_REQUEST_RETRIES: u32 = 3;
        const MAX_TOTAL_RETRIES: u32 = 6;
        let mut total_retries = 0;

        while total_retries < MAX_TOTAL_RETRIES {
            match self.select_first_worker() {
                Ok(worker_url) => {
                    let mut request_retries = 0;

                    // Try the same worker multiple times
                    while request_retries < MAX_REQUEST_RETRIES {
                        if total_retries >= 1 {
                            info!("Retrying request after {} failed attempts", total_retries);
                        }

485
                        let response = self.send_request(client, &worker_url, route, req).await;
486
487
488

                        if response.status().is_success() {
                            return response;
489
490
491
492
493
494
495
                        } else {
                            // if the worker is healthy, it means the request is bad, so return the error response
                            let health_response =
                                self.send_request(client, &worker_url, "/health", req).await;
                            if health_response.status().is_success() {
                                return response;
                            }
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
                        }

                        warn!(
                            "Request to {} failed (attempt {}/{})",
                            worker_url,
                            request_retries + 1,
                            MAX_REQUEST_RETRIES
                        );

                        request_retries += 1;
                        total_retries += 1;

                        if request_retries == MAX_REQUEST_RETRIES {
                            warn!("Removing failed worker: {}", worker_url);
                            self.remove_worker(&worker_url);
                            break;
                        }
                    }
                }
                Err(e) => return HttpResponse::InternalServerError().body(e),
            }
517
        }
518
519

        HttpResponse::InternalServerError().body("All retry attempts failed")
520
521
    }

522
523
524
525
526
527
528
529
530
531
532
533
534
    pub async fn route_to_all(
        &self,
        client: &reqwest::Client,
        route: &str,
        req: &HttpRequest,
    ) -> HttpResponse {
        // Get all worker URLs based on router type
        let worker_urls = match self {
            Router::PrefillDecode { .. } => {
                // For PD mode, route_to_all is not supported directly
                // It should be handled by PDRouter if needed
                return HttpResponse::NotImplemented()
                    .body("route_to_all not implemented for PrefillDecode mode");
535
            }
536
            _ => self.get_worker_urls(),
537
        };
538

539
540
541
542
543
544
545
546
        // Send requests to all workers concurrently
        let mut tasks = Vec::new();
        for worker_url in &worker_urls {
            let mut request_builder = client.post(format!("{}{}", worker_url, route));

            // Copy headers from original request
            for (name, value) in copy_request_headers(req) {
                request_builder = request_builder.header(name, value);
547
            }
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577

            tasks.push(request_builder.send());
        }

        // Wait for all responses
        let results = futures_util::future::join_all(tasks).await;

        // Check if all succeeded
        let all_success = results.iter().all(|r| {
            r.as_ref()
                .map(|res| res.status().is_success())
                .unwrap_or(false)
        });

        if all_success {
            HttpResponse::Ok().body("Operation completed on all servers")
        } else {
            HttpResponse::InternalServerError().body("Operation failed on one or more servers")
        }
    }

    pub async fn get_all_loads(
        &self,
        client: &reqwest::Client,
        _req: &HttpRequest,
    ) -> HttpResponse {
        // For PD mode, delegate to PDRouter
        match self {
            Router::PrefillDecode { pd_router } => {
                return pd_router.get_loads(client).await;
578
579
            }
            _ => {
580
                // For non-PD routers, handle normally
581
582
            }
        }
583

584
        let urls = self.get_worker_urls();
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
        let prefill_urls: Vec<String> = Vec::new();
        let decode_urls = urls;

        // Collect loads from all servers
        let mut prefill_loads = Vec::new();
        let mut decode_loads = Vec::new();

        // Get prefill loads
        for url in &prefill_urls {
            let load = self.get_worker_load(client, url).await.unwrap_or(-1);
            prefill_loads.push(serde_json::json!({
                "engine": format!("(Prefill@{})", url),
                "load": load as i64
            }));
        }

        // Get decode loads
        for url in &decode_urls {
            let load = self.get_worker_load(client, url).await.unwrap_or(-1);
            decode_loads.push(serde_json::json!({
                "engine": format!("(Decode@{})", url),
                "load": load as i64
            }));
        }

        HttpResponse::Ok().json(serde_json::json!({
            "prefill": prefill_loads,
            "decode": decode_loads
        }))
614
615
    }

616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
    // New method to route typed requests directly
    pub async fn route_typed_request<
        T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone,
    >(
        &self,
        client: &reqwest::Client,
        req: &HttpRequest,
        typed_req: &T,
        route: &str,
    ) -> HttpResponse {
        match self {
            Router::PrefillDecode { .. } => HttpResponse::InternalServerError()
                .body("PD routing should use specialized typed handlers"),
            _ => {
                // Handle retries like the original implementation
                let start = Instant::now();
                const MAX_REQUEST_RETRIES: u32 = 3;
                const MAX_TOTAL_RETRIES: u32 = 6;
                let mut total_retries = 0;

                while total_retries < MAX_TOTAL_RETRIES {
                    // Extract routing text directly from typed request
                    let text = typed_req.extract_text_for_routing();
                    let is_stream = typed_req.is_stream();

                    // Select worker based on text
                    let worker_url = self.select_generate_worker_from_text(&text);
                    let mut request_retries = 0;

                    // Try the same worker multiple times
                    while request_retries < MAX_REQUEST_RETRIES {
                        if total_retries >= 1 {
                            info!("Retrying request after {} failed attempts", total_retries);
                            counter!("sgl_router_retries_total", "route" => route.to_string())
                                .increment(1);
                        }

653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
                        // For CacheAware router, increment load before request
                        let load_incremented = match self {
                            Router::CacheAware { workers, .. } => {
                                let workers_guard = workers.read().unwrap();
                                if let Some(worker) =
                                    workers_guard.iter().find(|w| w.url() == &worker_url)
                                {
                                    worker.increment_load();
                                    gauge!("sgl_router_running_requests", "worker" => worker_url.to_string())
                                        .set(worker.load() as f64);
                                    true
                                } else {
                                    false
                                }
                            }
                            _ => false,
                        };

671
672
673
674
675
676
677
678
679
                        // Send typed request directly
                        let response = self
                            .send_typed_request(
                                client,
                                req,
                                typed_req,
                                route,
                                &worker_url,
                                is_stream,
680
                                load_incremented,
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
                            )
                            .await;

                        if response.status().is_success() {
                            let duration = start.elapsed();
                            histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string())
                                .record(duration.as_secs_f64());
                            return response;
                        } else {
                            // if the worker is healthy, it means the request is bad, so return the error response
                            let health_response =
                                self.send_request(client, &worker_url, "/health", req).await;
                            if health_response.status().is_success() {
                                counter!("sgl_router_request_errors_total", "route" => route.to_string())
                                    .increment(1);
                                return response;
                            }
                        }

                        warn!(
                            "Generate request to {} failed (attempt {}/{})",
                            worker_url,
                            request_retries + 1,
                            MAX_REQUEST_RETRIES
                        );

                        request_retries += 1;
                        total_retries += 1;
709

710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
                        if request_retries == MAX_REQUEST_RETRIES {
                            warn!("Removing failed worker: {}", worker_url);
                            self.remove_worker(&worker_url);
                            break;
                        }
                    }
                }

                counter!("sgl_router_request_errors_total", "route" => route.to_string())
                    .increment(1);
                HttpResponse::InternalServerError().body("All retry attempts failed")
            }
        }
    }

725
    // Helper method to select worker from text (returns index for RoundRobin/Random, URL for CacheAware)
726
727
    fn select_generate_worker_from_text(&self, text: &str) -> String {
        match self {
728
            Router::RoundRobin {
729
                workers,
730
                current_index,
731
                ..
732
            } => {
733
                let workers_guard = workers.read().unwrap();
734
                let idx = current_index
735
736
737
                    .fetch_update(
                        std::sync::atomic::Ordering::SeqCst,
                        std::sync::atomic::Ordering::SeqCst,
738
                        |x| Some((x + 1) % workers_guard.len()),
739
                    )
740
                    .unwrap();
741
                workers_guard[idx].url().to_string()
742
            }
743

744
745
746
747
748
749
            Router::Random { workers, .. } => {
                let workers_guard = workers.read().unwrap();
                workers_guard[rand::random::<usize>() % workers_guard.len()]
                    .url()
                    .to_string()
            }
750

751
            Router::CacheAware {
752
                workers,
753
                tree,
754
                cache_threshold,
755
756
                balance_abs_threshold,
                balance_rel_threshold,
757
758
                ..
            } => {
Byron Hsu's avatar
Byron Hsu committed
759
                let tree = tree.lock().unwrap();
760
                let workers_guard = workers.read().unwrap();
761

762
763
764
765
                // Get current load statistics from workers
                let loads: Vec<usize> = workers_guard.iter().map(|w| w.load()).collect();
                let max_load = *loads.iter().max().unwrap_or(&0);
                let min_load = *loads.iter().min().unwrap_or(&0);
766
767
768
769
770
771
772
773
774

                // Load is considered imbalanced if:
                // 1. (max - min) > abs_threshold AND
                // 2. max > rel_threshold * min
                let is_imbalanced = max_load.saturating_sub(min_load) > *balance_abs_threshold
                    && (max_load as f32) > (min_load as f32 * balance_rel_threshold);

                let selected_url = if is_imbalanced {
                    // Log load balancing trigger and current queue state
775
776
777
778
779
                    let worker_loads: Vec<(String, usize)> = workers_guard
                        .iter()
                        .map(|w| (w.url().to_string(), w.load()))
                        .collect();

780
                    info!(
781
782
                        "Load balancing triggered due to workload imbalance:\n\
                        Max load: {}, Min load: {}\n\
783
784
                        Current worker loads: {:?}",
                        max_load, min_load, worker_loads
785
786
                    );

787
788
789
790
                    counter!("sgl_router_load_balancing_events_total").increment(1);
                    gauge!("sgl_router_max_load").set(max_load as f64);
                    gauge!("sgl_router_min_load").set(min_load as f64);

791
                    // Use shortest queue routing when load is imbalanced
792
                    workers_guard
793
                        .iter()
794
795
796
                        .min_by_key(|w| w.load())
                        .map(|w| w.url().to_string())
                        .unwrap_or_else(|| workers_guard[0].url().to_string())
797
798
                } else {
                    // Use cache-aware routing when load is balanced
799
800
801
                    let (matched_text, matched_worker) = tree.prefix_match(&text);
                    let matched_rate =
                        matched_text.chars().count() as f32 / text.chars().count() as f32;
802

803
                    if matched_rate > *cache_threshold {
804
                        counter!("sgl_router_cache_hits_total").increment(1);
805
806
                        matched_worker.to_string()
                    } else {
807
                        counter!("sgl_router_cache_misses_total").increment(1);
808
                        tree.get_smallest_tenant()
809
                    }
810
                };
811

812
813
814
815
816
817
                // Find the selected worker and increment processed counter only
                if let Some(worker) = workers_guard.iter().find(|w| w.url() == &selected_url) {
                    worker.increment_processed();
                    counter!("sgl_router_processed_requests_total", "worker" => selected_url.to_string())
                        .increment(1);
                }
818

819
820
821
                tree.insert(&text, &selected_url);

                selected_url
822
            }
823
824
825
826
827
            Router::PrefillDecode { .. } => {
                // For PD mode, we don't use this method
                return "PD_MODE_ERROR".to_string();
            }
        }
828
829
    }

830
831
    // Send typed request directly without conversion
    async fn send_typed_request<T: serde::Serialize>(
832
833
        &self,
        client: &reqwest::Client,
834
        req: &HttpRequest,
835
        typed_req: &T,
836
837
        route: &str,
        worker_url: &str,
838
        is_stream: bool,
839
        load_incremented: bool, // Whether load was incremented for this request
840
    ) -> HttpResponse {
841
842
843
844
845
846
        let start = Instant::now();

        // Debug: Log what we're sending
        if let Ok(json_str) = serde_json::to_string_pretty(typed_req) {
            debug!("Sending request to {}: {}", route, json_str);
        }
847

848
        let mut request_builder = client
849
            .post(format!("{}{}", worker_url, route))
850
            .json(typed_req); // Use json() directly with typed request
851
852
853

        // Copy all headers from original request
        for (name, value) in copy_request_headers(req) {
854
855
856
857
            // Skip Content-Type and Content-Length as .json() sets them
            if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" {
                request_builder = request_builder.header(&name, &value);
            }
858
859
860
        }

        let res = match request_builder.send().await {
861
            Ok(res) => res,
862
863
            Err(e) => {
                error!("Failed to send request to {}: {}", worker_url, e);
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879

                // Decrement load on error for CacheAware router
                if load_incremented {
                    if let Router::CacheAware { workers, .. } = self {
                        if let Ok(workers_guard) = workers.read() {
                            if let Some(worker) =
                                workers_guard.iter().find(|w| w.url() == worker_url)
                            {
                                worker.decrement_load();
                                gauge!("sgl_router_running_requests", "worker" => worker_url.to_string())
                                    .set(worker.load() as f64);
                            }
                        }
                    }
                }

880
881
                return HttpResponse::InternalServerError().body(format!("Request failed: {}", e));
            }
882
        };
883

884
885
        let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
            .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
886

887
        if !is_stream {
888
889
            // For non-streaming requests, get response first
            let response = match res.bytes().await {
890
                Ok(body) => HttpResponse::build(status).body(body.to_vec()),
891
892
893
894
                Err(e) => {
                    let error_msg = format!("Failed to get response body: {}", e);
                    HttpResponse::InternalServerError().body(error_msg)
                }
895
896
            };

897
898
899
900
901
902
903
904
905
            // Decrement load counter for non-streaming CacheAware requests
            if load_incremented && !is_stream {
                if let Router::CacheAware { workers, .. } = self {
                    if let Ok(workers_guard) = workers.read() {
                        if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) {
                            worker.decrement_load();
                            gauge!("sgl_router_running_requests", "worker" => worker_url.to_string())
                                .set(worker.load() as f64);
                        }
906
907
                    }
                }
908
            }
909

910
911
912
913
914
915
            // Record metrics
            let duration = start.elapsed();
            histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string())
                .record(duration.as_secs_f64());
            counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1);

916
            response
917
918
919
        } else if let Router::CacheAware { workers, .. } = self {
            // For streaming with CacheAware router, we need to manually decrement when done
            let workers = Arc::clone(workers);
920
            let worker_url = worker_url.to_string();
921
922
923
924
925
926
927
928
929

            HttpResponse::build(status)
                .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
                .streaming(
                    res.bytes_stream()
                        .map_err(|_| {
                            actix_web::error::ErrorInternalServerError("Failed to read stream")
                        })
                        .inspect(move |bytes| {
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
                            if let Ok(bytes) = bytes {
                                if bytes
                                    .as_ref()
                                    .windows(12)
                                    .any(|window| window == b"data: [DONE]")
                                {
                                    if let Ok(workers_guard) = workers.read() {
                                        if let Some(worker) =
                                            workers_guard.iter().find(|w| w.url() == &worker_url)
                                        {
                                            worker.decrement_load();
                                            gauge!("sgl_router_running_requests", "worker" => worker_url.to_string())
                                                .set(worker.load() as f64);
                                            debug!("Streaming is done!!")
                                        }
                                    }
                                }
947
948
949
                            }
                        }),
                )
950
        } else {
951
            // For non-CacheAware routers, just stream without load tracking
952
953
954
            HttpResponse::build(status)
                .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
                .streaming(res.bytes_stream().map_err(|_| {
955
                    actix_web::error::ErrorInternalServerError("Failed to read stream")
956
                }))
957
958
        }
    }
959

960
    pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
        let (timeout_secs, interval_secs) = match self {
            Router::Random {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
            Router::RoundRobin {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
            Router::CacheAware {
                timeout_secs,
                interval_secs,
                ..
            } => (*timeout_secs, *interval_secs),
977
978
979
980
            Router::PrefillDecode { .. } => {
                // For PD mode, we don't support adding workers via this method
                return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string());
            }
981
        };
982
983

        let start_time = std::time::Instant::now();
984
985
986
987
        let client = reqwest::Client::builder()
            .timeout(Duration::from_secs(timeout_secs))
            .build()
            .map_err(|e| format!("Failed to create HTTP client: {}", e))?;
988
989
990

        loop {
            if start_time.elapsed() > Duration::from_secs(timeout_secs) {
991
                error!(
992
                    "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
993
994
                    timeout_secs, worker_url
                );
995
                return Err(format!(
996
                    "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
997
998
999
1000
1001
1002
1003
1004
                    timeout_secs, worker_url
                ));
            }

            match client.get(&format!("{}/health", worker_url)).send().await {
                Ok(res) => {
                    if res.status().is_success() {
                        match self {
1005
1006
1007
                            Router::RoundRobin { workers, .. }
                            | Router::Random { workers, .. }
                            | Router::CacheAware { workers, .. } => {
1008
                                info!("Worker {} health check passed", worker_url);
1009
1010
                                let mut workers_guard = workers.write().unwrap();
                                if workers_guard.iter().any(|w| w.url() == worker_url) {
1011
                                    return Err(format!("Worker {} already exists", worker_url));
1012
1013
                                }
                                info!("Added worker: {}", worker_url);
1014
1015
1016
1017
                                let new_worker =
                                    WorkerFactory::create_regular(worker_url.to_string());
                                workers_guard.push(new_worker);
                                gauge!("sgl_router_active_workers").set(workers_guard.len() as f64);
1018
                            }
1019
1020
1021
                            Router::PrefillDecode { .. } => {
                                return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string());
                            }
1022
                        }
1023

1024
1025
                        // If cache aware, add worker to tree
                        if let Router::CacheAware { tree, .. } = self {
1026
                            // Add worker to tree
1027
                            tree.lock().unwrap().insert("", worker_url);
1028
1029
1030
                        }

                        return Ok(format!("Successfully added worker: {}", worker_url));
1031
1032
                    } else {
                        info!(
1033
1034
1035
                            "Worker {} health check is pending with status: {}.",
                            worker_url,
                            res.status()
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
                        );
                        // if the url does not have http or https prefix, warn users
                        if !worker_url.starts_with("http://") && !worker_url.starts_with("https://")
                        {
                            warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
                        }

                        tokio::time::sleep(Duration::from_secs(interval_secs)).await;
                        continue;
                    }
                }
                Err(e) => {
1048
1049
1050
1051
                    info!(
                        "Worker {} health check is pending with error: {}",
                        worker_url, e
                    );
1052
1053
1054
1055
1056
1057
1058
1059
1060

                    // if the url does not have http or https prefix, warn users
                    if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") {
                        warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url);
                    }

                    tokio::time::sleep(Duration::from_secs(interval_secs)).await;
                    continue;
                }
1061
1062
1063
            }
        }
    }
1064

1065
    pub fn remove_worker(&self, worker_url: &str) {
1066
        match self {
1067
1068
1069
1070
1071
1072
            Router::RoundRobin { workers, .. }
            | Router::Random { workers, .. }
            | Router::CacheAware { workers, .. } => {
                let mut workers_guard = workers.write().unwrap();
                if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
                    workers_guard.remove(index);
1073
                    info!("Removed worker: {}", worker_url);
1074
                    gauge!("sgl_router_active_workers").set(workers_guard.len() as f64);
1075
1076
1077
1078
                } else {
                    warn!("Worker {} not found, skipping removal", worker_url);
                    return;
                }
1079
            }
1080
1081
1082
1083
            Router::PrefillDecode { .. } => {
                warn!("Removing workers from PrefillDecode router not supported via remove_worker. Use dedicated PD management methods.");
                return;
            }
1084
1085
1086
        }

        // if cache aware, remove the worker from the tree
1087
        if let Router::CacheAware { tree, .. } = self {
1088
            tree.lock().unwrap().remove_tenant(&worker_url);
1089
            info!("Removed worker from tree: {}", worker_url);
1090
1091
        }
    }
1092

1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
    /// Add a worker with PD mode support
    pub async fn add_pd_worker(
        &self,
        worker_url: &str,
        pod_type: crate::service_discovery::PodType,
        bootstrap_port: Option<u16>,
    ) -> Result<String, String> {
        match self {
            Router::PrefillDecode { pd_router } => match pod_type {
                crate::service_discovery::PodType::Prefill => pd_router
                    .add_prefill_server(worker_url.to_string(), bootstrap_port)
                    .await
                    .map_err(|e| e.to_string()),
                crate::service_discovery::PodType::Decode => pd_router
                    .add_decode_server(worker_url.to_string())
                    .await
                    .map_err(|e| e.to_string()),
                crate::service_discovery::PodType::Regular => {
                    Err("Regular pod type not supported in PD mode".to_string())
                }
            },
            _ => Err("add_pd_worker only supported in PD mode".to_string()),
        }
    }

    /// Remove a worker with PD mode support
    pub async fn remove_pd_worker(
        &self,
        worker_url: &str,
        pod_type: crate::service_discovery::PodType,
    ) -> Result<String, String> {
        match self {
            Router::PrefillDecode { pd_router } => match pod_type {
                crate::service_discovery::PodType::Prefill => pd_router
                    .remove_prefill_server(worker_url)
                    .await
                    .map_err(|e| e.to_string()),
                crate::service_discovery::PodType::Decode => pd_router
                    .remove_decode_server(worker_url)
                    .await
                    .map_err(|e| e.to_string()),
                crate::service_discovery::PodType::Regular => {
                    Err("Regular pod type not supported in PD mode".to_string())
                }
            },
            _ => Err("remove_pd_worker only supported in PD mode".to_string()),
        }
    }

1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
    async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
        match client.get(&format!("{}/get_load", worker_url)).send().await {
            Ok(res) if res.status().is_success() => match res.bytes().await {
                Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
                    Ok(data) => data
                        .get("load")
                        .and_then(|v| v.as_i64())
                        .map(|v| v as isize),
                    Err(e) => {
                        debug!("Failed to parse load response from {}: {}", worker_url, e);
                        None
                    }
                },
                Err(e) => {
                    debug!("Failed to read load response from {}: {}", worker_url, e);
                    None
                }
            },
            Ok(res) => {
                debug!(
                    "Worker {} returned non-success status: {}",
                    worker_url,
                    res.status()
                );
                None
            }
            Err(e) => {
                debug!("Failed to get load from {}: {}", worker_url, e);
                None
            }
        }
    }

    // PD-specific wrapper methods that delegate to PDRouter
    pub async fn route_pd_health_generate(
        &self,
        _client: &reqwest::Client,
        _req: &HttpRequest,
    ) -> HttpResponse {
        match self {
            Router::PrefillDecode { pd_router } => {
                pd_router.health_generate(&pd_router.http_client).await
            }
            _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
        }
    }

    pub async fn route_pd_generate_typed(
        &self,
        _client: &reqwest::Client,
        req: &HttpRequest,
        typed_req: crate::pd_types::GenerateReqInput,
        route: &str,
    ) -> HttpResponse {
        match self {
            Router::PrefillDecode { pd_router } => {
                pd_router
                    .route_generate(&pd_router.http_client, req, typed_req, route)
                    .await
            }
            _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
        }
    }

    pub async fn route_pd_chat_typed(
        &self,
        _client: &reqwest::Client,
        req: &HttpRequest,
        typed_req: crate::pd_types::ChatReqInput,
        route: &str,
    ) -> HttpResponse {
        match self {
            Router::PrefillDecode { pd_router } => {
                pd_router
                    .route_chat(&pd_router.http_client, req, typed_req, route)
                    .await
            }
            _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
        }
    }

    pub async fn get_pd_server_info(
        &self,
        _client: &reqwest::Client,
        _req: &HttpRequest,
    ) -> HttpResponse {
        match self {
            Router::PrefillDecode { pd_router } => {
                pd_router.get_server_info(&pd_router.http_client).await
            }
            _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
        }
    }

    pub async fn get_pd_models(
        &self,
        _client: &reqwest::Client,
        req: &HttpRequest,
    ) -> HttpResponse {
        match self {
            Router::PrefillDecode { pd_router } => {
                pd_router.get_models(&pd_router.http_client, req).await
            }
            _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
        }
    }

    pub async fn route_pd_flush_cache(&self, _client: &reqwest::Client) -> HttpResponse {
        match self {
            Router::PrefillDecode { pd_router } => {
                pd_router.flush_cache(&pd_router.http_client).await
            }
            _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
        }
    }

    pub async fn get_pd_model_info(
        &self,
        _client: &reqwest::Client,
        req: &HttpRequest,
    ) -> HttpResponse {
        match self {
            Router::PrefillDecode { pd_router } => {
                pd_router.get_model_info(&pd_router.http_client, req).await
            }
            _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
        }
    }
1270
}
1271
1272
1273
1274
1275
1276
1277

#[cfg(test)]
mod tests {
    use super::*;
    use crate::service_discovery::PodType;

    fn create_test_regular_router() -> Router {
1278
1279
1280
1281
        let workers = vec![
            WorkerFactory::create_regular("http://worker1:8080".to_string()),
            WorkerFactory::create_regular("http://worker2:8080".to_string()),
        ];
1282
        Router::Random {
1283
            workers: Arc::new(RwLock::new(workers)),
1284
1285
            timeout_secs: 5,
            interval_secs: 1,
1286
            _health_checker: None,
1287
1288
1289
1290
1291
1292
        }
    }

    #[test]
    fn test_router_get_worker_urls_regular() {
        let router = create_test_regular_router();
1293
        let urls = router.get_worker_urls();
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376

        assert_eq!(urls.len(), 2);
        assert!(urls.contains(&"http://worker1:8080".to_string()));
        assert!(urls.contains(&"http://worker2:8080".to_string()));
    }

    // #[test]
    // fn test_router_get_worker_urls_pd_mode() {
    //     // For PD mode, get_worker_urls returns empty list
    //     // Note: PDRouter::new requires health checks which fail in tests
    //     // This test would need a mock server or different test setup
    // }

    #[tokio::test]
    async fn test_add_pd_worker_with_regular_router() {
        let router = create_test_regular_router();

        let result = router
            .add_pd_worker("http://new-worker:8080", PodType::Prefill, Some(8081))
            .await;

        assert!(result.is_err());
        assert!(result
            .unwrap_err()
            .contains("add_pd_worker only supported in PD mode"));
    }

    #[tokio::test]
    async fn test_remove_pd_worker_with_regular_router() {
        let router = create_test_regular_router();

        let result = router
            .remove_pd_worker("http://worker:8080", PodType::Decode)
            .await;

        assert!(result.is_err());
        assert!(result
            .unwrap_err()
            .contains("remove_pd_worker only supported in PD mode"));
    }

    // #[tokio::test]
    // async fn test_add_pd_worker_with_pd_router_regular_type() {
    //     // Note: PDRouter::new requires health checks which fail in tests
    //     // This test would need a mock server or different test setup
    // }

    // #[tokio::test]
    // async fn test_remove_pd_worker_with_pd_router_regular_type() {
    //     // Note: PDRouter::new requires health checks which fail in tests
    //     // This test would need a mock server or different test setup
    // }

    #[test]
    fn test_select_first_worker_regular() {
        let router = create_test_regular_router();
        let result = router.select_first_worker();

        assert!(result.is_ok());
        assert_eq!(result.unwrap(), "http://worker1:8080");
    }

    // #[test]
    // fn test_select_first_worker_pd_mode() {
    //     // Note: PDRouter::new requires health checks which fail in tests
    //     // This test would need a mock server or different test setup
    // }

    #[test]
    fn test_wait_for_healthy_workers_empty_list() {
        let result = Router::wait_for_healthy_workers(&[], 1, 1);
        assert!(result.is_ok());
    }

    #[test]
    fn test_wait_for_healthy_workers_invalid_urls() {
        // This test will timeout quickly since the URLs are invalid
        let result =
            Router::wait_for_healthy_workers(&["http://nonexistent:8080".to_string()], 1, 1);
        assert!(result.is_err());
        assert!(result.unwrap_err().contains("Timeout"));
    }
}