router.rs 41 KB
Newer Older
1
use crate::config::types::RetryConfig;
2
use crate::core::{
3
    is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType,
4
};
5
use crate::metrics::RouterMetrics;
6
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
7
use crate::protocols::spec::{
8
9
    ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, GenerationRequest,
    RerankRequest, RerankResponse, RerankResult, ResponsesRequest,
10
};
11
use crate::routers::header_utils;
12
use crate::routers::RouterTrait;
13
use axum::body::to_bytes;
14
15
16
use axum::{
    body::Body,
    extract::Request,
17
18
19
    http::{
        header::CONTENT_LENGTH, header::CONTENT_TYPE, HeaderMap, HeaderValue, Method, StatusCode,
    },
20
21
22
23
    response::{IntoResponse, Response},
    Json,
};
use futures_util::StreamExt;
24
use reqwest::Client;
25
use std::collections::HashMap;
26
use std::sync::Arc;
27
use std::time::{Duration, Instant};
28
use tokio_stream::wrappers::UnboundedReceiverStream;
29
use tracing::{debug, error};
30
31
32
33

/// Regular router that uses injected load balancing policies
#[derive(Debug)]
pub struct Router {
34
35
    worker_registry: Arc<WorkerRegistry>,
    policy_registry: Arc<PolicyRegistry>,
36
    client: Client,
37
    dp_aware: bool,
38
    retry_config: RetryConfig,
39
40
41
42
43
    _worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
    _load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
}

impl Router {
44
    /// Create a new router with injected policy and client
45
46
47
48
49
50
51
    pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
        let workers = ctx.worker_registry.get_workers_filtered(
            None, // any model
            Some(WorkerType::Regular),
            Some(ConnectionMode::Http),
            false, // include all workers
        );
52

53
        RouterMetrics::set_active_workers(workers.len());
54

55
        let worker_urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
56

57
58
59
        let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
        let worker_loads = Arc::new(rx);

60
61
        let default_policy = ctx.policy_registry.get_default_policy();

62
        let load_monitor_handle = if default_policy.name() == "power_of_two" {
63
            let monitor_urls = worker_urls.clone();
64
65
66
67
68
69
70
71
            let monitor_api_keys = monitor_urls
                .iter()
                .map(|url| {
                    ctx.worker_registry
                        .get_by_url(url)
                        .and_then(|w| w.api_key().clone())
                })
                .collect::<Vec<Option<String>>>();
72
            let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
73
            let policy_clone = default_policy.clone();
74
            let client_clone = ctx.client.clone();
75
76

            Some(Arc::new(tokio::spawn(async move {
77
78
                Self::monitor_worker_loads(
                    monitor_urls,
79
                    monitor_api_keys,
80
81
82
83
84
85
                    tx,
                    monitor_interval,
                    policy_clone,
                    client_clone,
                )
                .await;
86
87
88
89
90
91
            })))
        } else {
            None
        };

        Ok(Router {
92
93
            worker_registry: ctx.worker_registry.clone(),
            policy_registry: ctx.policy_registry.clone(),
94
95
96
            client: ctx.client.clone(),
            dp_aware: ctx.router_config.dp_aware,
            retry_config: ctx.router_config.effective_retry_config(),
97
98
99
100
101
102
            _worker_loads: worker_loads,
            _load_monitor_handle: load_monitor_handle,
        })
    }

    fn select_first_worker(&self) -> Result<String, String> {
103
104
        let workers = self.worker_registry.get_all();
        if workers.is_empty() {
105
106
            Err("No workers are available".to_string())
        } else {
107
108
109
110
            Ok(workers[0].url().to_string())
        }
    }

111
    // Helper method to proxy GET requests to the first available worker
112
    async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response {
113
        let headers = header_utils::copy_request_headers(&req);
114
115
116

        match self.select_first_worker() {
            Ok(worker_url) => {
117
                let mut request_builder = self.client.get(format!("{}/{}", worker_url, endpoint));
118
                for (name, value) in headers {
119
120
                    let name_lc = name.to_lowercase();
                    if name_lc != "content-type" && name_lc != "content-length" {
121
122
123
                        request_builder = request_builder.header(name, value);
                    }
                }
124

125
126
127
128
                match request_builder.send().await {
                    Ok(res) => {
                        let status = StatusCode::from_u16(res.status().as_u16())
                            .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
129
130
131
132
133

                        // Preserve headers from backend
                        let response_headers =
                            header_utils::preserve_response_headers(res.headers());

134
                        match res.bytes().await {
135
136
137
138
139
140
                            Ok(body) => {
                                let mut response = Response::new(axum::body::Body::from(body));
                                *response.status_mut() = status;
                                *response.headers_mut() = response_headers;
                                response
                            }
141
142
143
144
145
                            Err(e) => (
                                StatusCode::INTERNAL_SERVER_ERROR,
                                format!("Failed to read response: {}", e),
                            )
                                .into_response(),
146
147
                        }
                    }
148
149
150
151
152
                    Err(e) => (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        format!("Request failed: {}", e),
                    )
                        .into_response(),
153
154
                }
            }
155
            Err(e) => (StatusCode::SERVICE_UNAVAILABLE, e).into_response(),
156
157
158
        }
    }

159
160
161
162
163
164
    /// Select worker for a specific model considering circuit breaker state
    fn select_worker_for_model(
        &self,
        model_id: Option<&str>,
        text: Option<&str>,
    ) -> Option<Arc<dyn Worker>> {
165
166
167
168
169
170
171
        // Get workers for the specified model O(1), filtered by connection mode
        let workers = self.worker_registry.get_workers_filtered(
            model_id,
            Some(WorkerType::Regular),
            Some(ConnectionMode::Http),
            false, // get all workers, we'll filter by is_available() next
        );
172
173

        let available: Vec<Arc<dyn Worker>> = workers
174
175
            .iter()
            .filter(|w| w.is_available())
176
            .cloned()
177
178
179
180
            .collect();
        if available.is_empty() {
            return None;
        }
181
182
183
184
185
186
187
188
189

        // Get the appropriate policy for this model
        let policy = match model_id {
            Some(model) => self.policy_registry.get_policy_or_default(model),
            None => self.policy_registry.get_default_policy(),
        };

        let idx = policy.select_worker(&available, text)?;
        Some(available[idx].clone())
190
191
    }

192
    pub async fn route_typed_request<T: GenerationRequest + serde::Serialize + Clone>(
193
        &self,
194
        headers: Option<&HeaderMap>,
195
196
        typed_req: &T,
        route: &str,
197
        model_id: Option<&str>,
198
    ) -> Response {
199
        let start = Instant::now();
200
201
202
203
204
205
206
        let is_stream = typed_req.is_stream();
        let text = typed_req.extract_text_for_routing();

        let response = RetryExecutor::execute_response_with_retry(
            &self.retry_config,
            // operation per attempt
            |_: u32| async {
207
                let worker = match self.select_worker_for_model(model_id, Some(&text)) {
208
209
210
211
212
213
214
215
216
217
                    Some(w) => w,
                    None => {
                        RouterMetrics::record_request_error(route, "no_available_workers");
                        return (
                            StatusCode::SERVICE_UNAVAILABLE,
                            "No available workers (all circuits open or unhealthy)",
                        )
                            .into_response();
                    }
                };
218

219
                // Optional load tracking for cache-aware policy
220
221
222
223
224
225
226
                // Get the policy for this model to check if it's cache-aware
                let policy = match model_id {
                    Some(model) => self.policy_registry.get_policy_or_default(model),
                    None => self.policy_registry.get_default_policy(),
                };

                let load_incremented = if policy.name() == "cache_aware" {
227
228
229
                    worker.increment_load();
                    RouterMetrics::set_running_requests(worker.url(), worker.load());
                    true
230
231
232
233
                } else {
                    false
                };

234
235
                // Keep a clone for potential cleanup on retry
                let worker_for_cleanup = if load_incremented {
236
                    Some(worker.clone())
237
238
239
240
                } else {
                    None
                };

241
242
                let response = self
                    .send_typed_request(
243
                        headers,
244
245
                        typed_req,
                        route,
246
                        worker.url(),
247
248
249
250
251
                        is_stream,
                        load_incremented,
                    )
                    .await;

252
                worker.record_outcome(response.status().is_success());
253
254
255
256
257
258
259
260
261
262
263
264
265

                // For retryable failures, we need to decrement load since send_typed_request
                // won't have done it (it only decrements on success or non-retryable failures)
                if is_retryable_status(response.status()) && load_incremented {
                    if let Some(cleanup_worker) = worker_for_cleanup {
                        cleanup_worker.decrement_load();
                        RouterMetrics::set_running_requests(
                            cleanup_worker.url(),
                            cleanup_worker.load(),
                        );
                    }
                }

266
267
268
                response
            },
            // should_retry predicate
269
            |res, _attempt| is_retryable_status(res.status()),
270
271
272
273
274
275
276
            // on_backoff hook
            |delay, attempt| {
                RouterMetrics::record_retry(route);
                RouterMetrics::record_retry_backoff_duration(delay, attempt);
            },
            // on_exhausted hook
            || RouterMetrics::record_retries_exhausted(route),
277
        )
278
279
280
281
282
283
        .await;

        if response.status().is_success() {
            let duration = start.elapsed();
            RouterMetrics::record_request(route);
            RouterMetrics::record_generate_duration(duration);
284
        } else if !is_retryable_status(response.status()) {
285
            RouterMetrics::record_request_error(route, "non_retryable_error");
286
        }
287
288

        response
289
290
    }

291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
    // Helper: return base worker URL (strips DP suffix when enabled)
    fn worker_base_url(&self, worker_url: &str) -> String {
        if self.dp_aware {
            if let Ok((prefix, _)) = Self::extract_dp_rank(worker_url) {
                return prefix.to_string();
            }
        }
        worker_url.to_string()
    }

    // Generic simple routing for GET/POST without JSON body
    async fn route_simple_request(
        &self,
        headers: Option<&HeaderMap>,
        endpoint: &str,
        method: Method,
    ) -> Response {
        // TODO: currently the sglang worker is using in-memory state management, so this implementation has to fan out to all workers.
        // Eventually, we need to have router to manage the chat history with a proper database, will update this implementation accordingly.
310
311
        let workers = self.worker_registry.get_all();
        if workers.is_empty() {
312
313
314
315
            return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response();
        }

        let mut last_response: Option<Response> = None;
316
317
318
        for worker in workers {
            let worker_url = worker.url();
            let base = self.worker_base_url(worker_url);
319
320
321
322
323
324
325
326
327
328
329
330
331
332

            let url = format!("{}/{}", base, endpoint);
            let mut request_builder = match method {
                Method::GET => self.client.get(url),
                Method::POST => self.client.post(url),
                _ => {
                    return (
                        StatusCode::METHOD_NOT_ALLOWED,
                        "Unsupported method for simple routing",
                    )
                        .into_response()
                }
            };

333
334
335
336
337
            if let Some(api_key) = worker.api_key() {
                request_builder =
                    request_builder.header("Authorization", format!("Bearer {}", api_key));
            }

338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
            if let Some(hdrs) = headers {
                for (name, value) in hdrs {
                    let name_lc = name.as_str().to_lowercase();
                    if name_lc != "content-type" && name_lc != "content-length" {
                        request_builder = request_builder.header(name, value);
                    }
                }
            }

            match request_builder.send().await {
                Ok(res) => {
                    let status = StatusCode::from_u16(res.status().as_u16())
                        .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
                    let response_headers = header_utils::preserve_response_headers(res.headers());
                    match res.bytes().await {
                        Ok(body) => {
                            let mut response = Response::new(axum::body::Body::from(body));
                            *response.status_mut() = status;
                            *response.headers_mut() = response_headers;
                            if status.is_success() {
                                return response;
                            }
                            last_response = Some(response);
                        }
                        Err(e) => {
                            last_response = Some(
                                (
                                    StatusCode::INTERNAL_SERVER_ERROR,
                                    format!("Failed to read response: {}", e),
                                )
                                    .into_response(),
                            );
                        }
                    }
                }
                Err(e) => {
                    last_response = Some(
                        (
                            StatusCode::INTERNAL_SERVER_ERROR,
                            format!("Request failed: {}", e),
                        )
                            .into_response(),
                    );
                }
            }
        }

        last_response
            .unwrap_or_else(|| (StatusCode::BAD_GATEWAY, "No worker response").into_response())
    }

    // Route a GET request with provided headers to a specific endpoint
    async fn route_get_request(&self, headers: Option<&HeaderMap>, endpoint: &str) -> Response {
        self.route_simple_request(headers, endpoint, Method::GET)
            .await
    }

    // Route a POST request with empty body to a specific endpoint
    async fn route_post_empty_request(
        &self,
        headers: Option<&HeaderMap>,
        endpoint: &str,
    ) -> Response {
        self.route_simple_request(headers, endpoint, Method::POST)
            .await
    }

405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
    // TODO (rui): Better accommodate to the Worker abstraction
    fn extract_dp_rank(worker_url: &str) -> Result<(&str, usize), String> {
        let parts: Vec<&str> = worker_url.split('@').collect();
        if parts.len() != 2 {
            return Err(format!("invalid worker_url format: {}", worker_url));
        }

        // Parse the second part (dp_rank) into an integer
        match parts[1].parse::<usize>() {
            Ok(dp_rank) => Ok((parts[0], dp_rank)),
            Err(_) => Err(format!(
                "failed to parse dp_rank from worker_url: {}",
                worker_url
            )),
        }
    }

422
423
424
    // Send typed request directly without conversion
    async fn send_typed_request<T: serde::Serialize>(
        &self,
425
        headers: Option<&HeaderMap>,
426
427
428
429
430
        typed_req: &T,
        route: &str,
        worker_url: &str,
        is_stream: bool,
        load_incremented: bool, // Whether load was incremented for this request
431
    ) -> Response {
432
433
434
435
436
437
        // Get the worker's API key if available
        let api_key = self
            .worker_registry
            .get_by_url(worker_url)
            .and_then(|w| w.api_key().clone());

438
439
440
441
442
        let mut request_builder = if self.dp_aware {
            let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) {
                Ok(tup) => tup,
                Err(e) => {
                    error!("Failed to extract dp_rank: {}", e);
443
444
445
446
447
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        format!("Failed to extract dp_rank: {}", e),
                    )
                        .into_response();
448
449
450
451
452
453
                }
            };

            let mut json_val = match serde_json::to_value(typed_req) {
                Ok(j) => j,
                Err(e) => {
454
455
456
457
458
                    return (
                        StatusCode::BAD_REQUEST,
                        format!("Convert into serde_json::Value failed: {}", e),
                    )
                        .into_response();
459
460
461
462
463
464
465
466
467
468
469
470
471
                }
            };

            if let Some(map) = json_val.as_object_mut() {
                map.insert(
                    String::from("data_parallel_rank"),
                    serde_json::json!(dp_rank),
                );
                debug!(
                    "Modified request body: {}",
                    serde_json::to_string(&json_val).unwrap_or(String::from("ERR"))
                );
            } else {
472
473
474
475
476
                return (
                    StatusCode::BAD_REQUEST,
                    "Failed to insert the data_parallel_rank field into the request body",
                )
                    .into_response();
477
478
            }

479
            self.client
480
481
482
                .post(format!("{}{}", worker_url_prefix, route))
                .json(&json_val)
        } else {
483
            self.client
484
485
486
                .post(format!("{}{}", worker_url, route))
                .json(typed_req) // Use json() directly with typed request
        };
487

488
489
490
491
        if let Some(key) = api_key {
            request_builder = request_builder.header("Authorization", format!("Bearer {}", key));
        }

492
493
494
495
        // Copy all headers from original request if provided
        if let Some(headers) = headers {
            for (name, value) in headers {
                // Skip Content-Type and Content-Length as .json() sets them
496
                if *name != CONTENT_TYPE && *name != CONTENT_LENGTH {
497
498
                    request_builder = request_builder.header(name, value);
                }
499
500
501
502
503
504
            }
        }

        let res = match request_builder.send().await {
            Ok(res) => res,
            Err(e) => {
505
506
507
508
                error!(
                    "Failed to send typed request worker_url={} route={} error={}",
                    worker_url, route, e
                );
509
510
511

                // Decrement load on error if it was incremented
                if load_incremented {
512
513
514
                    if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                        worker.decrement_load();
                        RouterMetrics::set_running_requests(worker_url, worker.load());
515
516
517
                    }
                }

518
519
520
521
522
                return (
                    StatusCode::INTERNAL_SERVER_ERROR,
                    format!("Request failed: {}", e),
                )
                    .into_response();
523
524
525
            }
        };

526
527
        let status = StatusCode::from_u16(res.status().as_u16())
            .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
528
529

        if !is_stream {
530
            // For non-streaming requests, preserve headers
531
            let response_headers = header_utils::preserve_response_headers(res.headers());
532

533
            let response = match res.bytes().await {
534
535
536
537
538
539
                Ok(body) => {
                    let mut response = Response::new(axum::body::Body::from(body));
                    *response.status_mut() = status;
                    *response.headers_mut() = response_headers;
                    response
                }
540
                Err(e) => {
541
542
                    // IMPORTANT: Decrement load on error before returning
                    if load_incremented {
543
544
545
                        if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                            worker.decrement_load();
                            RouterMetrics::set_running_requests(worker_url, worker.load());
546
547
548
                        }
                    }

549
                    let error_msg = format!("Failed to get response body: {}", e);
550
                    (StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response()
551
552
553
554
                }
            };

            // Decrement load counter for non-streaming requests if it was incremented
555
            if load_incremented {
556
557
558
                if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                    worker.decrement_load();
                    RouterMetrics::set_running_requests(worker_url, worker.load());
559
560
561
562
563
564
                }
            }

            response
        } else if load_incremented {
            // For streaming with load tracking, we need to manually decrement when done
565
            let registry = Arc::clone(&self.worker_registry);
566
567
            let worker_url = worker_url.to_string();

568
569
570
571
572
            // Preserve headers for streaming response
            let mut response_headers = header_utils::preserve_response_headers(res.headers());
            // Ensure we set the correct content-type for SSE
            response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));

573
574
575
576
577
578
            let stream = res.bytes_stream();
            let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

            // Spawn task to forward stream and detect completion
            tokio::spawn(async move {
                let mut stream = stream;
579
                let mut decremented = false;
580
581
582
583
584
585
586
587
588
                while let Some(chunk) = stream.next().await {
                    match chunk {
                        Ok(bytes) => {
                            // Check for stream end marker
                            if bytes
                                .as_ref()
                                .windows(12)
                                .any(|window| window == b"data: [DONE]")
                            {
589
590
591
592
                                if let Some(worker) = registry.get_by_url(&worker_url) {
                                    worker.decrement_load();
                                    RouterMetrics::set_running_requests(&worker_url, worker.load());
                                    decremented = true;
593
594
                                }
                            }
595
596
597
598
599
600
601
602
603
604
                            if tx.send(Ok(bytes)).is_err() {
                                break;
                            }
                        }
                        Err(e) => {
                            let _ = tx.send(Err(format!("Stream error: {}", e)));
                            break;
                        }
                    }
                }
605
                if !decremented {
606
607
608
                    if let Some(worker) = registry.get_by_url(&worker_url) {
                        worker.decrement_load();
                        RouterMetrics::set_running_requests(&worker_url, worker.load());
609
610
                    }
                }
611
612
613
614
615
616
617
            });

            let stream = UnboundedReceiverStream::new(rx);
            let body = Body::from_stream(stream);

            let mut response = Response::new(body);
            *response.status_mut() = status;
618
            *response.headers_mut() = response_headers;
619
            response
620
621
        } else {
            // For requests without load tracking, just stream
622
623
624
625
626
            // Preserve headers for streaming response
            let mut response_headers = header_utils::preserve_response_headers(res.headers());
            // Ensure we set the correct content-type for SSE
            response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));

627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
            let stream = res.bytes_stream();
            let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

            // Spawn task to forward stream
            tokio::spawn(async move {
                let mut stream = stream;
                while let Some(chunk) = stream.next().await {
                    match chunk {
                        Ok(bytes) => {
                            if tx.send(Ok(bytes)).is_err() {
                                break;
                            }
                        }
                        Err(e) => {
                            let _ = tx.send(Err(format!("Stream error: {}", e)));
                            break;
                        }
                    }
                }
            });

            let stream = UnboundedReceiverStream::new(rx);
            let body = Body::from_stream(stream);

            let mut response = Response::new(body);
            *response.status_mut() = status;
653
            *response.headers_mut() = response_headers;
654
            response
655
656
657
        }
    }

658
    async fn get_worker_load(&self, worker_url: &str, api_key: &Option<String>) -> Option<isize> {
659
660
661
662
663
664
665
666
667
668
669
670
671
672
        let worker_url = if self.dp_aware {
            // Need to extract the URL from "http://host:port@dp_rank"
            let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
                Ok(tup) => tup,
                Err(e) => {
                    error!("Failed to extract dp_rank: {}", e);
                    return None;
                }
            };
            worker_url_prefix
        } else {
            worker_url
        };

673
674
675
676
677
678
        let mut req_builder = self.client.get(format!("{}/get_load", worker_url));
        if let Some(key) = api_key {
            req_builder = req_builder.bearer_auth(key);
        }

        match req_builder.send().await {
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
            Ok(res) if res.status().is_success() => match res.bytes().await {
                Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
                    Ok(data) => data
                        .get("load")
                        .and_then(|v| v.as_i64())
                        .map(|v| v as isize),
                    Err(e) => {
                        debug!("Failed to parse load response from {}: {}", worker_url, e);
                        None
                    }
                },
                Err(e) => {
                    debug!("Failed to read load response from {}: {}", worker_url, e);
                    None
                }
            },
            Ok(res) => {
                debug!(
                    "Worker {} returned non-success status: {}",
                    worker_url,
                    res.status()
                );
                None
            }
            Err(e) => {
                debug!("Failed to get load from {}: {}", worker_url, e);
                None
            }
        }
    }

    // Background task to monitor worker loads
    async fn monitor_worker_loads(
        worker_urls: Vec<String>,
713
        worker_api_keys: Vec<Option<String>>,
714
715
716
        tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
        interval_secs: u64,
        policy: Arc<dyn LoadBalancingPolicy>,
717
        client: Client,
718
719
720
721
722
723
724
    ) {
        let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));

        loop {
            interval.tick().await;

            let mut loads = HashMap::new();
725
726
            for (url, api_key) in worker_urls.iter().zip(worker_api_keys.iter()) {
                if let Some(load) = Self::get_worker_load_static(&client, url, api_key).await {
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
                    loads.insert(url.clone(), load);
                }
            }

            if !loads.is_empty() {
                // Update policy with new loads
                policy.update_loads(&loads);

                // Send to watchers
                if let Err(e) = tx.send(loads) {
                    error!("Failed to send load update: {}", e);
                }
            }
        }
    }

    // Static version of get_worker_load for use in monitoring task
744
    async fn get_worker_load_static(
745
        client: &Client,
746
747
748
        worker_url: &str,
        api_key: &Option<String>,
    ) -> Option<isize> {
749
750
751
752
753
754
755
756
757
758
759
760
761
762
        let worker_url = if worker_url.contains("@") {
            // Need to extract the URL from "http://host:port@dp_rank"
            let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
                Ok(tup) => tup,
                Err(e) => {
                    debug!("Failed to extract dp_rank: {}", e);
                    return None;
                }
            };
            worker_url_prefix
        } else {
            worker_url
        };

763
764
765
766
767
        let mut req_builder = client.get(format!("{}/get_load", worker_url));
        if let Some(key) = api_key {
            req_builder = req_builder.bearer_auth(key);
        }
        match req_builder.send().await {
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
            Ok(res) if res.status().is_success() => match res.bytes().await {
                Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
                    Ok(data) => data
                        .get("load")
                        .and_then(|v| v.as_i64())
                        .map(|v| v as isize),
                    Err(e) => {
                        debug!("Failed to parse load response from {}: {}", worker_url, e);
                        None
                    }
                },
                Err(e) => {
                    debug!("Failed to read load response from {}: {}", worker_url, e);
                    None
                }
            },
            Ok(res) => {
                debug!(
                    "Worker {} returned non-success status: {}",
                    worker_url,
                    res.status()
                );
                None
            }
            Err(e) => {
                debug!("Failed to get load from {}: {}", worker_url, e);
                None
            }
        }
    }
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816

    async fn build_rerank_response(
        req: &RerankRequest,
        response: Response,
    ) -> anyhow::Result<Response> {
        let (_, response_body) = response.into_parts();
        let body_bytes = to_bytes(response_body, usize::MAX).await?;
        let rerank_results = serde_json::from_slice::<Vec<RerankResult>>(&body_bytes)?;
        let mut rerank_response =
            RerankResponse::new(rerank_results, req.model.clone(), req.rid.clone());
        rerank_response.sort_by_score();
        if let Some(top_k) = req.top_k {
            rerank_response.apply_top_k(top_k);
        }
        if !req.return_documents {
            rerank_response.drop_documents();
        }
        Ok(Json(rerank_response).into_response())
    }
817
818
819
820
}

use async_trait::async_trait;

821
#[async_trait]
822
823
824
825
826
impl RouterTrait for Router {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

827
    async fn health(&self, _req: Request<Body>) -> Response {
828
        let workers = self.worker_registry.get_all();
829
830
831
832
833
        let unhealthy_servers: Vec<_> = workers
            .iter()
            .filter(|w| !w.is_healthy())
            .map(|w| w.url().to_string())
            .collect();
834

835
836
        if unhealthy_servers.is_empty() {
            (StatusCode::OK, "All servers healthy").into_response()
837
        } else {
838
839
840
841
842
            (
                StatusCode::SERVICE_UNAVAILABLE,
                format!("Unhealthy servers: {:?}", unhealthy_servers),
            )
                .into_response()
843
844
845
        }
    }

846
847
    async fn health_generate(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "health_generate").await
848
849
    }

850
851
    async fn get_server_info(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "get_server_info").await
852
853
    }

854
855
    async fn get_models(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "v1/models").await
856
857
    }

858
859
    async fn get_model_info(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "get_model_info").await
860
861
862
863
    }

    async fn route_generate(
        &self,
864
865
        headers: Option<&HeaderMap>,
        body: &GenerateRequest,
866
        model_id: Option<&str>,
867
    ) -> Response {
868
869
        self.route_typed_request(headers, body, "/generate", model_id)
            .await
870
871
872
873
    }

    async fn route_chat(
        &self,
874
875
        headers: Option<&HeaderMap>,
        body: &ChatCompletionRequest,
876
        model_id: Option<&str>,
877
    ) -> Response {
878
        self.route_typed_request(headers, body, "/v1/chat/completions", model_id)
879
            .await
880
881
882
883
    }

    async fn route_completion(
        &self,
884
885
        headers: Option<&HeaderMap>,
        body: &CompletionRequest,
886
        model_id: Option<&str>,
887
    ) -> Response {
888
        self.route_typed_request(headers, body, "/v1/completions", model_id)
889
            .await
890
891
    }

892
893
894
895
    async fn route_responses(
        &self,
        headers: Option<&HeaderMap>,
        body: &ResponsesRequest,
896
        model_id: Option<&str>,
897
    ) -> Response {
898
        self.route_typed_request(headers, body, "/v1/responses", model_id)
899
900
901
            .await
    }

902
903
904
905
906
907
908
909
910
911
    async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
        let endpoint = format!("v1/responses/{}", response_id);
        self.route_get_request(headers, &endpoint).await
    }

    async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
        let endpoint = format!("v1/responses/{}/cancel", response_id);
        self.route_post_empty_request(headers, &endpoint).await
    }

912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
    async fn route_embeddings(
        &self,
        headers: Option<&HeaderMap>,
        body: &EmbeddingRequest,
        model_id: Option<&str>,
    ) -> Response {
        // Record embeddings-specific metrics in addition to general request metrics
        let start = Instant::now();
        let res = self
            .route_typed_request(headers, body, "/v1/embeddings", model_id)
            .await;

        // Embedding specific metrics
        if res.status().is_success() {
            RouterMetrics::record_embeddings_request();
            RouterMetrics::record_embeddings_duration(start.elapsed());
        } else {
            let error_type = format!("http_{}", res.status().as_u16());
            RouterMetrics::record_embeddings_error(&error_type);
        }

        res
934
935
    }

936
937
938
939
940
941
    async fn route_rerank(
        &self,
        headers: Option<&HeaderMap>,
        body: &RerankRequest,
        model_id: Option<&str>,
    ) -> Response {
942
943
944
        if let Err(e) = body.validate() {
            return (StatusCode::BAD_REQUEST, e).into_response();
        }
945
946
947
        let response = self
            .route_typed_request(headers, body, "/v1/rerank", model_id)
            .await;
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
        if response.status().is_success() {
            match Self::build_rerank_response(body, response).await {
                Ok(rerank_response) => rerank_response,
                Err(e) => {
                    error!("Failed to build rerank response: {}", e);
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        "Failed to build rerank response".to_string(),
                    )
                        .into_response();
                }
            }
        } else {
            response
        }
963
964
    }

965
    async fn flush_cache(&self) -> Response {
966
967
968
        // Get all workers
        let workers = self.worker_registry.get_all();
        let worker_urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
969
970
971
972

        // Send requests to all workers concurrently without headers
        let mut tasks = Vec::new();
        for worker_url in &worker_urls {
973
974
975
976
977
978
            // Get the worker's API key if available
            let api_key = self
                .worker_registry
                .get_by_url(worker_url)
                .and_then(|w| w.api_key().clone());

979
980
981
982
983
984
            let worker_url = if self.dp_aware {
                // Need to extract the URL from "http://host:port@dp_rank"
                let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
                    Ok(tup) => tup,
                    Err(e) => {
                        error!("Failed to extract dp_rank: {}", e);
985
986
987
988
989
                        return (
                            StatusCode::INTERNAL_SERVER_ERROR,
                            format!("Failed to extract dp_rank: {}", e),
                        )
                            .into_response();
990
991
992
993
994
995
                    }
                };
                worker_url_prefix
            } else {
                worker_url
            };
996
997
998
999
1000
1001
1002
            let mut request_builder = self.client.post(format!("{}/flush_cache", worker_url));

            if let Some(key) = api_key {
                request_builder =
                    request_builder.header("Authorization", format!("Bearer {}", key));
            }

1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
            tasks.push(request_builder.send());
        }

        // Wait for all responses
        let results = futures_util::future::join_all(tasks).await;

        // Check if all succeeded
        let all_success = results.iter().all(|r| {
            r.as_ref()
                .map(|res| res.status().is_success())
                .unwrap_or(false)
        });

        if all_success {
1017
            (StatusCode::OK, "Cache flushed on all servers").into_response()
1018
        } else {
1019
1020
1021
1022
1023
            (
                StatusCode::INTERNAL_SERVER_ERROR,
                "Cache flush failed on one or more servers",
            )
                .into_response()
1024
1025
1026
        }
    }

