Unverified Commit cb352d95 authored by atchernych's avatar atchernych Committed by GitHub
Browse files

feat: Support the reading of routing hints from the headers (#5502)


Signed-off-by: default avatarAnna Tchernych <atchernych@nvidia.com>
parent f438aee2
......@@ -118,3 +118,6 @@ profiling_results*
# Node.js
node_modules/
package-lock.json
# Compiled static libraries
*.a
......@@ -41,6 +41,7 @@ use super::{
};
use crate::engines::ValidateRequest;
use crate::protocols::openai::chat_completions::aggregator::ChatCompletionAggregator;
use crate::protocols::openai::nvext::apply_header_routing_overrides;
use crate::protocols::openai::{
chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionResponse,
......@@ -287,11 +288,13 @@ fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> Strin
async fn handler_completions(
State(state): State<Arc<service_v2::State>>,
headers: HeaderMap,
Json(request): Json<NvCreateCompletionRequest>,
Json(mut request): Json<NvCreateCompletionRequest>,
) -> Result<Response, ErrorResponse> {
// return a 503 if the service is not ready
check_ready(&state)?;
request.nvext = apply_header_routing_overrides(request.nvext.take(), &headers);
// create the context for the request
let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers);
let request = Context::with_id(request, request_id);
......@@ -707,11 +710,13 @@ async fn embeddings(
async fn handler_chat_completions(
State((state, template)): State<(Arc<service_v2::State>, Option<RequestTemplate>)>,
headers: HeaderMap,
Json(request): Json<NvCreateChatCompletionRequest>,
Json(mut request): Json<NvCreateChatCompletionRequest>,
) -> Result<Response, ErrorResponse> {
// return a 503 if the service is not ready
check_ready(&state)?;
request.nvext = apply_header_routing_overrides(request.nvext.take(), &headers);
// create the context for the request
let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers);
let request = Context::with_id(request, request_id);
......@@ -1137,11 +1142,13 @@ pub fn validate_completion_fields_generic(
async fn handler_responses(
State((state, template)): State<(Arc<service_v2::State>, Option<RequestTemplate>)>,
headers: HeaderMap,
Json(request): Json<NvCreateResponse>,
Json(mut request): Json<NvCreateResponse>,
) -> Result<Response, ErrorResponse> {
// return a 503 if the service is not ready
check_ready(&state)?;
request.nvext = apply_header_routing_overrides(request.nvext.take(), &headers);
// create the context for the request
let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers);
let request = Context::with_id(request, request_id);
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use axum::http::HeaderMap;
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
......@@ -8,6 +9,43 @@ use validator::{Validate, ValidationError};
pub use crate::protocols::common::timing::TimingInfo;
pub const HEADER_WORKER_INSTANCE_ID: &str = "x-worker-instance-id";
pub const HEADER_PREFILL_INSTANCE_ID: &str = "x-prefill-instance-id";
/// Apply routing overrides from HTTP headers to nvext.
///
/// Header mappings:
/// - `x-worker-instance-id` -> `backend_instance_id` and `decode_worker_id`
/// - `x-prefill-instance-id` -> `prefill_worker_id`
///
/// Headers take priority over existing nvext values when present.
/// If no headers are present, returns the original nvext unchanged.
pub fn apply_header_routing_overrides(nvext: Option<NvExt>, headers: &HeaderMap) -> Option<NvExt> {
let worker_id = headers
.get(HEADER_WORKER_INSTANCE_ID)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
let prefill_id = headers
.get(HEADER_PREFILL_INSTANCE_ID)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
if worker_id.is_none() && prefill_id.is_none() {
return nvext;
}
let mut ext = nvext.unwrap_or_default();
if let Some(id) = worker_id {
ext.backend_instance_id = Some(id);
ext.decode_worker_id = Some(id);
}
if let Some(id) = prefill_id {
ext.prefill_worker_id = Some(id);
}
Some(ext)
}
pub trait NvExtProvider {
fn nvext(&self) -> Option<&NvExt>;
fn raw_prompt(&self) -> Option<String>;
......@@ -210,4 +248,30 @@ mod tests {
assert_eq!(nv_ext.decode_worker_id, Some(200));
assert!(nv_ext.validate().is_ok());
}
// Test apply_header_routing_overrides - worker header present, prefill header absent
#[test]
fn test_apply_header_routing_overrides() {
use axum::http::HeaderMap;
// Only HEADER_WORKER_INSTANCE_ID is in the header
let mut headers = HeaderMap::new();
headers.insert(HEADER_WORKER_INSTANCE_ID, "123".parse().unwrap());
// Note: HEADER_PREFILL_INSTANCE_ID is NOT in the header
let nvext = NvExt::builder()
.backend_instance_id(999)
.decode_worker_id(888)
.prefill_worker_id(777)
.build()
.unwrap();
let result = apply_header_routing_overrides(Some(nvext), &headers).unwrap();
// Header should override backend_instance_id and decode_worker_id
assert_eq!(result.backend_instance_id, Some(123));
assert_eq!(result.decode_worker_id, Some(123));
// prefill_worker_id should remain from original nvext (not overwritten by header)
assert_eq!(result.prefill_worker_id, Some(777));
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment