router.rs 31 KB
Newer Older
1
2
use std::{sync::Arc, time::Instant};

3
use axum::{
4
    body::{to_bytes, Body},
5
    extract::Request,
6
    http::{
7
8
        header::{CONTENT_LENGTH, CONTENT_TYPE},
        HeaderMap, HeaderValue, Method, StatusCode,
9
    },
10
11
12
13
    response::{IntoResponse, Response},
    Json,
};
use futures_util::StreamExt;
14
use reqwest::Client;
15
use tokio_stream::wrappers::UnboundedReceiverStream;
16
use tracing::{debug, error};
17

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
use crate::{
    config::types::RetryConfig,
    core::{
        is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType,
    },
    metrics::RouterMetrics,
    policies::PolicyRegistry,
    protocols::{
        chat::ChatCompletionRequest,
        common::GenerationRequest,
        completion::CompletionRequest,
        embedding::EmbeddingRequest,
        generate::GenerateRequest,
        rerank::{RerankRequest, RerankResponse, RerankResult},
        responses::{ResponsesGetParams, ResponsesRequest},
    },
    routers::{header_utils, RouterTrait},
};

37
38
39
/// Regular router that uses injected load balancing policies
#[derive(Debug)]
pub struct Router {
40
41
    worker_registry: Arc<WorkerRegistry>,
    policy_registry: Arc<PolicyRegistry>,
42
    client: Client,
43
    dp_aware: bool,
44
    enable_igw: bool,
45
    retry_config: RetryConfig,
46
47
48
}

impl Router {
49
    /// Create a new router with injected policy and client
50
51
52
53
54
55
56
    pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
        let workers = ctx.worker_registry.get_workers_filtered(
            None, // any model
            Some(WorkerType::Regular),
            Some(ConnectionMode::Http),
            false, // include all workers
        );
57

58
        RouterMetrics::set_active_workers(workers.len());
59
60

        Ok(Router {
61
62
            worker_registry: ctx.worker_registry.clone(),
            policy_registry: ctx.policy_registry.clone(),
63
64
            client: ctx.client.clone(),
            dp_aware: ctx.router_config.dp_aware,
65
            enable_igw: ctx.router_config.enable_igw,
66
            retry_config: ctx.router_config.effective_retry_config(),
67
68
69
70
        })
    }

    fn select_first_worker(&self) -> Result<String, String> {
71
        let workers = self.worker_registry.get_all();
72
73
        let healthy_workers: Vec<_> = workers.iter().filter(|w| w.is_healthy()).collect();
        if healthy_workers.is_empty() {
74
75
            Err("No workers are available".to_string())
        } else {
76
            Ok(healthy_workers[0].url().to_string())
77
78
79
        }
    }

80
    // Helper method to proxy GET requests to the first available worker
81
    async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response {
82
        let headers = header_utils::copy_request_headers(&req);
83
84
85

        match self.select_first_worker() {
            Ok(worker_url) => {
86
                let mut request_builder = self.client.get(format!("{}/{}", worker_url, endpoint));
87
                for (name, value) in headers {
88
89
                    let name_lc = name.to_lowercase();
                    if name_lc != "content-type" && name_lc != "content-length" {
90
91
92
                        request_builder = request_builder.header(name, value);
                    }
                }
93

94
95
96
97
                match request_builder.send().await {
                    Ok(res) => {
                        let status = StatusCode::from_u16(res.status().as_u16())
                            .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
98
99
100
101
102

                        // Preserve headers from backend
                        let response_headers =
                            header_utils::preserve_response_headers(res.headers());

103
                        match res.bytes().await {
104
                            Ok(body) => {
105
                                let mut response = Response::new(Body::from(body));
106
107
108
109
                                *response.status_mut() = status;
                                *response.headers_mut() = response_headers;
                                response
                            }
110
111
112
113
114
                            Err(e) => (
                                StatusCode::INTERNAL_SERVER_ERROR,
                                format!("Failed to read response: {}", e),
                            )
                                .into_response(),
115
116
                        }
                    }
117
118
119
120
121
                    Err(e) => (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        format!("Request failed: {}", e),
                    )
                        .into_response(),
122
123
                }
            }
124
            Err(e) => (StatusCode::SERVICE_UNAVAILABLE, e).into_response(),
125
126
127
        }
    }

128
129
130
131
132
133
    /// Select worker for a specific model considering circuit breaker state
    fn select_worker_for_model(
        &self,
        model_id: Option<&str>,
        text: Option<&str>,
    ) -> Option<Arc<dyn Worker>> {
134
135
        let effective_model_id = if !self.enable_igw { None } else { model_id };

136
137
        // Get workers for the specified model O(1), filtered by connection mode
        let workers = self.worker_registry.get_workers_filtered(
138
            effective_model_id,
139
140
141
142
            Some(WorkerType::Regular),
            Some(ConnectionMode::Http),
            false, // get all workers, we'll filter by is_available() next
        );
143
144

        let available: Vec<Arc<dyn Worker>> = workers
145
146
            .iter()
            .filter(|w| w.is_available())
147
            .cloned()
148
149
150
151
            .collect();
        if available.is_empty() {
            return None;
        }
152
153
154
155
156
157
158
159
160

        // Get the appropriate policy for this model
        let policy = match model_id {
            Some(model) => self.policy_registry.get_policy_or_default(model),
            None => self.policy_registry.get_default_policy(),
        };

        let idx = policy.select_worker(&available, text)?;
        Some(available[idx].clone())
161
162
    }

163
    pub async fn route_typed_request<T: GenerationRequest + serde::Serialize + Clone>(
164
        &self,
165
        headers: Option<&HeaderMap>,
166
167
        typed_req: &T,
        route: &str,
168
        model_id: Option<&str>,
169
    ) -> Response {
170
        let start = Instant::now();
171
172
173
174
175
176
177
        let is_stream = typed_req.is_stream();
        let text = typed_req.extract_text_for_routing();

        let response = RetryExecutor::execute_response_with_retry(
            &self.retry_config,
            // operation per attempt
            |_: u32| async {
178
                let worker = match self.select_worker_for_model(model_id, Some(&text)) {
179
180
181
182
183
184
185
186
187
188
                    Some(w) => w,
                    None => {
                        RouterMetrics::record_request_error(route, "no_available_workers");
                        return (
                            StatusCode::SERVICE_UNAVAILABLE,
                            "No available workers (all circuits open or unhealthy)",
                        )
                            .into_response();
                    }
                };
189

190
                // Optional load tracking for cache-aware policy
191
192
193
194
195
196
197
                // Get the policy for this model to check if it's cache-aware
                let policy = match model_id {
                    Some(model) => self.policy_registry.get_policy_or_default(model),
                    None => self.policy_registry.get_default_policy(),
                };

                let load_incremented = if policy.name() == "cache_aware" {
198
199
200
                    worker.increment_load();
                    RouterMetrics::set_running_requests(worker.url(), worker.load());
                    true
201
202
203
204
                } else {
                    false
                };

205
206
                // Keep a clone for potential cleanup on retry
                let worker_for_cleanup = if load_incremented {
207
                    Some(worker.clone())
208
209
210
211
                } else {
                    None
                };

212
213
                let response = self
                    .send_typed_request(
214
                        headers,
215
216
                        typed_req,
                        route,
217
                        worker.url(),
218
219
220
221
222
                        is_stream,
                        load_incremented,
                    )
                    .await;

223
                worker.record_outcome(response.status().is_success());
224
225
226
227
228
229
230
231
232
233
234
235
236

                // For retryable failures, we need to decrement load since send_typed_request
                // won't have done it (it only decrements on success or non-retryable failures)
                if is_retryable_status(response.status()) && load_incremented {
                    if let Some(cleanup_worker) = worker_for_cleanup {
                        cleanup_worker.decrement_load();
                        RouterMetrics::set_running_requests(
                            cleanup_worker.url(),
                            cleanup_worker.load(),
                        );
                    }
                }

237
238
239
                response
            },
            // should_retry predicate
240
            |res, _attempt| is_retryable_status(res.status()),
241
242
243
244
245
246
247
            // on_backoff hook
            |delay, attempt| {
                RouterMetrics::record_retry(route);
                RouterMetrics::record_retry_backoff_duration(delay, attempt);
            },
            // on_exhausted hook
            || RouterMetrics::record_retries_exhausted(route),
248
        )
249
250
251
252
253
254
        .await;

        if response.status().is_success() {
            let duration = start.elapsed();
            RouterMetrics::record_request(route);
            RouterMetrics::record_generate_duration(duration);
255
        } else if !is_retryable_status(response.status()) {
256
            RouterMetrics::record_request_error(route, "non_retryable_error");
257
        }
258
259

        response
260
261
    }

262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
    // Helper: return base worker URL (strips DP suffix when enabled)
    fn worker_base_url(&self, worker_url: &str) -> String {
        if self.dp_aware {
            if let Ok((prefix, _)) = Self::extract_dp_rank(worker_url) {
                return prefix.to_string();
            }
        }
        worker_url.to_string()
    }

    // Generic simple routing for GET/POST without JSON body
    async fn route_simple_request(
        &self,
        headers: Option<&HeaderMap>,
        endpoint: &str,
        method: Method,
    ) -> Response {
        // TODO: currently the sglang worker is using in-memory state management, so this implementation has to fan out to all workers.
        // Eventually, we need to have router to manage the chat history with a proper database, will update this implementation accordingly.
281
282
        let workers = self.worker_registry.get_all();
        if workers.is_empty() {
283
284
285
286
            return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response();
        }

        let mut last_response: Option<Response> = None;
287
288
289
        for worker in workers {
            let worker_url = worker.url();
            let base = self.worker_base_url(worker_url);
290
291
292
293
294
295
296
297
298
299
300
301
302
303

            let url = format!("{}/{}", base, endpoint);
            let mut request_builder = match method {
                Method::GET => self.client.get(url),
                Method::POST => self.client.post(url),
                _ => {
                    return (
                        StatusCode::METHOD_NOT_ALLOWED,
                        "Unsupported method for simple routing",
                    )
                        .into_response()
                }
            };

304
305
306
307
308
            if let Some(api_key) = worker.api_key() {
                request_builder =
                    request_builder.header("Authorization", format!("Bearer {}", api_key));
            }

309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
            if let Some(hdrs) = headers {
                for (name, value) in hdrs {
                    let name_lc = name.as_str().to_lowercase();
                    if name_lc != "content-type" && name_lc != "content-length" {
                        request_builder = request_builder.header(name, value);
                    }
                }
            }

            match request_builder.send().await {
                Ok(res) => {
                    let status = StatusCode::from_u16(res.status().as_u16())
                        .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
                    let response_headers = header_utils::preserve_response_headers(res.headers());
                    match res.bytes().await {
                        Ok(body) => {
325
                            let mut response = Response::new(Body::from(body));
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
                            *response.status_mut() = status;
                            *response.headers_mut() = response_headers;
                            if status.is_success() {
                                return response;
                            }
                            last_response = Some(response);
                        }
                        Err(e) => {
                            last_response = Some(
                                (
                                    StatusCode::INTERNAL_SERVER_ERROR,
                                    format!("Failed to read response: {}", e),
                                )
                                    .into_response(),
                            );
                        }
                    }
                }
                Err(e) => {
                    last_response = Some(
                        (
                            StatusCode::INTERNAL_SERVER_ERROR,
                            format!("Request failed: {}", e),
                        )
                            .into_response(),
                    );
                }
            }
        }

        last_response
            .unwrap_or_else(|| (StatusCode::BAD_GATEWAY, "No worker response").into_response())
    }

    // Route a GET request with provided headers to a specific endpoint
    async fn route_get_request(&self, headers: Option<&HeaderMap>, endpoint: &str) -> Response {
        self.route_simple_request(headers, endpoint, Method::GET)
            .await
    }

    // Route a POST request with empty body to a specific endpoint
    async fn route_post_empty_request(
        &self,
        headers: Option<&HeaderMap>,
        endpoint: &str,
    ) -> Response {
        self.route_simple_request(headers, endpoint, Method::POST)
            .await
    }

376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
    // TODO (rui): Better accommodate to the Worker abstraction
    fn extract_dp_rank(worker_url: &str) -> Result<(&str, usize), String> {
        let parts: Vec<&str> = worker_url.split('@').collect();
        if parts.len() != 2 {
            return Err(format!("invalid worker_url format: {}", worker_url));
        }

        // Parse the second part (dp_rank) into an integer
        match parts[1].parse::<usize>() {
            Ok(dp_rank) => Ok((parts[0], dp_rank)),
            Err(_) => Err(format!(
                "failed to parse dp_rank from worker_url: {}",
                worker_url
            )),
        }
    }

393
394
395
    // Send typed request directly without conversion
    async fn send_typed_request<T: serde::Serialize>(
        &self,
396
        headers: Option<&HeaderMap>,
397
398
399
400
401
        typed_req: &T,
        route: &str,
        worker_url: &str,
        is_stream: bool,
        load_incremented: bool, // Whether load was incremented for this request
402
    ) -> Response {
403
404
405
406
407
408
        // Get the worker's API key if available
        let api_key = self
            .worker_registry
            .get_by_url(worker_url)
            .and_then(|w| w.api_key().clone());

409
410
411
412
413
        let mut request_builder = if self.dp_aware {
            let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) {
                Ok(tup) => tup,
                Err(e) => {
                    error!("Failed to extract dp_rank: {}", e);
414
415
416
417
418
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        format!("Failed to extract dp_rank: {}", e),
                    )
                        .into_response();
419
420
421
422
423
424
                }
            };

            let mut json_val = match serde_json::to_value(typed_req) {
                Ok(j) => j,
                Err(e) => {
425
426
427
428
429
                    return (
                        StatusCode::BAD_REQUEST,
                        format!("Convert into serde_json::Value failed: {}", e),
                    )
                        .into_response();
430
431
432
433
434
435
436
437
438
439
440
441
442
                }
            };

            if let Some(map) = json_val.as_object_mut() {
                map.insert(
                    String::from("data_parallel_rank"),
                    serde_json::json!(dp_rank),
                );
                debug!(
                    "Modified request body: {}",
                    serde_json::to_string(&json_val).unwrap_or(String::from("ERR"))
                );
            } else {
443
444
445
446
447
                return (
                    StatusCode::BAD_REQUEST,
                    "Failed to insert the data_parallel_rank field into the request body",
                )
                    .into_response();
448
449
            }

450
            self.client
451
452
453
                .post(format!("{}{}", worker_url_prefix, route))
                .json(&json_val)
        } else {
454
            self.client
455
456
457
                .post(format!("{}{}", worker_url, route))
                .json(typed_req) // Use json() directly with typed request
        };
458

459
460
461
462
        if let Some(key) = api_key {
            request_builder = request_builder.header("Authorization", format!("Bearer {}", key));
        }

463
464
465
466
        // Copy all headers from original request if provided
        if let Some(headers) = headers {
            for (name, value) in headers {
                // Skip Content-Type and Content-Length as .json() sets them
467
                if *name != CONTENT_TYPE && *name != CONTENT_LENGTH {
468
469
                    request_builder = request_builder.header(name, value);
                }
470
471
472
473
474
475
            }
        }

        let res = match request_builder.send().await {
            Ok(res) => res,
            Err(e) => {
476
477
478
479
                error!(
                    "Failed to send typed request worker_url={} route={} error={}",
                    worker_url, route, e
                );
480
481
482

                // Decrement load on error if it was incremented
                if load_incremented {
483
484
485
                    if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                        worker.decrement_load();
                        RouterMetrics::set_running_requests(worker_url, worker.load());
486
487
488
                    }
                }

489
490
491
492
493
                return (
                    StatusCode::INTERNAL_SERVER_ERROR,
                    format!("Request failed: {}", e),
                )
                    .into_response();
494
495
496
            }
        };

497
498
        let status = StatusCode::from_u16(res.status().as_u16())
            .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
499
500

        if !is_stream {
501
            // For non-streaming requests, preserve headers
502
            let response_headers = header_utils::preserve_response_headers(res.headers());
503

504
            let response = match res.bytes().await {
505
                Ok(body) => {
506
                    let mut response = Response::new(Body::from(body));
507
508
509
510
                    *response.status_mut() = status;
                    *response.headers_mut() = response_headers;
                    response
                }
511
                Err(e) => {
512
513
                    // IMPORTANT: Decrement load on error before returning
                    if load_incremented {
514
515
516
                        if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                            worker.decrement_load();
                            RouterMetrics::set_running_requests(worker_url, worker.load());
517
518
519
                        }
                    }

520
                    let error_msg = format!("Failed to get response body: {}", e);
521
                    (StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response()
522
523
524
525
                }
            };

            // Decrement load counter for non-streaming requests if it was incremented
526
            if load_incremented {
527
528
529
                if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                    worker.decrement_load();
                    RouterMetrics::set_running_requests(worker_url, worker.load());
530
531
532
533
534
535
                }
            }

            response
        } else if load_incremented {
            // For streaming with load tracking, we need to manually decrement when done
536
            let registry = Arc::clone(&self.worker_registry);
537
538
            let worker_url = worker_url.to_string();

539
540
541
542
543
            // Preserve headers for streaming response
            let mut response_headers = header_utils::preserve_response_headers(res.headers());
            // Ensure we set the correct content-type for SSE
            response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));

544
545
546
547
548
549
            let stream = res.bytes_stream();
            let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

            // Spawn task to forward stream and detect completion
            tokio::spawn(async move {
                let mut stream = stream;
550
                let mut decremented = false;
551
552
553
554
555
556
557
558
559
                while let Some(chunk) = stream.next().await {
                    match chunk {
                        Ok(bytes) => {
                            // Check for stream end marker
                            if bytes
                                .as_ref()
                                .windows(12)
                                .any(|window| window == b"data: [DONE]")
                            {
560
561
562
563
                                if let Some(worker) = registry.get_by_url(&worker_url) {
                                    worker.decrement_load();
                                    RouterMetrics::set_running_requests(&worker_url, worker.load());
                                    decremented = true;
564
565
                                }
                            }
566
567
568
569
570
571
572
573
574
575
                            if tx.send(Ok(bytes)).is_err() {
                                break;
                            }
                        }
                        Err(e) => {
                            let _ = tx.send(Err(format!("Stream error: {}", e)));
                            break;
                        }
                    }
                }
576
                if !decremented {
577
578
579
                    if let Some(worker) = registry.get_by_url(&worker_url) {
                        worker.decrement_load();
                        RouterMetrics::set_running_requests(&worker_url, worker.load());
580
581
                    }
                }
582
583
584
585
586
587
588
            });

            let stream = UnboundedReceiverStream::new(rx);
            let body = Body::from_stream(stream);

            let mut response = Response::new(body);
            *response.status_mut() = status;
589
            *response.headers_mut() = response_headers;
590
            response
591
592
        } else {
            // For requests without load tracking, just stream
593
594
595
596
597
            // Preserve headers for streaming response
            let mut response_headers = header_utils::preserve_response_headers(res.headers());
            // Ensure we set the correct content-type for SSE
            response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));

598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
            let stream = res.bytes_stream();
            let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

            // Spawn task to forward stream
            tokio::spawn(async move {
                let mut stream = stream;
                while let Some(chunk) = stream.next().await {
                    match chunk {
                        Ok(bytes) => {
                            if tx.send(Ok(bytes)).is_err() {
                                break;
                            }
                        }
                        Err(e) => {
                            let _ = tx.send(Err(format!("Stream error: {}", e)));
                            break;
                        }
                    }
                }
            });

            let stream = UnboundedReceiverStream::new(rx);
            let body = Body::from_stream(stream);

            let mut response = Response::new(body);
            *response.status_mut() = status;
624
            *response.headers_mut() = response_headers;
625
            response
626
627
628
        }
    }

629
630
631
632
633
634
635
636
637
    async fn build_rerank_response(
        req: &RerankRequest,
        response: Response,
    ) -> anyhow::Result<Response> {
        let (_, response_body) = response.into_parts();
        let body_bytes = to_bytes(response_body, usize::MAX).await?;
        let rerank_results = serde_json::from_slice::<Vec<RerankResult>>(&body_bytes)?;
        let mut rerank_response =
            RerankResponse::new(rerank_results, req.model.clone(), req.rid.clone());
638
        // Sorting is handled by Python worker (serving_rerank.py)
639
640
641
642
643
644
645
646
        if let Some(top_k) = req.top_k {
            rerank_response.apply_top_k(top_k);
        }
        if !req.return_documents {
            rerank_response.drop_documents();
        }
        Ok(Json(rerank_response).into_response())
    }
647
648
649
650
}

use async_trait::async_trait;

651
#[async_trait]
652
653
654
655
656
impl RouterTrait for Router {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

657
658
    async fn health_generate(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "health_generate").await
659
660
    }

661
662
    async fn get_server_info(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "get_server_info").await
663
664
    }

665
666
    async fn get_models(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "v1/models").await
667
668
    }

669
670
    async fn get_model_info(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "get_model_info").await
671
672
673
674
    }

    async fn route_generate(
        &self,
675
676
        headers: Option<&HeaderMap>,
        body: &GenerateRequest,
677
        model_id: Option<&str>,
678
    ) -> Response {
679
680
        self.route_typed_request(headers, body, "/generate", model_id)
            .await
681
682
683
684
    }

    async fn route_chat(
        &self,
685
686
        headers: Option<&HeaderMap>,
        body: &ChatCompletionRequest,
687
        model_id: Option<&str>,
688
    ) -> Response {
689
        self.route_typed_request(headers, body, "/v1/chat/completions", model_id)
690
            .await
691
692
693
694
    }

    async fn route_completion(
        &self,
695
696
        headers: Option<&HeaderMap>,
        body: &CompletionRequest,
697
        model_id: Option<&str>,
698
    ) -> Response {
699
        self.route_typed_request(headers, body, "/v1/completions", model_id)
700
            .await
701
702
    }

703
704
705
706
    async fn route_responses(
        &self,
        headers: Option<&HeaderMap>,
        body: &ResponsesRequest,
707
        model_id: Option<&str>,
708
    ) -> Response {
709
        self.route_typed_request(headers, body, "/v1/responses", model_id)
710
711
712
            .await
    }

713
714
715
716
717
718
    async fn get_response(
        &self,
        headers: Option<&HeaderMap>,
        response_id: &str,
        _params: &ResponsesGetParams,
    ) -> Response {
719
720
721
722
723
724
725
726
727
        let endpoint = format!("v1/responses/{}", response_id);
        self.route_get_request(headers, &endpoint).await
    }

    async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
        let endpoint = format!("v1/responses/{}/cancel", response_id);
        self.route_post_empty_request(headers, &endpoint).await
    }

728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
    async fn route_embeddings(
        &self,
        headers: Option<&HeaderMap>,
        body: &EmbeddingRequest,
        model_id: Option<&str>,
    ) -> Response {
        // Record embeddings-specific metrics in addition to general request metrics
        let start = Instant::now();
        let res = self
            .route_typed_request(headers, body, "/v1/embeddings", model_id)
            .await;

        // Embedding specific metrics
        if res.status().is_success() {
            RouterMetrics::record_embeddings_request();
            RouterMetrics::record_embeddings_duration(start.elapsed());
        } else {
            let error_type = format!("http_{}", res.status().as_u16());
            RouterMetrics::record_embeddings_error(&error_type);
        }

        res
750
751
    }

752
753
754
755
756
757
758
759
760
    async fn route_rerank(
        &self,
        headers: Option<&HeaderMap>,
        body: &RerankRequest,
        model_id: Option<&str>,
    ) -> Response {
        let response = self
            .route_typed_request(headers, body, "/v1/rerank", model_id)
            .await;
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
        if response.status().is_success() {
            match Self::build_rerank_response(body, response).await {
                Ok(rerank_response) => rerank_response,
                Err(e) => {
                    error!("Failed to build rerank response: {}", e);
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        "Failed to build rerank response".to_string(),
                    )
                        .into_response();
                }
            }
        } else {
            response
        }
776
777
    }

778
779
780
781
782
783
784
785
    fn router_type(&self) -> &'static str {
        "regular"
    }
}

#[cfg(test)]
mod tests {
    use super::*;
786
    use crate::core::BasicWorkerBuilder;
787
788

    fn create_test_regular_router() -> Router {
789
790
791
792
793
794
795
        // Create registries
        let worker_registry = Arc::new(WorkerRegistry::new());
        let policy_registry = Arc::new(PolicyRegistry::new(
            crate::config::types::PolicyConfig::RoundRobin,
        ));

        // Register test workers
796
797
798
799
800
801
        let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
            .worker_type(WorkerType::Regular)
            .build();
        let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
            .worker_type(WorkerType::Regular)
            .build();
802
803
804
        worker_registry.register(Arc::new(worker1));
        worker_registry.register(Arc::new(worker2));

805
        Router {
806
807
            worker_registry,
            policy_registry,
808
            dp_aware: false,
809
            client: Client::new(),
810
            retry_config: RetryConfig::default(),
811
            enable_igw: false,
812
813
814
        }
    }

815
816
817
818
819
820
821
    fn create_test_unhealthy_router() -> Router {
        let router = create_test_regular_router();
        let workers = router.worker_registry.get_all();
        workers[0].set_healthy(false);
        router
    }

822
823
824
    #[test]
    fn test_router_get_worker_urls_regular() {
        let router = create_test_regular_router();
825
826
        let workers = router.worker_registry.get_all();
        let urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
827
828
829
830
831
832
833
834
835
836
837
838

        assert_eq!(urls.len(), 2);
        assert!(urls.contains(&"http://worker1:8080".to_string()));
        assert!(urls.contains(&"http://worker2:8080".to_string()));
    }

    #[test]
    fn test_select_first_worker_regular() {
        let router = create_test_regular_router();
        let result = router.select_first_worker();

        assert!(result.is_ok());
839
840
841
        let url = result.unwrap();
        // DashMap doesn't guarantee order, so just check we get one of the workers
        assert!(url == "http://worker1:8080" || url == "http://worker2:8080");
842
    }
843
844
845
846
847
848
849
850
851
852
853
854

    #[test]
    fn test_select_first_worker_with_unhealthy_worker() {
        let router = create_test_unhealthy_router();
        let result = router.select_first_worker();

        assert!(result.is_ok());
        let url = result.unwrap();

        let worker = router.worker_registry.get_by_url(&url).unwrap();
        assert!(worker.is_healthy());
    }
855
}