1027
    async fn get_worker_loads(&self) -> Response {
1028
        let urls_with_key = self.worker_registry.get_all_urls_with_api_key();
1029
1030
1031
        let mut loads = Vec::new();

        // Get loads from all workers
1032
1033
        for (url, api_key) in &urls_with_key {
            let load = self.get_worker_load(url, api_key).await.unwrap_or(-1);
1034
1035
1036
1037
1038
1039
            loads.push(serde_json::json!({
                "worker": url,
                "load": load
            }));
        }

1040
        Json(serde_json::json!({
1041
1042
            "workers": loads
        }))
1043
        .into_response()
1044
1045
1046
1047
1048
1049
    }

    fn router_type(&self) -> &'static str {
        "regular"
    }

1050
    fn readiness(&self) -> Response {
1051
        // Regular router is ready if it has at least one healthy worker
1052
1053
1054
        let workers = self.worker_registry.get_all();
        let healthy_count = workers.iter().filter(|w| w.is_healthy()).count();
        let total_workers = workers.len();
1055
1056

        if healthy_count > 0 {
1057
            Json(serde_json::json!({
1058
1059
                "status": "ready",
                "healthy_workers": healthy_count,
1060
                "total_workers": total_workers
1061
            }))
1062
            .into_response()
1063
        } else {
1064
1065
1066
1067
1068
            (
                StatusCode::SERVICE_UNAVAILABLE,
                Json(serde_json::json!({
                    "status": "not_ready",
                    "reason": "no healthy workers available",
1069
                    "total_workers": total_workers
1070
1071
1072
                })),
            )
                .into_response()
1073
1074
1075
1076
1077
1078
1079
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
1080
    use crate::core::BasicWorkerBuilder;
1081
1082
1083
    use std::collections::HashMap;

    fn create_test_regular_router() -> Router {
1084
1085
1086
1087
1088
1089
1090
        // Create registries
        let worker_registry = Arc::new(WorkerRegistry::new());
        let policy_registry = Arc::new(PolicyRegistry::new(
            crate::config::types::PolicyConfig::RoundRobin,
        ));

        // Register test workers
1091
1092
1093
1094
1095
1096
        let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
            .worker_type(WorkerType::Regular)
            .build();
        let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
            .worker_type(WorkerType::Regular)
            .build();
1097
1098
1099
        worker_registry.register(Arc::new(worker1));
        worker_registry.register(Arc::new(worker2));

1100
1101
        let (_, rx) = tokio::sync::watch::channel(HashMap::new());
        Router {
1102
1103
            worker_registry,
            policy_registry,
1104
            dp_aware: false,
1105
            client: Client::new(),
1106
            retry_config: RetryConfig::default(),
1107
1108
1109
1110
1111
1112
1113
1114
            _worker_loads: Arc::new(rx),
            _load_monitor_handle: None,
        }
    }

    #[test]
    fn test_router_get_worker_urls_regular() {
        let router = create_test_regular_router();
1115
1116
        let workers = router.worker_registry.get_all();
        let urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128

        assert_eq!(urls.len(), 2);
        assert!(urls.contains(&"http://worker1:8080".to_string()));
        assert!(urls.contains(&"http://worker2:8080".to_string()));
    }

    #[test]
    fn test_select_first_worker_regular() {
        let router = create_test_regular_router();
        let result = router.select_first_worker();

        assert!(result.is_ok());
1129
1130
1131
        let url = result.unwrap();
        // DashMap doesn't guarantee order, so just check we get one of the workers
        assert!(url == "http://worker1:8080" || url == "http://worker2:8080");
1132
1133
    }
}