"tasks/vision/segmentation/utils.py" did not exist on "53f3efc45d65f51d08e763d8f73c5283b6015fd0"
router.rs 41.2 KB
Newer Older
1
use crate::config::types::RetryConfig;
2
use crate::core::{
3
    is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType,
4
};
5
use crate::metrics::RouterMetrics;
6
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
7
use crate::protocols::spec::{
8
9
    ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, GenerationRequest,
    RerankRequest, RerankResponse, RerankResult, ResponsesRequest,
10
};
11
use crate::routers::header_utils;
12
use crate::routers::RouterTrait;
13
use axum::body::to_bytes;
14
15
16
use axum::{
    body::Body,
    extract::Request,
17
18
19
    http::{
        header::CONTENT_LENGTH, header::CONTENT_TYPE, HeaderMap, HeaderValue, Method, StatusCode,
    },
20
21
22
23
    response::{IntoResponse, Response},
    Json,
};
use futures_util::StreamExt;
24
use reqwest::Client;
25
use std::collections::HashMap;
26
use std::sync::Arc;
27
use std::time::{Duration, Instant};
28
use tokio_stream::wrappers::UnboundedReceiverStream;
29
use tracing::{debug, error};
30
31
32
33

/// Regular router that uses injected load balancing policies
#[derive(Debug)]
pub struct Router {
34
35
    worker_registry: Arc<WorkerRegistry>,
    policy_registry: Arc<PolicyRegistry>,
36
    client: Client,
37
    dp_aware: bool,
38
    enable_igw: bool,
39
    retry_config: RetryConfig,
40
41
42
43
44
    _worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
    _load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
}

impl Router {
45
    /// Create a new router with injected policy and client
46
47
48
49
50
51
52
    pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
        let workers = ctx.worker_registry.get_workers_filtered(
            None, // any model
            Some(WorkerType::Regular),
            Some(ConnectionMode::Http),
            false, // include all workers
        );
53

54
        RouterMetrics::set_active_workers(workers.len());
55

56
        let worker_urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
57

58
59
60
        let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
        let worker_loads = Arc::new(rx);

61
62
        let default_policy = ctx.policy_registry.get_default_policy();

63
        let load_monitor_handle = if default_policy.name() == "power_of_two" {
64
            let monitor_urls = worker_urls.clone();
65
66
67
68
69
70
71
72
            let monitor_api_keys = monitor_urls
                .iter()
                .map(|url| {
                    ctx.worker_registry
                        .get_by_url(url)
                        .and_then(|w| w.api_key().clone())
                })
                .collect::<Vec<Option<String>>>();
73
            let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
74
            let policy_clone = default_policy.clone();
75
            let client_clone = ctx.client.clone();
76
77

            Some(Arc::new(tokio::spawn(async move {
78
79
                Self::monitor_worker_loads(
                    monitor_urls,
80
                    monitor_api_keys,
81
82
83
84
85
86
                    tx,
                    monitor_interval,
                    policy_clone,
                    client_clone,
                )
                .await;
87
88
89
90
91
92
            })))
        } else {
            None
        };

        Ok(Router {
93
94
            worker_registry: ctx.worker_registry.clone(),
            policy_registry: ctx.policy_registry.clone(),
95
96
            client: ctx.client.clone(),
            dp_aware: ctx.router_config.dp_aware,
97
            enable_igw: ctx.router_config.enable_igw,
98
            retry_config: ctx.router_config.effective_retry_config(),
99
100
101
102
103
104
            _worker_loads: worker_loads,
            _load_monitor_handle: load_monitor_handle,
        })
    }

    fn select_first_worker(&self) -> Result<String, String> {
105
106
        let workers = self.worker_registry.get_all();
        if workers.is_empty() {
107
108
            Err("No workers are available".to_string())
        } else {
109
110
111
112
            Ok(workers[0].url().to_string())
        }
    }

113
    // Helper method to proxy GET requests to the first available worker
114
    async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response {
115
        let headers = header_utils::copy_request_headers(&req);
116
117
118

        match self.select_first_worker() {
            Ok(worker_url) => {
119
                let mut request_builder = self.client.get(format!("{}/{}", worker_url, endpoint));
120
                for (name, value) in headers {
121
122
                    let name_lc = name.to_lowercase();
                    if name_lc != "content-type" && name_lc != "content-length" {
123
124
125
                        request_builder = request_builder.header(name, value);
                    }
                }
126

127
128
129
130
                match request_builder.send().await {
                    Ok(res) => {
                        let status = StatusCode::from_u16(res.status().as_u16())
                            .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
131
132
133
134
135

                        // Preserve headers from backend
                        let response_headers =
                            header_utils::preserve_response_headers(res.headers());

136
                        match res.bytes().await {
137
138
139
140
141
142
                            Ok(body) => {
                                let mut response = Response::new(axum::body::Body::from(body));
                                *response.status_mut() = status;
                                *response.headers_mut() = response_headers;
                                response
                            }
143
144
145
146
147
                            Err(e) => (
                                StatusCode::INTERNAL_SERVER_ERROR,
                                format!("Failed to read response: {}", e),
                            )
                                .into_response(),
148
149
                        }
                    }
150
151
152
153
154
                    Err(e) => (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        format!("Request failed: {}", e),
                    )
                        .into_response(),
155
156
                }
            }
157
            Err(e) => (StatusCode::SERVICE_UNAVAILABLE, e).into_response(),
158
159
160
        }
    }

161
162
163
164
165
166
    /// Select worker for a specific model considering circuit breaker state
    fn select_worker_for_model(
        &self,
        model_id: Option<&str>,
        text: Option<&str>,
    ) -> Option<Arc<dyn Worker>> {
167
168
        let effective_model_id = if !self.enable_igw { None } else { model_id };

169
170
        // Get workers for the specified model O(1), filtered by connection mode
        let workers = self.worker_registry.get_workers_filtered(
171
            effective_model_id,
172
173
174
175
            Some(WorkerType::Regular),
            Some(ConnectionMode::Http),
            false, // get all workers, we'll filter by is_available() next
        );
176
177

        let available: Vec<Arc<dyn Worker>> = workers
178
179
            .iter()
            .filter(|w| w.is_available())
180
            .cloned()
181
182
183
184
            .collect();
        if available.is_empty() {
            return None;
        }
185
186
187
188
189
190
191
192
193

        // Get the appropriate policy for this model
        let policy = match model_id {
            Some(model) => self.policy_registry.get_policy_or_default(model),
            None => self.policy_registry.get_default_policy(),
        };

        let idx = policy.select_worker(&available, text)?;
        Some(available[idx].clone())
194
195
    }

196
    pub async fn route_typed_request<T: GenerationRequest + serde::Serialize + Clone>(
197
        &self,
198
        headers: Option<&HeaderMap>,
199
200
        typed_req: &T,
        route: &str,
201
        model_id: Option<&str>,
202
    ) -> Response {
203
        let start = Instant::now();
204
205
206
207
208
209
210
        let is_stream = typed_req.is_stream();
        let text = typed_req.extract_text_for_routing();

        let response = RetryExecutor::execute_response_with_retry(
            &self.retry_config,
            // operation per attempt
            |_: u32| async {
211
                let worker = match self.select_worker_for_model(model_id, Some(&text)) {
212
213
214
215
216
217
218
219
220
221
                    Some(w) => w,
                    None => {
                        RouterMetrics::record_request_error(route, "no_available_workers");
                        return (
                            StatusCode::SERVICE_UNAVAILABLE,
                            "No available workers (all circuits open or unhealthy)",
                        )
                            .into_response();
                    }
                };
222

223
                // Optional load tracking for cache-aware policy
224
225
226
227
228
229
230
                // Get the policy for this model to check if it's cache-aware
                let policy = match model_id {
                    Some(model) => self.policy_registry.get_policy_or_default(model),
                    None => self.policy_registry.get_default_policy(),
                };

                let load_incremented = if policy.name() == "cache_aware" {
231
232
233
                    worker.increment_load();
                    RouterMetrics::set_running_requests(worker.url(), worker.load());
                    true
234
235
236
237
                } else {
                    false
                };

238
239
                // Keep a clone for potential cleanup on retry
                let worker_for_cleanup = if load_incremented {
240
                    Some(worker.clone())
241
242
243
244
                } else {
                    None
                };

245
246
                let response = self
                    .send_typed_request(
247
                        headers,
248
249
                        typed_req,
                        route,
250
                        worker.url(),
251
252
253
254
255
                        is_stream,
                        load_incremented,
                    )
                    .await;

256
                worker.record_outcome(response.status().is_success());
257
258
259
260
261
262
263
264
265
266
267
268
269

                // For retryable failures, we need to decrement load since send_typed_request
                // won't have done it (it only decrements on success or non-retryable failures)
                if is_retryable_status(response.status()) && load_incremented {
                    if let Some(cleanup_worker) = worker_for_cleanup {
                        cleanup_worker.decrement_load();
                        RouterMetrics::set_running_requests(
                            cleanup_worker.url(),
                            cleanup_worker.load(),
                        );
                    }
                }

270
271
272
                response
            },
            // should_retry predicate
273
            |res, _attempt| is_retryable_status(res.status()),
274
275
276
277
278
279
280
            // on_backoff hook
            |delay, attempt| {
                RouterMetrics::record_retry(route);
                RouterMetrics::record_retry_backoff_duration(delay, attempt);
            },
            // on_exhausted hook
            || RouterMetrics::record_retries_exhausted(route),
281
        )
282
283
284
285
286
287
        .await;

        if response.status().is_success() {
            let duration = start.elapsed();
            RouterMetrics::record_request(route);
            RouterMetrics::record_generate_duration(duration);
288
        } else if !is_retryable_status(response.status()) {
289
            RouterMetrics::record_request_error(route, "non_retryable_error");
290
        }
291
292

        response
293
294
    }

295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    // Helper: return base worker URL (strips DP suffix when enabled)
    fn worker_base_url(&self, worker_url: &str) -> String {
        if self.dp_aware {
            if let Ok((prefix, _)) = Self::extract_dp_rank(worker_url) {
                return prefix.to_string();
            }
        }
        worker_url.to_string()
    }

    // Generic simple routing for GET/POST without JSON body
    async fn route_simple_request(
        &self,
        headers: Option<&HeaderMap>,
        endpoint: &str,
        method: Method,
    ) -> Response {
        // TODO: currently the sglang worker is using in-memory state management, so this implementation has to fan out to all workers.
        // Eventually, we need to have router to manage the chat history with a proper database, will update this implementation accordingly.
314
315
        let workers = self.worker_registry.get_all();
        if workers.is_empty() {
316
317
318
319
            return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response();
        }

        let mut last_response: Option<Response> = None;
320
321
322
        for worker in workers {
            let worker_url = worker.url();
            let base = self.worker_base_url(worker_url);
323
324
325
326
327
328
329
330
331
332
333
334
335
336

            let url = format!("{}/{}", base, endpoint);
            let mut request_builder = match method {
                Method::GET => self.client.get(url),
                Method::POST => self.client.post(url),
                _ => {
                    return (
                        StatusCode::METHOD_NOT_ALLOWED,
                        "Unsupported method for simple routing",
                    )
                        .into_response()
                }
            };

337
338
339
340
341
            if let Some(api_key) = worker.api_key() {
                request_builder =
                    request_builder.header("Authorization", format!("Bearer {}", api_key));
            }

342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
            if let Some(hdrs) = headers {
                for (name, value) in hdrs {
                    let name_lc = name.as_str().to_lowercase();
                    if name_lc != "content-type" && name_lc != "content-length" {
                        request_builder = request_builder.header(name, value);
                    }
                }
            }

            match request_builder.send().await {
                Ok(res) => {
                    let status = StatusCode::from_u16(res.status().as_u16())
                        .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
                    let response_headers = header_utils::preserve_response_headers(res.headers());
                    match res.bytes().await {
                        Ok(body) => {
                            let mut response = Response::new(axum::body::Body::from(body));
                            *response.status_mut() = status;
                            *response.headers_mut() = response_headers;
                            if status.is_success() {
                                return response;
                            }
                            last_response = Some(response);
                        }
                        Err(e) => {
                            last_response = Some(
                                (
                                    StatusCode::INTERNAL_SERVER_ERROR,
                                    format!("Failed to read response: {}", e),
                                )
                                    .into_response(),
                            );
                        }
                    }
                }
                Err(e) => {
                    last_response = Some(
                        (
                            StatusCode::INTERNAL_SERVER_ERROR,
                            format!("Request failed: {}", e),
                        )
                            .into_response(),
                    );
                }
            }
        }

        last_response
            .unwrap_or_else(|| (StatusCode::BAD_GATEWAY, "No worker response").into_response())
    }

    // Route a GET request with provided headers to a specific endpoint
    async fn route_get_request(&self, headers: Option<&HeaderMap>, endpoint: &str) -> Response {
        self.route_simple_request(headers, endpoint, Method::GET)
            .await
    }

    // Route a POST request with empty body to a specific endpoint
    async fn route_post_empty_request(
        &self,
        headers: Option<&HeaderMap>,
        endpoint: &str,
    ) -> Response {
        self.route_simple_request(headers, endpoint, Method::POST)
            .await
    }

409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
    // TODO (rui): Better accommodate to the Worker abstraction
    fn extract_dp_rank(worker_url: &str) -> Result<(&str, usize), String> {
        let parts: Vec<&str> = worker_url.split('@').collect();
        if parts.len() != 2 {
            return Err(format!("invalid worker_url format: {}", worker_url));
        }

        // Parse the second part (dp_rank) into an integer
        match parts[1].parse::<usize>() {
            Ok(dp_rank) => Ok((parts[0], dp_rank)),
            Err(_) => Err(format!(
                "failed to parse dp_rank from worker_url: {}",
                worker_url
            )),
        }
    }

426
427
428
    // Send typed request directly without conversion
    async fn send_typed_request<T: serde::Serialize>(
        &self,
429
        headers: Option<&HeaderMap>,
430
431
432
433
434
        typed_req: &T,
        route: &str,
        worker_url: &str,
        is_stream: bool,
        load_incremented: bool, // Whether load was incremented for this request
435
    ) -> Response {
436
437
438
439
440
441
        // Get the worker's API key if available
        let api_key = self
            .worker_registry
            .get_by_url(worker_url)
            .and_then(|w| w.api_key().clone());

442
443
444
445
446
        let mut request_builder = if self.dp_aware {
            let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) {
                Ok(tup) => tup,
                Err(e) => {
                    error!("Failed to extract dp_rank: {}", e);
447
448
449
450
451
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        format!("Failed to extract dp_rank: {}", e),
                    )
                        .into_response();
452
453
454
455
456
457
                }
            };

            let mut json_val = match serde_json::to_value(typed_req) {
                Ok(j) => j,
                Err(e) => {
458
459
460
461
462
                    return (
                        StatusCode::BAD_REQUEST,
                        format!("Convert into serde_json::Value failed: {}", e),
                    )
                        .into_response();
463
464
465
466
467
468
469
470
471
472
473
474
475
                }
            };

            if let Some(map) = json_val.as_object_mut() {
                map.insert(
                    String::from("data_parallel_rank"),
                    serde_json::json!(dp_rank),
                );
                debug!(
                    "Modified request body: {}",
                    serde_json::to_string(&json_val).unwrap_or(String::from("ERR"))
                );
            } else {
476
477
478
479
480
                return (
                    StatusCode::BAD_REQUEST,
                    "Failed to insert the data_parallel_rank field into the request body",
                )
                    .into_response();
481
482
            }

483
            self.client
484
485
486
                .post(format!("{}{}", worker_url_prefix, route))
                .json(&json_val)
        } else {
487
            self.client
488
489
490
                .post(format!("{}{}", worker_url, route))
                .json(typed_req) // Use json() directly with typed request
        };
491

492
493
494
495
        if let Some(key) = api_key {
            request_builder = request_builder.header("Authorization", format!("Bearer {}", key));
        }

496
497
498
499
        // Copy all headers from original request if provided
        if let Some(headers) = headers {
            for (name, value) in headers {
                // Skip Content-Type and Content-Length as .json() sets them
500
                if *name != CONTENT_TYPE && *name != CONTENT_LENGTH {
501
502
                    request_builder = request_builder.header(name, value);
                }
503
504
505
506
507
508
            }
        }

        let res = match request_builder.send().await {
            Ok(res) => res,
            Err(e) => {
509
510
511
512
                error!(
                    "Failed to send typed request worker_url={} route={} error={}",
                    worker_url, route, e
                );
513
514
515

                // Decrement load on error if it was incremented
                if load_incremented {
516
517
518
                    if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                        worker.decrement_load();
                        RouterMetrics::set_running_requests(worker_url, worker.load());
519
520
521
                    }
                }

522
523
524
525
526
                return (
                    StatusCode::INTERNAL_SERVER_ERROR,
                    format!("Request failed: {}", e),
                )
                    .into_response();
527
528
529
            }
        };

530
531
        let status = StatusCode::from_u16(res.status().as_u16())
            .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
532
533

        if !is_stream {
534
            // For non-streaming requests, preserve headers
535
            let response_headers = header_utils::preserve_response_headers(res.headers());
536

537
            let response = match res.bytes().await {
538
539
540
541
542
543
                Ok(body) => {
                    let mut response = Response::new(axum::body::Body::from(body));
                    *response.status_mut() = status;
                    *response.headers_mut() = response_headers;
                    response
                }
544
                Err(e) => {
545
546
                    // IMPORTANT: Decrement load on error before returning
                    if load_incremented {
547
548
549
                        if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                            worker.decrement_load();
                            RouterMetrics::set_running_requests(worker_url, worker.load());
550
551
552
                        }
                    }

553
                    let error_msg = format!("Failed to get response body: {}", e);
554
                    (StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response()
555
556
557
558
                }
            };

            // Decrement load counter for non-streaming requests if it was incremented
559
            if load_incremented {
560
561
562
                if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                    worker.decrement_load();
                    RouterMetrics::set_running_requests(worker_url, worker.load());
563
564
565
566
567
568
                }
            }

            response
        } else if load_incremented {
            // For streaming with load tracking, we need to manually decrement when done
569
            let registry = Arc::clone(&self.worker_registry);
570
571
            let worker_url = worker_url.to_string();

572
573
574
575
576
            // Preserve headers for streaming response
            let mut response_headers = header_utils::preserve_response_headers(res.headers());
            // Ensure we set the correct content-type for SSE
            response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));

577
578
579
580
581
582
            let stream = res.bytes_stream();
            let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

            // Spawn task to forward stream and detect completion
            tokio::spawn(async move {
                let mut stream = stream;
583
                let mut decremented = false;
584
585
586
587
588
589
590
591
592
                while let Some(chunk) = stream.next().await {
                    match chunk {
                        Ok(bytes) => {
                            // Check for stream end marker
                            if bytes
                                .as_ref()
                                .windows(12)
                                .any(|window| window == b"data: [DONE]")
                            {
593
594
595
596
                                if let Some(worker) = registry.get_by_url(&worker_url) {
                                    worker.decrement_load();
                                    RouterMetrics::set_running_requests(&worker_url, worker.load());
                                    decremented = true;
597
598
                                }
                            }
599
600
601
602
603
604
605
606
607
608
                            if tx.send(Ok(bytes)).is_err() {
                                break;
                            }
                        }
                        Err(e) => {
                            let _ = tx.send(Err(format!("Stream error: {}", e)));
                            break;
                        }
                    }
                }
609
                if !decremented {
610
611
612
                    if let Some(worker) = registry.get_by_url(&worker_url) {
                        worker.decrement_load();
                        RouterMetrics::set_running_requests(&worker_url, worker.load());
613
614
                    }
                }
615
616
617
618
619
620
621
            });

            let stream = UnboundedReceiverStream::new(rx);
            let body = Body::from_stream(stream);

            let mut response = Response::new(body);
            *response.status_mut() = status;
622
            *response.headers_mut() = response_headers;
623
            response
624
625
        } else {
            // For requests without load tracking, just stream
626
627
628
629
630
            // Preserve headers for streaming response
            let mut response_headers = header_utils::preserve_response_headers(res.headers());
            // Ensure we set the correct content-type for SSE
            response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));

631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
            let stream = res.bytes_stream();
            let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

            // Spawn task to forward stream
            tokio::spawn(async move {
                let mut stream = stream;
                while let Some(chunk) = stream.next().await {
                    match chunk {
                        Ok(bytes) => {
                            if tx.send(Ok(bytes)).is_err() {
                                break;
                            }
                        }
                        Err(e) => {
                            let _ = tx.send(Err(format!("Stream error: {}", e)));
                            break;
                        }
                    }
                }
            });

            let stream = UnboundedReceiverStream::new(rx);
            let body = Body::from_stream(stream);

            let mut response = Response::new(body);
            *response.status_mut() = status;
657
            *response.headers_mut() = response_headers;
658
            response
659
660
661
        }
    }

662
    async fn get_worker_load(&self, worker_url: &str, api_key: &Option<String>) -> Option<isize> {
663
664
665
666
667
668
669
670
671
672
673
674
675
676
        let worker_url = if self.dp_aware {
            // Need to extract the URL from "http://host:port@dp_rank"
            let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
                Ok(tup) => tup,
                Err(e) => {
                    error!("Failed to extract dp_rank: {}", e);
                    return None;
                }
            };
            worker_url_prefix
        } else {
            worker_url
        };

677
678
679
680
681
682
        let mut req_builder = self.client.get(format!("{}/get_load", worker_url));
        if let Some(key) = api_key {
            req_builder = req_builder.bearer_auth(key);
        }

        match req_builder.send().await {
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
            Ok(res) if res.status().is_success() => match res.bytes().await {
                Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
                    Ok(data) => data
                        .get("load")
                        .and_then(|v| v.as_i64())
                        .map(|v| v as isize),
                    Err(e) => {
                        debug!("Failed to parse load response from {}: {}", worker_url, e);
                        None
                    }
                },
                Err(e) => {
                    debug!("Failed to read load response from {}: {}", worker_url, e);
                    None
                }
            },
            Ok(res) => {
                debug!(
                    "Worker {} returned non-success status: {}",
                    worker_url,
                    res.status()
                );
                None
            }
            Err(e) => {
                debug!("Failed to get load from {}: {}", worker_url, e);
                None
            }
        }
    }

    // Background task to monitor worker loads
    async fn monitor_worker_loads(
        worker_urls: Vec<String>,
717
        worker_api_keys: Vec<Option<String>>,
718
719
720
        tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
        interval_secs: u64,
        policy: Arc<dyn LoadBalancingPolicy>,
721
        client: Client,
722
723
724
725
726
727
728
    ) {
        let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));

        loop {
            interval.tick().await;

            let mut loads = HashMap::new();
729
730
            for (url, api_key) in worker_urls.iter().zip(worker_api_keys.iter()) {
                if let Some(load) = Self::get_worker_load_static(&client, url, api_key).await {
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
                    loads.insert(url.clone(), load);
                }
            }

            if !loads.is_empty() {
                // Update policy with new loads
                policy.update_loads(&loads);

                // Send to watchers
                if let Err(e) = tx.send(loads) {
                    error!("Failed to send load update: {}", e);
                }
            }
        }
    }

    // Static version of get_worker_load for use in monitoring task
748
    async fn get_worker_load_static(
749
        client: &Client,
750
751
752
        worker_url: &str,
        api_key: &Option<String>,
    ) -> Option<isize> {
753
754
755
756
757
758
759
760
761
762
763
764
765
766
        let worker_url = if worker_url.contains("@") {
            // Need to extract the URL from "http://host:port@dp_rank"
            let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
                Ok(tup) => tup,
                Err(e) => {
                    debug!("Failed to extract dp_rank: {}", e);
                    return None;
                }
            };
            worker_url_prefix
        } else {
            worker_url
        };

767
768
769
770
771
        let mut req_builder = client.get(format!("{}/get_load", worker_url));
        if let Some(key) = api_key {
            req_builder = req_builder.bearer_auth(key);
        }
        match req_builder.send().await {
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
            Ok(res) if res.status().is_success() => match res.bytes().await {
                Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
                    Ok(data) => data
                        .get("load")
                        .and_then(|v| v.as_i64())
                        .map(|v| v as isize),
                    Err(e) => {
                        debug!("Failed to parse load response from {}: {}", worker_url, e);
                        None
                    }
                },
                Err(e) => {
                    debug!("Failed to read load response from {}: {}", worker_url, e);
                    None
                }
            },
            Ok(res) => {
                debug!(
                    "Worker {} returned non-success status: {}",
                    worker_url,
                    res.status()
                );
                None
            }
            Err(e) => {
                debug!("Failed to get load from {}: {}", worker_url, e);
                None
            }
        }
    }
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820

    async fn build_rerank_response(
        req: &RerankRequest,
        response: Response,
    ) -> anyhow::Result<Response> {
        let (_, response_body) = response.into_parts();
        let body_bytes = to_bytes(response_body, usize::MAX).await?;
        let rerank_results = serde_json::from_slice::<Vec<RerankResult>>(&body_bytes)?;
        let mut rerank_response =
            RerankResponse::new(rerank_results, req.model.clone(), req.rid.clone());
        rerank_response.sort_by_score();
        if let Some(top_k) = req.top_k {
            rerank_response.apply_top_k(top_k);
        }
        if !req.return_documents {
            rerank_response.drop_documents();
        }
        Ok(Json(rerank_response).into_response())
    }
821
822
823
824
}

use async_trait::async_trait;

825
#[async_trait]
826
827
828
829
830
impl RouterTrait for Router {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

831
    async fn health(&self, _req: Request<Body>) -> Response {
832
        let workers = self.worker_registry.get_all();
833
834
835
836
837
        let unhealthy_servers: Vec<_> = workers
            .iter()
            .filter(|w| !w.is_healthy())
            .map(|w| w.url().to_string())
            .collect();
838

839
840
        if unhealthy_servers.is_empty() {
            (StatusCode::OK, "All servers healthy").into_response()
841
        } else {
842
843
844
845
846
            (
                StatusCode::SERVICE_UNAVAILABLE,
                format!("Unhealthy servers: {:?}", unhealthy_servers),
            )
                .into_response()
847
848
849
        }
    }

850
851
    async fn health_generate(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "health_generate").await
852
853
    }

854
855
    async fn get_server_info(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "get_server_info").await
856
857
    }

858
859
    async fn get_models(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "v1/models").await
860
861
    }

862
863
    async fn get_model_info(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "get_model_info").await
864
865
866
867
    }

    async fn route_generate(
        &self,
868
869
        headers: Option<&HeaderMap>,
        body: &GenerateRequest,
870
        model_id: Option<&str>,
871
    ) -> Response {
872
873
        self.route_typed_request(headers, body, "/generate", model_id)
            .await
874
875
876
877
    }

    async fn route_chat(
        &self,
878
879
        headers: Option<&HeaderMap>,
        body: &ChatCompletionRequest,
880
        model_id: Option<&str>,
881
    ) -> Response {
882
        self.route_typed_request(headers, body, "/v1/chat/completions", model_id)
883
            .await
884
885
886
887
    }

    async fn route_completion(
        &self,
888
889
        headers: Option<&HeaderMap>,
        body: &CompletionRequest,
890
        model_id: Option<&str>,
891
    ) -> Response {
892
        self.route_typed_request(headers, body, "/v1/completions", model_id)
893
            .await
894
895
    }

896
897
898
899
    async fn route_responses(
        &self,
        headers: Option<&HeaderMap>,
        body: &ResponsesRequest,
900
        model_id: Option<&str>,
901
    ) -> Response {
902
        self.route_typed_request(headers, body, "/v1/responses", model_id)
903
904
905
            .await
    }

906
907
908
909
910
911
912
913
914
915
    async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
        let endpoint = format!("v1/responses/{}", response_id);
        self.route_get_request(headers, &endpoint).await
    }

    async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
        let endpoint = format!("v1/responses/{}/cancel", response_id);
        self.route_post_empty_request(headers, &endpoint).await
    }

916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
    async fn route_embeddings(
        &self,
        headers: Option<&HeaderMap>,
        body: &EmbeddingRequest,
        model_id: Option<&str>,
    ) -> Response {
        // Record embeddings-specific metrics in addition to general request metrics
        let start = Instant::now();
        let res = self
            .route_typed_request(headers, body, "/v1/embeddings", model_id)
            .await;

        // Embedding specific metrics
        if res.status().is_success() {
            RouterMetrics::record_embeddings_request();
            RouterMetrics::record_embeddings_duration(start.elapsed());
        } else {
            let error_type = format!("http_{}", res.status().as_u16());
            RouterMetrics::record_embeddings_error(&error_type);
        }

        res
938
939
    }

940
941
942
943
944
945
    async fn route_rerank(
        &self,
        headers: Option<&HeaderMap>,
        body: &RerankRequest,
        model_id: Option<&str>,
    ) -> Response {
946
947
948
        if let Err(e) = body.validate() {
            return (StatusCode::BAD_REQUEST, e).into_response();
        }
949
950
951
        let response = self
            .route_typed_request(headers, body, "/v1/rerank", model_id)
            .await;
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
        if response.status().is_success() {
            match Self::build_rerank_response(body, response).await {
                Ok(rerank_response) => rerank_response,
                Err(e) => {
                    error!("Failed to build rerank response: {}", e);
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        "Failed to build rerank response".to_string(),
                    )
                        .into_response();
                }
            }
        } else {
            response
        }
967
968
    }

969
    async fn flush_cache(&self) -> Response {
970
971
972
        // Get all workers
        let workers = self.worker_registry.get_all();
        let worker_urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
973
974
975
976

        // Send requests to all workers concurrently without headers
        let mut tasks = Vec::new();
        for worker_url in &worker_urls {
977
978
979
980
981
982
            // Get the worker's API key if available
            let api_key = self
                .worker_registry
                .get_by_url(worker_url)
                .and_then(|w| w.api_key().clone());

983
984
985
986
987
988
            let worker_url = if self.dp_aware {
                // Need to extract the URL from "http://host:port@dp_rank"
                let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
                    Ok(tup) => tup,
                    Err(e) => {
                        error!("Failed to extract dp_rank: {}", e);
989
990
991
992
993
                        return (
                            StatusCode::INTERNAL_SERVER_ERROR,
                            format!("Failed to extract dp_rank: {}", e),
                        )
                            .into_response();
994
995
996
997
998
999
                    }
                };
                worker_url_prefix
            } else {
                worker_url
            };
1000
1001
1002
1003
1004
1005
1006
            let mut request_builder = self.client.post(format!("{}/flush_cache", worker_url));

            if let Some(key) = api_key {
                request_builder =
                    request_builder.header("Authorization", format!("Bearer {}", key));
            }

1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
            tasks.push(request_builder.send());
        }

        // Wait for all responses
        let results = futures_util::future::join_all(tasks).await;

        // Check if all succeeded
        let all_success = results.iter().all(|r| {
            r.as_ref()
                .map(|res| res.status().is_success())
                .unwrap_or(false)
        });

        if all_success {
1021
            (StatusCode::OK, "Cache flushed on all servers").into_response()
1022
        } else {
1023
1024
1025
1026
1027
            (
                StatusCode::INTERNAL_SERVER_ERROR,
                "Cache flush failed on one or more servers",
            )
                .into_response()
1028
1029
1030
        }
    }

