Unverified Commit cb352d95 authored by atchernych's avatar atchernych Committed by GitHub
Browse files

feat: Support the reading of routing hints from the headers (#5502)


Signed-off-by: default avatarAnna Tchernych <atchernych@nvidia.com>
parent f438aee2
...@@ -118,3 +118,6 @@ profiling_results* ...@@ -118,3 +118,6 @@ profiling_results*
# Node.js # Node.js
node_modules/ node_modules/
package-lock.json package-lock.json
# Compiled static libraries
*.a
...@@ -41,6 +41,7 @@ use super::{ ...@@ -41,6 +41,7 @@ use super::{
}; };
use crate::engines::ValidateRequest; use crate::engines::ValidateRequest;
use crate::protocols::openai::chat_completions::aggregator::ChatCompletionAggregator; use crate::protocols::openai::chat_completions::aggregator::ChatCompletionAggregator;
use crate::protocols::openai::nvext::apply_header_routing_overrides;
use crate::protocols::openai::{ use crate::protocols::openai::{
chat_completions::{ chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionResponse, NvCreateChatCompletionRequest, NvCreateChatCompletionResponse,
...@@ -287,11 +288,13 @@ fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> Strin ...@@ -287,11 +288,13 @@ fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> Strin
async fn handler_completions( async fn handler_completions(
State(state): State<Arc<service_v2::State>>, State(state): State<Arc<service_v2::State>>,
headers: HeaderMap, headers: HeaderMap,
Json(request): Json<NvCreateCompletionRequest>, Json(mut request): Json<NvCreateCompletionRequest>,
) -> Result<Response, ErrorResponse> { ) -> Result<Response, ErrorResponse> {
// return a 503 if the service is not ready // return a 503 if the service is not ready
check_ready(&state)?; check_ready(&state)?;
request.nvext = apply_header_routing_overrides(request.nvext.take(), &headers);
// create the context for the request // create the context for the request
let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers); let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers);
let request = Context::with_id(request, request_id); let request = Context::with_id(request, request_id);
...@@ -707,11 +710,13 @@ async fn embeddings( ...@@ -707,11 +710,13 @@ async fn embeddings(
async fn handler_chat_completions( async fn handler_chat_completions(
State((state, template)): State<(Arc<service_v2::State>, Option<RequestTemplate>)>, State((state, template)): State<(Arc<service_v2::State>, Option<RequestTemplate>)>,
headers: HeaderMap, headers: HeaderMap,
Json(request): Json<NvCreateChatCompletionRequest>, Json(mut request): Json<NvCreateChatCompletionRequest>,
) -> Result<Response, ErrorResponse> { ) -> Result<Response, ErrorResponse> {
// return a 503 if the service is not ready // return a 503 if the service is not ready
check_ready(&state)?; check_ready(&state)?;
request.nvext = apply_header_routing_overrides(request.nvext.take(), &headers);
// create the context for the request // create the context for the request
let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers); let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers);
let request = Context::with_id(request, request_id); let request = Context::with_id(request, request_id);
...@@ -1137,11 +1142,13 @@ pub fn validate_completion_fields_generic( ...@@ -1137,11 +1142,13 @@ pub fn validate_completion_fields_generic(
async fn handler_responses( async fn handler_responses(
State((state, template)): State<(Arc<service_v2::State>, Option<RequestTemplate>)>, State((state, template)): State<(Arc<service_v2::State>, Option<RequestTemplate>)>,
headers: HeaderMap, headers: HeaderMap,
Json(request): Json<NvCreateResponse>, Json(mut request): Json<NvCreateResponse>,
) -> Result<Response, ErrorResponse> { ) -> Result<Response, ErrorResponse> {
// return a 503 if the service is not ready // return a 503 if the service is not ready
check_ready(&state)?; check_ready(&state)?;
request.nvext = apply_header_routing_overrides(request.nvext.take(), &headers);
// create the context for the request // create the context for the request
let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers); let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers);
let request = Context::with_id(request, request_id); let request = Context::with_id(request, request_id);
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use axum::http::HeaderMap;
use derive_builder::Builder; use derive_builder::Builder;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use utoipa::ToSchema; use utoipa::ToSchema;
...@@ -8,6 +9,43 @@ use validator::{Validate, ValidationError}; ...@@ -8,6 +9,43 @@ use validator::{Validate, ValidationError};
pub use crate::protocols::common::timing::TimingInfo; pub use crate::protocols::common::timing::TimingInfo;
pub const HEADER_WORKER_INSTANCE_ID: &str = "x-worker-instance-id";
pub const HEADER_PREFILL_INSTANCE_ID: &str = "x-prefill-instance-id";
/// Apply routing overrides from HTTP headers to nvext.
///
/// Header mappings:
/// - `x-worker-instance-id` -> `backend_instance_id` and `decode_worker_id`
/// - `x-prefill-instance-id` -> `prefill_worker_id`
///
/// Headers take priority over existing nvext values when present.
/// If no headers are present, returns the original nvext unchanged.
pub fn apply_header_routing_overrides(nvext: Option<NvExt>, headers: &HeaderMap) -> Option<NvExt> {
let worker_id = headers
.get(HEADER_WORKER_INSTANCE_ID)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
let prefill_id = headers
.get(HEADER_PREFILL_INSTANCE_ID)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
if worker_id.is_none() && prefill_id.is_none() {
return nvext;
}
let mut ext = nvext.unwrap_or_default();
if let Some(id) = worker_id {
ext.backend_instance_id = Some(id);
ext.decode_worker_id = Some(id);
}
if let Some(id) = prefill_id {
ext.prefill_worker_id = Some(id);
}
Some(ext)
}
pub trait NvExtProvider { pub trait NvExtProvider {
fn nvext(&self) -> Option<&NvExt>; fn nvext(&self) -> Option<&NvExt>;
fn raw_prompt(&self) -> Option<String>; fn raw_prompt(&self) -> Option<String>;
...@@ -210,4 +248,30 @@ mod tests { ...@@ -210,4 +248,30 @@ mod tests {
assert_eq!(nv_ext.decode_worker_id, Some(200)); assert_eq!(nv_ext.decode_worker_id, Some(200));
assert!(nv_ext.validate().is_ok()); assert!(nv_ext.validate().is_ok());
} }
// Test apply_header_routing_overrides - worker header present, prefill header absent
#[test]
fn test_apply_header_routing_overrides() {
use axum::http::HeaderMap;
// Only HEADER_WORKER_INSTANCE_ID is in the header
let mut headers = HeaderMap::new();
headers.insert(HEADER_WORKER_INSTANCE_ID, "123".parse().unwrap());
// Note: HEADER_PREFILL_INSTANCE_ID is NOT in the header
let nvext = NvExt::builder()
.backend_instance_id(999)
.decode_worker_id(888)
.prefill_worker_id(777)
.build()
.unwrap();
let result = apply_header_routing_overrides(Some(nvext), &headers).unwrap();
// Header should override backend_instance_id and decode_worker_id
assert_eq!(result.backend_instance_id, Some(123));
assert_eq!(result.decode_worker_id, Some(123));
// prefill_worker_id should remain from original nvext (not overwritten by header)
assert_eq!(result.prefill_worker_id, Some(777));
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment