router.rs 31.1 KB
Newer Older
1
use crate::config::types::RetryConfig;
2
use crate::core::{
3
    is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType,
4
};
5
use crate::metrics::RouterMetrics;
6
use crate::policies::PolicyRegistry;
7
8
9
10
11
12
13
use crate::protocols::chat::ChatCompletionRequest;
use crate::protocols::common::GenerationRequest;
use crate::protocols::completion::CompletionRequest;
use crate::protocols::embedding::EmbeddingRequest;
use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::{RerankRequest, RerankResponse, RerankResult};
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
14
use crate::routers::header_utils;
15
use crate::routers::RouterTrait;
16
use axum::body::to_bytes;
17
18
19
use axum::{
    body::Body,
    extract::Request,
20
21
22
    http::{
        header::CONTENT_LENGTH, header::CONTENT_TYPE, HeaderMap, HeaderValue, Method, StatusCode,
    },
23
24
25
26
    response::{IntoResponse, Response},
    Json,
};
use futures_util::StreamExt;
27
use reqwest::Client;
28
use std::sync::Arc;
29
use std::time::Instant;
30
use tokio_stream::wrappers::UnboundedReceiverStream;
31
use tracing::{debug, error};
32
33
34
35

/// Regular router that uses injected load balancing policies
#[derive(Debug)]
pub struct Router {
36
37
    worker_registry: Arc<WorkerRegistry>,
    policy_registry: Arc<PolicyRegistry>,
38
    client: Client,
39
    dp_aware: bool,
40
    enable_igw: bool,
41
    retry_config: RetryConfig,
42
43
44
}

impl Router {
45
    /// Create a new router with injected policy and client
46
47
48
49
50
51
52
    pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
        let workers = ctx.worker_registry.get_workers_filtered(
            None, // any model
            Some(WorkerType::Regular),
            Some(ConnectionMode::Http),
            false, // include all workers
        );
53

54
        RouterMetrics::set_active_workers(workers.len());
55
56

        Ok(Router {
57
58
            worker_registry: ctx.worker_registry.clone(),
            policy_registry: ctx.policy_registry.clone(),
59
60
            client: ctx.client.clone(),
            dp_aware: ctx.router_config.dp_aware,
61
            enable_igw: ctx.router_config.enable_igw,
62
            retry_config: ctx.router_config.effective_retry_config(),
63
64
65
66
        })
    }

    fn select_first_worker(&self) -> Result<String, String> {
67
        let workers = self.worker_registry.get_all();
68
69
        let healthy_workers: Vec<_> = workers.iter().filter(|w| w.is_healthy()).collect();
        if healthy_workers.is_empty() {
70
71
            Err("No workers are available".to_string())
        } else {
72
            Ok(healthy_workers[0].url().to_string())
73
74
75
        }
    }

76
    // Helper method to proxy GET requests to the first available worker
77
    async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response {
78
        let headers = header_utils::copy_request_headers(&req);
79
80
81

        match self.select_first_worker() {
            Ok(worker_url) => {
82
                let mut request_builder = self.client.get(format!("{}/{}", worker_url, endpoint));
83
                for (name, value) in headers {
84
85
                    let name_lc = name.to_lowercase();
                    if name_lc != "content-type" && name_lc != "content-length" {
86
87
88
                        request_builder = request_builder.header(name, value);
                    }
                }
89

90
91
92
93
                match request_builder.send().await {
                    Ok(res) => {
                        let status = StatusCode::from_u16(res.status().as_u16())
                            .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
94
95
96
97
98

                        // Preserve headers from backend
                        let response_headers =
                            header_utils::preserve_response_headers(res.headers());

99
                        match res.bytes().await {
100
                            Ok(body) => {
101
                                let mut response = Response::new(Body::from(body));
102
103
104
105
                                *response.status_mut() = status;
                                *response.headers_mut() = response_headers;
                                response
                            }
106
107
108
109
110
                            Err(e) => (
                                StatusCode::INTERNAL_SERVER_ERROR,
                                format!("Failed to read response: {}", e),
                            )
                                .into_response(),
111
112
                        }
                    }
113
114
115
116
117
                    Err(e) => (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        format!("Request failed: {}", e),
                    )
                        .into_response(),
118
119
                }
            }
120
            Err(e) => (StatusCode::SERVICE_UNAVAILABLE, e).into_response(),
121
122
123
        }
    }

124
125
126
127
128
129
    /// Select worker for a specific model considering circuit breaker state
    fn select_worker_for_model(
        &self,
        model_id: Option<&str>,
        text: Option<&str>,
    ) -> Option<Arc<dyn Worker>> {
130
131
        let effective_model_id = if !self.enable_igw { None } else { model_id };

132
133
        // Get workers for the specified model O(1), filtered by connection mode
        let workers = self.worker_registry.get_workers_filtered(
134
            effective_model_id,
135
136
137
138
            Some(WorkerType::Regular),
            Some(ConnectionMode::Http),
            false, // get all workers, we'll filter by is_available() next
        );
139
140

        let available: Vec<Arc<dyn Worker>> = workers
141
142
            .iter()
            .filter(|w| w.is_available())
143
            .cloned()
144
145
146
147
            .collect();
        if available.is_empty() {
            return None;
        }
148
149
150
151
152
153
154
155
156

        // Get the appropriate policy for this model
        let policy = match model_id {
            Some(model) => self.policy_registry.get_policy_or_default(model),
            None => self.policy_registry.get_default_policy(),
        };

        let idx = policy.select_worker(&available, text)?;
        Some(available[idx].clone())
157
158
    }

159
    pub async fn route_typed_request<T: GenerationRequest + serde::Serialize + Clone>(
160
        &self,
161
        headers: Option<&HeaderMap>,
162
163
        typed_req: &T,
        route: &str,
164
        model_id: Option<&str>,
165
    ) -> Response {
166
        let start = Instant::now();
167
168
169
170
171
172
173
        let is_stream = typed_req.is_stream();
        let text = typed_req.extract_text_for_routing();

        let response = RetryExecutor::execute_response_with_retry(
            &self.retry_config,
            // operation per attempt
            |_: u32| async {
174
                let worker = match self.select_worker_for_model(model_id, Some(&text)) {
175
176
177
178
179
180
181
182
183
184
                    Some(w) => w,
                    None => {
                        RouterMetrics::record_request_error(route, "no_available_workers");
                        return (
                            StatusCode::SERVICE_UNAVAILABLE,
                            "No available workers (all circuits open or unhealthy)",
                        )
                            .into_response();
                    }
                };
185

186
                // Optional load tracking for cache-aware policy
187
188
189
190
191
192
193
                // Get the policy for this model to check if it's cache-aware
                let policy = match model_id {
                    Some(model) => self.policy_registry.get_policy_or_default(model),
                    None => self.policy_registry.get_default_policy(),
                };

                let load_incremented = if policy.name() == "cache_aware" {
194
195
196
                    worker.increment_load();
                    RouterMetrics::set_running_requests(worker.url(), worker.load());
                    true
197
198
199
200
                } else {
                    false
                };

201
202
                // Keep a clone for potential cleanup on retry
                let worker_for_cleanup = if load_incremented {
203
                    Some(worker.clone())
204
205
206
207
                } else {
                    None
                };

208
209
                let response = self
                    .send_typed_request(
210
                        headers,
211
212
                        typed_req,
                        route,
213
                        worker.url(),
214
215
216
217
218
                        is_stream,
                        load_incremented,
                    )
                    .await;

219
                worker.record_outcome(response.status().is_success());
220
221
222
223
224
225
226
227
228
229
230
231
232

                // For retryable failures, we need to decrement load since send_typed_request
                // won't have done it (it only decrements on success or non-retryable failures)
                if is_retryable_status(response.status()) && load_incremented {
                    if let Some(cleanup_worker) = worker_for_cleanup {
                        cleanup_worker.decrement_load();
                        RouterMetrics::set_running_requests(
                            cleanup_worker.url(),
                            cleanup_worker.load(),
                        );
                    }
                }

233
234
235
                response
            },
            // should_retry predicate
236
            |res, _attempt| is_retryable_status(res.status()),
237
238
239
240
241
242
243
            // on_backoff hook
            |delay, attempt| {
                RouterMetrics::record_retry(route);
                RouterMetrics::record_retry_backoff_duration(delay, attempt);
            },
            // on_exhausted hook
            || RouterMetrics::record_retries_exhausted(route),
244
        )
245
246
247
248
249
250
        .await;

        if response.status().is_success() {
            let duration = start.elapsed();
            RouterMetrics::record_request(route);
            RouterMetrics::record_generate_duration(duration);
251
        } else if !is_retryable_status(response.status()) {
252
            RouterMetrics::record_request_error(route, "non_retryable_error");
253
        }
254
255

        response
256
257
    }

258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
    // Helper: return base worker URL (strips DP suffix when enabled)
    fn worker_base_url(&self, worker_url: &str) -> String {
        if self.dp_aware {
            if let Ok((prefix, _)) = Self::extract_dp_rank(worker_url) {
                return prefix.to_string();
            }
        }
        worker_url.to_string()
    }

    // Generic simple routing for GET/POST without JSON body
    async fn route_simple_request(
        &self,
        headers: Option<&HeaderMap>,
        endpoint: &str,
        method: Method,
    ) -> Response {
        // TODO: currently the sglang worker is using in-memory state management, so this implementation has to fan out to all workers.
        // Eventually, we need to have router to manage the chat history with a proper database, will update this implementation accordingly.
277
278
        let workers = self.worker_registry.get_all();
        if workers.is_empty() {
279
280
281
282
            return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response();
        }

        let mut last_response: Option<Response> = None;
283
284
285
        for worker in workers {
            let worker_url = worker.url();
            let base = self.worker_base_url(worker_url);
286
287
288
289
290
291
292
293
294
295
296
297
298
299

            let url = format!("{}/{}", base, endpoint);
            let mut request_builder = match method {
                Method::GET => self.client.get(url),
                Method::POST => self.client.post(url),
                _ => {
                    return (
                        StatusCode::METHOD_NOT_ALLOWED,
                        "Unsupported method for simple routing",
                    )
                        .into_response()
                }
            };

300
301
302
303
304
            if let Some(api_key) = worker.api_key() {
                request_builder =
                    request_builder.header("Authorization", format!("Bearer {}", api_key));
            }

305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
            if let Some(hdrs) = headers {
                for (name, value) in hdrs {
                    let name_lc = name.as_str().to_lowercase();
                    if name_lc != "content-type" && name_lc != "content-length" {
                        request_builder = request_builder.header(name, value);
                    }
                }
            }

            match request_builder.send().await {
                Ok(res) => {
                    let status = StatusCode::from_u16(res.status().as_u16())
                        .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
                    let response_headers = header_utils::preserve_response_headers(res.headers());
                    match res.bytes().await {
                        Ok(body) => {
321
                            let mut response = Response::new(Body::from(body));
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
                            *response.status_mut() = status;
                            *response.headers_mut() = response_headers;
                            if status.is_success() {
                                return response;
                            }
                            last_response = Some(response);
                        }
                        Err(e) => {
                            last_response = Some(
                                (
                                    StatusCode::INTERNAL_SERVER_ERROR,
                                    format!("Failed to read response: {}", e),
                                )
                                    .into_response(),
                            );
                        }
                    }
                }
                Err(e) => {
                    last_response = Some(
                        (
                            StatusCode::INTERNAL_SERVER_ERROR,
                            format!("Request failed: {}", e),
                        )
                            .into_response(),
                    );
                }
            }
        }

        last_response
            .unwrap_or_else(|| (StatusCode::BAD_GATEWAY, "No worker response").into_response())
    }

    // Route a GET request with provided headers to a specific endpoint
    async fn route_get_request(&self, headers: Option<&HeaderMap>, endpoint: &str) -> Response {
        self.route_simple_request(headers, endpoint, Method::GET)
            .await
    }

    // Route a POST request with empty body to a specific endpoint
    async fn route_post_empty_request(
        &self,
        headers: Option<&HeaderMap>,
        endpoint: &str,
    ) -> Response {
        self.route_simple_request(headers, endpoint, Method::POST)
            .await
    }

372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
    // TODO (rui): Better accommodate to the Worker abstraction
    fn extract_dp_rank(worker_url: &str) -> Result<(&str, usize), String> {
        let parts: Vec<&str> = worker_url.split('@').collect();
        if parts.len() != 2 {
            return Err(format!("invalid worker_url format: {}", worker_url));
        }

        // Parse the second part (dp_rank) into an integer
        match parts[1].parse::<usize>() {
            Ok(dp_rank) => Ok((parts[0], dp_rank)),
            Err(_) => Err(format!(
                "failed to parse dp_rank from worker_url: {}",
                worker_url
            )),
        }
    }

389
390
391
    // Send typed request directly without conversion
    async fn send_typed_request<T: serde::Serialize>(
        &self,
392
        headers: Option<&HeaderMap>,
393
394
395
396
397
        typed_req: &T,
        route: &str,
        worker_url: &str,
        is_stream: bool,
        load_incremented: bool, // Whether load was incremented for this request
398
    ) -> Response {
399
400
401
402
403
404
        // Get the worker's API key if available
        let api_key = self
            .worker_registry
            .get_by_url(worker_url)
            .and_then(|w| w.api_key().clone());

405
406
407
408
409
        let mut request_builder = if self.dp_aware {
            let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) {
                Ok(tup) => tup,
                Err(e) => {
                    error!("Failed to extract dp_rank: {}", e);
410
411
412
413
414
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        format!("Failed to extract dp_rank: {}", e),
                    )
                        .into_response();
415
416
417
418
419
420
                }
            };

            let mut json_val = match serde_json::to_value(typed_req) {
                Ok(j) => j,
                Err(e) => {
421
422
423
424
425
                    return (
                        StatusCode::BAD_REQUEST,
                        format!("Convert into serde_json::Value failed: {}", e),
                    )
                        .into_response();
426
427
428
429
430
431
432
433
434
435
436
437
438
                }
            };

            if let Some(map) = json_val.as_object_mut() {
                map.insert(
                    String::from("data_parallel_rank"),
                    serde_json::json!(dp_rank),
                );
                debug!(
                    "Modified request body: {}",
                    serde_json::to_string(&json_val).unwrap_or(String::from("ERR"))
                );
            } else {
439
440
441
442
443
                return (
                    StatusCode::BAD_REQUEST,
                    "Failed to insert the data_parallel_rank field into the request body",
                )
                    .into_response();
444
445
            }

446
            self.client
447
448
449
                .post(format!("{}{}", worker_url_prefix, route))
                .json(&json_val)
        } else {
450
            self.client
451
452
453
                .post(format!("{}{}", worker_url, route))
                .json(typed_req) // Use json() directly with typed request
        };
454

455
456
457
458
        if let Some(key) = api_key {
            request_builder = request_builder.header("Authorization", format!("Bearer {}", key));
        }

459
460
461
462
        // Copy all headers from original request if provided
        if let Some(headers) = headers {
            for (name, value) in headers {
                // Skip Content-Type and Content-Length as .json() sets them
463
                if *name != CONTENT_TYPE && *name != CONTENT_LENGTH {
464
465
                    request_builder = request_builder.header(name, value);
                }
466
467
468
469
470
471
            }
        }

        let res = match request_builder.send().await {
            Ok(res) => res,
            Err(e) => {
472
473
474
475
                error!(
                    "Failed to send typed request worker_url={} route={} error={}",
                    worker_url, route, e
                );
476
477
478

                // Decrement load on error if it was incremented
                if load_incremented {
479
480
481
                    if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                        worker.decrement_load();
                        RouterMetrics::set_running_requests(worker_url, worker.load());
482
483
484
                    }
                }

485
486
487
488
489
                return (
                    StatusCode::INTERNAL_SERVER_ERROR,
                    format!("Request failed: {}", e),
                )
                    .into_response();
490
491
492
            }
        };

493
494
        let status = StatusCode::from_u16(res.status().as_u16())
            .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
495
496

        if !is_stream {
497
            // For non-streaming requests, preserve headers
498
            let response_headers = header_utils::preserve_response_headers(res.headers());
499

500
            let response = match res.bytes().await {
501
                Ok(body) => {
502
                    let mut response = Response::new(Body::from(body));
503
504
505
506
                    *response.status_mut() = status;
                    *response.headers_mut() = response_headers;
                    response
                }
507
                Err(e) => {
508
509
                    // IMPORTANT: Decrement load on error before returning
                    if load_incremented {
510
511
512
                        if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                            worker.decrement_load();
                            RouterMetrics::set_running_requests(worker_url, worker.load());
513
514
515
                        }
                    }

516
                    let error_msg = format!("Failed to get response body: {}", e);
517
                    (StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response()
518
519
520
521
                }
            };

            // Decrement load counter for non-streaming requests if it was incremented
522
            if load_incremented {
523
524
525
                if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                    worker.decrement_load();
                    RouterMetrics::set_running_requests(worker_url, worker.load());
526
527
528
529
530
531
                }
            }

            response
        } else if load_incremented {
            // For streaming with load tracking, we need to manually decrement when done
532
            let registry = Arc::clone(&self.worker_registry);
533
534
            let worker_url = worker_url.to_string();

535
536
537
538
539
            // Preserve headers for streaming response
            let mut response_headers = header_utils::preserve_response_headers(res.headers());
            // Ensure we set the correct content-type for SSE
            response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));

540
541
542
543
544
545
            let stream = res.bytes_stream();
            let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

            // Spawn task to forward stream and detect completion
            tokio::spawn(async move {
                let mut stream = stream;
546
                let mut decremented = false;
547
548
549
550
551
552
553
554
555
                while let Some(chunk) = stream.next().await {
                    match chunk {
                        Ok(bytes) => {
                            // Check for stream end marker
                            if bytes
                                .as_ref()
                                .windows(12)
                                .any(|window| window == b"data: [DONE]")
                            {
556
557
558
559
                                if let Some(worker) = registry.get_by_url(&worker_url) {
                                    worker.decrement_load();
                                    RouterMetrics::set_running_requests(&worker_url, worker.load());
                                    decremented = true;
560
561
                                }
                            }
562
563
564
565
566
567
568
569
570
571
                            if tx.send(Ok(bytes)).is_err() {
                                break;
                            }
                        }
                        Err(e) => {
                            let _ = tx.send(Err(format!("Stream error: {}", e)));
                            break;
                        }
                    }
                }
572
                if !decremented {
573
574
575
                    if let Some(worker) = registry.get_by_url(&worker_url) {
                        worker.decrement_load();
                        RouterMetrics::set_running_requests(&worker_url, worker.load());
576
577
                    }
                }
578
579
580
581
582
583
584
            });

            let stream = UnboundedReceiverStream::new(rx);
            let body = Body::from_stream(stream);

            let mut response = Response::new(body);
            *response.status_mut() = status;
585
            *response.headers_mut() = response_headers;
586
            response
587
588
        } else {
            // For requests without load tracking, just stream
589
590
591
592
593
            // Preserve headers for streaming response
            let mut response_headers = header_utils::preserve_response_headers(res.headers());
            // Ensure we set the correct content-type for SSE
            response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));

594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
            let stream = res.bytes_stream();
            let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

            // Spawn task to forward stream
            tokio::spawn(async move {
                let mut stream = stream;
                while let Some(chunk) = stream.next().await {
                    match chunk {
                        Ok(bytes) => {
                            if tx.send(Ok(bytes)).is_err() {
                                break;
                            }
                        }
                        Err(e) => {
                            let _ = tx.send(Err(format!("Stream error: {}", e)));
                            break;
                        }
                    }
                }
            });

            let stream = UnboundedReceiverStream::new(rx);
            let body = Body::from_stream(stream);

            let mut response = Response::new(body);
            *response.status_mut() = status;
620
            *response.headers_mut() = response_headers;
621
            response
622
623
624
        }
    }

625
626
627
628
629
630
631
632
633
    async fn build_rerank_response(
        req: &RerankRequest,
        response: Response,
    ) -> anyhow::Result<Response> {
        let (_, response_body) = response.into_parts();
        let body_bytes = to_bytes(response_body, usize::MAX).await?;
        let rerank_results = serde_json::from_slice::<Vec<RerankResult>>(&body_bytes)?;
        let mut rerank_response =
            RerankResponse::new(rerank_results, req.model.clone(), req.rid.clone());
634
        // Sorting is handled by Python worker (serving_rerank.py)
635
636
637
638
639
640
641
642
        if let Some(top_k) = req.top_k {
            rerank_response.apply_top_k(top_k);
        }
        if !req.return_documents {
            rerank_response.drop_documents();
        }
        Ok(Json(rerank_response).into_response())
    }
643
644
645
646
}

use async_trait::async_trait;

647
#[async_trait]
648
649
650
651
652
impl RouterTrait for Router {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

653
654
    async fn health_generate(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "health_generate").await
655
656
    }

657
658
    async fn get_server_info(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "get_server_info").await
659
660
    }

661
662
    async fn get_models(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "v1/models").await
663
664
    }

665
666
    async fn get_model_info(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "get_model_info").await
667
668
669
670
    }

    async fn route_generate(
        &self,
671
672
        headers: Option<&HeaderMap>,
        body: &GenerateRequest,
673
        model_id: Option<&str>,
674
    ) -> Response {
675
676
        self.route_typed_request(headers, body, "/generate", model_id)
            .await
677
678
679
680
    }

    async fn route_chat(
        &self,
681
682
        headers: Option<&HeaderMap>,
        body: &ChatCompletionRequest,
683
        model_id: Option<&str>,
684
    ) -> Response {
685
        self.route_typed_request(headers, body, "/v1/chat/completions", model_id)
686
            .await
687
688
689
690
    }

    async fn route_completion(
        &self,
691
692
        headers: Option<&HeaderMap>,
        body: &CompletionRequest,
693
        model_id: Option<&str>,
694
    ) -> Response {
695
        self.route_typed_request(headers, body, "/v1/completions", model_id)
696
            .await
697
698
    }

699
700
701
702
    async fn route_responses(
        &self,
        headers: Option<&HeaderMap>,
        body: &ResponsesRequest,
703
        model_id: Option<&str>,
704
    ) -> Response {
705
        self.route_typed_request(headers, body, "/v1/responses", model_id)
706
707
708
            .await
    }

709
710
711
712
713
714
    async fn get_response(
        &self,
        headers: Option<&HeaderMap>,
        response_id: &str,
        _params: &ResponsesGetParams,
    ) -> Response {
715
716
717
718
719
720
721
722
723
        let endpoint = format!("v1/responses/{}", response_id);
        self.route_get_request(headers, &endpoint).await
    }

    async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
        let endpoint = format!("v1/responses/{}/cancel", response_id);
        self.route_post_empty_request(headers, &endpoint).await
    }

724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
    async fn route_embeddings(
        &self,
        headers: Option<&HeaderMap>,
        body: &EmbeddingRequest,
        model_id: Option<&str>,
    ) -> Response {
        // Record embeddings-specific metrics in addition to general request metrics
        let start = Instant::now();
        let res = self
            .route_typed_request(headers, body, "/v1/embeddings", model_id)
            .await;

        // Embedding specific metrics
        if res.status().is_success() {
            RouterMetrics::record_embeddings_request();
            RouterMetrics::record_embeddings_duration(start.elapsed());
        } else {
            let error_type = format!("http_{}", res.status().as_u16());
            RouterMetrics::record_embeddings_error(&error_type);
        }

        res
746
747
    }

748
749
750
751
752
753
754
755
756
    async fn route_rerank(
        &self,
        headers: Option<&HeaderMap>,
        body: &RerankRequest,
        model_id: Option<&str>,
    ) -> Response {
        let response = self
            .route_typed_request(headers, body, "/v1/rerank", model_id)
            .await;
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
        if response.status().is_success() {
            match Self::build_rerank_response(body, response).await {
                Ok(rerank_response) => rerank_response,
                Err(e) => {
                    error!("Failed to build rerank response: {}", e);
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        "Failed to build rerank response".to_string(),
                    )
                        .into_response();
                }
            }
        } else {
            response
        }
772
773
    }

774
775
776
777
778
779
780
781
    fn router_type(&self) -> &'static str {
        "regular"
    }
}

#[cfg(test)]
mod tests {
    use super::*;
782
    use crate::core::BasicWorkerBuilder;
783
784

    fn create_test_regular_router() -> Router {
785
786
787
788
789
790
791
        // Create registries
        let worker_registry = Arc::new(WorkerRegistry::new());
        let policy_registry = Arc::new(PolicyRegistry::new(
            crate::config::types::PolicyConfig::RoundRobin,
        ));

        // Register test workers
792
793
794
795
796
797
        let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
            .worker_type(WorkerType::Regular)
            .build();
        let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
            .worker_type(WorkerType::Regular)
            .build();
798
799
800
        worker_registry.register(Arc::new(worker1));
        worker_registry.register(Arc::new(worker2));

801
        Router {
802
803
            worker_registry,
            policy_registry,
804
            dp_aware: false,
805
            client: Client::new(),
806
            retry_config: RetryConfig::default(),
807
            enable_igw: false,
808
809
810
        }
    }

811
812
813
814
815
816
817
    fn create_test_unhealthy_router() -> Router {
        let router = create_test_regular_router();
        let workers = router.worker_registry.get_all();
        workers[0].set_healthy(false);
        router
    }

818
819
820
    #[test]
    fn test_router_get_worker_urls_regular() {
        let router = create_test_regular_router();
821
822
        let workers = router.worker_registry.get_all();
        let urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
823
824
825
826
827
828
829
830
831
832
833
834

        assert_eq!(urls.len(), 2);
        assert!(urls.contains(&"http://worker1:8080".to_string()));
        assert!(urls.contains(&"http://worker2:8080".to_string()));
    }

    #[test]
    fn test_select_first_worker_regular() {
        let router = create_test_regular_router();
        let result = router.select_first_worker();

        assert!(result.is_ok());
835
836
837
        let url = result.unwrap();
        // DashMap doesn't guarantee order, so just check we get one of the workers
        assert!(url == "http://worker1:8080" || url == "http://worker2:8080");
838
    }
839
840
841
842
843
844
845
846
847
848
849
850

    #[test]
    fn test_select_first_worker_with_unhealthy_worker() {
        let router = create_test_unhealthy_router();
        let result = router.select_first_worker();

        assert!(result.is_ok());
        let url = result.unwrap();

        let worker = router.worker_registry.get_by_url(&url).unwrap();
        assert!(worker.is_healthy());
    }
851
}