"lm_eval/tasks/arabicmmlu/utils.py" did not exist on "78a54e1448ccb8c05ab3e83c110e0386316206a2"
router.rs 34 KB
Newer Older
1
use crate::config::types::RetryConfig;
2
use crate::core::{
3
4
    is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerManager, WorkerRegistry,
    WorkerType,
5
};
6
use crate::metrics::RouterMetrics;
7
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
8
use crate::protocols::spec::{
9
    ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, GenerationRequest,
10
    RerankRequest, RerankResponse, RerankResult, ResponsesGetParams, ResponsesRequest,
11
};
12
use crate::routers::header_utils;
13
use crate::routers::RouterTrait;
14
use axum::body::to_bytes;
15
16
17
use axum::{
    body::Body,
    extract::Request,
18
19
20
    http::{
        header::CONTENT_LENGTH, header::CONTENT_TYPE, HeaderMap, HeaderValue, Method, StatusCode,
    },
21
22
23
24
    response::{IntoResponse, Response},
    Json,
};
use futures_util::StreamExt;
25
use reqwest::Client;
26
use std::collections::HashMap;
27
use std::sync::Arc;
28
use std::time::{Duration, Instant};
29
use tokio_stream::wrappers::UnboundedReceiverStream;
30
use tracing::{debug, error};
31
32
33
34

/// Regular router that uses injected load balancing policies
#[derive(Debug)]
pub struct Router {
35
36
    worker_registry: Arc<WorkerRegistry>,
    policy_registry: Arc<PolicyRegistry>,
37
    client: Client,
38
    dp_aware: bool,
39
    enable_igw: bool,
40
    retry_config: RetryConfig,
41
42
43
44
45
    _worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
    _load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
}

impl Router {
46
    /// Create a new router with injected policy and client
47
48
49
50
51
52
53
    pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
        let workers = ctx.worker_registry.get_workers_filtered(
            None, // any model
            Some(WorkerType::Regular),
            Some(ConnectionMode::Http),
            false, // include all workers
        );
54

55
        RouterMetrics::set_active_workers(workers.len());
56

57
        let worker_urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
58

59
60
61
        let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
        let worker_loads = Arc::new(rx);

62
63
        let default_policy = ctx.policy_registry.get_default_policy();

64
        let load_monitor_handle = if default_policy.name() == "power_of_two" {
65
            let monitor_urls = worker_urls.clone();
66
67
68
69
70
71
72
73
            let monitor_api_keys = monitor_urls
                .iter()
                .map(|url| {
                    ctx.worker_registry
                        .get_by_url(url)
                        .and_then(|w| w.api_key().clone())
                })
                .collect::<Vec<Option<String>>>();
74
            let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
75
            let policy_clone = default_policy.clone();
76
            let client_clone = ctx.client.clone();
77
78

            Some(Arc::new(tokio::spawn(async move {
79
80
                Self::monitor_worker_loads(
                    monitor_urls,
81
                    monitor_api_keys,
82
83
84
85
86
87
                    tx,
                    monitor_interval,
                    policy_clone,
                    client_clone,
                )
                .await;
88
89
90
91
92
93
            })))
        } else {
            None
        };

        Ok(Router {
94
95
            worker_registry: ctx.worker_registry.clone(),
            policy_registry: ctx.policy_registry.clone(),
96
97
            client: ctx.client.clone(),
            dp_aware: ctx.router_config.dp_aware,
98
            enable_igw: ctx.router_config.enable_igw,
99
            retry_config: ctx.router_config.effective_retry_config(),
100
101
102
103
104
105
            _worker_loads: worker_loads,
            _load_monitor_handle: load_monitor_handle,
        })
    }

    fn select_first_worker(&self) -> Result<String, String> {
106
        let workers = self.worker_registry.get_all();
107
108
        let healthy_workers: Vec<_> = workers.iter().filter(|w| w.is_healthy()).collect();
        if healthy_workers.is_empty() {
109
110
            Err("No workers are available".to_string())
        } else {
111
            Ok(healthy_workers[0].url().to_string())
112
113
114
        }
    }

115
    // Helper method to proxy GET requests to the first available worker
116
    async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response {
117
        let headers = header_utils::copy_request_headers(&req);
118
119
120

        match self.select_first_worker() {
            Ok(worker_url) => {
121
                let mut request_builder = self.client.get(format!("{}/{}", worker_url, endpoint));
122
                for (name, value) in headers {
123
124
                    let name_lc = name.to_lowercase();
                    if name_lc != "content-type" && name_lc != "content-length" {
125
126
127
                        request_builder = request_builder.header(name, value);
                    }
                }
128

129
130
131
132
                match request_builder.send().await {
                    Ok(res) => {
                        let status = StatusCode::from_u16(res.status().as_u16())
                            .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
133
134
135
136
137

                        // Preserve headers from backend
                        let response_headers =
                            header_utils::preserve_response_headers(res.headers());

138
                        match res.bytes().await {
139
140
141
142
143
144
                            Ok(body) => {
                                let mut response = Response::new(axum::body::Body::from(body));
                                *response.status_mut() = status;
                                *response.headers_mut() = response_headers;
                                response
                            }
145
146
147
148
149
                            Err(e) => (
                                StatusCode::INTERNAL_SERVER_ERROR,
                                format!("Failed to read response: {}", e),
                            )
                                .into_response(),
150
151
                        }
                    }
152
153
154
155
156
                    Err(e) => (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        format!("Request failed: {}", e),
                    )
                        .into_response(),
157
158
                }
            }
159
            Err(e) => (StatusCode::SERVICE_UNAVAILABLE, e).into_response(),
160
161
162
        }
    }

163
164
165
166
167
168
    /// Select worker for a specific model considering circuit breaker state
    fn select_worker_for_model(
        &self,
        model_id: Option<&str>,
        text: Option<&str>,
    ) -> Option<Arc<dyn Worker>> {
169
170
        let effective_model_id = if !self.enable_igw { None } else { model_id };

171
172
        // Get workers for the specified model O(1), filtered by connection mode
        let workers = self.worker_registry.get_workers_filtered(
173
            effective_model_id,
174
175
176
177
            Some(WorkerType::Regular),
            Some(ConnectionMode::Http),
            false, // get all workers, we'll filter by is_available() next
        );
178
179

        let available: Vec<Arc<dyn Worker>> = workers
180
181
            .iter()
            .filter(|w| w.is_available())
182
            .cloned()
183
184
185
186
            .collect();
        if available.is_empty() {
            return None;
        }
187
188
189
190
191
192
193
194
195

        // Get the appropriate policy for this model
        let policy = match model_id {
            Some(model) => self.policy_registry.get_policy_or_default(model),
            None => self.policy_registry.get_default_policy(),
        };

        let idx = policy.select_worker(&available, text)?;
        Some(available[idx].clone())
196
197
    }

198
    pub async fn route_typed_request<T: GenerationRequest + serde::Serialize + Clone>(
199
        &self,
200
        headers: Option<&HeaderMap>,
201
202
        typed_req: &T,
        route: &str,
203
        model_id: Option<&str>,
204
    ) -> Response {
205
        let start = Instant::now();
206
207
208
209
210
211
212
        let is_stream = typed_req.is_stream();
        let text = typed_req.extract_text_for_routing();

        let response = RetryExecutor::execute_response_with_retry(
            &self.retry_config,
            // operation per attempt
            |_: u32| async {
213
                let worker = match self.select_worker_for_model(model_id, Some(&text)) {
214
215
216
217
218
219
220
221
222
223
                    Some(w) => w,
                    None => {
                        RouterMetrics::record_request_error(route, "no_available_workers");
                        return (
                            StatusCode::SERVICE_UNAVAILABLE,
                            "No available workers (all circuits open or unhealthy)",
                        )
                            .into_response();
                    }
                };
224

225
                // Optional load tracking for cache-aware policy
226
227
228
229
230
231
232
                // Get the policy for this model to check if it's cache-aware
                let policy = match model_id {
                    Some(model) => self.policy_registry.get_policy_or_default(model),
                    None => self.policy_registry.get_default_policy(),
                };

                let load_incremented = if policy.name() == "cache_aware" {
233
234
235
                    worker.increment_load();
                    RouterMetrics::set_running_requests(worker.url(), worker.load());
                    true
236
237
238
239
                } else {
                    false
                };

240
241
                // Keep a clone for potential cleanup on retry
                let worker_for_cleanup = if load_incremented {
242
                    Some(worker.clone())
243
244
245
246
                } else {
                    None
                };

247
248
                let response = self
                    .send_typed_request(
249
                        headers,
250
251
                        typed_req,
                        route,
252
                        worker.url(),
253
254
255
256
257
                        is_stream,
                        load_incremented,
                    )
                    .await;

258
                worker.record_outcome(response.status().is_success());
259
260
261
262
263
264
265
266
267
268
269
270
271

                // For retryable failures, we need to decrement load since send_typed_request
                // won't have done it (it only decrements on success or non-retryable failures)
                if is_retryable_status(response.status()) && load_incremented {
                    if let Some(cleanup_worker) = worker_for_cleanup {
                        cleanup_worker.decrement_load();
                        RouterMetrics::set_running_requests(
                            cleanup_worker.url(),
                            cleanup_worker.load(),
                        );
                    }
                }

272
273
274
                response
            },
            // should_retry predicate
275
            |res, _attempt| is_retryable_status(res.status()),
276
277
278
279
280
281
282
            // on_backoff hook
            |delay, attempt| {
                RouterMetrics::record_retry(route);
                RouterMetrics::record_retry_backoff_duration(delay, attempt);
            },
            // on_exhausted hook
            || RouterMetrics::record_retries_exhausted(route),
283
        )
284
285
286
287
288
289
        .await;

        if response.status().is_success() {
            let duration = start.elapsed();
            RouterMetrics::record_request(route);
            RouterMetrics::record_generate_duration(duration);
290
        } else if !is_retryable_status(response.status()) {
291
            RouterMetrics::record_request_error(route, "non_retryable_error");
292
        }
293
294

        response
295
296
    }

297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
    // Helper: return base worker URL (strips DP suffix when enabled)
    fn worker_base_url(&self, worker_url: &str) -> String {
        if self.dp_aware {
            if let Ok((prefix, _)) = Self::extract_dp_rank(worker_url) {
                return prefix.to_string();
            }
        }
        worker_url.to_string()
    }

    // Generic simple routing for GET/POST without JSON body
    async fn route_simple_request(
        &self,
        headers: Option<&HeaderMap>,
        endpoint: &str,
        method: Method,
    ) -> Response {
        // TODO: currently the sglang worker is using in-memory state management, so this implementation has to fan out to all workers.
        // Eventually, we need to have router to manage the chat history with a proper database, will update this implementation accordingly.
316
317
        let workers = self.worker_registry.get_all();
        if workers.is_empty() {
318
319
320
321
            return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response();
        }

        let mut last_response: Option<Response> = None;
322
323
324
        for worker in workers {
            let worker_url = worker.url();
            let base = self.worker_base_url(worker_url);
325
326
327
328
329
330
331
332
333
334
335
336
337
338

            let url = format!("{}/{}", base, endpoint);
            let mut request_builder = match method {
                Method::GET => self.client.get(url),
                Method::POST => self.client.post(url),
                _ => {
                    return (
                        StatusCode::METHOD_NOT_ALLOWED,
                        "Unsupported method for simple routing",
                    )
                        .into_response()
                }
            };

339
340
341
342
343
            if let Some(api_key) = worker.api_key() {
                request_builder =
                    request_builder.header("Authorization", format!("Bearer {}", api_key));
            }

344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
            if let Some(hdrs) = headers {
                for (name, value) in hdrs {
                    let name_lc = name.as_str().to_lowercase();
                    if name_lc != "content-type" && name_lc != "content-length" {
                        request_builder = request_builder.header(name, value);
                    }
                }
            }

            match request_builder.send().await {
                Ok(res) => {
                    let status = StatusCode::from_u16(res.status().as_u16())
                        .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
                    let response_headers = header_utils::preserve_response_headers(res.headers());
                    match res.bytes().await {
                        Ok(body) => {
                            let mut response = Response::new(axum::body::Body::from(body));
                            *response.status_mut() = status;
                            *response.headers_mut() = response_headers;
                            if status.is_success() {
                                return response;
                            }
                            last_response = Some(response);
                        }
                        Err(e) => {
                            last_response = Some(
                                (
                                    StatusCode::INTERNAL_SERVER_ERROR,
                                    format!("Failed to read response: {}", e),
                                )
                                    .into_response(),
                            );
                        }
                    }
                }
                Err(e) => {
                    last_response = Some(
                        (
                            StatusCode::INTERNAL_SERVER_ERROR,
                            format!("Request failed: {}", e),
                        )
                            .into_response(),
                    );
                }
            }
        }

        last_response
            .unwrap_or_else(|| (StatusCode::BAD_GATEWAY, "No worker response").into_response())
    }

    // Route a GET request with provided headers to a specific endpoint
    async fn route_get_request(&self, headers: Option<&HeaderMap>, endpoint: &str) -> Response {
        self.route_simple_request(headers, endpoint, Method::GET)
            .await
    }

    // Route a POST request with empty body to a specific endpoint
    async fn route_post_empty_request(
        &self,
        headers: Option<&HeaderMap>,
        endpoint: &str,
    ) -> Response {
        self.route_simple_request(headers, endpoint, Method::POST)
            .await
    }

411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
    // TODO (rui): Better accommodate to the Worker abstraction
    fn extract_dp_rank(worker_url: &str) -> Result<(&str, usize), String> {
        let parts: Vec<&str> = worker_url.split('@').collect();
        if parts.len() != 2 {
            return Err(format!("invalid worker_url format: {}", worker_url));
        }

        // Parse the second part (dp_rank) into an integer
        match parts[1].parse::<usize>() {
            Ok(dp_rank) => Ok((parts[0], dp_rank)),
            Err(_) => Err(format!(
                "failed to parse dp_rank from worker_url: {}",
                worker_url
            )),
        }
    }

428
429
430
    // Send typed request directly without conversion
    async fn send_typed_request<T: serde::Serialize>(
        &self,
431
        headers: Option<&HeaderMap>,
432
433
434
435
436
        typed_req: &T,
        route: &str,
        worker_url: &str,
        is_stream: bool,
        load_incremented: bool, // Whether load was incremented for this request
437
    ) -> Response {
438
439
440
441
442
443
        // Get the worker's API key if available
        let api_key = self
            .worker_registry
            .get_by_url(worker_url)
            .and_then(|w| w.api_key().clone());

444
445
446
447
448
        let mut request_builder = if self.dp_aware {
            let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) {
                Ok(tup) => tup,
                Err(e) => {
                    error!("Failed to extract dp_rank: {}", e);
449
450
451
452
453
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        format!("Failed to extract dp_rank: {}", e),
                    )
                        .into_response();
454
455
456
457
458
459
                }
            };

            let mut json_val = match serde_json::to_value(typed_req) {
                Ok(j) => j,
                Err(e) => {
460
461
462
463
464
                    return (
                        StatusCode::BAD_REQUEST,
                        format!("Convert into serde_json::Value failed: {}", e),
                    )
                        .into_response();
465
466
467
468
469
470
471
472
473
474
475
476
477
                }
            };

            if let Some(map) = json_val.as_object_mut() {
                map.insert(
                    String::from("data_parallel_rank"),
                    serde_json::json!(dp_rank),
                );
                debug!(
                    "Modified request body: {}",
                    serde_json::to_string(&json_val).unwrap_or(String::from("ERR"))
                );
            } else {
478
479
480
481
482
                return (
                    StatusCode::BAD_REQUEST,
                    "Failed to insert the data_parallel_rank field into the request body",
                )
                    .into_response();
483
484
            }

485
            self.client
486
487
488
                .post(format!("{}{}", worker_url_prefix, route))
                .json(&json_val)
        } else {
489
            self.client
490
491
492
                .post(format!("{}{}", worker_url, route))
                .json(typed_req) // Use json() directly with typed request
        };
493

494
495
496
497
        if let Some(key) = api_key {
            request_builder = request_builder.header("Authorization", format!("Bearer {}", key));
        }

498
499
500
501
        // Copy all headers from original request if provided
        if let Some(headers) = headers {
            for (name, value) in headers {
                // Skip Content-Type and Content-Length as .json() sets them
502
                if *name != CONTENT_TYPE && *name != CONTENT_LENGTH {
503
504
                    request_builder = request_builder.header(name, value);
                }
505
506
507
508
509
510
            }
        }

        let res = match request_builder.send().await {
            Ok(res) => res,
            Err(e) => {
511
512
513
514
                error!(
                    "Failed to send typed request worker_url={} route={} error={}",
                    worker_url, route, e
                );
515
516
517

                // Decrement load on error if it was incremented
                if load_incremented {
518
519
520
                    if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                        worker.decrement_load();
                        RouterMetrics::set_running_requests(worker_url, worker.load());
521
522
523
                    }
                }

524
525
526
527
528
                return (
                    StatusCode::INTERNAL_SERVER_ERROR,
                    format!("Request failed: {}", e),
                )
                    .into_response();
529
530
531
            }
        };

532
533
        let status = StatusCode::from_u16(res.status().as_u16())
            .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
534
535

        if !is_stream {
536
            // For non-streaming requests, preserve headers
537
            let response_headers = header_utils::preserve_response_headers(res.headers());
538

539
            let response = match res.bytes().await {
540
541
542
543
544
545
                Ok(body) => {
                    let mut response = Response::new(axum::body::Body::from(body));
                    *response.status_mut() = status;
                    *response.headers_mut() = response_headers;
                    response
                }
546
                Err(e) => {
547
548
                    // IMPORTANT: Decrement load on error before returning
                    if load_incremented {
549
550
551
                        if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                            worker.decrement_load();
                            RouterMetrics::set_running_requests(worker_url, worker.load());
552
553
554
                        }
                    }

555
                    let error_msg = format!("Failed to get response body: {}", e);
556
                    (StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response()
557
558
559
560
                }
            };

            // Decrement load counter for non-streaming requests if it was incremented
561
            if load_incremented {
562
563
564
                if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                    worker.decrement_load();
                    RouterMetrics::set_running_requests(worker_url, worker.load());
565
566
567
568
569
570
                }
            }

            response
        } else if load_incremented {
            // For streaming with load tracking, we need to manually decrement when done
571
            let registry = Arc::clone(&self.worker_registry);
572
573
            let worker_url = worker_url.to_string();

574
575
576
577
578
            // Preserve headers for streaming response
            let mut response_headers = header_utils::preserve_response_headers(res.headers());
            // Ensure we set the correct content-type for SSE
            response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));

579
580
581
582
583
584
            let stream = res.bytes_stream();
            let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

            // Spawn task to forward stream and detect completion
            tokio::spawn(async move {
                let mut stream = stream;
585
                let mut decremented = false;
586
587
588
589
590
591
592
593
594
                while let Some(chunk) = stream.next().await {
                    match chunk {
                        Ok(bytes) => {
                            // Check for stream end marker
                            if bytes
                                .as_ref()
                                .windows(12)
                                .any(|window| window == b"data: [DONE]")
                            {
595
596
597
598
                                if let Some(worker) = registry.get_by_url(&worker_url) {
                                    worker.decrement_load();
                                    RouterMetrics::set_running_requests(&worker_url, worker.load());
                                    decremented = true;
599
600
                                }
                            }
601
602
603
604
605
606
607
608
609
610
                            if tx.send(Ok(bytes)).is_err() {
                                break;
                            }
                        }
                        Err(e) => {
                            let _ = tx.send(Err(format!("Stream error: {}", e)));
                            break;
                        }
                    }
                }
611
                if !decremented {
612
613
614
                    if let Some(worker) = registry.get_by_url(&worker_url) {
                        worker.decrement_load();
                        RouterMetrics::set_running_requests(&worker_url, worker.load());
615
616
                    }
                }
617
618
619
620
621
622
623
            });

            let stream = UnboundedReceiverStream::new(rx);
            let body = Body::from_stream(stream);

            let mut response = Response::new(body);
            *response.status_mut() = status;
624
            *response.headers_mut() = response_headers;
625
            response
626
627
        } else {
            // For requests without load tracking, just stream
628
629
630
631
632
            // Preserve headers for streaming response
            let mut response_headers = header_utils::preserve_response_headers(res.headers());
            // Ensure we set the correct content-type for SSE
            response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));

633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
            let stream = res.bytes_stream();
            let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

            // Spawn task to forward stream
            tokio::spawn(async move {
                let mut stream = stream;
                while let Some(chunk) = stream.next().await {
                    match chunk {
                        Ok(bytes) => {
                            if tx.send(Ok(bytes)).is_err() {
                                break;
                            }
                        }
                        Err(e) => {
                            let _ = tx.send(Err(format!("Stream error: {}", e)));
                            break;
                        }
                    }
                }
            });

            let stream = UnboundedReceiverStream::new(rx);
            let body = Body::from_stream(stream);

            let mut response = Response::new(body);
            *response.status_mut() = status;
659
            *response.headers_mut() = response_headers;
660
            response
661
662
663
664
665
666
        }
    }

    // Background task to monitor worker loads
    async fn monitor_worker_loads(
        worker_urls: Vec<String>,
667
        worker_api_keys: Vec<Option<String>>,
668
669
670
        tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
        interval_secs: u64,
        policy: Arc<dyn LoadBalancingPolicy>,
671
        client: Client,
672
673
674
675
676
677
678
    ) {
        let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));

        loop {
            interval.tick().await;

            let mut loads = HashMap::new();
679
            for (url, api_key) in worker_urls.iter().zip(worker_api_keys.iter()) {
680
681
682
683
                // Use WorkerManager for consistent load fetching
                if let Some(load) =
                    WorkerManager::get_worker_load(url, api_key.as_deref(), &client).await
                {
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
                    loads.insert(url.clone(), load);
                }
            }

            if !loads.is_empty() {
                // Update policy with new loads
                policy.update_loads(&loads);

                // Send to watchers
                if let Err(e) = tx.send(loads) {
                    error!("Failed to send load update: {}", e);
                }
            }
        }
    }

700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
    async fn build_rerank_response(
        req: &RerankRequest,
        response: Response,
    ) -> anyhow::Result<Response> {
        let (_, response_body) = response.into_parts();
        let body_bytes = to_bytes(response_body, usize::MAX).await?;
        let rerank_results = serde_json::from_slice::<Vec<RerankResult>>(&body_bytes)?;
        let mut rerank_response =
            RerankResponse::new(rerank_results, req.model.clone(), req.rid.clone());
        rerank_response.sort_by_score();
        if let Some(top_k) = req.top_k {
            rerank_response.apply_top_k(top_k);
        }
        if !req.return_documents {
            rerank_response.drop_documents();
        }
        Ok(Json(rerank_response).into_response())
    }
718
719
720
721
}

use async_trait::async_trait;

722
#[async_trait]
723
724
725
726
727
impl RouterTrait for Router {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

728
729
    async fn health_generate(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "health_generate").await
730
731
    }

732
733
    async fn get_server_info(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "get_server_info").await
734
735
    }

736
737
    async fn get_models(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "v1/models").await
738
739
    }

740
741
    async fn get_model_info(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "get_model_info").await
742
743
744
745
    }

    async fn route_generate(
        &self,
746
747
        headers: Option<&HeaderMap>,
        body: &GenerateRequest,
748
        model_id: Option<&str>,
749
    ) -> Response {
750
751
        self.route_typed_request(headers, body, "/generate", model_id)
            .await
752
753
754
755
    }

    async fn route_chat(
        &self,
756
757
        headers: Option<&HeaderMap>,
        body: &ChatCompletionRequest,
758
        model_id: Option<&str>,
759
    ) -> Response {
760
        self.route_typed_request(headers, body, "/v1/chat/completions", model_id)
761
            .await
762
763
764
765
    }

    async fn route_completion(
        &self,
766
767
        headers: Option<&HeaderMap>,
        body: &CompletionRequest,
768
        model_id: Option<&str>,
769
    ) -> Response {
770
        self.route_typed_request(headers, body, "/v1/completions", model_id)
771
            .await
772
773
    }

774
775
776
777
    async fn route_responses(
        &self,
        headers: Option<&HeaderMap>,
        body: &ResponsesRequest,
778
        model_id: Option<&str>,
779
    ) -> Response {
780
        self.route_typed_request(headers, body, "/v1/responses", model_id)
781
782
783
            .await
    }

784
785
786
787
788
789
    async fn get_response(
        &self,
        headers: Option<&HeaderMap>,
        response_id: &str,
        _params: &ResponsesGetParams,
    ) -> Response {
790
791
792
793
794
795
796
797
798
        let endpoint = format!("v1/responses/{}", response_id);
        self.route_get_request(headers, &endpoint).await
    }

    async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
        let endpoint = format!("v1/responses/{}/cancel", response_id);
        self.route_post_empty_request(headers, &endpoint).await
    }

799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
    async fn route_embeddings(
        &self,
        headers: Option<&HeaderMap>,
        body: &EmbeddingRequest,
        model_id: Option<&str>,
    ) -> Response {
        // Record embeddings-specific metrics in addition to general request metrics
        let start = Instant::now();
        let res = self
            .route_typed_request(headers, body, "/v1/embeddings", model_id)
            .await;

        // Embedding specific metrics
        if res.status().is_success() {
            RouterMetrics::record_embeddings_request();
            RouterMetrics::record_embeddings_duration(start.elapsed());
        } else {
            let error_type = format!("http_{}", res.status().as_u16());
            RouterMetrics::record_embeddings_error(&error_type);
        }

        res
821
822
    }

823
824
825
826
827
828
    async fn route_rerank(
        &self,
        headers: Option<&HeaderMap>,
        body: &RerankRequest,
        model_id: Option<&str>,
    ) -> Response {
829
830
831
        if let Err(e) = body.validate() {
            return (StatusCode::BAD_REQUEST, e).into_response();
        }
832
833
834
        let response = self
            .route_typed_request(headers, body, "/v1/rerank", model_id)
            .await;
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
        if response.status().is_success() {
            match Self::build_rerank_response(body, response).await {
                Ok(rerank_response) => rerank_response,
                Err(e) => {
                    error!("Failed to build rerank response: {}", e);
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        "Failed to build rerank response".to_string(),
                    )
                        .into_response();
                }
            }
        } else {
            response
        }
850
851
    }

852
853
854
855
856
857
858
859
    fn router_type(&self) -> &'static str {
        "regular"
    }
}

#[cfg(test)]
mod tests {
    use super::*;
860
    use crate::core::BasicWorkerBuilder;
861
862
863
    use std::collections::HashMap;

    fn create_test_regular_router() -> Router {
864
865
866
867
868
869
870
        // Create registries
        let worker_registry = Arc::new(WorkerRegistry::new());
        let policy_registry = Arc::new(PolicyRegistry::new(
            crate::config::types::PolicyConfig::RoundRobin,
        ));

        // Register test workers
871
872
873
874
875
876
        let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
            .worker_type(WorkerType::Regular)
            .build();
        let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
            .worker_type(WorkerType::Regular)
            .build();
877
878
879
        worker_registry.register(Arc::new(worker1));
        worker_registry.register(Arc::new(worker2));

880
881
        let (_, rx) = tokio::sync::watch::channel(HashMap::new());
        Router {
882
883
            worker_registry,
            policy_registry,
884
            dp_aware: false,
885
            client: Client::new(),
886
            retry_config: RetryConfig::default(),
887
888
            _worker_loads: Arc::new(rx),
            _load_monitor_handle: None,
889
            enable_igw: false,
890
891
892
        }
    }

893
894
895
896
897
898
899
    fn create_test_unhealthy_router() -> Router {
        let router = create_test_regular_router();
        let workers = router.worker_registry.get_all();
        workers[0].set_healthy(false);
        router
    }

900
901
902
    #[test]
    fn test_router_get_worker_urls_regular() {
        let router = create_test_regular_router();
903
904
        let workers = router.worker_registry.get_all();
        let urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
905
906
907
908
909
910
911
912
913
914
915
916

        assert_eq!(urls.len(), 2);
        assert!(urls.contains(&"http://worker1:8080".to_string()));
        assert!(urls.contains(&"http://worker2:8080".to_string()));
    }

    #[test]
    fn test_select_first_worker_regular() {
        let router = create_test_regular_router();
        let result = router.select_first_worker();

        assert!(result.is_ok());
917
918
919
        let url = result.unwrap();
        // DashMap doesn't guarantee order, so just check we get one of the workers
        assert!(url == "http://worker1:8080" || url == "http://worker2:8080");
920
    }
921
922
923
924
925
926
927
928
929
930
931
932

    #[test]
    fn test_select_first_worker_with_unhealthy_worker() {
        let router = create_test_unhealthy_router();
        let result = router.select_first_worker();

        assert!(result.is_ok());
        let url = result.unwrap();

        let worker = router.worker_registry.get_by_url(&url).unwrap();
        assert!(worker.is_healthy());
    }
933
}