1031
    async fn get_worker_loads(&self) -> Response {
1032
        let urls_with_key = self.worker_registry.get_all_urls_with_api_key();
1033
1034
1035
        let mut loads = Vec::new();

        // Get loads from all workers
1036
1037
        for (url, api_key) in &urls_with_key {
            let load = self.get_worker_load(url, api_key).await.unwrap_or(-1);
1038
1039
1040
1041
1042
1043
            loads.push(serde_json::json!({
                "worker": url,
                "load": load
            }));
        }

1044
        Json(serde_json::json!({
1045
1046
            "workers": loads
        }))
1047
        .into_response()
1048
1049
1050
1051
1052
1053
    }

    fn router_type(&self) -> &'static str {
        "regular"
    }

1054
    fn readiness(&self) -> Response {
1055
        // Regular router is ready if it has at least one healthy worker
1056
1057
1058
        let workers = self.worker_registry.get_all();
        let healthy_count = workers.iter().filter(|w| w.is_healthy()).count();
        let total_workers = workers.len();
1059
1060

        if healthy_count > 0 {
1061
            Json(serde_json::json!({
1062
1063
                "status": "ready",
                "healthy_workers": healthy_count,
1064
                "total_workers": total_workers
1065
            }))
1066
            .into_response()
1067
        } else {
1068
1069
1070
1071
1072
            (
                StatusCode::SERVICE_UNAVAILABLE,
                Json(serde_json::json!({
                    "status": "not_ready",
                    "reason": "no healthy workers available",
1073
                    "total_workers": total_workers
1074
1075
1076
                })),
            )
                .into_response()
1077
1078
1079
1080
1081
1082
1083
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
1084
    use crate::core::BasicWorkerBuilder;
1085
1086
1087
    use std::collections::HashMap;

    fn create_test_regular_router() -> Router {
1088
1089
1090
1091
1092
1093
1094
        // Create registries
        let worker_registry = Arc::new(WorkerRegistry::new());
        let policy_registry = Arc::new(PolicyRegistry::new(
            crate::config::types::PolicyConfig::RoundRobin,
        ));

        // Register test workers
1095
1096
1097
1098
1099
1100
        let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
            .worker_type(WorkerType::Regular)
            .build();
        let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
            .worker_type(WorkerType::Regular)
            .build();
1101
1102
1103
        worker_registry.register(Arc::new(worker1));
        worker_registry.register(Arc::new(worker2));

1104
1105
        let (_, rx) = tokio::sync::watch::channel(HashMap::new());
        Router {
1106
1107
            worker_registry,
            policy_registry,
1108
            dp_aware: false,
1109
            client: Client::new(),
1110
            retry_config: RetryConfig::default(),
1111
1112
            _worker_loads: Arc::new(rx),
            _load_monitor_handle: None,
1113
            enable_igw: false,
1114
1115
1116
1117
1118
1119
        }
    }

    #[test]
    fn test_router_get_worker_urls_regular() {
        let router = create_test_regular_router();
1120
1121
        let workers = router.worker_registry.get_all();
        let urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133

        assert_eq!(urls.len(), 2);
        assert!(urls.contains(&"http://worker1:8080".to_string()));
        assert!(urls.contains(&"http://worker2:8080".to_string()));
    }

    #[test]
    fn test_select_first_worker_regular() {
        let router = create_test_regular_router();
        let result = router.select_first_worker();

        assert!(result.is_ok());
1134
1135
1136
        let url = result.unwrap();
        // DashMap doesn't guarantee order, so just check we get one of the workers
        assert!(url == "http://worker1:8080" || url == "http://worker2:8080");
1137
1138
    }
}