Unverified Commit 15b49818 authored by atchernych's avatar atchernych Committed by GitHub
Browse files

feat: support disag serving in GAIE [DEP-659] (#4756)


Signed-off-by: default avatarAnna Tchernych <atchernych@nvidia.com>
parent 7a3b15e6
......@@ -161,10 +161,10 @@ index 670d922..0cf04cb 100644
}
diff --git a/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go b/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go
new file mode 100644
index 0000000..cd9a0b5
index 0000000..1c8f979
--- /dev/null
+++ b/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go
@@ -0,0 +1,119 @@
@@ -0,0 +1,171 @@
+package dynamo_inject_workerid
+
+import (
......@@ -182,6 +182,7 @@ index 0000000..cd9a0b5
+ typeString = "dynamo-inject-workerid"
+ pluginName = "dynamo-inject-workerid"
+ WorkerIDHeader = "x-worker-instance-id"
+ PrefillWorkerIDHeader = "x-prefiller-host-port"
+ tokenDataAnnotationKey = "dynamo/token-data"
+)
+
......@@ -222,11 +223,18 @@ index 0000000..cd9a0b5
+ if req.Headers == nil {
+ req.Headers = map[string]string{}
+ }
+
+ // Handle worker instance ID
+ wid := strings.TrimSpace(req.Headers[WorkerIDHeader])
+ if wid == "" {
+ return
+ }
+ if wid != "" {
+ req.Headers[WorkerIDHeader] = wid
+ }
+
+ // Handle prefill worker ID
+ prefillWid := strings.TrimSpace(req.Headers[PrefillWorkerIDHeader])
+ if prefillWid != "" {
+ req.Headers[PrefillWorkerIDHeader] = prefillWid
+ }
+}
+
+func (p *InjectWorkerIDPreRequest) MutateRequestBody(
......@@ -248,14 +256,28 @@ index 0000000..cd9a0b5
+ return
+ }
+
+ prefillWid := strings.TrimSpace(req.Headers[PrefillWorkerIDHeader])
+
+ nvext, _ := body["nvext"].(map[string]any)
+ if nvext == nil {
+ nvext = map[string]any{}
+ body["nvext"] = nvext
+ }
+
+ if prefillWid != "" && prefillWid != wid {
+ // Disaggregated mode: use prefill_worker_id and decode_worker_id
+ if prefillWidUint, err := strconv.ParseUint(prefillWid, 10, 64); err == nil {
+ nvext["prefill_worker_id"] = prefillWidUint
+ }
+ if widUint, err := strconv.ParseUint(wid, 10, 64); err == nil {
+ nvext["decode_worker_id"] = widUint
+ }
+ } else {
+ // Aggregated mode (empty prefill or prefill == decode): use backend_instance_id
+ if widUint, err := strconv.ParseUint(wid, 10, 64); err == nil {
+ nvext["backend_instance_id"] = widUint
+ }
+ }
+
+ if tokens, ok := req.Annotations[tokenDataAnnotationKey]; ok {
+ switch v := tokens.(type) {
......@@ -283,6 +305,36 @@ index 0000000..cd9a0b5
+ }
+ }
+ }
+
+ // Remove query_instance_id from nvext.annotations if present
+ if annotations, ok := nvext["annotations"]; ok {
+ switch annList := annotations.(type) {
+ case []string:
+ filtered := make([]string, 0, len(annList))
+ for _, ann := range annList {
+ if ann != "query_instance_id" {
+ filtered = append(filtered, ann)
+ }
+ }
+ if len(filtered) == 0 {
+ delete(nvext, "annotations")
+ } else {
+ nvext["annotations"] = filtered
+ }
+ case []any:
+ filtered := make([]any, 0, len(annList))
+ for _, ann := range annList {
+ if str, ok := ann.(string); !ok || str != "query_instance_id" {
+ filtered = append(filtered, ann)
+ }
+ }
+ if len(filtered) == 0 {
+ delete(nvext, "annotations")
+ } else {
+ nvext["annotations"] = filtered
+ }
+ }
+ }
+}
diff --git a/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml
new file mode 100644
......@@ -313,10 +365,10 @@ index 0000000..b689c00
+ - pluginRef: picker
diff --git a/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go
new file mode 100644
index 0000000..bc29c0a
index 0000000..75f30e9
--- /dev/null
+++ b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go
@@ -0,0 +1,424 @@
@@ -0,0 +1,446 @@
+package dynamo_kv_scorer
+
+/*
......@@ -367,13 +419,15 @@ index 0000000..bc29c0a
+ double router_temperature,
+ bool use_kv_events,
+ bool router_replica_sync,
+ bool enforce_disagg,
+ WorkerSelectionPipeline **pipeline_out);
+
+dynamo_llm_result_t dynamo_destroy_worker_selection_pipeline(WorkerSelectionPipeline *pipeline);
+
+dynamo_llm_result_t dynamo_query_worker_selection_and_annotate(WorkerSelectionPipeline *pipeline,
+ const char *request_json_c_str,
+ int64_t *worker_instance_id_out,
+ int64_t *decode_worker_id_out,
+ int64_t *prefill_worker_id_out,
+ uint32_t **token_ids_out,
+ size_t *token_count_out,
+ char **annotated_request_json_out);
......@@ -404,7 +458,9 @@ index 0000000..bc29c0a
+ PluginName = "dynamo-kv-scorer"
+ KVAwareScorerType = "kv-aware-scorer"
+ StateKeyWorkerInstanceID = schedtypes.StateKey("dynamo/worker-instance-id")
+ StateKeyPrefillWorkerID = schedtypes.StateKey("dynamo/prefill-worker-id")
+ WorkerIDHeader = "x-worker-instance-id"
+ PrefillWorkerIDHeader = "x-prefiller-host-port"
+ tokenDataAnnotationKey = "dynamo/token-data"
+)
+
......@@ -471,6 +527,7 @@ index 0000000..bc29c0a
+ ffiRouterTemperature float64
+ ffiKvBlockSize uint32
+ ffiWorkerID int64
+ ffiEnforceDisagg bool
+
+ runtimeInitialized bool
+
......@@ -484,6 +541,7 @@ index 0000000..bc29c0a
+ ffiComponent = getEnvOrDefault("DYNAMO_COMPONENT", "backend")
+ ffiModel = getEnvOrDefault("DYNAMO_MODEL", "Qwen/Qwen3-0.6B")
+ ffiWorkerID = getEnvInt64OrDefault("DYNAMO_WORKER_ID", 1)
+ ffiEnforceDisagg = getEnvBoolOrDefault("DYNAMO_ENFORCE_DISAGG", true) // TODO default to false
+
+ ffiOverlapScoreWeight = getEnvFloatOrDefault("DYNAMO_OVERLAP_SCORE_WEIGHT", -1.0)
+ ffiRouterTemperature = getEnvFloatOrDefault("DYNAMO_ROUTER_TEMPERATURE", -1.0)
......@@ -575,6 +633,7 @@ index 0000000..bc29c0a
+ C.double(ffiRouterTemperature),
+ C.bool(getEnvBoolOrDefault("DYNAMO_USE_KV_EVENTS", true)),
+ C.bool(getEnvBoolOrDefault("DYNAMO_ROUTER_REPLICA_SYNC", true)),
+ C.bool(ffiEnforceDisagg),
+ &pipeline,
+ )
+ if rc != C.DYNAMO_OK {
......@@ -595,13 +654,14 @@ index 0000000..bc29c0a
+) map[schedtypes.Pod]float64 {
+ logger := log.FromContext(ctx)
+
+ workerID, tokenData, err := k.callDynamoRouter(ctx, req)
+ workerID, prefillWorkerID, tokenData, err := k.callDynamoRouter(ctx, req)
+ if err != nil {
+ logger.V(logutil.DEFAULT).Error(err, "Dynamo call failed; proceeding without worker id")
+ } else if workerID != "" {
+ logger.V(logutil.DEFAULT).Info(
+ "Dynamo router selected worker",
+ "workerID", workerID,
+ "prefillWorkerID", prefillWorkerID,
+ "tokenDataCount", len(tokenData),
+ "tokenData", tokenData,
+ )
......@@ -610,6 +670,13 @@ index 0000000..bc29c0a
+ req.Headers = map[string]string{}
+ }
+ req.Headers[WorkerIDHeader] = workerID
+
+ // Set prefill worker ID if present
+ if prefillWorkerID != "" {
+ cycle.Write(StateKeyPrefillWorkerID, stateString(prefillWorkerID))
+ req.Headers[PrefillWorkerIDHeader] = prefillWorkerID
+ }
+
+ if len(tokenData) > 0 {
+ if req.Annotations == nil {
+ req.Annotations = map[string]any{}
......@@ -632,15 +699,15 @@ index 0000000..bc29c0a
+func (k *KVAwareScorer) callDynamoRouter(
+ ctx context.Context,
+ req *schedtypes.LLMRequest,
+) (string, []int64, error) {
+) (workerID string, prefillWorkerID string, tokenData []int64, err error) {
+ logger := log.FromContext(ctx)
+
+ if err := initFFI(); err != nil {
+ logger.V(logutil.DEFAULT).Error(err, "FFI init failed")
+ return "", nil, err
+ return "", "", nil, err
+ }
+ if !runtimeInitialized {
+ return "", nil, fmt.Errorf("dynamo runtime not initialized")
+ return "", "", nil, fmt.Errorf("dynamo runtime not initialized")
+ }
+
+ pipelineMutex.RLock()
......@@ -648,21 +715,22 @@ index 0000000..bc29c0a
+ pipelineMutex.RUnlock()
+
+ if currentPipeline == nil {
+ return "", nil, fmt.Errorf("dynamo worker selection pipeline not created")
+ return "", "", nil, fmt.Errorf("dynamo worker selection pipeline not created")
+ }
+
+ // Build OpenAI-compatible JSON request
+ requestBody := buildOpenAIRequest(req)
+ requestJSON, err := json.Marshal(requestBody)
+ if err != nil {
+ logger.V(logutil.DEFAULT).Error(err, "Failed to marshal OpenAI request")
+ return "", nil, fmt.Errorf("marshal OpenAI request: %w", err)
+ requestJSON, jsonErr := json.Marshal(requestBody)
+ if jsonErr != nil {
+ logger.V(logutil.DEFAULT).Error(jsonErr, "Failed to marshal OpenAI request")
+ return "", "", nil, fmt.Errorf("marshal OpenAI request: %w", jsonErr)
+ }
+ cRequestJSON := C.CString(string(requestJSON))
+ defer C.free(unsafe.Pointer(cRequestJSON))
+
+ // Output variables
+ var cWorkerID C.int64_t
+ var cDecodeWorkerID C.int64_t
+ var cPrefillWorkerID C.int64_t
+ var cTokens *C.uint32_t
+ var cTokenCount C.size_t
+ var cAnnotatedJSON *C.char
......@@ -671,13 +739,14 @@ index 0000000..bc29c0a
+ rc := C.dynamo_query_worker_selection_and_annotate(
+ currentPipeline,
+ cRequestJSON,
+ &cWorkerID,
+ &cDecodeWorkerID,
+ &cPrefillWorkerID,
+ &cTokens,
+ &cTokenCount,
+ &cAnnotatedJSON,
+ )
+ if rc != C.DYNAMO_OK {
+ return "", nil, fmt.Errorf("dynamo_query_worker_selection_and_annotate failed")
+ return "", "", nil, fmt.Errorf("dynamo_query_worker_selection_and_annotate failed")
+ }
+
+ // Copy tokens into Go memory and free C memory
......@@ -692,11 +761,16 @@ index 0000000..bc29c0a
+ }
+ C.dynamo_free_worker_selection_result(cTokens, cTokenCount, cAnnotatedJSON)
+
+ workerID := fmt.Sprintf("%d", int64(cWorkerID))
+ workerIDStr := fmt.Sprintf("%d", int64(cDecodeWorkerID))
+ prefillWorkerIDStr := ""
+ // Rust returns -1 for prefill_worker_id when not in disaggregated mode
+ if int64(cPrefillWorkerID) >= 0 {
+ prefillWorkerIDStr = fmt.Sprintf("%d", int64(cPrefillWorkerID))
+ }
+ logger.V(logutil.DEFAULT).Info("Worker selection completed",
+ "workerID", workerID, "tokenCount", count)
+ "workerID", workerIDStr, "prefillWorkerID", prefillWorkerIDStr, "tokenCount", count)
+
+ return workerID, tokens64, nil
+ return workerIDStr, prefillWorkerIDStr, tokens64, nil
+}
+
+func buildOpenAIRequest(req *schedtypes.LLMRequest) map[string]any {
......
This diff is collapsed.
......@@ -97,6 +97,46 @@ pub fn router_endpoint_id(namespace: String) -> EndpointId {
}
}
/// Specifies the type of worker being queried when using the `query_instance_id` annotation.
/// This tells the router which worker pool to select from and what type of operation is intended.
///
/// Query instance types for worker selection
/// - "prefill" → select a prefill worker (disaggregated serving)
/// - "decode" → select a decode worker (disaggregated serving)
///
/// Note: Empty value ("query_instance_id:") is handled by PrefillRouter for disagg orchestration
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum QueryInstanceType {
/// Query for a prefill worker (disaggregated serving)
Prefill,
/// Query for a decode worker (disaggregated serving)
Decode,
}
impl std::fmt::Display for QueryInstanceType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
QueryInstanceType::Prefill => write!(f, "prefill"),
QueryInstanceType::Decode => write!(f, "decode"),
}
}
}
impl std::str::FromStr for QueryInstanceType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"prefill" => Ok(QueryInstanceType::Prefill),
"decode" => Ok(QueryInstanceType::Decode),
_ => Err(format!(
"Invalid QueryInstanceType: '{s}'. Expected 'prefill' or 'decode'"
)),
}
}
}
/// Creates a DiscoveryQuery for the KV router in the given namespace.
pub fn router_discovery_query(namespace: String) -> DiscoveryQuery {
DiscoveryQuery::Endpoint {
......@@ -731,13 +771,34 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
// Extract context ID for request tracking
let context_id = request.context().id().to_string();
// Check if this is a query_instance_id request first
let query_instance_id = request.has_annotation("query_instance_id");
// Check if this is a query_instance_id request and parse its type
// Format: "query_instance_id:type" where type is "prefill", "decode", or "" (empty for aggregated)
// Empty value ("query_instance_id:") means GAIE Aggregated mode - return same worker as both prefill and decode
let query_instance_annotation = request.get_annotation_value("query_instance_id");
let is_gaie_agg_query = query_instance_annotation
.as_ref()
.is_some_and(|s| s.is_empty());
let query_instance_type: Option<QueryInstanceType> =
if let Some(type_str) = &query_instance_annotation {
match type_str.parse::<QueryInstanceType>() {
Ok(t) => Some(t),
Err(_) if type_str.is_empty() => {
// Empty value is valid for aggregated mode, not a warning
None
}
Err(e) => {
tracing::warn!("Invalid query_instance_id type '{type_str}': {e}");
None
}
}
} else {
None
};
let (instance_id, dp_rank, overlap_amount) = if let Some(id) = request.backend_instance_id {
// If instance_id is set, use it and compute actual overlap
let dp_rank = request.dp_rank.unwrap_or(0);
if query_instance_id {
if query_instance_type.is_some() {
tracing::debug!(
"backend_instance_id is set, routing to instance {id} with dp_rank {dp_rank} and ignoring query_instance_id annotation"
);
......@@ -761,33 +822,80 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
(id, dp_rank, overlap_blocks)
} else {
// Otherwise, find the best match
// Don't update states if this is a query-only request (any query_instance_id annotation)
let should_update_states = query_instance_annotation.is_none();
let (best_worker, overlap_amount) = self
.chooser
.find_best_match(
Some(&context_id),
&request.token_ids,
request.router_config_override.as_ref(),
!query_instance_id, // Don't update states if query_instance_id
should_update_states,
)
.await?;
(best_worker.worker_id, best_worker.dp_rank, overlap_amount)
};
// if request has the annotation "query_instance_id",
// then the request will not be routed to the worker,
// and instead the worker_instance_id will be returned.
// If request has a query_instance_id annotation, return worker selection info
// without routing to the actual worker. Returns LLMEngineOutput with disaggregated_params
// containing worker_id info, same structure as normal execution for uniform extraction.
let stream_context = request.context().clone();
if query_instance_id {
let instance_id_str = instance_id.to_string();
let response = Annotated::from_annotation("worker_instance_id", &instance_id_str)?;
// Return the tokens in nvext.token_data format
let response_tokens = Annotated::from_annotation("token_data", &request.token_ids)?;
// Handle query-only requests (GAIE Stage 1)
if query_instance_type.is_some() || is_gaie_agg_query {
let worker_id_info = if is_gaie_agg_query {
// GAIE Aggregated mode: same worker serves both prefill and decode
tracing::trace!(
query_type = "aggregated",
worker_id = instance_id,
"Returning aggregated worker selection (same worker for prefill and decode)"
);
WorkerIdInfo {
prefill_worker_id: Some(instance_id),
decode_worker_id: Some(instance_id),
}
} else {
match query_instance_type.unwrap() {
QueryInstanceType::Prefill => {
tracing::trace!(
query_type = "prefill",
prefill_worker_id = instance_id,
"Returning prefill worker selection"
);
WorkerIdInfo {
prefill_worker_id: Some(instance_id),
decode_worker_id: None,
}
}
QueryInstanceType::Decode => {
// Get prefill_worker_id from annotation (set by caller after prefill selection)
let prefill_worker_id = request
.get_annotation_value("prefill_worker_id")
.and_then(|s| s.parse::<u64>().ok());
tracing::trace!(
"Tokens requested in the response through the query_instance_id annotation: {:?}",
response_tokens
query_type = "decode",
prefill_worker_id = ?prefill_worker_id,
decode_worker_id = instance_id,
"Returning decode worker selection"
);
let stream = stream::iter(vec![response, response_tokens]);
WorkerIdInfo {
prefill_worker_id,
decode_worker_id: Some(instance_id),
}
}
}
};
// Return as LLMEngineOutput with disaggregated_params (same structure as normal execution)
let output = LLMEngineOutput {
disaggregated_params: Some(json!({
"worker_id": worker_id_info,
"token_ids": request.token_ids
})),
..Default::default()
};
let response = Annotated::from_data(output);
let stream = stream::iter(vec![response]);
return Ok(ResponseStream::new(Box::pin(stream), stream_context));
}
let (mut backend_input, context) = request.into_parts();
......
......@@ -20,9 +20,10 @@ use dynamo_runtime::{
use crate::{
discovery::ModelManager,
kv_router::{KvPushRouter, KvRouterConfig, RouterConfigOverride},
kv_router::{KvPushRouter, KvRouterConfig, QueryInstanceType, RouterConfigOverride},
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
protocols::common::preprocessor::{BootstrapInfo, PrefillResult},
protocols::openai::nvext::WorkerIdInfo,
};
/// Errors that can occur during prefill routing
......@@ -67,6 +68,11 @@ impl InnerPrefillRouter {
/// PrefillRouter is a forward-only operator that sits between Migration and the decode router.
/// It optionally calls a prefill worker before routing to decode, extracting disaggregated_params
/// from the prefill response and injecting them into the decode request.
///
/// Supports regular Dynamo and GAIE integrated mode via query_instance_id state machine:
/// - GAIE Stage 1: query_instance_id transitions "" -> "prefill" -> "decode", returns only worker IDs
/// - GAIE Stage 2: target_prefill_worker_id/target_decode_worker_id are set, full execution with specified workers
/// - Non-GAIE: like GAIE Stage 2 but the worker ids have to be determined.
pub struct PrefillRouter {
prefill_router: OnceLock<InnerPrefillRouter>,
cancel_token: CancellationToken,
......@@ -196,10 +202,13 @@ impl PrefillRouter {
rand::rng().random()
}
/// Query best worker upfront, build bootstrap_info, and spawn prefill in background
/// Build bootstrap_info for disaggregated serving
/// If preselected_worker is provided (GAIE Stage 2), use it directly.
/// Otherwise, query for the best worker.
async fn build_bootstrap_info(
&self,
req: &PreprocessedRequest,
preselected_worker: Option<u64>,
) -> Option<(u64, u32, BootstrapInfo)> {
let prefill_router = self.prefill_router.get()?;
......@@ -209,14 +218,24 @@ impl PrefillRouter {
InnerPrefillRouter::SimpleRouter(_) => return None,
};
// Query best worker without routing
let (worker_id, dp_rank) = match kv_router
// Use pre-selected worker (GAIE Stage 2) or query for best worker
let (worker_id, dp_rank) = if let Some(id) = preselected_worker {
let dp_rank = req.dp_rank.unwrap_or(0);
tracing::debug!(
worker_id = id,
dp_rank = dp_rank,
"Using pre-selected prefill worker for bootstrap"
);
(id, dp_rank)
} else {
match kv_router
.chooser
.find_best_match(None, &req.token_ids, None, false)
.await
{
Ok((worker, _overlap)) => (worker.worker_id, worker.dp_rank),
Err(_) => return None,
}
};
// Look up bootstrap endpoint from discovery
......@@ -343,6 +362,56 @@ impl PrefillRouter {
}
}
/// GAIE helper functions for preparing prefill requests
impl PrefillRouter {
/// Prepare prefill request for GAIE flows
/// - Stage 1: Sets query_instance_id:prefill annotation
/// - Stage 2: Sets backend_instance_id to target prefill worker
fn prepare_prefill_for_gaie(prefill_req: &mut PreprocessedRequest, is_gaie_stage1: bool) {
if is_gaie_stage1 {
// GAIE Stage 1: Set query_instance_id to "prefill" for prefill worker selection
prefill_req
.annotations
.retain(|a| !a.starts_with("query_instance_id"));
prefill_req
.annotations
.push(format!("query_instance_id:{}", QueryInstanceType::Prefill));
} else if let Some(prefill_worker_id) = prefill_req.target_prefill_worker_id {
// GAIE Stage 2: Route to pre-selected prefill worker from the stage 1
tracing::debug!(
target_prefill_worker_id = prefill_worker_id,
"GAIE Stage 2: Routing prefill to pre-selected worker"
);
prefill_req.backend_instance_id = Some(prefill_worker_id);
}
}
/// Prepare decode request for GAIE Stage 1
/// Extracts prefill_worker_id from prefill result and sets decode annotations
fn prepare_decode_for_gaie_stage1(
decode_req: &mut PreprocessedRequest,
prefill_result: &PrefillResult,
) {
let prefill_worker_id = prefill_result
.disaggregated_params
.get("worker_id")
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok())
.and_then(|info| info.prefill_worker_id);
if let Some(worker_id) = prefill_worker_id {
decode_req
.annotations
.retain(|a| !a.starts_with("query_instance_id"));
decode_req
.annotations
.push(format!("query_instance_id:{}", QueryInstanceType::Decode));
decode_req
.annotations
.push(format!("prefill_worker_id:{worker_id}"));
}
}
}
impl Drop for PrefillRouter {
fn drop(&mut self) {
tracing::debug!("Dropping PrefillRouter, cancelling background activation task");
......@@ -369,6 +438,12 @@ impl
let request_id = context.id().to_string();
let engine_ctx = context.context();
// GAIE Stage 1: the presence of the empty query_instance_id signals query-only mode
// State machine: "" -> "prefill" -> "decode" (disagg) OR "" -> aggregated worker (agg fallback)
let is_gaie_stage1 = req
.get_annotation_value("query_instance_id")
.is_some_and(|s| s.is_empty());
// Save original max_tokens for decode
let original_max_tokens = req.stop_conditions.max_tokens;
......@@ -376,9 +451,16 @@ impl
let mut prefill_req = req.clone();
prefill_req.stop_conditions.max_tokens = Some(1);
// Try build_bootstrap_info optimization
let prefill_result = if let Some((worker_id, dp_rank, bootstrap_info)) =
self.build_bootstrap_info(&prefill_req).await
// Prepare prefill request for GAIE flows (Stage 1 or Stage 2)
Self::prepare_prefill_for_gaie(&mut prefill_req, is_gaie_stage1);
// Try build_bootstrap_info optimization (skip for GAIE Stage 1 which needs query-only flow)
// For GAIE Stage 2, use target_prefill_worker_id if provided
let preselected_worker = prefill_req.target_prefill_worker_id;
let prefill_result = if !is_gaie_stage1 {
if let Some((worker_id, dp_rank, bootstrap_info)) = self
.build_bootstrap_info(&prefill_req, preselected_worker)
.await
{
let bootstrap_room = bootstrap_info.bootstrap_room;
......@@ -408,6 +490,15 @@ impl
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
self.call_prefill(prefill_context)
.await
.map(|(result, worker_id)| (Some(result), worker_id, None))
}
} else {
// GAIE Stage 1: Use original path (no bootstrap optimization)
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
self.call_prefill(prefill_context)
.await
.map(|(result, worker_id)| (Some(result), worker_id, None))
......@@ -429,8 +520,13 @@ impl
let mut decode_req = req;
// Update request with prefill result if available (only in original path)
if let Some(prefill_result) = maybe_prefill_result {
// Update request with prefill result
if is_gaie_stage1 {
if let Some(ref prefill_result) = maybe_prefill_result {
Self::prepare_decode_for_gaie_stage1(&mut decode_req, prefill_result);
}
} else if let Some(prefill_result) = maybe_prefill_result {
// Normal or GAIE Stage 2: Set prefill_result for decode
decode_req.prefill_result = Some(prefill_result);
}
......@@ -449,6 +545,15 @@ impl
..existing_override.unwrap_or_default()
});
// GAIE Stage 2: Route to pre-selected decode worker if specified
if let Some(decode_worker_id) = decode_req.target_decode_worker_id {
decode_req.backend_instance_id = Some(decode_worker_id);
tracing::debug!(
decode_worker_id = decode_worker_id,
"GAIE Stage 2: Routing decode to pre-selected worker"
);
}
// Map the modified request through with preserved context
let decode_request = context.map(|_| decode_req);
next.generate(decode_request).await
......
......@@ -238,10 +238,13 @@ impl OpenAIPreprocessor {
builder.annotations(request.annotations().unwrap_or_default());
builder.mdc_sum(Some(self.mdcsum.clone()));
builder.estimated_prefix_hit_num_blocks(None);
// Extract backend_instance_id and extra_fields from nvext if present
// Extract backend_instance_id, extra_fields, and worker IDs from nvext if present
if let Some(nvext) = request.nvext() {
builder.backend_instance_id(nvext.backend_instance_id);
builder.extra_fields(nvext.extra_fields.clone());
// GAIE Stage 2: Extract targeted worker IDs for disaggregated serving
builder.target_prefill_worker_id(nvext.prefill_worker_id);
builder.target_decode_worker_id(nvext.decode_worker_id);
}
Ok(builder)
......
......@@ -74,7 +74,7 @@ pub struct BackendOutput {
///
/// This is the minimal raw output from the LLM engine. The Backend may then apply multiple
/// levels of post-processing before the BackendOutput is returns
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct LLMEngineOutput {
// new token_ids
pub token_ids: Vec<TokenIdType>,
......
......@@ -118,12 +118,34 @@ pub struct PreprocessedRequest {
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub extra_fields: Option<Vec<String>>,
/// Targeted prefill worker ID for disaggregated serving (GAIE Stage 2)
/// When set, the prefill request will be routed to this specific worker.
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub target_prefill_worker_id: Option<u64>,
/// Targeted decode worker ID for disaggregated serving (GAIE Stage 2)
/// When set, the decode request will be routed to this specific worker.
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub target_decode_worker_id: Option<u64>,
}
impl PreprocessedRequest {
pub fn has_annotation(&self, annotation: &str) -> bool {
self.annotations.contains(&annotation.to_string())
}
/// Get the value of an annotation in the format "key:value"
/// Returns None if the annotation is not found or has no value
pub fn get_annotation_value(&self, key: &str) -> Option<String> {
let prefix = format!("{}:", key);
self.annotations
.iter()
.find(|a| a.starts_with(&prefix))
.map(|a| a[prefix.len()..].to_string())
}
}
impl PreprocessedRequest {
......
......@@ -400,13 +400,19 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
tracker.record_first_token();
}
// Extract worker_id from disaggregated_params
// Extract worker_id and token_ids from disaggregated_params
let worker_id_info = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("worker_id"))
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok());
let token_ids = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("token_ids"))
.and_then(|v| serde_json::from_value::<Vec<u32>>(v.clone()).ok());
// Get timing info if this is the final response (has finish_reason)
let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
self.timing_tracker.as_ref().map(|tracker| {
......@@ -417,11 +423,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
None
};
// Inject nvext if we have worker_id or timing
if worker_id_info.is_some() || timing_info.is_some() {
// Inject nvext if we have worker_id, token_ids, or timing
if worker_id_info.is_some() || token_ids.is_some() || timing_info.is_some() {
let nvext_response = NvExtResponse {
worker_id: worker_id_info.clone(),
timing: timing_info,
token_ids: token_ids.clone(),
};
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
......@@ -433,6 +440,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
info.decode_worker_id
);
}
if let Some(ref tokens) = token_ids {
tracing::debug!(
"Injected token_ids into chat completion nvext: {} tokens",
tokens.len()
);
}
}
}
......
......@@ -295,13 +295,19 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
tracker.record_first_token();
}
// Extract worker_id from disaggregated_params
// Extract worker_id and token_ids from disaggregated_params
let worker_id_info = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("worker_id"))
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok());
let token_ids = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("token_ids"))
.and_then(|v| serde_json::from_value::<Vec<u32>>(v.clone()).ok());
// Get timing info if this is the final response (has finish_reason)
let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
self.timing_tracker.as_ref().map(|tracker| {
......@@ -312,11 +318,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
None
};
// Inject nvext if we have worker_id or timing
if worker_id_info.is_some() || timing_info.is_some() {
// Inject nvext if we have worker_id, token_ids, or timing
if worker_id_info.is_some() || token_ids.is_some() || timing_info.is_some() {
let nvext_response = NvExtResponse {
worker_id: worker_id_info.clone(),
timing: timing_info,
token_ids: token_ids.clone(),
};
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
......@@ -328,6 +335,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
info.decode_worker_id
);
}
if let Some(ref tokens) = token_ids {
tracing::debug!(
"Injected token_ids into completions nvext: {} tokens",
tokens.len()
);
}
}
}
......
......@@ -35,6 +35,11 @@ pub struct NvExtResponse {
/// Populated when client requests `extra_fields: ["timing"]`
#[serde(skip_serializing_if = "Option::is_none")]
pub timing: Option<TimingInfo>,
/// Token IDs for GAIE Stage 1 query-only mode
/// Contains the tokenized prompt for reuse in Stage 2
#[serde(skip_serializing_if = "Option::is_none")]
pub token_ids: Option<Vec<u32>>,
}
/// NVIDIA LLM extensions to the OpenAI API
......@@ -87,6 +92,18 @@ pub struct NvExt {
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub extra_fields: Option<Vec<String>>,
/// Targeted prefill worker ID for disaggregated serving (GAIE Stage 2)
/// When set, the request will be routed to this specific prefill worker.
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prefill_worker_id: Option<u64>,
/// Targeted decode worker ID for disaggregated serving (GAIE Stage 2)
/// When set, the request will be routed to this specific decode worker.
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub decode_worker_id: Option<u64>,
}
impl Default for NvExt {
......@@ -133,6 +150,8 @@ mod tests {
assert_eq!(nv_ext.token_data, None);
assert_eq!(nv_ext.max_thinking_tokens, None);
assert_eq!(nv_ext.extra_fields, None);
assert_eq!(nv_ext.prefill_worker_id, None);
assert_eq!(nv_ext.decode_worker_id, None);
}
// Test valid builder configurations
......@@ -157,4 +176,18 @@ mod tests {
// Validate the built struct
assert!(nv_ext.validate().is_ok());
}
// Test GAIE Stage 2 disaggregated worker IDs
#[test]
fn test_nv_ext_disagg_worker_ids() {
let nv_ext = NvExt::builder()
.prefill_worker_id(100)
.decode_worker_id(200)
.build()
.unwrap();
assert_eq!(nv_ext.prefill_worker_id, Some(100));
assert_eq!(nv_ext.decode_worker_id, Some(200));
assert!(nv_ext.validate().is_ok());
}
}
......@@ -1026,9 +1026,10 @@ def _test_router_query_instance_id(
asyncio.run(send_request_with_retry(url, test_payload))
# Test payload with query_instance_id annotation
# Format: "query_instance_id:" (colon with empty value) for GAIE aggregated mode
annotated_payload = {
**test_payload,
"nvext": {"annotations": ["query_instance_id"]},
"nvext": {"annotations": ["query_instance_id:"]},
}
async def test_annotation_response():
......@@ -1053,100 +1054,80 @@ def _test_router_query_instance_id(
f"Full SSE response ({len(full_response)} bytes):\n{full_response}"
)
# Parse and validate the response structure
events = []
# Parse the SSE response to extract the first chunk with nvext data
# New format: nvext contains worker_id and token_ids
sse_parts = full_response.split("\n\n")
worker_id_info = None
token_list = None
for part in sse_parts:
part = part.strip()
if not part:
if not part or not part.startswith("data:"):
continue
if part.startswith("event:"):
lines = part.split("\n")
event_line = next(
(line for line in lines if line.startswith("event:")),
None,
)
data_line = next(
(
line
for line in lines
if line.startswith("data:") or line.startswith(":")
),
None,
)
if event_line and data_line:
event_type = event_line.split(":", 1)[1].strip()
if data_line.startswith("data:"):
data_value = data_line.split(":", 1)[1].strip()
else:
data_value = data_line.split(":", 1)[1].strip()
events.append((event_type, data_value))
elif part.startswith("data:"):
data_value = part.split(":", 1)[1].strip()
data_str = part.split("data:", 1)[1].strip()
if data_str == "[DONE]":
continue
logger.info(f"Parsed events: {events}")
try:
chunk = json.loads(data_str)
logger.info(f"Parsed chunk: {json.dumps(chunk, indent=2)}")
# Validate worker_instance_id event
worker_event = next(
(e for e in events if e[0] == "worker_instance_id"), None
# Extract nvext data containing worker_id and token_ids
nvext = chunk.get("nvext", {})
if nvext:
if "worker_id" in nvext:
worker_id_info = nvext["worker_id"]
logger.info(
f"Found worker_id info: {worker_id_info}"
)
assert (
worker_event is not None
), f"Missing worker_instance_id event in: {events}"
# Validate token_data event
token_event = next(
(e for e in events if e[0] == "token_data"), None
if "token_ids" in nvext:
token_list = nvext["token_ids"]
logger.info(
f"Found token_ids: {len(token_list)} tokens"
)
except json.JSONDecodeError:
continue
# Validate worker_id info
assert (
token_event is not None
), f"Missing token_data event in: {events}"
worker_id_info is not None
), f"Missing worker_id in nvext. Response: {full_response}"
token_data_str = token_event[1].strip('"')
try:
token_list = json.loads(token_data_str)
except json.JSONDecodeError as e:
raise AssertionError(
f"token_data is not valid JSON: {token_data_str}, error: {e}"
)
# For aggregated mode, both prefill and decode should be the same
prefill_worker_id = worker_id_info.get("prefill_worker_id")
decode_worker_id = worker_id_info.get("decode_worker_id")
assert (
prefill_worker_id is not None
), f"Missing prefill_worker_id in worker_id: {worker_id_info}"
assert (
decode_worker_id is not None
), f"Missing decode_worker_id in worker_id: {worker_id_info}"
assert (
prefill_worker_id == decode_worker_id
), f"For aggregated mode, prefill and decode worker should be same: {worker_id_info}"
# Validate token_ids
assert (
token_list is not None
), f"Missing token_ids in nvext. Response: {full_response}"
assert isinstance(
token_list, list
), f"token_data should be a list, got: {type(token_list)}"
), f"token_ids should be a list, got: {type(token_list)}"
assert (
len(token_list) > 0
), f"token_data should not be empty: {token_list}"
), f"token_ids should not be empty: {token_list}"
assert all(
isinstance(token, int) for token in token_list
), f"All tokens should be integers: {token_list}"
logger.info(
f"Valid token_data with {len(token_list)} tokens: {token_list[:10]}{'...' if len(token_list) > 10 else ''}"
)
# Validate that no actual generation happened (should only be metadata)
# This proves the early return worked correctly
generation_indicators = [
"choices",
"content",
"delta",
"finish_reason",
]
for indicator in generation_indicators:
assert (
indicator not in full_response.lower()
), f"Found generation indicator '{indicator}' - request should not have been routed to worker"
logger.info(
"No generation content found - early return worked correctly"
f"Valid token_ids with {len(token_list)} tokens: {token_list[:10]}{'...' if len(token_list) > 10 else ''}"
)
return {
"worker_instance_id": worker_event[1].strip('"'),
"prefill_worker_id": prefill_worker_id,
"decode_worker_id": decode_worker_id,
"token_count": len(token_list),
"tokens": token_list,
}
......@@ -1154,7 +1135,8 @@ def _test_router_query_instance_id(
result = asyncio.run(test_annotation_response())
logger.info("Successfully validated query_instance_id annotation response:")
logger.info(f"Worker ID: {result['worker_instance_id']}")
logger.info(f"Prefill Worker ID: {result['prefill_worker_id']}")
logger.info(f"Decode Worker ID: {result['decode_worker_id']}")
logger.info(f"Token count: {result['token_count']}")
finally:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment