router.rs 31 KB
Newer Older
1
use crate::config::types::RetryConfig;
2
use crate::core::{
3
    is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType,
4
};
5
use crate::metrics::RouterMetrics;
6
use crate::policies::PolicyRegistry;
7
use crate::protocols::spec::{
8
    ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, GenerationRequest,
9
    RerankRequest, RerankResponse, RerankResult, ResponsesGetParams, ResponsesRequest,
10
};
11
use crate::routers::header_utils;
12
use crate::routers::RouterTrait;
13
use axum::body::to_bytes;
14
15
16
use axum::{
    body::Body,
    extract::Request,
17
18
19
    http::{
        header::CONTENT_LENGTH, header::CONTENT_TYPE, HeaderMap, HeaderValue, Method, StatusCode,
    },
20
21
22
23
    response::{IntoResponse, Response},
    Json,
};
use futures_util::StreamExt;
24
use reqwest::Client;
25
use std::sync::Arc;
26
use std::time::Instant;
27
use tokio_stream::wrappers::UnboundedReceiverStream;
28
use tracing::{debug, error};
29
30
31
32

/// Regular router that uses injected load balancing policies
#[derive(Debug)]
pub struct Router {
33
34
    worker_registry: Arc<WorkerRegistry>,
    policy_registry: Arc<PolicyRegistry>,
35
    client: Client,
36
    dp_aware: bool,
37
    enable_igw: bool,
38
    retry_config: RetryConfig,
39
40
41
}

impl Router {
42
    /// Create a new router with injected policy and client
43
44
45
46
47
48
49
    pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
        let workers = ctx.worker_registry.get_workers_filtered(
            None, // any model
            Some(WorkerType::Regular),
            Some(ConnectionMode::Http),
            false, // include all workers
        );
50

51
        RouterMetrics::set_active_workers(workers.len());
52
53

        Ok(Router {
54
55
            worker_registry: ctx.worker_registry.clone(),
            policy_registry: ctx.policy_registry.clone(),
56
57
            client: ctx.client.clone(),
            dp_aware: ctx.router_config.dp_aware,
58
            enable_igw: ctx.router_config.enable_igw,
59
            retry_config: ctx.router_config.effective_retry_config(),
60
61
62
63
        })
    }

    fn select_first_worker(&self) -> Result<String, String> {
64
        let workers = self.worker_registry.get_all();
65
66
        let healthy_workers: Vec<_> = workers.iter().filter(|w| w.is_healthy()).collect();
        if healthy_workers.is_empty() {
67
68
            Err("No workers are available".to_string())
        } else {
69
            Ok(healthy_workers[0].url().to_string())
70
71
72
        }
    }

73
    // Helper method to proxy GET requests to the first available worker
74
    async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response {
75
        let headers = header_utils::copy_request_headers(&req);
76
77
78

        match self.select_first_worker() {
            Ok(worker_url) => {
79
                let mut request_builder = self.client.get(format!("{}/{}", worker_url, endpoint));
80
                for (name, value) in headers {
81
82
                    let name_lc = name.to_lowercase();
                    if name_lc != "content-type" && name_lc != "content-length" {
83
84
85
                        request_builder = request_builder.header(name, value);
                    }
                }
86

87
88
89
90
                match request_builder.send().await {
                    Ok(res) => {
                        let status = StatusCode::from_u16(res.status().as_u16())
                            .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
91
92
93
94
95

                        // Preserve headers from backend
                        let response_headers =
                            header_utils::preserve_response_headers(res.headers());

96
                        match res.bytes().await {
97
                            Ok(body) => {
98
                                let mut response = Response::new(Body::from(body));
99
100
101
102
                                *response.status_mut() = status;
                                *response.headers_mut() = response_headers;
                                response
                            }
103
104
105
106
107
                            Err(e) => (
                                StatusCode::INTERNAL_SERVER_ERROR,
                                format!("Failed to read response: {}", e),
                            )
                                .into_response(),
108
109
                        }
                    }
110
111
112
113
114
                    Err(e) => (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        format!("Request failed: {}", e),
                    )
                        .into_response(),
115
116
                }
            }
117
            Err(e) => (StatusCode::SERVICE_UNAVAILABLE, e).into_response(),
118
119
120
        }
    }

121
122
123
124
125
126
    /// Select worker for a specific model considering circuit breaker state
    fn select_worker_for_model(
        &self,
        model_id: Option<&str>,
        text: Option<&str>,
    ) -> Option<Arc<dyn Worker>> {
127
128
        let effective_model_id = if !self.enable_igw { None } else { model_id };

129
130
        // Get workers for the specified model O(1), filtered by connection mode
        let workers = self.worker_registry.get_workers_filtered(
131
            effective_model_id,
132
133
134
135
            Some(WorkerType::Regular),
            Some(ConnectionMode::Http),
            false, // get all workers, we'll filter by is_available() next
        );
136
137

        let available: Vec<Arc<dyn Worker>> = workers
138
139
            .iter()
            .filter(|w| w.is_available())
140
            .cloned()
141
142
143
144
            .collect();
        if available.is_empty() {
            return None;
        }
145
146
147
148
149
150
151
152
153

        // Get the appropriate policy for this model
        let policy = match model_id {
            Some(model) => self.policy_registry.get_policy_or_default(model),
            None => self.policy_registry.get_default_policy(),
        };

        let idx = policy.select_worker(&available, text)?;
        Some(available[idx].clone())
154
155
    }

156
    pub async fn route_typed_request<T: GenerationRequest + serde::Serialize + Clone>(
157
        &self,
158
        headers: Option<&HeaderMap>,
159
160
        typed_req: &T,
        route: &str,
161
        model_id: Option<&str>,
162
    ) -> Response {
163
        let start = Instant::now();
164
165
166
167
168
169
170
        let is_stream = typed_req.is_stream();
        let text = typed_req.extract_text_for_routing();

        let response = RetryExecutor::execute_response_with_retry(
            &self.retry_config,
            // operation per attempt
            |_: u32| async {
171
                let worker = match self.select_worker_for_model(model_id, Some(&text)) {
172
173
174
175
176
177
178
179
180
181
                    Some(w) => w,
                    None => {
                        RouterMetrics::record_request_error(route, "no_available_workers");
                        return (
                            StatusCode::SERVICE_UNAVAILABLE,
                            "No available workers (all circuits open or unhealthy)",
                        )
                            .into_response();
                    }
                };
182

183
                // Optional load tracking for cache-aware policy
184
185
186
187
188
189
190
                // Get the policy for this model to check if it's cache-aware
                let policy = match model_id {
                    Some(model) => self.policy_registry.get_policy_or_default(model),
                    None => self.policy_registry.get_default_policy(),
                };

                let load_incremented = if policy.name() == "cache_aware" {
191
192
193
                    worker.increment_load();
                    RouterMetrics::set_running_requests(worker.url(), worker.load());
                    true
194
195
196
197
                } else {
                    false
                };

198
199
                // Keep a clone for potential cleanup on retry
                let worker_for_cleanup = if load_incremented {
200
                    Some(worker.clone())
201
202
203
204
                } else {
                    None
                };

205
206
                let response = self
                    .send_typed_request(
207
                        headers,
208
209
                        typed_req,
                        route,
210
                        worker.url(),
211
212
213
214
215
                        is_stream,
                        load_incremented,
                    )
                    .await;

216
                worker.record_outcome(response.status().is_success());
217
218
219
220
221
222
223
224
225
226
227
228
229

                // For retryable failures, we need to decrement load since send_typed_request
                // won't have done it (it only decrements on success or non-retryable failures)
                if is_retryable_status(response.status()) && load_incremented {
                    if let Some(cleanup_worker) = worker_for_cleanup {
                        cleanup_worker.decrement_load();
                        RouterMetrics::set_running_requests(
                            cleanup_worker.url(),
                            cleanup_worker.load(),
                        );
                    }
                }

230
231
232
                response
            },
            // should_retry predicate
233
            |res, _attempt| is_retryable_status(res.status()),
234
235
236
237
238
239
240
            // on_backoff hook
            |delay, attempt| {
                RouterMetrics::record_retry(route);
                RouterMetrics::record_retry_backoff_duration(delay, attempt);
            },
            // on_exhausted hook
            || RouterMetrics::record_retries_exhausted(route),
241
        )
242
243
244
245
246
247
        .await;

        if response.status().is_success() {
            let duration = start.elapsed();
            RouterMetrics::record_request(route);
            RouterMetrics::record_generate_duration(duration);
248
        } else if !is_retryable_status(response.status()) {
249
            RouterMetrics::record_request_error(route, "non_retryable_error");
250
        }
251
252

        response
253
254
    }

255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
    // Helper: return base worker URL (strips DP suffix when enabled)
    fn worker_base_url(&self, worker_url: &str) -> String {
        if self.dp_aware {
            if let Ok((prefix, _)) = Self::extract_dp_rank(worker_url) {
                return prefix.to_string();
            }
        }
        worker_url.to_string()
    }

    // Generic simple routing for GET/POST without JSON body
    async fn route_simple_request(
        &self,
        headers: Option<&HeaderMap>,
        endpoint: &str,
        method: Method,
    ) -> Response {
        // TODO: currently the sglang worker is using in-memory state management, so this implementation has to fan out to all workers.
        // Eventually, we need to have router to manage the chat history with a proper database, will update this implementation accordingly.
274
275
        let workers = self.worker_registry.get_all();
        if workers.is_empty() {
276
277
278
279
            return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response();
        }

        let mut last_response: Option<Response> = None;
280
281
282
        for worker in workers {
            let worker_url = worker.url();
            let base = self.worker_base_url(worker_url);
283
284
285
286
287
288
289
290
291
292
293
294
295
296

            let url = format!("{}/{}", base, endpoint);
            let mut request_builder = match method {
                Method::GET => self.client.get(url),
                Method::POST => self.client.post(url),
                _ => {
                    return (
                        StatusCode::METHOD_NOT_ALLOWED,
                        "Unsupported method for simple routing",
                    )
                        .into_response()
                }
            };

297
298
299
300
301
            if let Some(api_key) = worker.api_key() {
                request_builder =
                    request_builder.header("Authorization", format!("Bearer {}", api_key));
            }

302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
            if let Some(hdrs) = headers {
                for (name, value) in hdrs {
                    let name_lc = name.as_str().to_lowercase();
                    if name_lc != "content-type" && name_lc != "content-length" {
                        request_builder = request_builder.header(name, value);
                    }
                }
            }

            match request_builder.send().await {
                Ok(res) => {
                    let status = StatusCode::from_u16(res.status().as_u16())
                        .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
                    let response_headers = header_utils::preserve_response_headers(res.headers());
                    match res.bytes().await {
                        Ok(body) => {
318
                            let mut response = Response::new(Body::from(body));
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
                            *response.status_mut() = status;
                            *response.headers_mut() = response_headers;
                            if status.is_success() {
                                return response;
                            }
                            last_response = Some(response);
                        }
                        Err(e) => {
                            last_response = Some(
                                (
                                    StatusCode::INTERNAL_SERVER_ERROR,
                                    format!("Failed to read response: {}", e),
                                )
                                    .into_response(),
                            );
                        }
                    }
                }
                Err(e) => {
                    last_response = Some(
                        (
                            StatusCode::INTERNAL_SERVER_ERROR,
                            format!("Request failed: {}", e),
                        )
                            .into_response(),
                    );
                }
            }
        }

        last_response
            .unwrap_or_else(|| (StatusCode::BAD_GATEWAY, "No worker response").into_response())
    }

    // Route a GET request with provided headers to a specific endpoint
    async fn route_get_request(&self, headers: Option<&HeaderMap>, endpoint: &str) -> Response {
        self.route_simple_request(headers, endpoint, Method::GET)
            .await
    }

    // Route a POST request with empty body to a specific endpoint
    async fn route_post_empty_request(
        &self,
        headers: Option<&HeaderMap>,
        endpoint: &str,
    ) -> Response {
        self.route_simple_request(headers, endpoint, Method::POST)
            .await
    }

369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    // TODO (rui): Better accommodate to the Worker abstraction
    fn extract_dp_rank(worker_url: &str) -> Result<(&str, usize), String> {
        let parts: Vec<&str> = worker_url.split('@').collect();
        if parts.len() != 2 {
            return Err(format!("invalid worker_url format: {}", worker_url));
        }

        // Parse the second part (dp_rank) into an integer
        match parts[1].parse::<usize>() {
            Ok(dp_rank) => Ok((parts[0], dp_rank)),
            Err(_) => Err(format!(
                "failed to parse dp_rank from worker_url: {}",
                worker_url
            )),
        }
    }

386
387
388
    // Send typed request directly without conversion
    async fn send_typed_request<T: serde::Serialize>(
        &self,
389
        headers: Option<&HeaderMap>,
390
391
392
393
394
        typed_req: &T,
        route: &str,
        worker_url: &str,
        is_stream: bool,
        load_incremented: bool, // Whether load was incremented for this request
395
    ) -> Response {
396
397
398
399
400
401
        // Get the worker's API key if available
        let api_key = self
            .worker_registry
            .get_by_url(worker_url)
            .and_then(|w| w.api_key().clone());

402
403
404
405
406
        let mut request_builder = if self.dp_aware {
            let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) {
                Ok(tup) => tup,
                Err(e) => {
                    error!("Failed to extract dp_rank: {}", e);
407
408
409
410
411
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        format!("Failed to extract dp_rank: {}", e),
                    )
                        .into_response();
412
413
414
415
416
417
                }
            };

            let mut json_val = match serde_json::to_value(typed_req) {
                Ok(j) => j,
                Err(e) => {
418
419
420
421
422
                    return (
                        StatusCode::BAD_REQUEST,
                        format!("Convert into serde_json::Value failed: {}", e),
                    )
                        .into_response();
423
424
425
426
427
428
429
430
431
432
433
434
435
                }
            };

            if let Some(map) = json_val.as_object_mut() {
                map.insert(
                    String::from("data_parallel_rank"),
                    serde_json::json!(dp_rank),
                );
                debug!(
                    "Modified request body: {}",
                    serde_json::to_string(&json_val).unwrap_or(String::from("ERR"))
                );
            } else {
436
437
438
439
440
                return (
                    StatusCode::BAD_REQUEST,
                    "Failed to insert the data_parallel_rank field into the request body",
                )
                    .into_response();
441
442
            }

443
            self.client
444
445
446
                .post(format!("{}{}", worker_url_prefix, route))
                .json(&json_val)
        } else {
447
            self.client
448
449
450
                .post(format!("{}{}", worker_url, route))
                .json(typed_req) // Use json() directly with typed request
        };
451

452
453
454
455
        if let Some(key) = api_key {
            request_builder = request_builder.header("Authorization", format!("Bearer {}", key));
        }

456
457
458
459
        // Copy all headers from original request if provided
        if let Some(headers) = headers {
            for (name, value) in headers {
                // Skip Content-Type and Content-Length as .json() sets them
460
                if *name != CONTENT_TYPE && *name != CONTENT_LENGTH {
461
462
                    request_builder = request_builder.header(name, value);
                }
463
464
465
466
467
468
            }
        }

        let res = match request_builder.send().await {
            Ok(res) => res,
            Err(e) => {
469
470
471
472
                error!(
                    "Failed to send typed request worker_url={} route={} error={}",
                    worker_url, route, e
                );
473
474
475

                // Decrement load on error if it was incremented
                if load_incremented {
476
477
478
                    if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                        worker.decrement_load();
                        RouterMetrics::set_running_requests(worker_url, worker.load());
479
480
481
                    }
                }

482
483
484
485
486
                return (
                    StatusCode::INTERNAL_SERVER_ERROR,
                    format!("Request failed: {}", e),
                )
                    .into_response();
487
488
489
            }
        };

490
491
        let status = StatusCode::from_u16(res.status().as_u16())
            .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
492
493

        if !is_stream {
494
            // For non-streaming requests, preserve headers
495
            let response_headers = header_utils::preserve_response_headers(res.headers());
496

497
            let response = match res.bytes().await {
498
                Ok(body) => {
499
                    let mut response = Response::new(Body::from(body));
500
501
502
503
                    *response.status_mut() = status;
                    *response.headers_mut() = response_headers;
                    response
                }
504
                Err(e) => {
505
506
                    // IMPORTANT: Decrement load on error before returning
                    if load_incremented {
507
508
509
                        if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                            worker.decrement_load();
                            RouterMetrics::set_running_requests(worker_url, worker.load());
510
511
512
                        }
                    }

513
                    let error_msg = format!("Failed to get response body: {}", e);
514
                    (StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response()
515
516
517
518
                }
            };

            // Decrement load counter for non-streaming requests if it was incremented
519
            if load_incremented {
520
521
522
                if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                    worker.decrement_load();
                    RouterMetrics::set_running_requests(worker_url, worker.load());
523
524
525
526
527
528
                }
            }

            response
        } else if load_incremented {
            // For streaming with load tracking, we need to manually decrement when done
529
            let registry = Arc::clone(&self.worker_registry);
530
531
            let worker_url = worker_url.to_string();

532
533
534
535
536
            // Preserve headers for streaming response
            let mut response_headers = header_utils::preserve_response_headers(res.headers());
            // Ensure we set the correct content-type for SSE
            response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));

537
538
539
540
541
542
            let stream = res.bytes_stream();
            let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

            // Spawn task to forward stream and detect completion
            tokio::spawn(async move {
                let mut stream = stream;
543
                let mut decremented = false;
544
545
546
547
548
549
550
551
552
                while let Some(chunk) = stream.next().await {
                    match chunk {
                        Ok(bytes) => {
                            // Check for stream end marker
                            if bytes
                                .as_ref()
                                .windows(12)
                                .any(|window| window == b"data: [DONE]")
                            {
553
554
555
556
                                if let Some(worker) = registry.get_by_url(&worker_url) {
                                    worker.decrement_load();
                                    RouterMetrics::set_running_requests(&worker_url, worker.load());
                                    decremented = true;
557
558
                                }
                            }
559
560
561
562
563
564
565
566
567
568
                            if tx.send(Ok(bytes)).is_err() {
                                break;
                            }
                        }
                        Err(e) => {
                            let _ = tx.send(Err(format!("Stream error: {}", e)));
                            break;
                        }
                    }
                }
569
                if !decremented {
570
571
572
                    if let Some(worker) = registry.get_by_url(&worker_url) {
                        worker.decrement_load();
                        RouterMetrics::set_running_requests(&worker_url, worker.load());
573
574
                    }
                }
575
576
577
578
579
580
581
            });

            let stream = UnboundedReceiverStream::new(rx);
            let body = Body::from_stream(stream);

            let mut response = Response::new(body);
            *response.status_mut() = status;
582
            *response.headers_mut() = response_headers;
583
            response
584
585
        } else {
            // For requests without load tracking, just stream
586
587
588
589
590
            // Preserve headers for streaming response
            let mut response_headers = header_utils::preserve_response_headers(res.headers());
            // Ensure we set the correct content-type for SSE
            response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));

591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
            let stream = res.bytes_stream();
            let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

            // Spawn task to forward stream
            tokio::spawn(async move {
                let mut stream = stream;
                while let Some(chunk) = stream.next().await {
                    match chunk {
                        Ok(bytes) => {
                            if tx.send(Ok(bytes)).is_err() {
                                break;
                            }
                        }
                        Err(e) => {
                            let _ = tx.send(Err(format!("Stream error: {}", e)));
                            break;
                        }
                    }
                }
            });

            let stream = UnboundedReceiverStream::new(rx);
            let body = Body::from_stream(stream);

            let mut response = Response::new(body);
            *response.status_mut() = status;
617
            *response.headers_mut() = response_headers;
618
            response
619
620
621
        }
    }

622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
    async fn build_rerank_response(
        req: &RerankRequest,
        response: Response,
    ) -> anyhow::Result<Response> {
        let (_, response_body) = response.into_parts();
        let body_bytes = to_bytes(response_body, usize::MAX).await?;
        let rerank_results = serde_json::from_slice::<Vec<RerankResult>>(&body_bytes)?;
        let mut rerank_response =
            RerankResponse::new(rerank_results, req.model.clone(), req.rid.clone());
        rerank_response.sort_by_score();
        if let Some(top_k) = req.top_k {
            rerank_response.apply_top_k(top_k);
        }
        if !req.return_documents {
            rerank_response.drop_documents();
        }
        Ok(Json(rerank_response).into_response())
    }
640
641
642
643
}

use async_trait::async_trait;

644
#[async_trait]
645
646
647
648
649
impl RouterTrait for Router {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

650
651
    async fn health_generate(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "health_generate").await
652
653
    }

654
655
    async fn get_server_info(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "get_server_info").await
656
657
    }

658
659
    async fn get_models(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "v1/models").await
660
661
    }

662
663
    async fn get_model_info(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "get_model_info").await
664
665
666
667
    }

    async fn route_generate(
        &self,
668
669
        headers: Option<&HeaderMap>,
        body: &GenerateRequest,
670
        model_id: Option<&str>,
671
    ) -> Response {
672
673
        self.route_typed_request(headers, body, "/generate", model_id)
            .await
674
675
676
677
    }

    async fn route_chat(
        &self,
678
679
        headers: Option<&HeaderMap>,
        body: &ChatCompletionRequest,
680
        model_id: Option<&str>,
681
    ) -> Response {
682
        self.route_typed_request(headers, body, "/v1/chat/completions", model_id)
683
            .await
684
685
686
687
    }

    async fn route_completion(
        &self,
688
689
        headers: Option<&HeaderMap>,
        body: &CompletionRequest,
690
        model_id: Option<&str>,
691
    ) -> Response {
692
        self.route_typed_request(headers, body, "/v1/completions", model_id)
693
            .await
694
695
    }

696
697
698
699
    async fn route_responses(
        &self,
        headers: Option<&HeaderMap>,
        body: &ResponsesRequest,
700
        model_id: Option<&str>,
701
    ) -> Response {
702
        self.route_typed_request(headers, body, "/v1/responses", model_id)
703
704
705
            .await
    }

706
707
708
709
710
711
    async fn get_response(
        &self,
        headers: Option<&HeaderMap>,
        response_id: &str,
        _params: &ResponsesGetParams,
    ) -> Response {
712
713
714
715
716
717
718
719
720
        let endpoint = format!("v1/responses/{}", response_id);
        self.route_get_request(headers, &endpoint).await
    }

    async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
        let endpoint = format!("v1/responses/{}/cancel", response_id);
        self.route_post_empty_request(headers, &endpoint).await
    }

721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
    async fn route_embeddings(
        &self,
        headers: Option<&HeaderMap>,
        body: &EmbeddingRequest,
        model_id: Option<&str>,
    ) -> Response {
        // Record embeddings-specific metrics in addition to general request metrics
        let start = Instant::now();
        let res = self
            .route_typed_request(headers, body, "/v1/embeddings", model_id)
            .await;

        // Embedding specific metrics
        if res.status().is_success() {
            RouterMetrics::record_embeddings_request();
            RouterMetrics::record_embeddings_duration(start.elapsed());
        } else {
            let error_type = format!("http_{}", res.status().as_u16());
            RouterMetrics::record_embeddings_error(&error_type);
        }

        res
743
744
    }

745
746
747
748
749
750
    async fn route_rerank(
        &self,
        headers: Option<&HeaderMap>,
        body: &RerankRequest,
        model_id: Option<&str>,
    ) -> Response {
751
752
753
        if let Err(e) = body.validate() {
            return (StatusCode::BAD_REQUEST, e).into_response();
        }
754
755
756
        let response = self
            .route_typed_request(headers, body, "/v1/rerank", model_id)
            .await;
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
        if response.status().is_success() {
            match Self::build_rerank_response(body, response).await {
                Ok(rerank_response) => rerank_response,
                Err(e) => {
                    error!("Failed to build rerank response: {}", e);
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        "Failed to build rerank response".to_string(),
                    )
                        .into_response();
                }
            }
        } else {
            response
        }
772
773
    }

774
775
776
777
778
779
780
781
    fn router_type(&self) -> &'static str {
        "regular"
    }
}

#[cfg(test)]
mod tests {
    use super::*;
782
    use crate::core::BasicWorkerBuilder;
783
784

    fn create_test_regular_router() -> Router {
785
786
787
788
789
790
791
        // Create registries
        let worker_registry = Arc::new(WorkerRegistry::new());
        let policy_registry = Arc::new(PolicyRegistry::new(
            crate::config::types::PolicyConfig::RoundRobin,
        ));

        // Register test workers
792
793
794
795
796
797
        let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
            .worker_type(WorkerType::Regular)
            .build();
        let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
            .worker_type(WorkerType::Regular)
            .build();
798
799
800
        worker_registry.register(Arc::new(worker1));
        worker_registry.register(Arc::new(worker2));

801
        Router {
802
803
            worker_registry,
            policy_registry,
804
            dp_aware: false,
805
            client: Client::new(),
806
            retry_config: RetryConfig::default(),
807
            enable_igw: false,
808
809
810
        }
    }

811
812
813
814
815
816
817
    fn create_test_unhealthy_router() -> Router {
        let router = create_test_regular_router();
        let workers = router.worker_registry.get_all();
        workers[0].set_healthy(false);
        router
    }

818
819
820
    #[test]
    fn test_router_get_worker_urls_regular() {
        let router = create_test_regular_router();
821
822
        let workers = router.worker_registry.get_all();
        let urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
823
824
825
826
827
828
829
830
831
832
833
834

        assert_eq!(urls.len(), 2);
        assert!(urls.contains(&"http://worker1:8080".to_string()));
        assert!(urls.contains(&"http://worker2:8080".to_string()));
    }

    #[test]
    fn test_select_first_worker_regular() {
        let router = create_test_regular_router();
        let result = router.select_first_worker();

        assert!(result.is_ok());
835
836
837
        let url = result.unwrap();
        // DashMap doesn't guarantee order, so just check we get one of the workers
        assert!(url == "http://worker1:8080" || url == "http://worker2:8080");
838
    }
839
840
841
842
843
844
845
846
847
848
849
850

    #[test]
    fn test_select_first_worker_with_unhealthy_worker() {
        let router = create_test_unhealthy_router();
        let result = router.select_first_worker();

        assert!(result.is_ok());
        let url = result.unwrap();

        let worker = router.worker_registry.get_by_url(&url).unwrap();
        assert!(worker.is_healthy());
    }
851
}