"lib/vscode:/vscode.git/clone" did not exist on "61c6780469c8c41ed23b308180b45ffd7778bce0"
Unverified Commit 2b906504 authored by atchernych's avatar atchernych Committed by GitHub
Browse files

feat: enable DP for GAIE (#7741)


Signed-off-by: default avatarAnna Tchernych <atchernych@nvidia.com>
parent bcab0304
...@@ -20,6 +20,7 @@ import ( ...@@ -20,6 +20,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"strconv"
"sync" "sync"
log "sigs.k8s.io/controller-runtime/pkg/log" log "sigs.k8s.io/controller-runtime/pkg/log"
...@@ -39,6 +40,8 @@ const ( ...@@ -39,6 +40,8 @@ const (
WorkerIDHeader = "x-worker-instance-id" WorkerIDHeader = "x-worker-instance-id"
PrefillWorkerIDHeader = "x-prefill-instance-id" PrefillWorkerIDHeader = "x-prefill-instance-id"
DpRankHeader = "x-dp-rank"
PrefillDpRankHeader = "x-prefill-dp-rank"
RoutingModeHeader = "x-dynamo-routing-mode" RoutingModeHeader = "x-dynamo-routing-mode"
// decodeStateKey is the key used to store routing state in PluginState // decodeStateKey is the key used to store routing state in PluginState
...@@ -55,6 +58,7 @@ var _ rc.ResponseComplete = &DynDecodeScorer{} ...@@ -55,6 +58,7 @@ var _ rc.ResponseComplete = &DynDecodeScorer{}
// DecodeRoutingState holds routing information passed from Score() to PreRequest(). // DecodeRoutingState holds routing information passed from Score() to PreRequest().
type DecodeRoutingState struct { type DecodeRoutingState struct {
WorkerID string WorkerID string
DpRank uint32
PrefillWorkerID string PrefillWorkerID string
TokenData []int64 TokenData []int64
} }
...@@ -66,6 +70,7 @@ func (s *DecodeRoutingState) Clone() plugins.StateData { ...@@ -66,6 +70,7 @@ func (s *DecodeRoutingState) Clone() plugins.StateData {
} }
clone := &DecodeRoutingState{ clone := &DecodeRoutingState{
WorkerID: s.WorkerID, WorkerID: s.WorkerID,
DpRank: s.DpRank,
PrefillWorkerID: s.PrefillWorkerID, PrefillWorkerID: s.PrefillWorkerID,
} }
if s.TokenData != nil { if s.TokenData != nil {
...@@ -157,8 +162,10 @@ func (s *DynDecodeScorer) Score(ctx context.Context, cycleState *schedtypes.Cycl ...@@ -157,8 +162,10 @@ func (s *DynDecodeScorer) Score(ctx context.Context, cycleState *schedtypes.Cycl
} }
workerIDStr := fmt.Sprintf("%d", result.WorkerID) workerIDStr := fmt.Sprintf("%d", result.WorkerID)
dpRankStr := strconv.FormatUint(uint64(result.DpRank), 10)
logger.V(logutil.DEFAULT).Info("DynDecodeScorer: decode worker selected", logger.V(logutil.DEFAULT).Info("DynDecodeScorer: decode worker selected",
"decodeWorkerID", workerIDStr, "decodeWorkerID", workerIDStr,
"decodeDpRank", result.DpRank,
"isDisaggregated", isDisaggregated, "isDisaggregated", isDisaggregated,
"tokenCount", len(result.TokenData)) "tokenCount", len(result.TokenData))
...@@ -167,6 +174,7 @@ func (s *DynDecodeScorer) Score(ctx context.Context, cycleState *schedtypes.Cycl ...@@ -167,6 +174,7 @@ func (s *DynDecodeScorer) Score(ctx context.Context, cycleState *schedtypes.Cycl
req.Headers = map[string]string{} req.Headers = map[string]string{}
} }
req.Headers[WorkerIDHeader] = workerIDStr req.Headers[WorkerIDHeader] = workerIDStr
req.Headers[DpRankHeader] = dpRankStr
if isDisaggregated { if isDisaggregated {
req.Headers[RoutingModeHeader] = "disaggregated" req.Headers[RoutingModeHeader] = "disaggregated"
...@@ -188,6 +196,7 @@ func (s *DynDecodeScorer) Score(ctx context.Context, cycleState *schedtypes.Cycl ...@@ -188,6 +196,7 @@ func (s *DynDecodeScorer) Score(ctx context.Context, cycleState *schedtypes.Cycl
if req.RequestId != "" { if req.RequestId != "" {
routingState := &DecodeRoutingState{ routingState := &DecodeRoutingState{
WorkerID: workerIDStr, WorkerID: workerIDStr,
DpRank: result.DpRank,
TokenData: result.TokenData, TokenData: result.TokenData,
} }
s.pluginState.Write(req.RequestId, plugins.StateKey(decodeStateKey), routingState) s.pluginState.Write(req.RequestId, plugins.StateKey(decodeStateKey), routingState)
...@@ -226,7 +235,7 @@ func (s *DynDecodeScorer) PreRequest(ctx context.Context, request *schedtypes.LL ...@@ -226,7 +235,7 @@ func (s *DynDecodeScorer) PreRequest(ctx context.Context, request *schedtypes.LL
return return
} }
if addErr := dynscorer.CallAddRequest(request.RequestId, state.TokenData, workerIDUint, 0); addErr != nil { if addErr := dynscorer.CallAddRequest(request.RequestId, state.TokenData, workerIDUint, state.DpRank); addErr != nil {
logger.V(logutil.DEFAULT).Error(addErr, "DynDecodeScorer PreRequest: failed to add request", logger.V(logutil.DEFAULT).Error(addErr, "DynDecodeScorer PreRequest: failed to add request",
"requestID", request.RequestId) "requestID", request.RequestId)
return return
...@@ -235,6 +244,7 @@ func (s *DynDecodeScorer) PreRequest(ctx context.Context, request *schedtypes.LL ...@@ -235,6 +244,7 @@ func (s *DynDecodeScorer) PreRequest(ctx context.Context, request *schedtypes.LL
logger.V(logutil.VERBOSE).Info("DynDecodeScorer PreRequest: registered request", logger.V(logutil.VERBOSE).Info("DynDecodeScorer PreRequest: registered request",
"requestID", request.RequestId, "requestID", request.RequestId,
"workerID", state.WorkerID, "workerID", state.WorkerID,
"dpRank", state.DpRank,
"tokenCount", len(state.TokenData)) "tokenCount", len(state.TokenData))
} }
......
...@@ -120,11 +120,13 @@ func (s *DynPrefillScorer) Score(ctx context.Context, cycleState *schedtypes.Cyc ...@@ -120,11 +120,13 @@ func (s *DynPrefillScorer) Score(ctx context.Context, cycleState *schedtypes.Cyc
} }
prefillWorkerID := strconv.FormatUint(result.WorkerID, 10) prefillWorkerID := strconv.FormatUint(result.WorkerID, 10)
prefillDpRank := strconv.FormatUint(uint64(result.DpRank), 10)
logger.V(logutil.DEFAULT).Info("DynPrefillScorer: prefill worker selected", logger.V(logutil.DEFAULT).Info("DynPrefillScorer: prefill worker selected",
"prefillWorkerID", prefillWorkerID, "prefillWorkerID", prefillWorkerID,
"prefillDpRank", result.DpRank,
"tokenCount", len(result.TokenData)) "tokenCount", len(result.TokenData))
// Set the prefill worker ID header directly on the request. // Set the prefill worker ID and DP rank headers directly on the request.
// The request object is shared across all profile runs in the scheduling // The request object is shared across all profile runs in the scheduling
// cycle, so the decode scorer (which runs in the next profile) will see it. // cycle, so the decode scorer (which runs in the next profile) will see it.
// This is more reliable than CycleState which may be scoped per profile. // This is more reliable than CycleState which may be scoped per profile.
...@@ -132,6 +134,7 @@ func (s *DynPrefillScorer) Score(ctx context.Context, cycleState *schedtypes.Cyc ...@@ -132,6 +134,7 @@ func (s *DynPrefillScorer) Score(ctx context.Context, cycleState *schedtypes.Cyc
req.Headers = map[string]string{} req.Headers = map[string]string{}
} }
req.Headers[PrefillWorkerIDHeader] = prefillWorkerID req.Headers[PrefillWorkerIDHeader] = prefillWorkerID
req.Headers[PrefillDpRankHeader] = prefillDpRank
// Score: 1.0 for all pods. The label-filter has already restricted to prefill workers, // Score: 1.0 for all pods. The label-filter has already restricted to prefill workers,
// and the FFI router's internal selection is authoritative. // and the FFI router's internal selection is authoritative.
......
...@@ -52,6 +52,8 @@ typedef struct { ...@@ -52,6 +52,8 @@ typedef struct {
bool is_disaggregated; bool is_disaggregated;
uint64_t prefill_worker_id; uint64_t prefill_worker_id;
uint64_t decode_worker_id; uint64_t decode_worker_id;
uint32_t prefill_dp_rank;
uint32_t decode_dp_rank;
uint32_t *token_ids; uint32_t *token_ids;
size_t token_count; size_t token_count;
} CRoutingResult; } CRoutingResult;
...@@ -411,6 +413,7 @@ func CallFreeRequest(requestID string) error { ...@@ -411,6 +413,7 @@ func CallFreeRequest(requestID string) error {
// RoutingResult holds the result of a prefill or decode routing call. // RoutingResult holds the result of a prefill or decode routing call.
type RoutingResult struct { type RoutingResult struct {
WorkerID uint64 WorkerID uint64
DpRank uint32
TokenData []int64 TokenData []int64
} }
...@@ -455,9 +458,10 @@ func CallRoutePrefillRequest(requestJSON string, podsJSON string) (*RoutingResul ...@@ -455,9 +458,10 @@ func CallRoutePrefillRequest(requestJSON string, podsJSON string) (*RoutingResul
} }
workerID := uint64(result.prefill_worker_id) workerID := uint64(result.prefill_worker_id)
dpRank := uint32(result.prefill_dp_rank)
C.free_routing_result(&result) C.free_routing_result(&result)
return &RoutingResult{WorkerID: workerID, TokenData: tokens64}, nil return &RoutingResult{WorkerID: workerID, DpRank: dpRank, TokenData: tokens64}, nil
} }
// CallRouteDecodeRequest routes a request to the best decode worker. // CallRouteDecodeRequest routes a request to the best decode worker.
...@@ -501,7 +505,8 @@ func CallRouteDecodeRequest(requestJSON string, podsJSON string, isDisaggregated ...@@ -501,7 +505,8 @@ func CallRouteDecodeRequest(requestJSON string, podsJSON string, isDisaggregated
} }
workerID := uint64(result.decode_worker_id) workerID := uint64(result.decode_worker_id)
dpRank := uint32(result.decode_dp_rank)
C.free_routing_result(&result) C.free_routing_result(&result)
return &RoutingResult{WorkerID: workerID, TokenData: tokens64}, nil return &RoutingResult{WorkerID: workerID, DpRank: dpRank, TokenData: tokens64}, nil
} }
...@@ -10,12 +10,17 @@ title: Inference Gateway (GAIE) ...@@ -10,12 +10,17 @@ title: Inference Gateway (GAIE)
Integrate Dynamo with the Gateway API Inference Extension for intelligent KV-aware request routing at the gateway layer. Integrate Dynamo with the Gateway API Inference Extension for intelligent KV-aware request routing at the gateway layer.
EPP's default kv-routing approach is not token-aware because the prompt is not tokenized. But the Dynamo plugin uses a token-aware KV algorithm. It employs the dynamo router which implements kv routing by running your model's tokenizer inline. The EPP plugin configuration lives in [`helm/dynamo-gaie/epp-config-dynamo.yaml`](https://github.com/ai-dynamo/dynamo/blob/main/deploy/inference-gateway/standalone/helm/dynamo-gaie/epp-config-dynamo.yaml), following the checked-in GAIE/EPP configuration layout used by this repository. ## Features
Dynamo Integration with the Inference Gateway supports Aggregated and Disaggregated Serving. A request only exercises disaggregated routing when the EPP config defines a `prefill` profile and prefill workers are available. The standalone [`epp-config-dynamo.yaml`](https://github.com/ai-dynamo/dynamo/blob/main/deploy/inference-gateway/standalone/helm/dynamo-gaie/epp-config-dynamo.yaml) currently only defines a `decode` profile, while the recipe examples use separate aggregated and disaggregated configs under `recipes/llama-3-70b/vllm/agg/gaie/` and `recipes/llama-3-70b/vllm/disagg-single-node/gaie/`. Unless `DYN_ENFORCE_DISAGG=true`, deployments without a `prefill` profile or prefill workers fall back to aggregated serving. - EPP's default kv-routing approach is not token-aware because the prompt is not tokenized. But the Dynamo plugin uses a token-aware KV algorithm. It employs the dynamo router which implements kv routing by running your model's tokenizer inline. The EPP plugin configuration lives in [`helm/dynamo-gaie/epp-config-dynamo.yaml`](https://github.com/ai-dynamo/dynamo/blob/main/deploy/inference-gateway/standalone/helm/dynamo-gaie/epp-config-dynamo.yaml), following the checked-in GAIE/EPP configuration layout used by this repository.
If you want to use LoRA deploy Dynamo without the Inference Gateway.
Currently, these setups are only supported with the kGateway based Inference Gateway. - Dynamo Integration with the Inference Gateway supports Aggregated and Disaggregated Serving. A request only exercises disaggregated routing when the EPP config defines a `prefill` profile and prefill workers are available. The standalone [`epp-config-dynamo.yaml`](https://github.com/ai-dynamo/dynamo/blob/main/deploy/inference-gateway/standalone/helm/dynamo-gaie/epp-config-dynamo.yaml) currently only defines a `decode` profile, while the recipe examples use separate aggregated and disaggregated configs under `recipes/llama-3-70b/vllm/agg/gaie/` and `recipes/llama-3-70b/vllm/disagg-single-node/gaie/`. Unless `DYN_ENFORCE_DISAGG=true`, deployments without a `prefill` profile or prefill workers fall back to aggregated serving.
- GAIE integration supports Data Parallelism.
- If you want to use LoRA deploy Dynamo without the Inference Gateway.
- Currently, these setups are only tested with the kGateway Inference Gateway.
## Prerequisites ## Prerequisites
......
...@@ -404,6 +404,10 @@ pub struct CRoutingResult { ...@@ -404,6 +404,10 @@ pub struct CRoutingResult {
pub prefill_worker_id: u64, pub prefill_worker_id: u64,
/// Decode worker ID /// Decode worker ID
pub decode_worker_id: u64, pub decode_worker_id: u64,
/// Data parallel rank selected for the prefill worker
pub prefill_dp_rank: u32,
/// Data parallel rank selected for the decode worker
pub decode_dp_rank: u32,
/// Token IDs (needed for add_request callback) /// Token IDs (needed for add_request callback)
pub token_ids: *mut u32, pub token_ids: *mut u32,
/// Number of tokens in the request /// Number of tokens in the request
...@@ -416,6 +420,8 @@ impl Default for CRoutingResult { ...@@ -416,6 +420,8 @@ impl Default for CRoutingResult {
is_disaggregated: false, is_disaggregated: false,
prefill_worker_id: 0, prefill_worker_id: 0,
decode_worker_id: 0, decode_worker_id: 0,
prefill_dp_rank: 0,
decode_dp_rank: 0,
token_ids: ptr::null_mut(), token_ids: ptr::null_mut(),
token_count: 0, token_count: 0,
} }
...@@ -449,7 +455,7 @@ impl RouterHandles { ...@@ -449,7 +455,7 @@ impl RouterHandles {
lora_name: Option<String>, lora_name: Option<String>,
priority_jump: f64, priority_jump: f64,
allowed_worker_ids: Option<HashSet<WorkerId>>, allowed_worker_ids: Option<HashSet<WorkerId>>,
) -> Result<u64, QueryRouterResult> { ) -> Result<(u64, u32), QueryRouterResult> {
if let Some(ref ids) = allowed_worker_ids { if let Some(ref ids) = allowed_worker_ids {
self.prefill_router.register_workers(ids); self.prefill_router.register_workers(ids);
} }
...@@ -464,7 +470,6 @@ impl RouterHandles { ...@@ -464,7 +470,6 @@ impl RouterHandles {
allowed_worker_ids, allowed_worker_ids,
) )
.await .await
.map(|(worker_id, _dp_rank)| worker_id)
.map_err(|e| { .map_err(|e| {
tracing::error!(error = ?e, "Prefill query failed"); tracing::error!(error = ?e, "Prefill query failed");
QueryRouterResult::ErrQueryFailed QueryRouterResult::ErrQueryFailed
...@@ -1203,25 +1208,27 @@ pub unsafe extern "C" fn route_prefill_request( ...@@ -1203,25 +1208,27 @@ pub unsafe extern "C" fn route_prefill_request(
let allowed_worker_ids = unsafe { parse_pods_filter(pods_json) }; let allowed_worker_ids = unsafe { parse_pods_filter(pods_json) };
let result = handles.runtime.secondary().block_on(async { let result = handles.runtime.secondary().block_on(async {
let prefill_worker_id = handles let (prefill_worker_id, prefill_dp_rank) = handles
.query_prefill_worker(&tokens, None, false, None, 0.0, allowed_worker_ids) .query_prefill_worker(&tokens, None, false, None, 0.0, allowed_worker_ids)
.await?; .await?;
tracing::info!( tracing::info!(
prefill_worker_id = prefill_worker_id, prefill_worker_id = prefill_worker_id,
prefill_dp_rank = prefill_dp_rank,
token_count = tokens.len(), token_count = tokens.len(),
"Routed prefill request" "Routed prefill request"
); );
Ok(prefill_worker_id) Ok((prefill_worker_id, prefill_dp_rank))
}); });
match result { match result {
Ok(prefill_worker_id) => { Ok((prefill_worker_id, prefill_dp_rank)) => {
let out = unsafe { &mut *out_result }; let out = unsafe { &mut *out_result };
*out = CRoutingResult::default(); *out = CRoutingResult::default();
out.is_disaggregated = true; out.is_disaggregated = true;
out.prefill_worker_id = prefill_worker_id; out.prefill_worker_id = prefill_worker_id;
out.prefill_dp_rank = prefill_dp_rank;
write_tokens_to_result(&tokens, out); write_tokens_to_result(&tokens, out);
QueryRouterResult::Ok QueryRouterResult::Ok
} }
...@@ -1290,6 +1297,7 @@ pub unsafe extern "C" fn route_decode_request( ...@@ -1290,6 +1297,7 @@ pub unsafe extern "C" fn route_decode_request(
*out = CRoutingResult::default(); *out = CRoutingResult::default();
out.is_disaggregated = is_disaggregated; out.is_disaggregated = is_disaggregated;
out.decode_worker_id = decode_worker.worker_id; out.decode_worker_id = decode_worker.worker_id;
out.decode_dp_rank = decode_worker.dp_rank;
write_tokens_to_result(&tokens, out); write_tokens_to_result(&tokens, out);
QueryRouterResult::Ok QueryRouterResult::Ok
} }
......
...@@ -40,7 +40,11 @@ impl PrefillRouter { ...@@ -40,7 +40,11 @@ impl PrefillRouter {
// Worker selection // Worker selection
let (worker_id, dp_rank) = if let Some(id) = preselected_worker { let (worker_id, dp_rank) = if let Some(id) = preselected_worker {
let dp_rank = req.routing.as_ref().and_then(|r| r.dp_rank).unwrap_or(0); let dp_rank = req
.routing
.as_ref()
.and_then(|r| r.prefill_dp_rank.or(r.dp_rank))
.unwrap_or(0);
tracing::debug!( tracing::debug!(
worker_id = id, worker_id = id,
dp_rank = dp_rank, dp_rank = dp_rank,
......
...@@ -321,7 +321,8 @@ impl OpenAIPreprocessor { ...@@ -321,7 +321,8 @@ impl OpenAIPreprocessor {
backend_instance_id: nvext.backend_instance_id, backend_instance_id: nvext.backend_instance_id,
prefill_worker_id: nvext.prefill_worker_id, prefill_worker_id: nvext.prefill_worker_id,
decode_worker_id: nvext.decode_worker_id, decode_worker_id: nvext.decode_worker_id,
dp_rank: None, // dp_rank is set later in the pipeline dp_rank: nvext.dp_rank,
prefill_dp_rank: nvext.prefill_dp_rank,
expected_output_tokens: hints.and_then(|h| h.osl), expected_output_tokens: hints.and_then(|h| h.osl),
priority_jump: hints.and_then(|h| { priority_jump: hints.and_then(|h| {
h.priority h.priority
......
...@@ -34,10 +34,14 @@ pub struct RoutingHints { ...@@ -34,10 +34,14 @@ pub struct RoutingHints {
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub decode_worker_id: Option<u64>, pub decode_worker_id: Option<u64>,
/// Data parallel rank for the request /// Data parallel rank for the decode worker
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub dp_rank: Option<u32>, pub dp_rank: Option<u32>,
/// Data parallel rank for the prefill worker in disaggregated serving
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prefill_dp_rank: Option<u32>,
/// Expected number of output tokens for this request. /// Expected number of output tokens for this request.
/// Used as a hint for routing decisions to estimate resource requirements. /// Used as a hint for routing decisions to estimate resource requirements.
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
......
...@@ -11,12 +11,16 @@ pub use crate::protocols::common::timing::TimingInfo; ...@@ -11,12 +11,16 @@ pub use crate::protocols::common::timing::TimingInfo;
pub const HEADER_WORKER_INSTANCE_ID: &str = "x-worker-instance-id"; pub const HEADER_WORKER_INSTANCE_ID: &str = "x-worker-instance-id";
pub const HEADER_PREFILL_INSTANCE_ID: &str = "x-prefill-instance-id"; pub const HEADER_PREFILL_INSTANCE_ID: &str = "x-prefill-instance-id";
pub const HEADER_DP_RANK: &str = "x-dp-rank";
pub const HEADER_PREFILL_DP_RANK: &str = "x-prefill-dp-rank";
/// Apply routing overrides from HTTP headers to nvext. /// Apply routing overrides from HTTP headers to nvext.
/// ///
/// Header mappings: /// Header mappings:
/// - `x-worker-instance-id` -> `backend_instance_id` and `decode_worker_id` /// - `x-worker-instance-id` -> `backend_instance_id` and `decode_worker_id`
/// - `x-prefill-instance-id` -> `prefill_worker_id` /// - `x-prefill-instance-id` -> `prefill_worker_id`
/// - `x-dp-rank` -> `dp_rank` (decode worker's DP rank)
/// - `x-prefill-dp-rank` -> `prefill_dp_rank`
/// ///
/// Headers take priority over existing nvext values when present. /// Headers take priority over existing nvext values when present.
/// If no headers are present, returns the original nvext unchanged. /// If no headers are present, returns the original nvext unchanged.
...@@ -31,7 +35,18 @@ pub fn apply_header_routing_overrides(nvext: Option<NvExt>, headers: &HeaderMap) ...@@ -31,7 +35,18 @@ pub fn apply_header_routing_overrides(nvext: Option<NvExt>, headers: &HeaderMap)
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok()); .and_then(|s| s.parse::<u64>().ok());
if worker_id.is_none() && prefill_id.is_none() { let dp_rank = headers
.get(HEADER_DP_RANK)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u32>().ok());
let prefill_dp_rank = headers
.get(HEADER_PREFILL_DP_RANK)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u32>().ok());
if worker_id.is_none() && prefill_id.is_none() && dp_rank.is_none() && prefill_dp_rank.is_none()
{
return nvext; return nvext;
} }
...@@ -43,6 +58,12 @@ pub fn apply_header_routing_overrides(nvext: Option<NvExt>, headers: &HeaderMap) ...@@ -43,6 +58,12 @@ pub fn apply_header_routing_overrides(nvext: Option<NvExt>, headers: &HeaderMap)
if let Some(id) = prefill_id { if let Some(id) = prefill_id {
ext.prefill_worker_id = Some(id); ext.prefill_worker_id = Some(id);
} }
if let Some(rank) = dp_rank {
ext.dp_rank = Some(rank);
}
if let Some(rank) = prefill_dp_rank {
ext.prefill_dp_rank = Some(rank);
}
Some(ext) Some(ext)
} }
...@@ -157,6 +178,19 @@ pub struct NvExt { ...@@ -157,6 +178,19 @@ pub struct NvExt {
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub decode_worker_id: Option<u64>, pub decode_worker_id: Option<u64>,
/// Data parallel rank for the decode worker, set by the EPP via the
/// `x-dp-rank` header. When a worker hosts multiple DP engines,
/// this steers the request to the correct engine instance.
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dp_rank: Option<u32>,
/// Data parallel rank for the prefill worker in disaggregated serving,
/// set by the EPP via the `x-prefill-dp-rank` header.
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prefill_dp_rank: Option<u32>,
/// Agent-provided hints for request handling. /// Agent-provided hints for request handling.
#[builder(default, setter(strip_option))] #[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
...@@ -288,29 +322,22 @@ mod tests { ...@@ -288,29 +322,22 @@ mod tests {
assert!(nv_ext.validate().is_ok()); assert!(nv_ext.validate().is_ok());
} }
// Test apply_header_routing_overrides - worker header present, prefill header absent
#[test] #[test]
fn test_apply_header_routing_overrides() { fn test_apply_header_routing_overrides() {
use axum::http::HeaderMap; use axum::http::HeaderMap;
// Only HEADER_WORKER_INSTANCE_ID is in the header
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(HEADER_WORKER_INSTANCE_ID, "123".parse().unwrap()); headers.insert(HEADER_WORKER_INSTANCE_ID, "123".parse().unwrap());
// Note: HEADER_PREFILL_INSTANCE_ID is NOT in the header headers.insert(HEADER_PREFILL_INSTANCE_ID, "456".parse().unwrap());
headers.insert(HEADER_DP_RANK, "3".parse().unwrap());
let nvext = NvExt::builder() headers.insert(HEADER_PREFILL_DP_RANK, "5".parse().unwrap());
.backend_instance_id(999)
.decode_worker_id(888)
.prefill_worker_id(777)
.build()
.unwrap();
let result = apply_header_routing_overrides(Some(nvext), &headers).unwrap(); let result = apply_header_routing_overrides(None, &headers).unwrap();
// Header should override backend_instance_id and decode_worker_id
assert_eq!(result.backend_instance_id, Some(123)); assert_eq!(result.backend_instance_id, Some(123));
assert_eq!(result.decode_worker_id, Some(123)); assert_eq!(result.decode_worker_id, Some(123));
// prefill_worker_id should remain from original nvext (not overwritten by header) assert_eq!(result.prefill_worker_id, Some(456));
assert_eq!(result.prefill_worker_id, Some(777)); assert_eq!(result.dp_rank, Some(3));
assert_eq!(result.prefill_dp_rank, Some(5));
} }
} }
...@@ -2161,6 +2161,8 @@ def _test_disagg_direct_mode( ...@@ -2161,6 +2161,8 @@ def _test_disagg_direct_mode(
headers = { headers = {
"x-worker-instance-id": str(target_decode), "x-worker-instance-id": str(target_decode),
"x-prefill-instance-id": str(target_prefill), "x-prefill-instance-id": str(target_prefill),
"x-dp-rank": "0",
"x-prefill-dp-rank": "0",
} }
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment