"lm_eval/datasets/bigbench_resources/__init__.py" did not exist on "f88bb82710a29057b378b747e8b9cfd33a0a80e8"
router.rs 31.8 KB
Newer Older
1
2
use std::{sync::Arc, time::Instant};

3
use axum::{
4
    body::{to_bytes, Body},
5
    extract::Request,
6
    http::{
7
8
        header::{CONTENT_LENGTH, CONTENT_TYPE},
        HeaderMap, HeaderValue, Method, StatusCode,
9
    },
10
11
12
13
    response::{IntoResponse, Response},
    Json,
};
use futures_util::StreamExt;
14
use reqwest::Client;
15
use tokio_stream::wrappers::UnboundedReceiverStream;
16
use tracing::{debug, error};
17

18
19
20
21
22
23
24
25
26
use crate::{
    config::types::RetryConfig,
    core::{
        is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType,
    },
    metrics::RouterMetrics,
    policies::PolicyRegistry,
    protocols::{
        chat::ChatCompletionRequest,
27
        classify::ClassifyRequest,
28
29
30
31
32
33
34
35
36
37
        common::GenerationRequest,
        completion::CompletionRequest,
        embedding::EmbeddingRequest,
        generate::GenerateRequest,
        rerank::{RerankRequest, RerankResponse, RerankResult},
        responses::{ResponsesGetParams, ResponsesRequest},
    },
    routers::{header_utils, RouterTrait},
};

38
39
40
/// Regular router that uses injected load balancing policies
#[derive(Debug)]
pub struct Router {
41
42
    worker_registry: Arc<WorkerRegistry>,
    policy_registry: Arc<PolicyRegistry>,
43
    client: Client,
44
    dp_aware: bool,
45
    enable_igw: bool,
46
    retry_config: RetryConfig,
47
48
49
}

impl Router {
50
    /// Create a new router with injected policy and client
51
52
53
54
55
56
57
    pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
        let workers = ctx.worker_registry.get_workers_filtered(
            None, // any model
            Some(WorkerType::Regular),
            Some(ConnectionMode::Http),
            false, // include all workers
        );
58

59
        RouterMetrics::set_active_workers(workers.len());
60
61

        Ok(Router {
62
63
            worker_registry: ctx.worker_registry.clone(),
            policy_registry: ctx.policy_registry.clone(),
64
65
            client: ctx.client.clone(),
            dp_aware: ctx.router_config.dp_aware,
66
            enable_igw: ctx.router_config.enable_igw,
67
            retry_config: ctx.router_config.effective_retry_config(),
68
69
70
71
        })
    }

    fn select_first_worker(&self) -> Result<String, String> {
72
        let workers = self.worker_registry.get_all();
73
74
        let healthy_workers: Vec<_> = workers.iter().filter(|w| w.is_healthy()).collect();
        if healthy_workers.is_empty() {
75
76
            Err("No workers are available".to_string())
        } else {
77
            Ok(healthy_workers[0].url().to_string())
78
79
80
        }
    }

81
    // Helper method to proxy GET requests to the first available worker
82
    async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response {
83
        let headers = header_utils::copy_request_headers(&req);
84
85
86

        match self.select_first_worker() {
            Ok(worker_url) => {
87
                let mut request_builder = self.client.get(format!("{}/{}", worker_url, endpoint));
88
                for (name, value) in headers {
89
90
                    let name_lc = name.to_lowercase();
                    if name_lc != "content-type" && name_lc != "content-length" {
91
92
93
                        request_builder = request_builder.header(name, value);
                    }
                }
94

95
96
97
98
                match request_builder.send().await {
                    Ok(res) => {
                        let status = StatusCode::from_u16(res.status().as_u16())
                            .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
99
100
101
102
103

                        // Preserve headers from backend
                        let response_headers =
                            header_utils::preserve_response_headers(res.headers());

104
                        match res.bytes().await {
105
                            Ok(body) => {
106
                                let mut response = Response::new(Body::from(body));
107
108
109
110
                                *response.status_mut() = status;
                                *response.headers_mut() = response_headers;
                                response
                            }
111
112
113
114
115
                            Err(e) => (
                                StatusCode::INTERNAL_SERVER_ERROR,
                                format!("Failed to read response: {}", e),
                            )
                                .into_response(),
116
117
                        }
                    }
118
119
120
121
122
                    Err(e) => (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        format!("Request failed: {}", e),
                    )
                        .into_response(),
123
124
                }
            }
125
            Err(e) => (StatusCode::SERVICE_UNAVAILABLE, e).into_response(),
126
127
128
        }
    }

129
130
131
132
133
134
    /// Select worker for a specific model considering circuit breaker state
    fn select_worker_for_model(
        &self,
        model_id: Option<&str>,
        text: Option<&str>,
    ) -> Option<Arc<dyn Worker>> {
135
136
        let effective_model_id = if !self.enable_igw { None } else { model_id };

137
138
        // Get workers for the specified model O(1), filtered by connection mode
        let workers = self.worker_registry.get_workers_filtered(
139
            effective_model_id,
140
141
142
143
            Some(WorkerType::Regular),
            Some(ConnectionMode::Http),
            false, // get all workers, we'll filter by is_available() next
        );
144
145

        let available: Vec<Arc<dyn Worker>> = workers
146
147
            .iter()
            .filter(|w| w.is_available())
148
            .cloned()
149
150
151
152
            .collect();
        if available.is_empty() {
            return None;
        }
153
154
155
156
157
158
159
160
161

        // Get the appropriate policy for this model
        let policy = match model_id {
            Some(model) => self.policy_registry.get_policy_or_default(model),
            None => self.policy_registry.get_default_policy(),
        };

        let idx = policy.select_worker(&available, text)?;
        Some(available[idx].clone())
162
163
    }

164
    pub async fn route_typed_request<T: GenerationRequest + serde::Serialize + Clone>(
165
        &self,
166
        headers: Option<&HeaderMap>,
167
168
        typed_req: &T,
        route: &str,
169
        model_id: Option<&str>,
170
    ) -> Response {
171
        let start = Instant::now();
172
173
174
175
176
177
178
        let is_stream = typed_req.is_stream();
        let text = typed_req.extract_text_for_routing();

        let response = RetryExecutor::execute_response_with_retry(
            &self.retry_config,
            // operation per attempt
            |_: u32| async {
179
                let worker = match self.select_worker_for_model(model_id, Some(&text)) {
180
181
182
183
184
185
186
187
188
189
                    Some(w) => w,
                    None => {
                        RouterMetrics::record_request_error(route, "no_available_workers");
                        return (
                            StatusCode::SERVICE_UNAVAILABLE,
                            "No available workers (all circuits open or unhealthy)",
                        )
                            .into_response();
                    }
                };
190

191
                // Optional load tracking for cache-aware policy
192
193
194
195
196
197
198
                // Get the policy for this model to check if it's cache-aware
                let policy = match model_id {
                    Some(model) => self.policy_registry.get_policy_or_default(model),
                    None => self.policy_registry.get_default_policy(),
                };

                let load_incremented = if policy.name() == "cache_aware" {
199
200
201
                    worker.increment_load();
                    RouterMetrics::set_running_requests(worker.url(), worker.load());
                    true
202
203
204
205
                } else {
                    false
                };

206
207
                // Keep a clone for potential cleanup on retry
                let worker_for_cleanup = if load_incremented {
208
                    Some(worker.clone())
209
210
211
212
                } else {
                    None
                };

213
214
                let response = self
                    .send_typed_request(
215
                        headers,
216
217
                        typed_req,
                        route,
218
                        worker.url(),
219
220
221
222
223
                        is_stream,
                        load_incremented,
                    )
                    .await;

224
                worker.record_outcome(response.status().is_success());
225
226
227
228
229
230
231
232
233
234
235
236
237

                // For retryable failures, we need to decrement load since send_typed_request
                // won't have done it (it only decrements on success or non-retryable failures)
                if is_retryable_status(response.status()) && load_incremented {
                    if let Some(cleanup_worker) = worker_for_cleanup {
                        cleanup_worker.decrement_load();
                        RouterMetrics::set_running_requests(
                            cleanup_worker.url(),
                            cleanup_worker.load(),
                        );
                    }
                }

238
239
240
                response
            },
            // should_retry predicate
241
            |res, _attempt| is_retryable_status(res.status()),
242
243
244
245
246
247
248
            // on_backoff hook
            |delay, attempt| {
                RouterMetrics::record_retry(route);
                RouterMetrics::record_retry_backoff_duration(delay, attempt);
            },
            // on_exhausted hook
            || RouterMetrics::record_retries_exhausted(route),
249
        )
250
251
252
253
254
255
        .await;

        if response.status().is_success() {
            let duration = start.elapsed();
            RouterMetrics::record_request(route);
            RouterMetrics::record_generate_duration(duration);
256
        } else if !is_retryable_status(response.status()) {
257
            RouterMetrics::record_request_error(route, "non_retryable_error");
258
        }
259
260

        response
261
262
    }

263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    // Helper: return base worker URL (strips DP suffix when enabled)
    fn worker_base_url(&self, worker_url: &str) -> String {
        if self.dp_aware {
            if let Ok((prefix, _)) = Self::extract_dp_rank(worker_url) {
                return prefix.to_string();
            }
        }
        worker_url.to_string()
    }

    // Generic simple routing for GET/POST without JSON body
    async fn route_simple_request(
        &self,
        headers: Option<&HeaderMap>,
        endpoint: &str,
        method: Method,
    ) -> Response {
        // TODO: currently the sglang worker is using in-memory state management, so this implementation has to fan out to all workers.
        // Eventually, we need to have router to manage the chat history with a proper database, will update this implementation accordingly.
282
283
        let workers = self.worker_registry.get_all();
        if workers.is_empty() {
284
285
286
287
            return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response();
        }

        let mut last_response: Option<Response> = None;
288
289
290
        for worker in workers {
            let worker_url = worker.url();
            let base = self.worker_base_url(worker_url);
291
292
293
294
295
296
297
298
299
300
301
302
303
304

            let url = format!("{}/{}", base, endpoint);
            let mut request_builder = match method {
                Method::GET => self.client.get(url),
                Method::POST => self.client.post(url),
                _ => {
                    return (
                        StatusCode::METHOD_NOT_ALLOWED,
                        "Unsupported method for simple routing",
                    )
                        .into_response()
                }
            };

305
306
307
308
309
            if let Some(api_key) = worker.api_key() {
                request_builder =
                    request_builder.header("Authorization", format!("Bearer {}", api_key));
            }

310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
            if let Some(hdrs) = headers {
                for (name, value) in hdrs {
                    let name_lc = name.as_str().to_lowercase();
                    if name_lc != "content-type" && name_lc != "content-length" {
                        request_builder = request_builder.header(name, value);
                    }
                }
            }

            match request_builder.send().await {
                Ok(res) => {
                    let status = StatusCode::from_u16(res.status().as_u16())
                        .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
                    let response_headers = header_utils::preserve_response_headers(res.headers());
                    match res.bytes().await {
                        Ok(body) => {
326
                            let mut response = Response::new(Body::from(body));
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
                            *response.status_mut() = status;
                            *response.headers_mut() = response_headers;
                            if status.is_success() {
                                return response;
                            }
                            last_response = Some(response);
                        }
                        Err(e) => {
                            last_response = Some(
                                (
                                    StatusCode::INTERNAL_SERVER_ERROR,
                                    format!("Failed to read response: {}", e),
                                )
                                    .into_response(),
                            );
                        }
                    }
                }
                Err(e) => {
                    last_response = Some(
                        (
                            StatusCode::INTERNAL_SERVER_ERROR,
                            format!("Request failed: {}", e),
                        )
                            .into_response(),
                    );
                }
            }
        }

        last_response
            .unwrap_or_else(|| (StatusCode::BAD_GATEWAY, "No worker response").into_response())
    }

    // Route a GET request with provided headers to a specific endpoint
    async fn route_get_request(&self, headers: Option<&HeaderMap>, endpoint: &str) -> Response {
        self.route_simple_request(headers, endpoint, Method::GET)
            .await
    }

    // Route a POST request with empty body to a specific endpoint
    async fn route_post_empty_request(
        &self,
        headers: Option<&HeaderMap>,
        endpoint: &str,
    ) -> Response {
        self.route_simple_request(headers, endpoint, Method::POST)
            .await
    }

377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
    // TODO (rui): Better accommodate to the Worker abstraction
    fn extract_dp_rank(worker_url: &str) -> Result<(&str, usize), String> {
        let parts: Vec<&str> = worker_url.split('@').collect();
        if parts.len() != 2 {
            return Err(format!("invalid worker_url format: {}", worker_url));
        }

        // Parse the second part (dp_rank) into an integer
        match parts[1].parse::<usize>() {
            Ok(dp_rank) => Ok((parts[0], dp_rank)),
            Err(_) => Err(format!(
                "failed to parse dp_rank from worker_url: {}",
                worker_url
            )),
        }
    }

394
395
396
    // Send typed request directly without conversion
    async fn send_typed_request<T: serde::Serialize>(
        &self,
397
        headers: Option<&HeaderMap>,
398
399
400
401
402
        typed_req: &T,
        route: &str,
        worker_url: &str,
        is_stream: bool,
        load_incremented: bool, // Whether load was incremented for this request
403
    ) -> Response {
404
405
406
407
408
409
        // Get the worker's API key if available
        let api_key = self
            .worker_registry
            .get_by_url(worker_url)
            .and_then(|w| w.api_key().clone());

410
411
412
413
414
        let mut request_builder = if self.dp_aware {
            let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) {
                Ok(tup) => tup,
                Err(e) => {
                    error!("Failed to extract dp_rank: {}", e);
415
416
417
418
419
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        format!("Failed to extract dp_rank: {}", e),
                    )
                        .into_response();
420
421
422
423
424
425
                }
            };

            let mut json_val = match serde_json::to_value(typed_req) {
                Ok(j) => j,
                Err(e) => {
426
427
428
429
430
                    return (
                        StatusCode::BAD_REQUEST,
                        format!("Convert into serde_json::Value failed: {}", e),
                    )
                        .into_response();
431
432
433
434
435
436
437
438
439
440
441
442
443
                }
            };

            if let Some(map) = json_val.as_object_mut() {
                map.insert(
                    String::from("data_parallel_rank"),
                    serde_json::json!(dp_rank),
                );
                debug!(
                    "Modified request body: {}",
                    serde_json::to_string(&json_val).unwrap_or(String::from("ERR"))
                );
            } else {
444
445
446
447
448
                return (
                    StatusCode::BAD_REQUEST,
                    "Failed to insert the data_parallel_rank field into the request body",
                )
                    .into_response();
449
450
            }

451
            self.client
452
453
454
                .post(format!("{}{}", worker_url_prefix, route))
                .json(&json_val)
        } else {
455
            self.client
456
457
458
                .post(format!("{}{}", worker_url, route))
                .json(typed_req) // Use json() directly with typed request
        };
459

460
461
462
463
        if let Some(key) = api_key {
            request_builder = request_builder.header("Authorization", format!("Bearer {}", key));
        }

464
465
466
467
        // Copy all headers from original request if provided
        if let Some(headers) = headers {
            for (name, value) in headers {
                // Skip Content-Type and Content-Length as .json() sets them
468
                if *name != CONTENT_TYPE && *name != CONTENT_LENGTH {
469
470
                    request_builder = request_builder.header(name, value);
                }
471
472
473
474
475
476
            }
        }

        let res = match request_builder.send().await {
            Ok(res) => res,
            Err(e) => {
477
478
479
480
                error!(
                    "Failed to send typed request worker_url={} route={} error={}",
                    worker_url, route, e
                );
481
482
483

                // Decrement load on error if it was incremented
                if load_incremented {
484
485
486
                    if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                        worker.decrement_load();
                        RouterMetrics::set_running_requests(worker_url, worker.load());
487
488
489
                    }
                }

490
491
492
493
494
                return (
                    StatusCode::INTERNAL_SERVER_ERROR,
                    format!("Request failed: {}", e),
                )
                    .into_response();
495
496
497
            }
        };

498
499
        let status = StatusCode::from_u16(res.status().as_u16())
            .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
500
501

        if !is_stream {
502
            // For non-streaming requests, preserve headers
503
            let response_headers = header_utils::preserve_response_headers(res.headers());
504

505
            let response = match res.bytes().await {
506
                Ok(body) => {
507
                    let mut response = Response::new(Body::from(body));
508
509
510
511
                    *response.status_mut() = status;
                    *response.headers_mut() = response_headers;
                    response
                }
512
                Err(e) => {
513
514
                    // IMPORTANT: Decrement load on error before returning
                    if load_incremented {
515
516
517
                        if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                            worker.decrement_load();
                            RouterMetrics::set_running_requests(worker_url, worker.load());
518
519
520
                        }
                    }

521
                    let error_msg = format!("Failed to get response body: {}", e);
522
                    (StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response()
523
524
525
526
                }
            };

            // Decrement load counter for non-streaming requests if it was incremented
527
            if load_incremented {
528
529
530
                if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
                    worker.decrement_load();
                    RouterMetrics::set_running_requests(worker_url, worker.load());
531
532
533
534
535
536
                }
            }

            response
        } else if load_incremented {
            // For streaming with load tracking, we need to manually decrement when done
537
            let registry = Arc::clone(&self.worker_registry);
538
539
            let worker_url = worker_url.to_string();

540
541
542
543
544
            // Preserve headers for streaming response
            let mut response_headers = header_utils::preserve_response_headers(res.headers());
            // Ensure we set the correct content-type for SSE
            response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));

545
546
547
548
549
550
            let stream = res.bytes_stream();
            let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

            // Spawn task to forward stream and detect completion
            tokio::spawn(async move {
                let mut stream = stream;
551
                let mut decremented = false;
552
553
554
555
556
557
558
559
560
                while let Some(chunk) = stream.next().await {
                    match chunk {
                        Ok(bytes) => {
                            // Check for stream end marker
                            if bytes
                                .as_ref()
                                .windows(12)
                                .any(|window| window == b"data: [DONE]")
                            {
561
562
563
564
                                if let Some(worker) = registry.get_by_url(&worker_url) {
                                    worker.decrement_load();
                                    RouterMetrics::set_running_requests(&worker_url, worker.load());
                                    decremented = true;
565
566
                                }
                            }
567
568
569
570
571
572
573
574
575
576
                            if tx.send(Ok(bytes)).is_err() {
                                break;
                            }
                        }
                        Err(e) => {
                            let _ = tx.send(Err(format!("Stream error: {}", e)));
                            break;
                        }
                    }
                }
577
                if !decremented {
578
579
580
                    if let Some(worker) = registry.get_by_url(&worker_url) {
                        worker.decrement_load();
                        RouterMetrics::set_running_requests(&worker_url, worker.load());
581
582
                    }
                }
583
584
585
586
587
588
589
            });

            let stream = UnboundedReceiverStream::new(rx);
            let body = Body::from_stream(stream);

            let mut response = Response::new(body);
            *response.status_mut() = status;
590
            *response.headers_mut() = response_headers;
591
            response
592
593
        } else {
            // For requests without load tracking, just stream
594
595
596
597
598
            // Preserve headers for streaming response
            let mut response_headers = header_utils::preserve_response_headers(res.headers());
            // Ensure we set the correct content-type for SSE
            response_headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));

599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
            let stream = res.bytes_stream();
            let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

            // Spawn task to forward stream
            tokio::spawn(async move {
                let mut stream = stream;
                while let Some(chunk) = stream.next().await {
                    match chunk {
                        Ok(bytes) => {
                            if tx.send(Ok(bytes)).is_err() {
                                break;
                            }
                        }
                        Err(e) => {
                            let _ = tx.send(Err(format!("Stream error: {}", e)));
                            break;
                        }
                    }
                }
            });

            let stream = UnboundedReceiverStream::new(rx);
            let body = Body::from_stream(stream);

            let mut response = Response::new(body);
            *response.status_mut() = status;
625
            *response.headers_mut() = response_headers;
626
            response
627
628
629
        }
    }

630
631
632
633
634
635
636
637
638
    async fn build_rerank_response(
        req: &RerankRequest,
        response: Response,
    ) -> anyhow::Result<Response> {
        let (_, response_body) = response.into_parts();
        let body_bytes = to_bytes(response_body, usize::MAX).await?;
        let rerank_results = serde_json::from_slice::<Vec<RerankResult>>(&body_bytes)?;
        let mut rerank_response =
            RerankResponse::new(rerank_results, req.model.clone(), req.rid.clone());
639
        // Sorting is handled by Python worker (serving_rerank.py)
640
641
642
643
644
645
646
647
        if let Some(top_k) = req.top_k {
            rerank_response.apply_top_k(top_k);
        }
        if !req.return_documents {
            rerank_response.drop_documents();
        }
        Ok(Json(rerank_response).into_response())
    }
648
649
650
651
}

use async_trait::async_trait;

652
#[async_trait]
653
654
655
656
657
impl RouterTrait for Router {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

658
659
    async fn health_generate(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "health_generate").await
660
661
    }

662
663
    async fn get_server_info(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "get_server_info").await
664
665
    }

666
667
    async fn get_models(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "v1/models").await
668
669
    }

670
671
    async fn get_model_info(&self, req: Request<Body>) -> Response {
        self.proxy_get_request(req, "get_model_info").await
672
673
674
675
    }

    async fn route_generate(
        &self,
676
677
        headers: Option<&HeaderMap>,
        body: &GenerateRequest,
678
        model_id: Option<&str>,
679
    ) -> Response {
680
681
        self.route_typed_request(headers, body, "/generate", model_id)
            .await
682
683
684
685
    }

    async fn route_chat(
        &self,
686
687
        headers: Option<&HeaderMap>,
        body: &ChatCompletionRequest,
688
        model_id: Option<&str>,
689
    ) -> Response {
690
        self.route_typed_request(headers, body, "/v1/chat/completions", model_id)
691
            .await
692
693
694
695
    }

    async fn route_completion(
        &self,
696
697
        headers: Option<&HeaderMap>,
        body: &CompletionRequest,
698
        model_id: Option<&str>,
699
    ) -> Response {
700
        self.route_typed_request(headers, body, "/v1/completions", model_id)
701
            .await
702
703
    }

704
705
706
707
    async fn route_responses(
        &self,
        headers: Option<&HeaderMap>,
        body: &ResponsesRequest,
708
        model_id: Option<&str>,
709
    ) -> Response {
710
        self.route_typed_request(headers, body, "/v1/responses", model_id)
711
712
713
            .await
    }

714
715
716
717
718
719
    async fn get_response(
        &self,
        headers: Option<&HeaderMap>,
        response_id: &str,
        _params: &ResponsesGetParams,
    ) -> Response {
720
721
722
723
724
725
726
727
728
        let endpoint = format!("v1/responses/{}", response_id);
        self.route_get_request(headers, &endpoint).await
    }

    async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
        let endpoint = format!("v1/responses/{}/cancel", response_id);
        self.route_post_empty_request(headers, &endpoint).await
    }

729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
    async fn route_embeddings(
        &self,
        headers: Option<&HeaderMap>,
        body: &EmbeddingRequest,
        model_id: Option<&str>,
    ) -> Response {
        // Record embeddings-specific metrics in addition to general request metrics
        let start = Instant::now();
        let res = self
            .route_typed_request(headers, body, "/v1/embeddings", model_id)
            .await;

        // Embedding specific metrics
        if res.status().is_success() {
            RouterMetrics::record_embeddings_request();
            RouterMetrics::record_embeddings_duration(start.elapsed());
        } else {
            let error_type = format!("http_{}", res.status().as_u16());
            RouterMetrics::record_embeddings_error(&error_type);
        }

        res
751
752
    }

753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
    async fn route_classify(
        &self,
        headers: Option<&HeaderMap>,
        body: &ClassifyRequest,
        model_id: Option<&str>,
    ) -> Response {
        // Record classification-specific metrics in addition to general request metrics
        let start = Instant::now();
        let res = self
            .route_typed_request(headers, body, "/v1/classify", model_id)
            .await;

        // Classification specific metrics
        if res.status().is_success() {
            RouterMetrics::record_classify_request();
            RouterMetrics::record_classify_duration(start.elapsed());
        } else {
            let error_type = format!("http_{}", res.status().as_u16());
            RouterMetrics::record_classify_error(&error_type);
        }

        res
    }

777
778
779
780
781
782
783
784
785
    async fn route_rerank(
        &self,
        headers: Option<&HeaderMap>,
        body: &RerankRequest,
        model_id: Option<&str>,
    ) -> Response {
        let response = self
            .route_typed_request(headers, body, "/v1/rerank", model_id)
            .await;
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
        if response.status().is_success() {
            match Self::build_rerank_response(body, response).await {
                Ok(rerank_response) => rerank_response,
                Err(e) => {
                    error!("Failed to build rerank response: {}", e);
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        "Failed to build rerank response".to_string(),
                    )
                        .into_response();
                }
            }
        } else {
            response
        }
801
802
    }

803
804
805
806
807
808
809
810
    fn router_type(&self) -> &'static str {
        "regular"
    }
}

#[cfg(test)]
mod tests {
    use super::*;
811
    use crate::core::BasicWorkerBuilder;
812
813

    fn create_test_regular_router() -> Router {
814
815
816
817
818
819
820
        // Create registries
        let worker_registry = Arc::new(WorkerRegistry::new());
        let policy_registry = Arc::new(PolicyRegistry::new(
            crate::config::types::PolicyConfig::RoundRobin,
        ));

        // Register test workers
821
822
823
824
825
826
        let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
            .worker_type(WorkerType::Regular)
            .build();
        let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
            .worker_type(WorkerType::Regular)
            .build();
827
828
829
        worker_registry.register(Arc::new(worker1));
        worker_registry.register(Arc::new(worker2));

830
        Router {
831
832
            worker_registry,
            policy_registry,
833
            dp_aware: false,
834
            client: Client::new(),
835
            retry_config: RetryConfig::default(),
836
            enable_igw: false,
837
838
839
        }
    }

840
841
842
843
844
845
846
    fn create_test_unhealthy_router() -> Router {
        let router = create_test_regular_router();
        let workers = router.worker_registry.get_all();
        workers[0].set_healthy(false);
        router
    }

847
848
849
    #[test]
    fn test_router_get_worker_urls_regular() {
        let router = create_test_regular_router();
850
851
        let workers = router.worker_registry.get_all();
        let urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
852
853
854
855
856
857
858
859
860
861
862
863

        assert_eq!(urls.len(), 2);
        assert!(urls.contains(&"http://worker1:8080".to_string()));
        assert!(urls.contains(&"http://worker2:8080".to_string()));
    }

    #[test]
    fn test_select_first_worker_regular() {
        let router = create_test_regular_router();
        let result = router.select_first_worker();

        assert!(result.is_ok());
864
865
866
        let url = result.unwrap();
        // DashMap doesn't guarantee order, so just check we get one of the workers
        assert!(url == "http://worker1:8080" || url == "http://worker2:8080");
867
    }
868
869
870
871
872
873
874
875
876
877
878
879

    #[test]
    fn test_select_first_worker_with_unhealthy_worker() {
        let router = create_test_unhealthy_router();
        let result = router.select_first_worker();

        assert!(result.is_ok());
        let url = result.unwrap();

        let worker = router.worker_registry.get_by_url(&url).unwrap();
        assert!(worker.is_healthy());
    }
880
}