router.rs 41.9 KB
Newer Older
1
use crate::config::types::RetryConfig;
2
use crate::core::{
3
    is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType,
4
};
5
use crate::metrics::RouterMetrics;
6
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
7
use crate::protocols::spec::{
8
    ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, GenerationRequest,
9
    RerankRequest, RerankResponse, RerankResult, ResponsesGetParams, ResponsesRequest,
10
};
11
use crate::routers::header_utils;
12
use crate::routers::RouterTrait;
13
use axum::body::to_bytes;
14
15
16
use axum::{
    body::Body,
    extract::Request,
17
18
19
    http::{
        header::CONTENT_LENGTH, header::CONTENT_TYPE, HeaderMap, HeaderValue, Method, StatusCode,
    },
20
21
22
23
    response::{IntoResponse, Response},
    Json,
};
use futures_util::StreamExt;
24
use reqwest::Client;
25
use std::collections::HashMap;
26
use std::sync::Arc;
27
use std::time::{Duration, Instant};
28
use tokio_stream::wrappers::UnboundedReceiverStream;
29
use tracing::{debug, error};
30
31
32
33

/// Regular router that uses injected load balancing policies
#[derive(Debug)]
pub struct Router {
34
35
    worker_registry: Arc<WorkerRegistry>,
    policy_registry: Arc<PolicyRegistry>,
36
    client: Client,
37
    dp_aware: bool,
38
    enable_igw: bool,
39
    retry_config: RetryConfig,
40
41
42
43
44
    _worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
    _load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
}

impl Router {
45
    /// Create a new router with injected policy and client
46
47
48
49
50
51
52
    pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
        let workers = ctx.worker_registry.get_workers_filtered(
            None, // any model
            Some(WorkerType::Regular),
            Some(ConnectionMode::Http),
            false, // include all workers
        );
53

54
        RouterMetrics::set_active_workers(workers.len());
55

56
        let worker_urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
57

58
59
60
        let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
        let worker_loads = Arc::new(rx);

61
62
        let default_policy = ctx.policy_registry.get_default_policy();

63
        let load_monitor_handle = if default_policy.name() == "power_of_two" {
64
            let monitor_urls = worker_urls.clone();
65
66
67
68
69
70
71
72
            let monitor_api_keys = monitor_urls
                .iter()
                .map(|url| {
                    ctx.worker_registry
                        .get_by_url(url)
                        .and_then(|w| w.api_key().clone())
                })
                .collect::<Vec<Option<String>>>();
73
            let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
74
            let policy_clone = default_policy.clone();
75
            let client_clone = ctx.client.clone();
76
77

            Some(Arc::new(tokio::spawn(async move {
78
79
                Self::monitor_worker_loads(
                    monitor_urls,
80
                    monitor_api_keys,
81
82
83
84
85
86
                    tx,
                    monitor_interval,
                    policy_clone,
                    client_clone,
                )
                .await;
87
88
89
90
91
92
            })))
        } else {
            None
        };

        Ok(Router {
93
94
            worker_registry: ctx.worker_registry.clone(),
            policy_registry: ctx.policy_registry.clone(),
95
96
            client: ctx.client.clone(),
            dp_aware: ctx.router_config.dp_aware,
97
            enable_igw: ctx.router_config.enable_igw,
98
            retry_config: ctx.router_config.effective_retry_config(),
99
100
101
102
103
104
            _worker_loads: worker_loads,
            _load_monitor_handle: load_monitor_handle,
        })
    }

    fn select_first_worker(&self) -> Result<String, String> {
105
        let workers = self.worker_registry.get_all();
106
107
        let healthy_workers: Vec<_> = workers.iter().filter(|w| w.is_healthy()).collect();
        if healthy_workers.is_empty() {
108
109
            Err("No workers are available".to_string())
        } else {
110
            Ok(healthy_workers[0].url().to_string())
111
112
113
        }
    }

114
    // Helper method to proxy GET requests to the first available worker
115
    async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response {
116
        let headers = header_utils::copy_request_headers(&req);
117
118
119

        match self.select_first_worker() {
            Ok(worker_url) => {
120
                let mut request_builder = self.client.get(format!("{}/{}", worker_url, endpoint));
121
                for (name, value) in headers {
122
123
                    let name_lc = name.to_lowercase();
                    if name_lc != "content-type" && name_lc != "content-length" {
124
125
126
                        request_builder = request_builder.header(name, value);
                    }
                }
127

128
129
130
131
                match request_builder.send().await {
                    Ok(res) => {
                        let status = StatusCode::from_u16(res.status().as_u16())
                            .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
132
133
134
135
136

                        // Preserve headers from backend
                        let response_headers =
                            header_utils::preserve_response_headers(res.headers());

137
                        match res.bytes().await {
138
139
140
141
142
143
                            Ok(body) => {
                                let mut response = Response::new(axum::body::Body::from(body));
                                *response.status_mut() = status;
                                *response.headers_mut() = response_headers;
                                response
                            }
144
145
146
147
148
                            Err(e) => (
                                StatusCode::INTERNAL_SERVER_ERROR,
                                format!("Failed to read response: {}", e),
                            )
                                .into_response(),
149
150
                        }
                    }
151
152
153
154
155
                    Err(e) => (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        format!("Request failed: {}", e),
                    )
                        .into_response(),
156
157
                }
            }
158
            Err(e) => (StatusCode::SERVICE_UNAVAILABLE, e).into_response(),
159
160
161
        }
    }

162
163
164
165
166
167
    /// Select worker for a specific model considering circuit breaker state
    fn select_worker_for_model(
        &self,
        model_id: Option<&str>,
        text: Option<&str>,
    ) -> Option<Arc<dyn Worker>> {
168
169
        let effective_model_id = if !self.enable_igw { None } else { model_id };

170
171
        // Get workers for the specified model O(1), filtered by connection mode
        let workers = self.worker_registry.get_workers_filtered(
172
            effective_model_id,
173
174
175
176
            Some(WorkerType::Regular),
            Some(ConnectionMode::Http),
            false, // get all workers, we'll filter by is_available() next
        );
177
178

        let available: Vec<Arc<dyn Worker>> = workers
179
180
            .iter()
            .filter(|w| w.is_available())
181
            .cloned()
182
183
184
185
            .collect();
        if available.is_empty() {
            return None;
        }
186
187
188
189
190
191
192
193
194

        // Get the appropriate policy for this model
        let policy = match model_id {
            Some(model) => self.policy_registry.get_policy_or_default(model),
            None => self.policy_registry.get_default_policy(),
        };

        let idx = policy.select_worker(&available, text)?;
        Some(available[idx].clone())
195
196
    }

197
    pub async fn route_typed_request<T: GenerationRequest + serde::Serialize + Clone>(
198
        &self,
199
        headers: Option<&HeaderMap>,
200
201
        typed_req: &T,
        route: &str,
202
        model_id: Option<&str>,
203
    ) -> Response {
204
        let start = Instant::now();
205
206
207
208
209
210
211
        let is_stream = typed_req.is_stream();
        let text = typed_req.extract_text_for_routing();

        let response = RetryExecutor::execute_response_with_retry(
            &self.retry_config,
            // operation per attempt
            |_: u32| async {
212
                let worker = match self.select_worker_for_model(model_id, Some(&text)) {
213
214
215
216
217
218
219
220
221
222
                    Some(w) => w,
                    None => {
                        RouterMetrics::record_request_error(route, "no_available_workers");
                        return (
                            StatusCode::SERVICE_UNAVAILABLE,
                            "No available workers (all circuits open or unhealthy)",
                        )
                            .into_response();
                    }
                };
223

224
                // Optional load tracking for cache-aware policy
225
226
227
228
229
230
231
                // Get the policy for this model to check if it's cache-aware
                let policy = match model_id {
                    Some(model) => self.policy_registry.get_policy_or_default(model),
                    None => self.policy_registry.get_default_policy(),
                };

                let load_incremented = if policy.name() == "cache_aware" {
232
233
234
                    worker.increment_load();
                    RouterMetrics::set_running_requests(worker.url(), worker.load());
                    true
235
236
237
238
                } else {
                    false
                };

239
240
                // Keep a clone for potential cleanup on retry
                let worker_for_cleanup = if load_incremented {
241
                    Some(worker.clone())
242
243
244
245
                } else {
                    None
                };

246
247
                let response = self
                    .send_typed_request(
248
                        headers,
249
250
                        typed_req,
                        route,
251
                        worker.url(),
252
253
254
255
256
                        is_stream,
                        load_incremented,
                    )
                    .await;

257
                worker.record_outcome(response.status().is_success());
258
259
260
261
262
263
264
265
266
267
268
269
270

                // For retryable failures, we need to decrement load since send_typed_request
                // won't have done it (it only decrements on success or non-retryable failures)
                if is_retryable_status(response.status()) && load_incremented {
                    if let Some(cleanup_worker) = worker_for_cleanup {
                        cleanup_worker.decrement_load();
                        RouterMetrics::set_running_requests(
                            cleanup_worker.url(),
                            cleanup_worker.load(),
                        );
                    }
                }

271
272
273
                response
            },
            // should_retry predicate
274
            |res, _attempt| is_retryable_status(res.status()),
275
276
277
278
279
280
281
            // on_backoff hook
            |delay, attempt| {
                RouterMetrics::record_retry(route);
                RouterMetrics::record_retry_backoff_duration(delay, attempt);
            },
            // on_exhausted hook
            || RouterMetrics::record_retries_exhausted(route),
282
        )
283
284
285
286
287
288
        .await;

        if response.status().is_success() {
            let duration = start.elapsed();
            RouterMetrics::record_request(route);
            RouterMetrics::record_generate_duration(duration);
289
        } else if !is_retryable_status(response.status()) {
290
            RouterMetrics::record_request_error(route, "non_retryable_error");
291
        }
292
293

        response
294
295
    }

296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
    // Helper: return base worker URL (strips DP suffix when enabled)
    fn worker_base_url(&self, worker_url: &str) -> String {
        if self.dp_aware {
            if let Ok((prefix, _)) = Self::extract_dp_rank(worker_url) {
                return prefix.to_string();
            }
        }
        worker_url.to_string()
    }

    // Generic simple routing for GET/POST without JSON body
    async fn route_simple_request(
        &self,
        headers: Option<&HeaderMap>,
        endpoint: &str,
        method: Method,
    ) -> Response {
        // TODO: currently the sglang worker is using in-memory state management, so this implementation has to fan out to all workers.
        // Eventually, we need to have router to manage the chat history with a proper database, will update this implementation accordingly.
315
316
        let workers = self.worker_registry.get_all();
        if workers.is_empty() {
317
318
319
320
            return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response();
        }

        let mut last_response: Option<Response> = None;
321
322
323
        for worker in workers {
            let worker_url = worker.url();
            let base = self.worker_base_url(worker_url);
324
325
326
327
328
329
330
331
332
333
334
335
336
337

            let url = format!("{}/{}", base, endpoint);
            let mut request_builder = match method {
                Method::GET => self.client.get(url),
                Method::POST => self.client.post(url),
                _ => {
                    return (
                        StatusCode::METHOD_NOT_ALLOWED,
                        "Unsupported method for simple routing",
                    )
                        .into_response()
                }
            };

338
339
340
341
342
            if let Some(api_key) = worker.api_key() {
                request_builder =
                    request_builder.header("Authorization", format!("Bearer {}", api_key));
            }

343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
            if let Some(hdrs) = headers {
                for (name, value) in hdrs {
                    let name_lc = name.as_str().to_lowercase();
                    if name_lc != "content-type" && name_lc != "content-length" {
                        request_builder = request_builder.header(name, value);
                    }
                }
            }

            match request_builder.send().await {
                Ok(res) => {
                    let status = StatusCode::from_u16(res.status().as_u16())
                        .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
                    let response_headers = header_utils::preserve_response_headers(res.headers());
                    match res.bytes().await {
                        Ok(body) => {
                            let mut response = Response::new(axum::body::Body::from(body));
                            *response.status_mut() = status;
                            *response.headers_mut() = response_headers;
                            if status.is_success() {
                                return response;
                            }
                            last_response = Some(response);
                        }
                        Err(e) => {
                            last_response = Some(
                                (
                                    StatusCode::INTERNAL_SERVER_ERROR,
                                    format!("Failed to read response: {}", e),
                                )
                                    .into_response(),
                            );
                        }
                    }
                }
                Err(e) => {
                    last_response = Some(
                        (
                            StatusCode::INTERNAL_SERVER_ERROR,
                            format!("Request failed: {}", e),
                        )
                            .into_response(),
                    );
                }
            }
        }

        last_response
            .unwrap_or_else(|| (StatusCode::BAD_GATEWAY, "No worker response").into_response())
    }

    // Route a GET request with provided headers to a specific endpoint
    async fn route_get_request(&self, headers: Option<&HeaderMap>, endpoint: &str) -> Response {
        self.route_simple_request(headers, endpoint, Method::GET)
            .await
    }

    // Route a POST request with empty body to a specific endpoint
    async fn route_post_empty_request(
        &self,
        headers: Option<&HeaderMap>,
        endpoint: &str,
    ) -> Response {
        self.route_simple_request(headers, endpoint, Method::POST)
            .await
    }

410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
    // TODO (rui): Better accommodate to the Worker abstraction
    fn extract_dp_rank(worker_url: &str) -> Result<(&str, usize), String> {
        let parts: Vec<&str> = worker_url.split('@').collect();
        if parts.len() != 2 {
            return Err(format!("invalid worker_url format: {}", worker_url));
        }

        // Parse the second part (dp_rank) into an integer
        match parts[1].parse::<usize>() {
            Ok(dp_rank) => Ok((parts[0], dp_rank)),
            Err(_) => Err(format!(
                "failed to parse dp_rank from worker_url: {}",
                worker_url
            )),
        }
    }

427
428
429
    // Send typed request directly without conversion
    async fn send_typed_request<T: serde::Serialize>(
        &self,
430
        headers: Option<&HeaderMap>,
431
432
433
434
435
        typed_req: &T,
        route: &str,
        worker_url: &str,
        is_stream: bool,
        load_incremented: bool, // Whether load was incremented for this request
436
    ) -> Response {
437
438
439
440
441
442
        // Get the worker's API key if available
        let api_key = self
            .worker_registry
            .get_by_url(worker_url)
            .and_then(|w| w.api_key().clone());

443
444
445
446
447
        let mut request_builder = if self.dp_aware {
            let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) {
                Ok(tup) => tup,
                Err(e) => {
                    error!("Failed to extract dp_rank: {}", e);
448
449
450
451
452
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        format!("Failed to extract dp_rank: {}", e),
                    )
                        .into_response();
453
454
455
456
457
458
                }
            };

            let mut json_val = match serde_json::to_value(typed_req) {
                Ok(j) => j,
                Err(e) => {
459
460
461
462
463
                    return (
                        StatusCode::BAD_REQUEST,
                        format!("Convert into serde_json::Value failed: {}", e),
                    )
                        .into_response();
464
465
466
467
468
469
470
471
472
473
474
475
476
                }
            };

            if let Some(map) = json_val.as_object_mut() {
                map.insert(
                    String::from("data_parallel_rank"),
                    serde_json::json!(dp_rank),
                );
                debug!(
                    "Modified request body: {}",
                    serde_json::to_string(&json_val).unwrap_or(String::from("ERR"))
                );
            } else {
477
478
479
480
481
                return (
                    StatusCode::BAD_REQUEST,
                    "Failed to insert the data_parallel_rank field into the request body",
                )
                    .into_response();
482
483
            }

484
            self.client
485
486
487
                .post(format!("{}{}", worker_url_prefix, route))
                .json(&json_val)
        } else {
488
            self.client
489
490
491
                .post(format!("{}{}", worker_url, route))
                .json(typed_req) // Use json() directly with typed request
        };
492

493
494
495
496
        if let Some(key) = api_key {
            request_builder = request_builder.header("Authorization", format!("Bearer {}", key));
        }

497
498
499
500
        // Copy all headers from original request if provided
        if let Some(headers) = headers {
            for (name, value) in headers {
                // Skip Content-Type and Content-Length as .json() sets them
501
                if *name != CONTENT_TYPE && *name != CONTENT_LENGTH {
502
503
                    request_builder = request_builder.header(name, value);
                }
504
505
506
507
508
509
            }
        }

        let res = match request_builder.send().await {
            Ok(res) => res,
            Err(e) => {
510
511
512
513
                error!(
                    "Failed to send typed request worker_url={} route={} error={}",
                    worker_url, route, e
                );
514
515
516

                // Decrement load on error if it was incremented
                if load_incremented {
517
518
519
                    if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                        worker.decrement_load();
                        RouterMetrics::set_running_requests(worker_url, worker.load());
520
521
522
                    }
                }

523
524
525
526
527
                return (
                    StatusCode::INTERNAL_SERVER_ERROR,
                    format!("Request failed: {}", e),
                )
                    .into_response();
528
529
530
            }
        };

531
532
        let status = StatusCode::from_u16(res.status().as_u16())
            .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
533
534

        if !is_stream {
535
            // For non-streaming requests, preserve headers
536
            let response_headers = header_utils::preserve_response_headers(res.headers());
537

538
            let response = match res.bytes().await {
539
540
541
542
543
544
                Ok(body) => {
                    let mut response = Response::new(axum::body::Body::from(body));
                    *response.status_mut() = status;
                    *response.headers_mut() = response_headers;
                    response
                }
545
                Err(e) => {
546
547
                    // IMPORTANT: Decrement load on error before returning
                    if load_incremented {
548
549
550
                        if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                            worker.decrement_load();
                            RouterMetrics::set_running_requests(worker_url, worker.load());
551
552
553
                        }
                    }

554
                    let error_msg = format!("Failed to get response body: {}", e);
555
                    (StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response()
556
557
558
559
                }
            };

            // Decrement load counter for non-streaming requests if it was incremented
560
            if load_incremented {
561
562
563
                if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                    worker.decrement_load();
                    RouterMetrics::set_running_requests(worker_url, worker.load());
564
565
566
567
568
569
                }
            }

            response
        } else if load_incremented {
            // For streaming with load tracking, we need to manually decrement when done
570
            let registry = Arc::clone(&self.worker_registry);
571
572
            let worker_url = worker_url.to_string();

573
574
575
576
577
            // Preserve headers for streaming response
            let mut response_headers = header_utils::preserve_response_headers(res.headers());
            // Ensure we set the correct content-type for SSE
            response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));

578
579
580
581
582
583
            let stream = res.bytes_stream();
            let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

            // Spawn task to forward stream and detect completion
            tokio::spawn(async move {
                let mut stream = stream;
584
                let mut decremented = false;
585
586
587
588
589
590
591
592
593
                while let Some(chunk) = stream.next().await {
                    match chunk {
                        Ok(bytes) => {
                            // Check for stream end marker
                            if bytes
                                .as_ref()
                                .windows(12)
                                .any(|window| window == b"data: [DONE]")
                            {
594
595
596
597
                                if let Some(worker) = registry.get_by_url(&worker_url) {
                                    worker.decrement_load();
                                    RouterMetrics::set_running_requests(&worker_url, worker.load());
                                    decremented = true;
598
599
                                }
                            }
600
601
602
603
604
605
606
607
608
609
                            if tx.send(Ok(bytes)).is_err() {
                                break;
                            }
                        }
                        Err(e) => {
                            let _ = tx.send(Err(format!("Stream error: {}", e)));
                            break;
                        }
                    }
                }
610
                if !decremented {
611
612
613
                    if let Some(worker) = registry.get_by_url(&worker_url) {
                        worker.decrement_load();
                        RouterMetrics::set_running_requests(&worker_url, worker.load());
614
615
                    }
                }
616
617
618
619
620
621
622
            });

            let stream = UnboundedReceiverStream::new(rx);
            let body = Body::from_stream(stream);

            let mut response = Response::new(body);
            *response.status_mut() = status;
623
            *response.headers_mut() = response_headers;
624
            response
625
626
        } else {
            // For requests without load tracking, just stream
627
628
629
630
631
            // Preserve headers for streaming response
            let mut response_headers = header_utils::preserve_response_headers(res.headers());
            // Ensure we set the correct content-type for SSE
            response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));

632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
            let stream = res.bytes_stream();
            let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

            // Spawn task to forward stream
            tokio::spawn(async move {
                let mut stream = stream;
                while let Some(chunk) = stream.next().await {
                    match chunk {
                        Ok(bytes) => {
                            if tx.send(Ok(bytes)).is_err() {
                                break;
                            }
                        }
                        Err(e) => {
                            let _ = tx.send(Err(format!("Stream error: {}", e)));
                            break;
                        }
                    }
                }
            });

            let stream = UnboundedReceiverStream::new(rx);
            let body = Body::from_stream(stream);

            let mut response = Response::new(body);
            *response.status_mut() = status;
658
            *response.headers_mut() = response_headers;
659
            response
660
661
662
        }
    }

663
    async fn get_worker_load(&self, worker_url: &str, api_key: &Option<String>) -> Option<isize> {
664
665
666
667
668
669
670
671
672
673
674
675
676
677
        let worker_url = if self.dp_aware {
            // Need to extract the URL from "http://host:port@dp_rank"
            let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
                Ok(tup) => tup,
                Err(e) => {
                    error!("Failed to extract dp_rank: {}", e);
                    return None;
                }
            };
            worker_url_prefix
        } else {
            worker_url
        };

678
679
680
681
682
683
        let mut req_builder = self.client.get(format!("{}/get_load", worker_url));
        if let Some(key) = api_key {
            req_builder = req_builder.bearer_auth(key);
        }

        match req_builder.send().await {
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
            Ok(res) if res.status().is_success() => match res.bytes().await {
                Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
                    Ok(data) => data
                        .get("load")
                        .and_then(|v| v.as_i64())
                        .map(|v| v as isize),
                    Err(e) => {
                        debug!("Failed to parse load response from {}: {}", worker_url, e);
                        None
                    }
                },
                Err(e) => {
                    debug!("Failed to read load response from {}: {}", worker_url, e);
                    None
                }
            },
            Ok(res) => {
                debug!(
                    "Worker {} returned non-success status: {}",
                    worker_url,
                    res.status()
                );
                None
            }
            Err(e) => {
                debug!("Failed to get load from {}: {}", worker_url, e);
                None
            }
        }
    }

    // Background task to monitor worker loads
    async fn monitor_worker_loads(
        worker_urls: Vec<String>,
718
        worker_api_keys: Vec<Option<String>>,
719
720
721
        tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
        interval_secs: u64,
        policy: Arc<dyn LoadBalancingPolicy>,
722
        client: Client,
723
724
725
726
727
728
729
    ) {
        let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));

        loop {
            interval.tick().await;

            let mut loads = HashMap::new();
730
731
            for (url, api_key) in worker_urls.iter().zip(worker_api_keys.iter()) {
                if let Some(load) = Self::get_worker_load_static(&client, url, api_key).await {
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
                    loads.insert(url.clone(), load);
                }
            }

            if !loads.is_empty() {
                // Update policy with new loads
                policy.update_loads(&loads);

                // Send to watchers
                if let Err(e) = tx.send(loads) {
                    error!("Failed to send load update: {}", e);
                }
            }
        }
    }

    // Static version of get_worker_load for use in monitoring task
749
    async fn get_worker_load_static(
750
        client: &Client,
751
752
753
        worker_url: &str,
        api_key: &Option<String>,
    ) -> Option<isize> {
754
755
756
757
758
759
760
761
762
763
764
765
766
767
        let worker_url = if worker_url.contains("@") {
            // Need to extract the URL from "http://host:port@dp_rank"
            let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
                Ok(tup) => tup,
                Err(e) => {
                    debug!("Failed to extract dp_rank: {}", e);
                    return None;
                }
            };
            worker_url_prefix
        } else {
            worker_url
        };

768
769
770
771
772
        let mut req_builder = client.get(format!("{}/get_load", worker_url));
        if let Some(key) = api_key {
            req_builder = req_builder.bearer_auth(key);
        }
        match req_builder.send().await {
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
            Ok(res) if res.status().is_success() => match res.bytes().await {
                Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
                    Ok(data) => data
                        .get("load")
                        .and_then(|v| v.as_i64())
                        .map(|v| v as isize),
                    Err(e) => {
                        debug!("Failed to parse load response from {}: {}", worker_url, e);
                        None
                    }
                },
                Err(e) => {
                    debug!("Failed to read load response from {}: {}", worker_url, e);
                    None
                }
            },
            Ok(res) => {
                debug!(
                    "Worker {} returned non-success status: {}",
                    worker_url,
                    res.status()
                );
                None
            }
            Err(e) => {
                debug!("Failed to get load from {}: {}", worker_url, e);
                None
            }
        }
    }
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821

    async fn build_rerank_response(
        req: &RerankRequest,
        response: Response,
    ) -> anyhow::Result<Response> {
        let (_, response_body) = response.into_parts();
        let body_bytes = to_bytes(response_body, usize::MAX).await?;
        let rerank_results = serde_json::from_slice::<Vec<RerankResult>>(&body_bytes)?;
        let mut rerank_response =
            RerankResponse::new(rerank_results, req.model.clone(), req.rid.clone());
        rerank_response.sort_by_score();
        if let Some(top_k) = req.top_k {
            rerank_response.apply_top_k(top_k);
        }
        if !req.return_documents {
            rerank_response.drop_documents();
        }
        Ok(Json(rerank_response).into_response())
    }
822
823
824
825
}

use async_trait::async_trait;

826
#[async_trait]
827
828
829
830
831
impl RouterTrait for Router {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

832
    async fn health(&self, _req: Request<Body>) -> Response {
833
        let workers = self.worker_registry.get_all();
834
835
836
837
838
        let unhealthy_servers: Vec<_> = workers
            .iter()
            .filter(|w| !w.is_healthy())
            .map(|w| w.url().to_string())
            .collect();
839

840
841
        if unhealthy_servers.is_empty() {
            (StatusCode::OK, "All servers healthy").into_response()
842
        } else {
843
844
845
846
847
            (
                StatusCode::SERVICE_UNAVAILABLE,
                format!("Unhealthy servers: {:?}", unhealthy_servers),
            )
                .into_response()
848
849
850
        }
    }

851
852
    async fn health_generate(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "health_generate").await
853
854
    }

855
856
    async fn get_server_info(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "get_server_info").await
857
858
    }

859
860
    async fn get_models(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "v1/models").await
861
862
    }

863
864
    async fn get_model_info(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "get_model_info").await
865
866
867
868
    }

    async fn route_generate(
        &self,
869
870
        headers: Option<&HeaderMap>,
        body: &GenerateRequest,
871
        model_id: Option<&str>,
872
    ) -> Response {
873
874
        self.route_typed_request(headers, body, "/generate", model_id)
            .await
875
876
877
878
    }

    async fn route_chat(
        &self,
879
880
        headers: Option<&HeaderMap>,
        body: &ChatCompletionRequest,
881
        model_id: Option<&str>,
882
    ) -> Response {
883
        self.route_typed_request(headers, body, "/v1/chat/completions", model_id)
884
            .await
885
886
887
888
    }

    async fn route_completion(
        &self,
889
890
        headers: Option<&HeaderMap>,
        body: &CompletionRequest,
891
        model_id: Option<&str>,
892
    ) -> Response {
893
        self.route_typed_request(headers, body, "/v1/completions", model_id)
894
            .await
895
896
    }

897
898
899
900
    async fn route_responses(
        &self,
        headers: Option<&HeaderMap>,
        body: &ResponsesRequest,
901
        model_id: Option<&str>,
902
    ) -> Response {
903
        self.route_typed_request(headers, body, "/v1/responses", model_id)
904
905
906
            .await
    }

907
908
909
910
911
912
    async fn get_response(
        &self,
        headers: Option<&HeaderMap>,
        response_id: &str,
        _params: &ResponsesGetParams,
    ) -> Response {
913
914
915
916
917
918
919
920
921
        let endpoint = format!("v1/responses/{}", response_id);
        self.route_get_request(headers, &endpoint).await
    }

    async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
        let endpoint = format!("v1/responses/{}/cancel", response_id);
        self.route_post_empty_request(headers, &endpoint).await
    }

922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
    async fn route_embeddings(
        &self,
        headers: Option<&HeaderMap>,
        body: &EmbeddingRequest,
        model_id: Option<&str>,
    ) -> Response {
        // Record embeddings-specific metrics in addition to general request metrics
        let start = Instant::now();
        let res = self
            .route_typed_request(headers, body, "/v1/embeddings", model_id)
            .await;

        // Embedding specific metrics
        if res.status().is_success() {
            RouterMetrics::record_embeddings_request();
            RouterMetrics::record_embeddings_duration(start.elapsed());
        } else {
            let error_type = format!("http_{}", res.status().as_u16());
            RouterMetrics::record_embeddings_error(&error_type);
        }

        res
944
945
    }

946
947
948
949
950
951
    async fn route_rerank(
        &self,
        headers: Option<&HeaderMap>,
        body: &RerankRequest,
        model_id: Option<&str>,
    ) -> Response {
952
953
954
        if let Err(e) = body.validate() {
            return (StatusCode::BAD_REQUEST, e).into_response();
        }
955
956
957
        let response = self
            .route_typed_request(headers, body, "/v1/rerank", model_id)
            .await;
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
        if response.status().is_success() {
            match Self::build_rerank_response(body, response).await {
                Ok(rerank_response) => rerank_response,
                Err(e) => {
                    error!("Failed to build rerank response: {}", e);
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        "Failed to build rerank response".to_string(),
                    )
                        .into_response();
                }
            }
        } else {
            response
        }
973
974
    }

975
    async fn flush_cache(&self) -> Response {
976
977
978
        // Get all workers
        let workers = self.worker_registry.get_all();
        let worker_urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
979
980
981
982

        // Send requests to all workers concurrently without headers
        let mut tasks = Vec::new();
        for worker_url in &worker_urls {
983
984
985
986
987
988
            // Get the worker's API key if available
            let api_key = self
                .worker_registry
                .get_by_url(worker_url)
                .and_then(|w| w.api_key().clone());

989
990
991
992
993
994
            let worker_url = if self.dp_aware {
                // Need to extract the URL from "http://host:port@dp_rank"
                let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
                    Ok(tup) => tup,
                    Err(e) => {
                        error!("Failed to extract dp_rank: {}", e);
995
996
997
998
999
                        return (
                            StatusCode::INTERNAL_SERVER_ERROR,
                            format!("Failed to extract dp_rank: {}", e),
                        )
                            .into_response();
1000
1001
1002
1003
1004
1005
                    }
                };
                worker_url_prefix
            } else {
                worker_url
            };
1006
1007
1008
1009
1010
1011
1012
            let mut request_builder = self.client.post(format!("{}/flush_cache", worker_url));

            if let Some(key) = api_key {
                request_builder =
                    request_builder.header("Authorization", format!("Bearer {}", key));
            }

1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
            tasks.push(request_builder.send());
        }

        // Wait for all responses
        let results = futures_util::future::join_all(tasks).await;

        // Check if all succeeded
        let all_success = results.iter().all(|r| {
            r.as_ref()
                .map(|res| res.status().is_success())
                .unwrap_or(false)
        });

        if all_success {
1027
            (StatusCode::OK, "Cache flushed on all servers").into_response()
1028
        } else {
1029
1030
1031
1032
1033
            (
                StatusCode::INTERNAL_SERVER_ERROR,
                "Cache flush failed on one or more servers",
            )
                .into_response()
1034
1035
1036
        }
    }

1037
    async fn get_worker_loads(&self) -> Response {
1038
        let urls_with_key = self.worker_registry.get_all_urls_with_api_key();
1039
1040
1041
        let mut loads = Vec::new();

        // Get loads from all workers
1042
1043
        for (url, api_key) in &urls_with_key {
            let load = self.get_worker_load(url, api_key).await.unwrap_or(-1);
1044
1045
1046
1047
1048
1049
            loads.push(serde_json::json!({
                "worker": url,
                "load": load
            }));
        }

1050
        Json(serde_json::json!({
1051
1052
            "workers": loads
        }))
1053
        .into_response()
1054
1055
1056
1057
1058
1059
    }

    fn router_type(&self) -> &'static str {
        "regular"
    }

1060
    fn readiness(&self) -> Response {
1061
        // Regular router is ready if it has at least one healthy worker
1062
1063
1064
        let workers = self.worker_registry.get_all();
        let healthy_count = workers.iter().filter(|w| w.is_healthy()).count();
        let total_workers = workers.len();
1065
1066

        if healthy_count > 0 {
1067
            Json(serde_json::json!({
1068
1069
                "status": "ready",
                "healthy_workers": healthy_count,
1070
                "total_workers": total_workers
1071
            }))
1072
            .into_response()
1073
        } else {
1074
1075
1076
1077
1078
            (
                StatusCode::SERVICE_UNAVAILABLE,
                Json(serde_json::json!({
                    "status": "not_ready",
                    "reason": "no healthy workers available",
1079
                    "total_workers": total_workers
1080
1081
1082
                })),
            )
                .into_response()
1083
1084
1085
1086
1087
1088
1089
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
1090
    use crate::core::BasicWorkerBuilder;
1091
1092
1093
    use std::collections::HashMap;

    fn create_test_regular_router() -> Router {
1094
1095
1096
1097
1098
1099
1100
        // Create registries
        let worker_registry = Arc::new(WorkerRegistry::new());
        let policy_registry = Arc::new(PolicyRegistry::new(
            crate::config::types::PolicyConfig::RoundRobin,
        ));

        // Register test workers
1101
1102
1103
1104
1105
1106
        let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
            .worker_type(WorkerType::Regular)
            .build();
        let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
            .worker_type(WorkerType::Regular)
            .build();
1107
1108
1109
        worker_registry.register(Arc::new(worker1));
        worker_registry.register(Arc::new(worker2));

1110
1111
        let (_, rx) = tokio::sync::watch::channel(HashMap::new());
        Router {
1112
1113
            worker_registry,
            policy_registry,
1114
            dp_aware: false,
1115
            client: Client::new(),
1116
            retry_config: RetryConfig::default(),
1117
1118
            _worker_loads: Arc::new(rx),
            _load_monitor_handle: None,
1119
            enable_igw: false,
1120
1121
1122
        }
    }

1123
1124
1125
1126
1127
1128
1129
    fn create_test_unhealthy_router() -> Router {
        let router = create_test_regular_router();
        let workers = router.worker_registry.get_all();
        workers[0].set_healthy(false);
        router
    }

1130
1131
1132
    #[test]
    fn test_router_get_worker_urls_regular() {
        let router = create_test_regular_router();
1133
1134
        let workers = router.worker_registry.get_all();
        let urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146

        assert_eq!(urls.len(), 2);
        assert!(urls.contains(&"http://worker1:8080".to_string()));
        assert!(urls.contains(&"http://worker2:8080".to_string()));
    }

    #[test]
    fn test_select_first_worker_regular() {
        let router = create_test_regular_router();
        let result = router.select_first_worker();

        assert!(result.is_ok());
1147
1148
1149
        let url = result.unwrap();
        // DashMap doesn't guarantee order, so just check we get one of the workers
        assert!(url == "http://worker1:8080" || url == "http://worker2:8080");
1150
    }
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162

    #[test]
    fn test_select_first_worker_with_unhealthy_worker() {
        let router = create_test_unhealthy_router();
        let result = router.select_first_worker();

        assert!(result.is_ok());
        let url = result.unwrap();

        let worker = router.worker_registry.get_by_url(&url).unwrap();
        assert!(worker.is_healthy());
    }
1163
}