Unverified Commit 2b906504 authored by atchernych's avatar atchernych Committed by GitHub
Browse files

feat: enable DP for GAIE (#7741)


Signed-off-by: default avatarAnna Tchernych <atchernych@nvidia.com>
parent bcab0304
......@@ -20,6 +20,7 @@ import (
"context"
"encoding/json"
"fmt"
"strconv"
"sync"
log "sigs.k8s.io/controller-runtime/pkg/log"
......@@ -39,6 +40,8 @@ const (
WorkerIDHeader = "x-worker-instance-id"
PrefillWorkerIDHeader = "x-prefill-instance-id"
DpRankHeader = "x-dp-rank"
PrefillDpRankHeader = "x-prefill-dp-rank"
RoutingModeHeader = "x-dynamo-routing-mode"
// decodeStateKey is the key used to store routing state in PluginState
......@@ -55,6 +58,7 @@ var _ rc.ResponseComplete = &DynDecodeScorer{}
// DecodeRoutingState holds routing information passed from Score() to PreRequest().
type DecodeRoutingState struct {
WorkerID string
DpRank uint32
PrefillWorkerID string
TokenData []int64
}
......@@ -66,6 +70,7 @@ func (s *DecodeRoutingState) Clone() plugins.StateData {
}
clone := &DecodeRoutingState{
WorkerID: s.WorkerID,
DpRank: s.DpRank,
PrefillWorkerID: s.PrefillWorkerID,
}
if s.TokenData != nil {
......@@ -157,8 +162,10 @@ func (s *DynDecodeScorer) Score(ctx context.Context, cycleState *schedtypes.Cycl
}
workerIDStr := fmt.Sprintf("%d", result.WorkerID)
dpRankStr := strconv.FormatUint(uint64(result.DpRank), 10)
logger.V(logutil.DEFAULT).Info("DynDecodeScorer: decode worker selected",
"decodeWorkerID", workerIDStr,
"decodeDpRank", result.DpRank,
"isDisaggregated", isDisaggregated,
"tokenCount", len(result.TokenData))
......@@ -167,6 +174,7 @@ func (s *DynDecodeScorer) Score(ctx context.Context, cycleState *schedtypes.Cycl
req.Headers = map[string]string{}
}
req.Headers[WorkerIDHeader] = workerIDStr
req.Headers[DpRankHeader] = dpRankStr
if isDisaggregated {
req.Headers[RoutingModeHeader] = "disaggregated"
......@@ -188,6 +196,7 @@ func (s *DynDecodeScorer) Score(ctx context.Context, cycleState *schedtypes.Cycl
if req.RequestId != "" {
routingState := &DecodeRoutingState{
WorkerID: workerIDStr,
DpRank: result.DpRank,
TokenData: result.TokenData,
}
s.pluginState.Write(req.RequestId, plugins.StateKey(decodeStateKey), routingState)
......@@ -226,7 +235,7 @@ func (s *DynDecodeScorer) PreRequest(ctx context.Context, request *schedtypes.LL
return
}
if addErr := dynscorer.CallAddRequest(request.RequestId, state.TokenData, workerIDUint, 0); addErr != nil {
if addErr := dynscorer.CallAddRequest(request.RequestId, state.TokenData, workerIDUint, state.DpRank); addErr != nil {
logger.V(logutil.DEFAULT).Error(addErr, "DynDecodeScorer PreRequest: failed to add request",
"requestID", request.RequestId)
return
......@@ -235,6 +244,7 @@ func (s *DynDecodeScorer) PreRequest(ctx context.Context, request *schedtypes.LL
logger.V(logutil.VERBOSE).Info("DynDecodeScorer PreRequest: registered request",
"requestID", request.RequestId,
"workerID", state.WorkerID,
"dpRank", state.DpRank,
"tokenCount", len(state.TokenData))
}
......
......@@ -120,11 +120,13 @@ func (s *DynPrefillScorer) Score(ctx context.Context, cycleState *schedtypes.Cyc
}
prefillWorkerID := strconv.FormatUint(result.WorkerID, 10)
prefillDpRank := strconv.FormatUint(uint64(result.DpRank), 10)
logger.V(logutil.DEFAULT).Info("DynPrefillScorer: prefill worker selected",
"prefillWorkerID", prefillWorkerID,
"prefillDpRank", result.DpRank,
"tokenCount", len(result.TokenData))
// Set the prefill worker ID header directly on the request.
// Set the prefill worker ID and DP rank headers directly on the request.
// The request object is shared across all profile runs in the scheduling
// cycle, so the decode scorer (which runs in the next profile) will see it.
// This is more reliable than CycleState which may be scoped per profile.
......@@ -132,6 +134,7 @@ func (s *DynPrefillScorer) Score(ctx context.Context, cycleState *schedtypes.Cyc
req.Headers = map[string]string{}
}
req.Headers[PrefillWorkerIDHeader] = prefillWorkerID
req.Headers[PrefillDpRankHeader] = prefillDpRank
// Score: 1.0 for all pods. The label-filter has already restricted to prefill workers,
// and the FFI router's internal selection is authoritative.
......
......@@ -52,6 +52,8 @@ typedef struct {
bool is_disaggregated;
uint64_t prefill_worker_id;
uint64_t decode_worker_id;
uint32_t prefill_dp_rank;
uint32_t decode_dp_rank;
uint32_t *token_ids;
size_t token_count;
} CRoutingResult;
......@@ -411,6 +413,7 @@ func CallFreeRequest(requestID string) error {
// RoutingResult holds the result of a prefill or decode routing call.
type RoutingResult struct {
WorkerID uint64
DpRank uint32
TokenData []int64
}
......@@ -455,9 +458,10 @@ func CallRoutePrefillRequest(requestJSON string, podsJSON string) (*RoutingResul
}
workerID := uint64(result.prefill_worker_id)
dpRank := uint32(result.prefill_dp_rank)
C.free_routing_result(&result)
return &RoutingResult{WorkerID: workerID, TokenData: tokens64}, nil
return &RoutingResult{WorkerID: workerID, DpRank: dpRank, TokenData: tokens64}, nil
}
// CallRouteDecodeRequest routes a request to the best decode worker.
......@@ -501,7 +505,8 @@ func CallRouteDecodeRequest(requestJSON string, podsJSON string, isDisaggregated
}
workerID := uint64(result.decode_worker_id)
dpRank := uint32(result.decode_dp_rank)
C.free_routing_result(&result)
return &RoutingResult{WorkerID: workerID, TokenData: tokens64}, nil
return &RoutingResult{WorkerID: workerID, DpRank: dpRank, TokenData: tokens64}, nil
}
......@@ -10,12 +10,17 @@ title: Inference Gateway (GAIE)
Integrate Dynamo with the Gateway API Inference Extension for intelligent KV-aware request routing at the gateway layer.
EPP's default kv-routing approach is not token-aware because the prompt is not tokenized. But the Dynamo plugin uses a token-aware KV algorithm. It employs the dynamo router which implements kv routing by running your model's tokenizer inline. The EPP plugin configuration lives in [`helm/dynamo-gaie/epp-config-dynamo.yaml`](https://github.com/ai-dynamo/dynamo/blob/main/deploy/inference-gateway/standalone/helm/dynamo-gaie/epp-config-dynamo.yaml), following the checked-in GAIE/EPP configuration layout used by this repository.
## Features
Dynamo Integration with the Inference Gateway supports Aggregated and Disaggregated Serving. A request only exercises disaggregated routing when the EPP config defines a `prefill` profile and prefill workers are available. The standalone [`epp-config-dynamo.yaml`](https://github.com/ai-dynamo/dynamo/blob/main/deploy/inference-gateway/standalone/helm/dynamo-gaie/epp-config-dynamo.yaml) currently only defines a `decode` profile, while the recipe examples use separate aggregated and disaggregated configs under `recipes/llama-3-70b/vllm/agg/gaie/` and `recipes/llama-3-70b/vllm/disagg-single-node/gaie/`. Unless `DYN_ENFORCE_DISAGG=true`, deployments without a `prefill` profile or prefill workers fall back to aggregated serving.
If you want to use LoRA deploy Dynamo without the Inference Gateway.
- EPP's default kv-routing approach is not token-aware because the prompt is not tokenized. But the Dynamo plugin uses a token-aware KV algorithm. It employs the dynamo router which implements kv routing by running your model's tokenizer inline. The EPP plugin configuration lives in [`helm/dynamo-gaie/epp-config-dynamo.yaml`](https://github.com/ai-dynamo/dynamo/blob/main/deploy/inference-gateway/standalone/helm/dynamo-gaie/epp-config-dynamo.yaml), following the checked-in GAIE/EPP configuration layout used by this repository.
Currently, these setups are only supported with the kGateway based Inference Gateway.
- Dynamo Integration with the Inference Gateway supports Aggregated and Disaggregated Serving. A request only exercises disaggregated routing when the EPP config defines a `prefill` profile and prefill workers are available. The standalone [`epp-config-dynamo.yaml`](https://github.com/ai-dynamo/dynamo/blob/main/deploy/inference-gateway/standalone/helm/dynamo-gaie/epp-config-dynamo.yaml) currently only defines a `decode` profile, while the recipe examples use separate aggregated and disaggregated configs under `recipes/llama-3-70b/vllm/agg/gaie/` and `recipes/llama-3-70b/vllm/disagg-single-node/gaie/`. Unless `DYN_ENFORCE_DISAGG=true`, deployments without a `prefill` profile or prefill workers fall back to aggregated serving.
- GAIE integration supports Data Parallelism.
- If you want to use LoRA deploy Dynamo without the Inference Gateway.
- Currently, these setups are only tested with the kGateway Inference Gateway.
## Prerequisites
......
......@@ -404,6 +404,10 @@ pub struct CRoutingResult {
pub prefill_worker_id: u64,
/// Decode worker ID
pub decode_worker_id: u64,
/// Data parallel rank selected for the prefill worker
pub prefill_dp_rank: u32,
/// Data parallel rank selected for the decode worker
pub decode_dp_rank: u32,
/// Token IDs (needed for add_request callback)
pub token_ids: *mut u32,
/// Number of tokens in the request
......@@ -416,6 +420,8 @@ impl Default for CRoutingResult {
is_disaggregated: false,
prefill_worker_id: 0,
decode_worker_id: 0,
prefill_dp_rank: 0,
decode_dp_rank: 0,
token_ids: ptr::null_mut(),
token_count: 0,
}
......@@ -449,7 +455,7 @@ impl RouterHandles {
lora_name: Option<String>,
priority_jump: f64,
allowed_worker_ids: Option<HashSet<WorkerId>>,
) -> Result<u64, QueryRouterResult> {
) -> Result<(u64, u32), QueryRouterResult> {
if let Some(ref ids) = allowed_worker_ids {
self.prefill_router.register_workers(ids);
}
......@@ -464,7 +470,6 @@ impl RouterHandles {
allowed_worker_ids,
)
.await
.map(|(worker_id, _dp_rank)| worker_id)
.map_err(|e| {
tracing::error!(error = ?e, "Prefill query failed");
QueryRouterResult::ErrQueryFailed
......@@ -1203,25 +1208,27 @@ pub unsafe extern "C" fn route_prefill_request(
let allowed_worker_ids = unsafe { parse_pods_filter(pods_json) };
let result = handles.runtime.secondary().block_on(async {
let prefill_worker_id = handles
let (prefill_worker_id, prefill_dp_rank) = handles
.query_prefill_worker(&tokens, None, false, None, 0.0, allowed_worker_ids)
.await?;
tracing::info!(
prefill_worker_id = prefill_worker_id,
prefill_dp_rank = prefill_dp_rank,
token_count = tokens.len(),
"Routed prefill request"
);
Ok(prefill_worker_id)
Ok((prefill_worker_id, prefill_dp_rank))
});
match result {
Ok(prefill_worker_id) => {
Ok((prefill_worker_id, prefill_dp_rank)) => {
let out = unsafe { &mut *out_result };
*out = CRoutingResult::default();
out.is_disaggregated = true;
out.prefill_worker_id = prefill_worker_id;
out.prefill_dp_rank = prefill_dp_rank;
write_tokens_to_result(&tokens, out);
QueryRouterResult::Ok
}
......@@ -1290,6 +1297,7 @@ pub unsafe extern "C" fn route_decode_request(
*out = CRoutingResult::default();
out.is_disaggregated = is_disaggregated;
out.decode_worker_id = decode_worker.worker_id;
out.decode_dp_rank = decode_worker.dp_rank;
write_tokens_to_result(&tokens, out);
QueryRouterResult::Ok
}
......
......@@ -40,7 +40,11 @@ impl PrefillRouter {
// Worker selection
let (worker_id, dp_rank) = if let Some(id) = preselected_worker {
let dp_rank = req.routing.as_ref().and_then(|r| r.dp_rank).unwrap_or(0);
let dp_rank = req
.routing
.as_ref()
.and_then(|r| r.prefill_dp_rank.or(r.dp_rank))
.unwrap_or(0);
tracing::debug!(
worker_id = id,
dp_rank = dp_rank,
......
......@@ -321,7 +321,8 @@ impl OpenAIPreprocessor {
backend_instance_id: nvext.backend_instance_id,
prefill_worker_id: nvext.prefill_worker_id,
decode_worker_id: nvext.decode_worker_id,
dp_rank: None, // dp_rank is set later in the pipeline
dp_rank: nvext.dp_rank,
prefill_dp_rank: nvext.prefill_dp_rank,
expected_output_tokens: hints.and_then(|h| h.osl),
priority_jump: hints.and_then(|h| {
h.priority
......
......@@ -34,10 +34,14 @@ pub struct RoutingHints {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub decode_worker_id: Option<u64>,
/// Data parallel rank for the request
/// Data parallel rank for the decode worker
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dp_rank: Option<u32>,
/// Data parallel rank for the prefill worker in disaggregated serving
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prefill_dp_rank: Option<u32>,
/// Expected number of output tokens for this request.
/// Used as a hint for routing decisions to estimate resource requirements.
#[serde(default, skip_serializing_if = "Option::is_none")]
......
......@@ -11,12 +11,16 @@ pub use crate::protocols::common::timing::TimingInfo;
pub const HEADER_WORKER_INSTANCE_ID: &str = "x-worker-instance-id";
pub const HEADER_PREFILL_INSTANCE_ID: &str = "x-prefill-instance-id";
pub const HEADER_DP_RANK: &str = "x-dp-rank";
pub const HEADER_PREFILL_DP_RANK: &str = "x-prefill-dp-rank";
/// Apply routing overrides from HTTP headers to nvext.
///
/// Header mappings:
/// - `x-worker-instance-id` -> `backend_instance_id` and `decode_worker_id`
/// - `x-prefill-instance-id` -> `prefill_worker_id`
/// - `x-dp-rank` -> `dp_rank` (decode worker's DP rank)
/// - `x-prefill-dp-rank` -> `prefill_dp_rank`
///
/// Headers take priority over existing nvext values when present.
/// If no headers are present, returns the original nvext unchanged.
......@@ -31,7 +35,18 @@ pub fn apply_header_routing_overrides(nvext: Option<NvExt>, headers: &HeaderMap)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
if worker_id.is_none() && prefill_id.is_none() {
let dp_rank = headers
.get(HEADER_DP_RANK)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u32>().ok());
let prefill_dp_rank = headers
.get(HEADER_PREFILL_DP_RANK)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u32>().ok());
if worker_id.is_none() && prefill_id.is_none() && dp_rank.is_none() && prefill_dp_rank.is_none()
{
return nvext;
}
......@@ -43,6 +58,12 @@ pub fn apply_header_routing_overrides(nvext: Option<NvExt>, headers: &HeaderMap)
if let Some(id) = prefill_id {
ext.prefill_worker_id = Some(id);
}
if let Some(rank) = dp_rank {
ext.dp_rank = Some(rank);
}
if let Some(rank) = prefill_dp_rank {
ext.prefill_dp_rank = Some(rank);
}
Some(ext)
}
......@@ -157,6 +178,19 @@ pub struct NvExt {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub decode_worker_id: Option<u64>,
/// Data parallel rank for the decode worker, set by the EPP via the
/// `x-dp-rank` header. When a worker hosts multiple DP engines,
/// this steers the request to the correct engine instance.
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dp_rank: Option<u32>,
/// Data parallel rank for the prefill worker in disaggregated serving,
/// set by the EPP via the `x-prefill-dp-rank` header.
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prefill_dp_rank: Option<u32>,
/// Agent-provided hints for request handling.
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
......@@ -288,29 +322,22 @@ mod tests {
assert!(nv_ext.validate().is_ok());
}
// Test apply_header_routing_overrides - worker header present, prefill header absent
#[test]
fn test_apply_header_routing_overrides() {
use axum::http::HeaderMap;
// Only HEADER_WORKER_INSTANCE_ID is in the header
let mut headers = HeaderMap::new();
headers.insert(HEADER_WORKER_INSTANCE_ID, "123".parse().unwrap());
// Note: HEADER_PREFILL_INSTANCE_ID is NOT in the header
let nvext = NvExt::builder()
.backend_instance_id(999)
.decode_worker_id(888)
.prefill_worker_id(777)
.build()
.unwrap();
headers.insert(HEADER_PREFILL_INSTANCE_ID, "456".parse().unwrap());
headers.insert(HEADER_DP_RANK, "3".parse().unwrap());
headers.insert(HEADER_PREFILL_DP_RANK, "5".parse().unwrap());
let result = apply_header_routing_overrides(Some(nvext), &headers).unwrap();
let result = apply_header_routing_overrides(None, &headers).unwrap();
// Header should override backend_instance_id and decode_worker_id
assert_eq!(result.backend_instance_id, Some(123));
assert_eq!(result.decode_worker_id, Some(123));
// prefill_worker_id should remain from original nvext (not overwritten by header)
assert_eq!(result.prefill_worker_id, Some(777));
assert_eq!(result.prefill_worker_id, Some(456));
assert_eq!(result.dp_rank, Some(3));
assert_eq!(result.prefill_dp_rank, Some(5));
}
}
......@@ -2161,6 +2161,8 @@ def _test_disagg_direct_mode(
headers = {
"x-worker-instance-id": str(target_decode),
"x-prefill-instance-id": str(target_prefill),
"x-dp-rank": "0",
"x-prefill-dp-rank": "0",
}
async with aiohttp.ClientSession() as session:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment