Unverified Commit 15b49818 authored by atchernych's avatar atchernych Committed by GitHub
Browse files

feat: support disag serving in GAIE [DEP-659] (#4756)


Signed-off-by: default avatarAnna Tchernych <atchernych@nvidia.com>
parent 7a3b15e6
...@@ -161,10 +161,10 @@ index 670d922..0cf04cb 100644 ...@@ -161,10 +161,10 @@ index 670d922..0cf04cb 100644
} }
diff --git a/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go b/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go diff --git a/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go b/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go
new file mode 100644 new file mode 100644
index 0000000..cd9a0b5 index 0000000..1c8f979
--- /dev/null --- /dev/null
+++ b/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go +++ b/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go
@@ -0,0 +1,119 @@ @@ -0,0 +1,171 @@
+package dynamo_inject_workerid +package dynamo_inject_workerid
+ +
+import ( +import (
...@@ -182,6 +182,7 @@ index 0000000..cd9a0b5 ...@@ -182,6 +182,7 @@ index 0000000..cd9a0b5
+ typeString = "dynamo-inject-workerid" + typeString = "dynamo-inject-workerid"
+ pluginName = "dynamo-inject-workerid" + pluginName = "dynamo-inject-workerid"
+ WorkerIDHeader = "x-worker-instance-id" + WorkerIDHeader = "x-worker-instance-id"
+ PrefillWorkerIDHeader = "x-prefiller-host-port"
+ tokenDataAnnotationKey = "dynamo/token-data" + tokenDataAnnotationKey = "dynamo/token-data"
+) +)
+ +
...@@ -222,11 +223,18 @@ index 0000000..cd9a0b5 ...@@ -222,11 +223,18 @@ index 0000000..cd9a0b5
+ if req.Headers == nil { + if req.Headers == nil {
+ req.Headers = map[string]string{} + req.Headers = map[string]string{}
+ } + }
+
+ // Handle worker instance ID
+ wid := strings.TrimSpace(req.Headers[WorkerIDHeader]) + wid := strings.TrimSpace(req.Headers[WorkerIDHeader])
+ if wid == "" { + if wid != "" {
+ return
+ }
+ req.Headers[WorkerIDHeader] = wid + req.Headers[WorkerIDHeader] = wid
+ }
+
+ // Handle prefill worker ID
+ prefillWid := strings.TrimSpace(req.Headers[PrefillWorkerIDHeader])
+ if prefillWid != "" {
+ req.Headers[PrefillWorkerIDHeader] = prefillWid
+ }
+} +}
+ +
+func (p *InjectWorkerIDPreRequest) MutateRequestBody( +func (p *InjectWorkerIDPreRequest) MutateRequestBody(
...@@ -248,14 +256,28 @@ index 0000000..cd9a0b5 ...@@ -248,14 +256,28 @@ index 0000000..cd9a0b5
+ return + return
+ } + }
+ +
+ prefillWid := strings.TrimSpace(req.Headers[PrefillWorkerIDHeader])
+
+ nvext, _ := body["nvext"].(map[string]any) + nvext, _ := body["nvext"].(map[string]any)
+ if nvext == nil { + if nvext == nil {
+ nvext = map[string]any{} + nvext = map[string]any{}
+ body["nvext"] = nvext + body["nvext"] = nvext
+ } + }
+
+ if prefillWid != "" && prefillWid != wid {
+ // Disaggregated mode: use prefill_worker_id and decode_worker_id
+ if prefillWidUint, err := strconv.ParseUint(prefillWid, 10, 64); err == nil {
+ nvext["prefill_worker_id"] = prefillWidUint
+ }
+ if widUint, err := strconv.ParseUint(wid, 10, 64); err == nil {
+ nvext["decode_worker_id"] = widUint
+ }
+ } else {
+ // Aggregated mode (empty prefill or prefill == decode): use backend_instance_id
+ if widUint, err := strconv.ParseUint(wid, 10, 64); err == nil { + if widUint, err := strconv.ParseUint(wid, 10, 64); err == nil {
+ nvext["backend_instance_id"] = widUint + nvext["backend_instance_id"] = widUint
+ } + }
+ }
+ +
+ if tokens, ok := req.Annotations[tokenDataAnnotationKey]; ok { + if tokens, ok := req.Annotations[tokenDataAnnotationKey]; ok {
+ switch v := tokens.(type) { + switch v := tokens.(type) {
...@@ -283,6 +305,36 @@ index 0000000..cd9a0b5 ...@@ -283,6 +305,36 @@ index 0000000..cd9a0b5
+ } + }
+ } + }
+ } + }
+
+ // Remove query_instance_id from nvext.annotations if present
+ if annotations, ok := nvext["annotations"]; ok {
+ switch annList := annotations.(type) {
+ case []string:
+ filtered := make([]string, 0, len(annList))
+ for _, ann := range annList {
+ if ann != "query_instance_id" {
+ filtered = append(filtered, ann)
+ }
+ }
+ if len(filtered) == 0 {
+ delete(nvext, "annotations")
+ } else {
+ nvext["annotations"] = filtered
+ }
+ case []any:
+ filtered := make([]any, 0, len(annList))
+ for _, ann := range annList {
+ if str, ok := ann.(string); !ok || str != "query_instance_id" {
+ filtered = append(filtered, ann)
+ }
+ }
+ if len(filtered) == 0 {
+ delete(nvext, "annotations")
+ } else {
+ nvext["annotations"] = filtered
+ }
+ }
+ }
+} +}
diff --git a/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml diff --git a/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml
new file mode 100644 new file mode 100644
...@@ -313,10 +365,10 @@ index 0000000..b689c00 ...@@ -313,10 +365,10 @@ index 0000000..b689c00
+ - pluginRef: picker + - pluginRef: picker
diff --git a/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go diff --git a/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go
new file mode 100644 new file mode 100644
index 0000000..bc29c0a index 0000000..75f30e9
--- /dev/null --- /dev/null
+++ b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go +++ b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go
@@ -0,0 +1,424 @@ @@ -0,0 +1,446 @@
+package dynamo_kv_scorer +package dynamo_kv_scorer
+ +
+/* +/*
...@@ -367,13 +419,15 @@ index 0000000..bc29c0a ...@@ -367,13 +419,15 @@ index 0000000..bc29c0a
+ double router_temperature, + double router_temperature,
+ bool use_kv_events, + bool use_kv_events,
+ bool router_replica_sync, + bool router_replica_sync,
+ bool enforce_disagg,
+ WorkerSelectionPipeline **pipeline_out); + WorkerSelectionPipeline **pipeline_out);
+ +
+dynamo_llm_result_t dynamo_destroy_worker_selection_pipeline(WorkerSelectionPipeline *pipeline); +dynamo_llm_result_t dynamo_destroy_worker_selection_pipeline(WorkerSelectionPipeline *pipeline);
+ +
+dynamo_llm_result_t dynamo_query_worker_selection_and_annotate(WorkerSelectionPipeline *pipeline, +dynamo_llm_result_t dynamo_query_worker_selection_and_annotate(WorkerSelectionPipeline *pipeline,
+ const char *request_json_c_str, + const char *request_json_c_str,
+ int64_t *worker_instance_id_out, + int64_t *decode_worker_id_out,
+ int64_t *prefill_worker_id_out,
+ uint32_t **token_ids_out, + uint32_t **token_ids_out,
+ size_t *token_count_out, + size_t *token_count_out,
+ char **annotated_request_json_out); + char **annotated_request_json_out);
...@@ -404,7 +458,9 @@ index 0000000..bc29c0a ...@@ -404,7 +458,9 @@ index 0000000..bc29c0a
+ PluginName = "dynamo-kv-scorer" + PluginName = "dynamo-kv-scorer"
+ KVAwareScorerType = "kv-aware-scorer" + KVAwareScorerType = "kv-aware-scorer"
+ StateKeyWorkerInstanceID = schedtypes.StateKey("dynamo/worker-instance-id") + StateKeyWorkerInstanceID = schedtypes.StateKey("dynamo/worker-instance-id")
+ StateKeyPrefillWorkerID = schedtypes.StateKey("dynamo/prefill-worker-id")
+ WorkerIDHeader = "x-worker-instance-id" + WorkerIDHeader = "x-worker-instance-id"
+ PrefillWorkerIDHeader = "x-prefiller-host-port"
+ tokenDataAnnotationKey = "dynamo/token-data" + tokenDataAnnotationKey = "dynamo/token-data"
+) +)
+ +
...@@ -471,6 +527,7 @@ index 0000000..bc29c0a ...@@ -471,6 +527,7 @@ index 0000000..bc29c0a
+ ffiRouterTemperature float64 + ffiRouterTemperature float64
+ ffiKvBlockSize uint32 + ffiKvBlockSize uint32
+ ffiWorkerID int64 + ffiWorkerID int64
+ ffiEnforceDisagg bool
+ +
+ runtimeInitialized bool + runtimeInitialized bool
+ +
...@@ -484,6 +541,7 @@ index 0000000..bc29c0a ...@@ -484,6 +541,7 @@ index 0000000..bc29c0a
+ ffiComponent = getEnvOrDefault("DYNAMO_COMPONENT", "backend") + ffiComponent = getEnvOrDefault("DYNAMO_COMPONENT", "backend")
+ ffiModel = getEnvOrDefault("DYNAMO_MODEL", "Qwen/Qwen3-0.6B") + ffiModel = getEnvOrDefault("DYNAMO_MODEL", "Qwen/Qwen3-0.6B")
+ ffiWorkerID = getEnvInt64OrDefault("DYNAMO_WORKER_ID", 1) + ffiWorkerID = getEnvInt64OrDefault("DYNAMO_WORKER_ID", 1)
+ ffiEnforceDisagg = getEnvBoolOrDefault("DYNAMO_ENFORCE_DISAGG", true) // TODO default to false
+ +
+ ffiOverlapScoreWeight = getEnvFloatOrDefault("DYNAMO_OVERLAP_SCORE_WEIGHT", -1.0) + ffiOverlapScoreWeight = getEnvFloatOrDefault("DYNAMO_OVERLAP_SCORE_WEIGHT", -1.0)
+ ffiRouterTemperature = getEnvFloatOrDefault("DYNAMO_ROUTER_TEMPERATURE", -1.0) + ffiRouterTemperature = getEnvFloatOrDefault("DYNAMO_ROUTER_TEMPERATURE", -1.0)
...@@ -575,6 +633,7 @@ index 0000000..bc29c0a ...@@ -575,6 +633,7 @@ index 0000000..bc29c0a
+ C.double(ffiRouterTemperature), + C.double(ffiRouterTemperature),
+ C.bool(getEnvBoolOrDefault("DYNAMO_USE_KV_EVENTS", true)), + C.bool(getEnvBoolOrDefault("DYNAMO_USE_KV_EVENTS", true)),
+ C.bool(getEnvBoolOrDefault("DYNAMO_ROUTER_REPLICA_SYNC", true)), + C.bool(getEnvBoolOrDefault("DYNAMO_ROUTER_REPLICA_SYNC", true)),
+ C.bool(ffiEnforceDisagg),
+ &pipeline, + &pipeline,
+ ) + )
+ if rc != C.DYNAMO_OK { + if rc != C.DYNAMO_OK {
...@@ -595,13 +654,14 @@ index 0000000..bc29c0a ...@@ -595,13 +654,14 @@ index 0000000..bc29c0a
+) map[schedtypes.Pod]float64 { +) map[schedtypes.Pod]float64 {
+ logger := log.FromContext(ctx) + logger := log.FromContext(ctx)
+ +
+ workerID, tokenData, err := k.callDynamoRouter(ctx, req) + workerID, prefillWorkerID, tokenData, err := k.callDynamoRouter(ctx, req)
+ if err != nil { + if err != nil {
+ logger.V(logutil.DEFAULT).Error(err, "Dynamo call failed; proceeding without worker id") + logger.V(logutil.DEFAULT).Error(err, "Dynamo call failed; proceeding without worker id")
+ } else if workerID != "" { + } else if workerID != "" {
+ logger.V(logutil.DEFAULT).Info( + logger.V(logutil.DEFAULT).Info(
+ "Dynamo router selected worker", + "Dynamo router selected worker",
+ "workerID", workerID, + "workerID", workerID,
+ "prefillWorkerID", prefillWorkerID,
+ "tokenDataCount", len(tokenData), + "tokenDataCount", len(tokenData),
+ "tokenData", tokenData, + "tokenData", tokenData,
+ ) + )
...@@ -610,6 +670,13 @@ index 0000000..bc29c0a ...@@ -610,6 +670,13 @@ index 0000000..bc29c0a
+ req.Headers = map[string]string{} + req.Headers = map[string]string{}
+ } + }
+ req.Headers[WorkerIDHeader] = workerID + req.Headers[WorkerIDHeader] = workerID
+
+ // Set prefill worker ID if present
+ if prefillWorkerID != "" {
+ cycle.Write(StateKeyPrefillWorkerID, stateString(prefillWorkerID))
+ req.Headers[PrefillWorkerIDHeader] = prefillWorkerID
+ }
+
+ if len(tokenData) > 0 { + if len(tokenData) > 0 {
+ if req.Annotations == nil { + if req.Annotations == nil {
+ req.Annotations = map[string]any{} + req.Annotations = map[string]any{}
...@@ -632,15 +699,15 @@ index 0000000..bc29c0a ...@@ -632,15 +699,15 @@ index 0000000..bc29c0a
+func (k *KVAwareScorer) callDynamoRouter( +func (k *KVAwareScorer) callDynamoRouter(
+ ctx context.Context, + ctx context.Context,
+ req *schedtypes.LLMRequest, + req *schedtypes.LLMRequest,
+) (string, []int64, error) { +) (workerID string, prefillWorkerID string, tokenData []int64, err error) {
+ logger := log.FromContext(ctx) + logger := log.FromContext(ctx)
+ +
+ if err := initFFI(); err != nil { + if err := initFFI(); err != nil {
+ logger.V(logutil.DEFAULT).Error(err, "FFI init failed") + logger.V(logutil.DEFAULT).Error(err, "FFI init failed")
+ return "", nil, err + return "", "", nil, err
+ } + }
+ if !runtimeInitialized { + if !runtimeInitialized {
+ return "", nil, fmt.Errorf("dynamo runtime not initialized") + return "", "", nil, fmt.Errorf("dynamo runtime not initialized")
+ } + }
+ +
+ pipelineMutex.RLock() + pipelineMutex.RLock()
...@@ -648,21 +715,22 @@ index 0000000..bc29c0a ...@@ -648,21 +715,22 @@ index 0000000..bc29c0a
+ pipelineMutex.RUnlock() + pipelineMutex.RUnlock()
+ +
+ if currentPipeline == nil { + if currentPipeline == nil {
+ return "", nil, fmt.Errorf("dynamo worker selection pipeline not created") + return "", "", nil, fmt.Errorf("dynamo worker selection pipeline not created")
+ } + }
+ +
+ // Build OpenAI-compatible JSON request + // Build OpenAI-compatible JSON request
+ requestBody := buildOpenAIRequest(req) + requestBody := buildOpenAIRequest(req)
+ requestJSON, err := json.Marshal(requestBody) + requestJSON, jsonErr := json.Marshal(requestBody)
+ if err != nil { + if jsonErr != nil {
+ logger.V(logutil.DEFAULT).Error(err, "Failed to marshal OpenAI request") + logger.V(logutil.DEFAULT).Error(jsonErr, "Failed to marshal OpenAI request")
+ return "", nil, fmt.Errorf("marshal OpenAI request: %w", err) + return "", "", nil, fmt.Errorf("marshal OpenAI request: %w", jsonErr)
+ } + }
+ cRequestJSON := C.CString(string(requestJSON)) + cRequestJSON := C.CString(string(requestJSON))
+ defer C.free(unsafe.Pointer(cRequestJSON)) + defer C.free(unsafe.Pointer(cRequestJSON))
+ +
+ // Output variables + // Output variables
+ var cWorkerID C.int64_t + var cDecodeWorkerID C.int64_t
+ var cPrefillWorkerID C.int64_t
+ var cTokens *C.uint32_t + var cTokens *C.uint32_t
+ var cTokenCount C.size_t + var cTokenCount C.size_t
+ var cAnnotatedJSON *C.char + var cAnnotatedJSON *C.char
...@@ -671,13 +739,14 @@ index 0000000..bc29c0a ...@@ -671,13 +739,14 @@ index 0000000..bc29c0a
+ rc := C.dynamo_query_worker_selection_and_annotate( + rc := C.dynamo_query_worker_selection_and_annotate(
+ currentPipeline, + currentPipeline,
+ cRequestJSON, + cRequestJSON,
+ &cWorkerID, + &cDecodeWorkerID,
+ &cPrefillWorkerID,
+ &cTokens, + &cTokens,
+ &cTokenCount, + &cTokenCount,
+ &cAnnotatedJSON, + &cAnnotatedJSON,
+ ) + )
+ if rc != C.DYNAMO_OK { + if rc != C.DYNAMO_OK {
+ return "", nil, fmt.Errorf("dynamo_query_worker_selection_and_annotate failed") + return "", "", nil, fmt.Errorf("dynamo_query_worker_selection_and_annotate failed")
+ } + }
+ +
+ // Copy tokens into Go memory and free C memory + // Copy tokens into Go memory and free C memory
...@@ -692,11 +761,16 @@ index 0000000..bc29c0a ...@@ -692,11 +761,16 @@ index 0000000..bc29c0a
+ } + }
+ C.dynamo_free_worker_selection_result(cTokens, cTokenCount, cAnnotatedJSON) + C.dynamo_free_worker_selection_result(cTokens, cTokenCount, cAnnotatedJSON)
+ +
+ workerID := fmt.Sprintf("%d", int64(cWorkerID)) + workerIDStr := fmt.Sprintf("%d", int64(cDecodeWorkerID))
+ prefillWorkerIDStr := ""
+ // Rust returns -1 for prefill_worker_id when not in disaggregated mode
+ if int64(cPrefillWorkerID) >= 0 {
+ prefillWorkerIDStr = fmt.Sprintf("%d", int64(cPrefillWorkerID))
+ }
+ logger.V(logutil.DEFAULT).Info("Worker selection completed", + logger.V(logutil.DEFAULT).Info("Worker selection completed",
+ "workerID", workerID, "tokenCount", count) + "workerID", workerIDStr, "prefillWorkerID", prefillWorkerIDStr, "tokenCount", count)
+ +
+ return workerID, tokens64, nil + return workerIDStr, prefillWorkerIDStr, tokens64, nil
+} +}
+ +
+func buildOpenAIRequest(req *schedtypes.LLMRequest) map[string]any { +func buildOpenAIRequest(req *schedtypes.LLMRequest) map[string]any {
......
This diff is collapsed.
...@@ -97,6 +97,46 @@ pub fn router_endpoint_id(namespace: String) -> EndpointId { ...@@ -97,6 +97,46 @@ pub fn router_endpoint_id(namespace: String) -> EndpointId {
} }
} }
/// Specifies the type of worker being queried when using the `query_instance_id` annotation.
/// This tells the router which worker pool to select from and what type of operation is intended.
///
/// Query instance types for worker selection
/// - "prefill" → select a prefill worker (disaggregated serving)
/// - "decode" → select a decode worker (disaggregated serving)
///
/// Note: Empty value ("query_instance_id:") is handled by PrefillRouter for disagg orchestration
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum QueryInstanceType {
/// Query for a prefill worker (disaggregated serving)
Prefill,
/// Query for a decode worker (disaggregated serving)
Decode,
}
impl std::fmt::Display for QueryInstanceType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
QueryInstanceType::Prefill => write!(f, "prefill"),
QueryInstanceType::Decode => write!(f, "decode"),
}
}
}
impl std::str::FromStr for QueryInstanceType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"prefill" => Ok(QueryInstanceType::Prefill),
"decode" => Ok(QueryInstanceType::Decode),
_ => Err(format!(
"Invalid QueryInstanceType: '{s}'. Expected 'prefill' or 'decode'"
)),
}
}
}
/// Creates a DiscoveryQuery for the KV router in the given namespace. /// Creates a DiscoveryQuery for the KV router in the given namespace.
pub fn router_discovery_query(namespace: String) -> DiscoveryQuery { pub fn router_discovery_query(namespace: String) -> DiscoveryQuery {
DiscoveryQuery::Endpoint { DiscoveryQuery::Endpoint {
...@@ -731,13 +771,34 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -731,13 +771,34 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
// Extract context ID for request tracking // Extract context ID for request tracking
let context_id = request.context().id().to_string(); let context_id = request.context().id().to_string();
// Check if this is a query_instance_id request first // Check if this is a query_instance_id request and parse its type
let query_instance_id = request.has_annotation("query_instance_id"); // Format: "query_instance_id:type" where type is "prefill", "decode", or "" (empty for aggregated)
// Empty value ("query_instance_id:") means GAIE Aggregated mode - return same worker as both prefill and decode
let query_instance_annotation = request.get_annotation_value("query_instance_id");
let is_gaie_agg_query = query_instance_annotation
.as_ref()
.is_some_and(|s| s.is_empty());
let query_instance_type: Option<QueryInstanceType> =
if let Some(type_str) = &query_instance_annotation {
match type_str.parse::<QueryInstanceType>() {
Ok(t) => Some(t),
Err(_) if type_str.is_empty() => {
// Empty value is valid for aggregated mode, not a warning
None
}
Err(e) => {
tracing::warn!("Invalid query_instance_id type '{type_str}': {e}");
None
}
}
} else {
None
};
let (instance_id, dp_rank, overlap_amount) = if let Some(id) = request.backend_instance_id { let (instance_id, dp_rank, overlap_amount) = if let Some(id) = request.backend_instance_id {
// If instance_id is set, use it and compute actual overlap // If instance_id is set, use it and compute actual overlap
let dp_rank = request.dp_rank.unwrap_or(0); let dp_rank = request.dp_rank.unwrap_or(0);
if query_instance_id { if query_instance_type.is_some() {
tracing::debug!( tracing::debug!(
"backend_instance_id is set, routing to instance {id} with dp_rank {dp_rank} and ignoring query_instance_id annotation" "backend_instance_id is set, routing to instance {id} with dp_rank {dp_rank} and ignoring query_instance_id annotation"
); );
...@@ -761,33 +822,80 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -761,33 +822,80 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
(id, dp_rank, overlap_blocks) (id, dp_rank, overlap_blocks)
} else { } else {
// Otherwise, find the best match // Otherwise, find the best match
// Don't update states if this is a query-only request (any query_instance_id annotation)
let should_update_states = query_instance_annotation.is_none();
let (best_worker, overlap_amount) = self let (best_worker, overlap_amount) = self
.chooser .chooser
.find_best_match( .find_best_match(
Some(&context_id), Some(&context_id),
&request.token_ids, &request.token_ids,
request.router_config_override.as_ref(), request.router_config_override.as_ref(),
!query_instance_id, // Don't update states if query_instance_id should_update_states,
) )
.await?; .await?;
(best_worker.worker_id, best_worker.dp_rank, overlap_amount) (best_worker.worker_id, best_worker.dp_rank, overlap_amount)
}; };
// if request has the annotation "query_instance_id", // If request has a query_instance_id annotation, return worker selection info
// then the request will not be routed to the worker, // without routing to the actual worker. Returns LLMEngineOutput with disaggregated_params
// and instead the worker_instance_id will be returned. // containing worker_id info, same structure as normal execution for uniform extraction.
let stream_context = request.context().clone(); let stream_context = request.context().clone();
if query_instance_id {
let instance_id_str = instance_id.to_string();
let response = Annotated::from_annotation("worker_instance_id", &instance_id_str)?;
// Return the tokens in nvext.token_data format // Handle query-only requests (GAIE Stage 1)
let response_tokens = Annotated::from_annotation("token_data", &request.token_ids)?; if query_instance_type.is_some() || is_gaie_agg_query {
let worker_id_info = if is_gaie_agg_query {
// GAIE Aggregated mode: same worker serves both prefill and decode
tracing::trace!(
query_type = "aggregated",
worker_id = instance_id,
"Returning aggregated worker selection (same worker for prefill and decode)"
);
WorkerIdInfo {
prefill_worker_id: Some(instance_id),
decode_worker_id: Some(instance_id),
}
} else {
match query_instance_type.unwrap() {
QueryInstanceType::Prefill => {
tracing::trace!(
query_type = "prefill",
prefill_worker_id = instance_id,
"Returning prefill worker selection"
);
WorkerIdInfo {
prefill_worker_id: Some(instance_id),
decode_worker_id: None,
}
}
QueryInstanceType::Decode => {
// Get prefill_worker_id from annotation (set by caller after prefill selection)
let prefill_worker_id = request
.get_annotation_value("prefill_worker_id")
.and_then(|s| s.parse::<u64>().ok());
tracing::trace!( tracing::trace!(
"Tokens requested in the response through the query_instance_id annotation: {:?}", query_type = "decode",
response_tokens prefill_worker_id = ?prefill_worker_id,
decode_worker_id = instance_id,
"Returning decode worker selection"
); );
let stream = stream::iter(vec![response, response_tokens]); WorkerIdInfo {
prefill_worker_id,
decode_worker_id: Some(instance_id),
}
}
}
};
// Return as LLMEngineOutput with disaggregated_params (same structure as normal execution)
let output = LLMEngineOutput {
disaggregated_params: Some(json!({
"worker_id": worker_id_info,
"token_ids": request.token_ids
})),
..Default::default()
};
let response = Annotated::from_data(output);
let stream = stream::iter(vec![response]);
return Ok(ResponseStream::new(Box::pin(stream), stream_context)); return Ok(ResponseStream::new(Box::pin(stream), stream_context));
} }
let (mut backend_input, context) = request.into_parts(); let (mut backend_input, context) = request.into_parts();
......
...@@ -20,9 +20,10 @@ use dynamo_runtime::{ ...@@ -20,9 +20,10 @@ use dynamo_runtime::{
use crate::{ use crate::{
discovery::ModelManager, discovery::ModelManager,
kv_router::{KvPushRouter, KvRouterConfig, RouterConfigOverride}, kv_router::{KvPushRouter, KvRouterConfig, QueryInstanceType, RouterConfigOverride},
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest}, protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
protocols::common::preprocessor::{BootstrapInfo, PrefillResult}, protocols::common::preprocessor::{BootstrapInfo, PrefillResult},
protocols::openai::nvext::WorkerIdInfo,
}; };
/// Errors that can occur during prefill routing /// Errors that can occur during prefill routing
...@@ -67,6 +68,11 @@ impl InnerPrefillRouter { ...@@ -67,6 +68,11 @@ impl InnerPrefillRouter {
/// PrefillRouter is a forward-only operator that sits between Migration and the decode router. /// PrefillRouter is a forward-only operator that sits between Migration and the decode router.
/// It optionally calls a prefill worker before routing to decode, extracting disaggregated_params /// It optionally calls a prefill worker before routing to decode, extracting disaggregated_params
/// from the prefill response and injecting them into the decode request. /// from the prefill response and injecting them into the decode request.
///
/// Supports regular Dynamo and GAIE integrated mode via query_instance_id state machine:
/// - GAIE Stage 1: query_instance_id transitions "" -> "prefill" -> "decode", returns only worker IDs
/// - GAIE Stage 2: target_prefill_worker_id/target_decode_worker_id are set, full execution with specified workers
/// - Non-GAIE: like GAIE Stage 2 but the worker ids have to be determined.
pub struct PrefillRouter { pub struct PrefillRouter {
prefill_router: OnceLock<InnerPrefillRouter>, prefill_router: OnceLock<InnerPrefillRouter>,
cancel_token: CancellationToken, cancel_token: CancellationToken,
...@@ -196,10 +202,13 @@ impl PrefillRouter { ...@@ -196,10 +202,13 @@ impl PrefillRouter {
rand::rng().random() rand::rng().random()
} }
/// Query best worker upfront, build bootstrap_info, and spawn prefill in background /// Build bootstrap_info for disaggregated serving
/// If preselected_worker is provided (GAIE Stage 2), use it directly.
/// Otherwise, query for the best worker.
async fn build_bootstrap_info( async fn build_bootstrap_info(
&self, &self,
req: &PreprocessedRequest, req: &PreprocessedRequest,
preselected_worker: Option<u64>,
) -> Option<(u64, u32, BootstrapInfo)> { ) -> Option<(u64, u32, BootstrapInfo)> {
let prefill_router = self.prefill_router.get()?; let prefill_router = self.prefill_router.get()?;
...@@ -209,14 +218,24 @@ impl PrefillRouter { ...@@ -209,14 +218,24 @@ impl PrefillRouter {
InnerPrefillRouter::SimpleRouter(_) => return None, InnerPrefillRouter::SimpleRouter(_) => return None,
}; };
// Query best worker without routing // Use pre-selected worker (GAIE Stage 2) or query for best worker
let (worker_id, dp_rank) = match kv_router let (worker_id, dp_rank) = if let Some(id) = preselected_worker {
let dp_rank = req.dp_rank.unwrap_or(0);
tracing::debug!(
worker_id = id,
dp_rank = dp_rank,
"Using pre-selected prefill worker for bootstrap"
);
(id, dp_rank)
} else {
match kv_router
.chooser .chooser
.find_best_match(None, &req.token_ids, None, false) .find_best_match(None, &req.token_ids, None, false)
.await .await
{ {
Ok((worker, _overlap)) => (worker.worker_id, worker.dp_rank), Ok((worker, _overlap)) => (worker.worker_id, worker.dp_rank),
Err(_) => return None, Err(_) => return None,
}
}; };
// Look up bootstrap endpoint from discovery // Look up bootstrap endpoint from discovery
...@@ -343,6 +362,56 @@ impl PrefillRouter { ...@@ -343,6 +362,56 @@ impl PrefillRouter {
} }
} }
/// GAIE helper functions for preparing prefill requests
impl PrefillRouter {
/// Prepare prefill request for GAIE flows
/// - Stage 1: Sets query_instance_id:prefill annotation
/// - Stage 2: Sets backend_instance_id to target prefill worker
fn prepare_prefill_for_gaie(prefill_req: &mut PreprocessedRequest, is_gaie_stage1: bool) {
if is_gaie_stage1 {
// GAIE Stage 1: Set query_instance_id to "prefill" for prefill worker selection
prefill_req
.annotations
.retain(|a| !a.starts_with("query_instance_id"));
prefill_req
.annotations
.push(format!("query_instance_id:{}", QueryInstanceType::Prefill));
} else if let Some(prefill_worker_id) = prefill_req.target_prefill_worker_id {
// GAIE Stage 2: Route to pre-selected prefill worker from the stage 1
tracing::debug!(
target_prefill_worker_id = prefill_worker_id,
"GAIE Stage 2: Routing prefill to pre-selected worker"
);
prefill_req.backend_instance_id = Some(prefill_worker_id);
}
}
/// Prepare decode request for GAIE Stage 1
/// Extracts prefill_worker_id from prefill result and sets decode annotations
fn prepare_decode_for_gaie_stage1(
decode_req: &mut PreprocessedRequest,
prefill_result: &PrefillResult,
) {
let prefill_worker_id = prefill_result
.disaggregated_params
.get("worker_id")
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok())
.and_then(|info| info.prefill_worker_id);
if let Some(worker_id) = prefill_worker_id {
decode_req
.annotations
.retain(|a| !a.starts_with("query_instance_id"));
decode_req
.annotations
.push(format!("query_instance_id:{}", QueryInstanceType::Decode));
decode_req
.annotations
.push(format!("prefill_worker_id:{worker_id}"));
}
}
}
impl Drop for PrefillRouter { impl Drop for PrefillRouter {
fn drop(&mut self) { fn drop(&mut self) {
tracing::debug!("Dropping PrefillRouter, cancelling background activation task"); tracing::debug!("Dropping PrefillRouter, cancelling background activation task");
...@@ -369,6 +438,12 @@ impl ...@@ -369,6 +438,12 @@ impl
let request_id = context.id().to_string(); let request_id = context.id().to_string();
let engine_ctx = context.context(); let engine_ctx = context.context();
// GAIE Stage 1: the presence of the empty query_instance_id signals query-only mode
// State machine: "" -> "prefill" -> "decode" (disagg) OR "" -> aggregated worker (agg fallback)
let is_gaie_stage1 = req
.get_annotation_value("query_instance_id")
.is_some_and(|s| s.is_empty());
// Save original max_tokens for decode // Save original max_tokens for decode
let original_max_tokens = req.stop_conditions.max_tokens; let original_max_tokens = req.stop_conditions.max_tokens;
...@@ -376,9 +451,16 @@ impl ...@@ -376,9 +451,16 @@ impl
let mut prefill_req = req.clone(); let mut prefill_req = req.clone();
prefill_req.stop_conditions.max_tokens = Some(1); prefill_req.stop_conditions.max_tokens = Some(1);
// Try build_bootstrap_info optimization // Prepare prefill request for GAIE flows (Stage 1 or Stage 2)
let prefill_result = if let Some((worker_id, dp_rank, bootstrap_info)) = Self::prepare_prefill_for_gaie(&mut prefill_req, is_gaie_stage1);
self.build_bootstrap_info(&prefill_req).await
// Try build_bootstrap_info optimization (skip for GAIE Stage 1 which needs query-only flow)
// For GAIE Stage 2, use target_prefill_worker_id if provided
let preselected_worker = prefill_req.target_prefill_worker_id;
let prefill_result = if !is_gaie_stage1 {
if let Some((worker_id, dp_rank, bootstrap_info)) = self
.build_bootstrap_info(&prefill_req, preselected_worker)
.await
{ {
let bootstrap_room = bootstrap_info.bootstrap_room; let bootstrap_room = bootstrap_info.bootstrap_room;
...@@ -408,6 +490,15 @@ impl ...@@ -408,6 +490,15 @@ impl
let prefill_context = Context::with_id(prefill_req, request_id.clone()); let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context()); engine_ctx.link_child(prefill_context.context());
self.call_prefill(prefill_context)
.await
.map(|(result, worker_id)| (Some(result), worker_id, None))
}
} else {
// GAIE Stage 1: Use original path (no bootstrap optimization)
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
self.call_prefill(prefill_context) self.call_prefill(prefill_context)
.await .await
.map(|(result, worker_id)| (Some(result), worker_id, None)) .map(|(result, worker_id)| (Some(result), worker_id, None))
...@@ -429,8 +520,13 @@ impl ...@@ -429,8 +520,13 @@ impl
let mut decode_req = req; let mut decode_req = req;
// Update request with prefill result if available (only in original path) // Update request with prefill result
if let Some(prefill_result) = maybe_prefill_result { if is_gaie_stage1 {
if let Some(ref prefill_result) = maybe_prefill_result {
Self::prepare_decode_for_gaie_stage1(&mut decode_req, prefill_result);
}
} else if let Some(prefill_result) = maybe_prefill_result {
// Normal or GAIE Stage 2: Set prefill_result for decode
decode_req.prefill_result = Some(prefill_result); decode_req.prefill_result = Some(prefill_result);
} }
...@@ -449,6 +545,15 @@ impl ...@@ -449,6 +545,15 @@ impl
..existing_override.unwrap_or_default() ..existing_override.unwrap_or_default()
}); });
// GAIE Stage 2: Route to pre-selected decode worker if specified
if let Some(decode_worker_id) = decode_req.target_decode_worker_id {
decode_req.backend_instance_id = Some(decode_worker_id);
tracing::debug!(
decode_worker_id = decode_worker_id,
"GAIE Stage 2: Routing decode to pre-selected worker"
);
}
// Map the modified request through with preserved context // Map the modified request through with preserved context
let decode_request = context.map(|_| decode_req); let decode_request = context.map(|_| decode_req);
next.generate(decode_request).await next.generate(decode_request).await
......
...@@ -238,10 +238,13 @@ impl OpenAIPreprocessor { ...@@ -238,10 +238,13 @@ impl OpenAIPreprocessor {
builder.annotations(request.annotations().unwrap_or_default()); builder.annotations(request.annotations().unwrap_or_default());
builder.mdc_sum(Some(self.mdcsum.clone())); builder.mdc_sum(Some(self.mdcsum.clone()));
builder.estimated_prefix_hit_num_blocks(None); builder.estimated_prefix_hit_num_blocks(None);
// Extract backend_instance_id and extra_fields from nvext if present // Extract backend_instance_id, extra_fields, and worker IDs from nvext if present
if let Some(nvext) = request.nvext() { if let Some(nvext) = request.nvext() {
builder.backend_instance_id(nvext.backend_instance_id); builder.backend_instance_id(nvext.backend_instance_id);
builder.extra_fields(nvext.extra_fields.clone()); builder.extra_fields(nvext.extra_fields.clone());
// GAIE Stage 2: Extract targeted worker IDs for disaggregated serving
builder.target_prefill_worker_id(nvext.prefill_worker_id);
builder.target_decode_worker_id(nvext.decode_worker_id);
} }
Ok(builder) Ok(builder)
......
...@@ -74,7 +74,7 @@ pub struct BackendOutput { ...@@ -74,7 +74,7 @@ pub struct BackendOutput {
/// ///
/// This is the minimal raw output from the LLM engine. The Backend may then apply multiple /// This is the minimal raw output from the LLM engine. The Backend may then apply multiple
/// levels of post-processing before the BackendOutput is returns /// levels of post-processing before the BackendOutput is returns
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct LLMEngineOutput { pub struct LLMEngineOutput {
// new token_ids // new token_ids
pub token_ids: Vec<TokenIdType>, pub token_ids: Vec<TokenIdType>,
......
...@@ -118,12 +118,34 @@ pub struct PreprocessedRequest { ...@@ -118,12 +118,34 @@ pub struct PreprocessedRequest {
#[builder(default)] #[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub extra_fields: Option<Vec<String>>, pub extra_fields: Option<Vec<String>>,
/// Targeted prefill worker ID for disaggregated serving (GAIE Stage 2)
/// When set, the prefill request will be routed to this specific worker.
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub target_prefill_worker_id: Option<u64>,
/// Targeted decode worker ID for disaggregated serving (GAIE Stage 2)
/// When set, the decode request will be routed to this specific worker.
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub target_decode_worker_id: Option<u64>,
} }
impl PreprocessedRequest { impl PreprocessedRequest {
pub fn has_annotation(&self, annotation: &str) -> bool { pub fn has_annotation(&self, annotation: &str) -> bool {
self.annotations.contains(&annotation.to_string()) self.annotations.contains(&annotation.to_string())
} }
/// Get the value of an annotation in the format "key:value"
/// Returns None if the annotation is not found or has no value
pub fn get_annotation_value(&self, key: &str) -> Option<String> {
let prefix = format!("{}:", key);
self.annotations
.iter()
.find(|a| a.starts_with(&prefix))
.map(|a| a[prefix.len()..].to_string())
}
} }
impl PreprocessedRequest { impl PreprocessedRequest {
......
...@@ -400,13 +400,19 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes ...@@ -400,13 +400,19 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
tracker.record_first_token(); tracker.record_first_token();
} }
// Extract worker_id from disaggregated_params // Extract worker_id and token_ids from disaggregated_params
let worker_id_info = delta let worker_id_info = delta
.disaggregated_params .disaggregated_params
.as_ref() .as_ref()
.and_then(|params| params.get("worker_id")) .and_then(|params| params.get("worker_id"))
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok()); .and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok());
let token_ids = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("token_ids"))
.and_then(|v| serde_json::from_value::<Vec<u32>>(v.clone()).ok());
// Get timing info if this is the final response (has finish_reason) // Get timing info if this is the final response (has finish_reason)
let timing_info: Option<TimingInfo> = if finish_reason.is_some() { let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
self.timing_tracker.as_ref().map(|tracker| { self.timing_tracker.as_ref().map(|tracker| {
...@@ -417,11 +423,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes ...@@ -417,11 +423,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
None None
}; };
// Inject nvext if we have worker_id or timing // Inject nvext if we have worker_id, token_ids, or timing
if worker_id_info.is_some() || timing_info.is_some() { if worker_id_info.is_some() || token_ids.is_some() || timing_info.is_some() {
let nvext_response = NvExtResponse { let nvext_response = NvExtResponse {
worker_id: worker_id_info.clone(), worker_id: worker_id_info.clone(),
timing: timing_info, timing: timing_info,
token_ids: token_ids.clone(),
}; };
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) { if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
...@@ -433,6 +440,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes ...@@ -433,6 +440,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
info.decode_worker_id info.decode_worker_id
); );
} }
if let Some(ref tokens) = token_ids {
tracing::debug!(
"Injected token_ids into chat completion nvext: {} tokens",
tokens.len()
);
}
} }
} }
......
...@@ -295,13 +295,19 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for ...@@ -295,13 +295,19 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
tracker.record_first_token(); tracker.record_first_token();
} }
// Extract worker_id from disaggregated_params // Extract worker_id and token_ids from disaggregated_params
let worker_id_info = delta let worker_id_info = delta
.disaggregated_params .disaggregated_params
.as_ref() .as_ref()
.and_then(|params| params.get("worker_id")) .and_then(|params| params.get("worker_id"))
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok()); .and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok());
let token_ids = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("token_ids"))
.and_then(|v| serde_json::from_value::<Vec<u32>>(v.clone()).ok());
// Get timing info if this is the final response (has finish_reason) // Get timing info if this is the final response (has finish_reason)
let timing_info: Option<TimingInfo> = if finish_reason.is_some() { let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
self.timing_tracker.as_ref().map(|tracker| { self.timing_tracker.as_ref().map(|tracker| {
...@@ -312,11 +318,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for ...@@ -312,11 +318,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
None None
}; };
// Inject nvext if we have worker_id or timing // Inject nvext if we have worker_id, token_ids, or timing
if worker_id_info.is_some() || timing_info.is_some() { if worker_id_info.is_some() || token_ids.is_some() || timing_info.is_some() {
let nvext_response = NvExtResponse { let nvext_response = NvExtResponse {
worker_id: worker_id_info.clone(), worker_id: worker_id_info.clone(),
timing: timing_info, timing: timing_info,
token_ids: token_ids.clone(),
}; };
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) { if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
...@@ -328,6 +335,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for ...@@ -328,6 +335,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
info.decode_worker_id info.decode_worker_id
); );
} }
if let Some(ref tokens) = token_ids {
tracing::debug!(
"Injected token_ids into completions nvext: {} tokens",
tokens.len()
);
}
} }
} }
......
...@@ -35,6 +35,11 @@ pub struct NvExtResponse { ...@@ -35,6 +35,11 @@ pub struct NvExtResponse {
/// Populated when client requests `extra_fields: ["timing"]` /// Populated when client requests `extra_fields: ["timing"]`
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub timing: Option<TimingInfo>, pub timing: Option<TimingInfo>,
/// Token IDs for GAIE Stage 1 query-only mode
/// Contains the tokenized prompt for reuse in Stage 2
#[serde(skip_serializing_if = "Option::is_none")]
pub token_ids: Option<Vec<u32>>,
} }
/// NVIDIA LLM extensions to the OpenAI API /// NVIDIA LLM extensions to the OpenAI API
...@@ -87,6 +92,18 @@ pub struct NvExt { ...@@ -87,6 +92,18 @@ pub struct NvExt {
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))] #[builder(default, setter(strip_option))]
pub extra_fields: Option<Vec<String>>, pub extra_fields: Option<Vec<String>>,
/// Targeted prefill worker ID for disaggregated serving (GAIE Stage 2)
/// When set, the request will be routed to this specific prefill worker.
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prefill_worker_id: Option<u64>,
/// Targeted decode worker ID for disaggregated serving (GAIE Stage 2)
/// When set, the request will be routed to this specific decode worker.
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub decode_worker_id: Option<u64>,
} }
impl Default for NvExt { impl Default for NvExt {
...@@ -133,6 +150,8 @@ mod tests { ...@@ -133,6 +150,8 @@ mod tests {
assert_eq!(nv_ext.token_data, None); assert_eq!(nv_ext.token_data, None);
assert_eq!(nv_ext.max_thinking_tokens, None); assert_eq!(nv_ext.max_thinking_tokens, None);
assert_eq!(nv_ext.extra_fields, None); assert_eq!(nv_ext.extra_fields, None);
assert_eq!(nv_ext.prefill_worker_id, None);
assert_eq!(nv_ext.decode_worker_id, None);
} }
// Test valid builder configurations // Test valid builder configurations
...@@ -157,4 +176,18 @@ mod tests { ...@@ -157,4 +176,18 @@ mod tests {
// Validate the built struct // Validate the built struct
assert!(nv_ext.validate().is_ok()); assert!(nv_ext.validate().is_ok());
} }
// Test GAIE Stage 2 disaggregated worker IDs
#[test]
fn test_nv_ext_disagg_worker_ids() {
let nv_ext = NvExt::builder()
.prefill_worker_id(100)
.decode_worker_id(200)
.build()
.unwrap();
assert_eq!(nv_ext.prefill_worker_id, Some(100));
assert_eq!(nv_ext.decode_worker_id, Some(200));
assert!(nv_ext.validate().is_ok());
}
} }
...@@ -1026,9 +1026,10 @@ def _test_router_query_instance_id( ...@@ -1026,9 +1026,10 @@ def _test_router_query_instance_id(
asyncio.run(send_request_with_retry(url, test_payload)) asyncio.run(send_request_with_retry(url, test_payload))
# Test payload with query_instance_id annotation # Test payload with query_instance_id annotation
# Format: "query_instance_id:" (colon with empty value) for GAIE aggregated mode
annotated_payload = { annotated_payload = {
**test_payload, **test_payload,
"nvext": {"annotations": ["query_instance_id"]}, "nvext": {"annotations": ["query_instance_id:"]},
} }
async def test_annotation_response(): async def test_annotation_response():
...@@ -1053,100 +1054,80 @@ def _test_router_query_instance_id( ...@@ -1053,100 +1054,80 @@ def _test_router_query_instance_id(
f"Full SSE response ({len(full_response)} bytes):\n{full_response}" f"Full SSE response ({len(full_response)} bytes):\n{full_response}"
) )
# Parse and validate the response structure # Parse the SSE response to extract the first chunk with nvext data
events = [] # New format: nvext contains worker_id and token_ids
sse_parts = full_response.split("\n\n") sse_parts = full_response.split("\n\n")
worker_id_info = None
token_list = None
for part in sse_parts: for part in sse_parts:
part = part.strip() part = part.strip()
if not part: if not part or not part.startswith("data:"):
continue continue
if part.startswith("event:"): data_str = part.split("data:", 1)[1].strip()
lines = part.split("\n") if data_str == "[DONE]":
event_line = next( continue
(line for line in lines if line.startswith("event:")),
None,
)
data_line = next(
(
line
for line in lines
if line.startswith("data:") or line.startswith(":")
),
None,
)
if event_line and data_line:
event_type = event_line.split(":", 1)[1].strip()
if data_line.startswith("data:"):
data_value = data_line.split(":", 1)[1].strip()
else:
data_value = data_line.split(":", 1)[1].strip()
events.append((event_type, data_value))
elif part.startswith("data:"):
data_value = part.split(":", 1)[1].strip()
logger.info(f"Parsed events: {events}") try:
chunk = json.loads(data_str)
logger.info(f"Parsed chunk: {json.dumps(chunk, indent=2)}")
# Validate worker_instance_id event # Extract nvext data containing worker_id and token_ids
worker_event = next( nvext = chunk.get("nvext", {})
(e for e in events if e[0] == "worker_instance_id"), None if nvext:
if "worker_id" in nvext:
worker_id_info = nvext["worker_id"]
logger.info(
f"Found worker_id info: {worker_id_info}"
) )
assert ( if "token_ids" in nvext:
worker_event is not None token_list = nvext["token_ids"]
), f"Missing worker_instance_id event in: {events}" logger.info(
f"Found token_ids: {len(token_list)} tokens"
# Validate token_data event
token_event = next(
(e for e in events if e[0] == "token_data"), None
) )
except json.JSONDecodeError:
continue
# Validate worker_id info
assert ( assert (
token_event is not None worker_id_info is not None
), f"Missing token_data event in: {events}" ), f"Missing worker_id in nvext. Response: {full_response}"
token_data_str = token_event[1].strip('"') # For aggregated mode, both prefill and decode should be the same
try: prefill_worker_id = worker_id_info.get("prefill_worker_id")
token_list = json.loads(token_data_str) decode_worker_id = worker_id_info.get("decode_worker_id")
except json.JSONDecodeError as e: assert (
raise AssertionError( prefill_worker_id is not None
f"token_data is not valid JSON: {token_data_str}, error: {e}" ), f"Missing prefill_worker_id in worker_id: {worker_id_info}"
) assert (
decode_worker_id is not None
), f"Missing decode_worker_id in worker_id: {worker_id_info}"
assert (
prefill_worker_id == decode_worker_id
), f"For aggregated mode, prefill and decode worker should be same: {worker_id_info}"
# Validate token_ids
assert (
token_list is not None
), f"Missing token_ids in nvext. Response: {full_response}"
assert isinstance( assert isinstance(
token_list, list token_list, list
), f"token_data should be a list, got: {type(token_list)}" ), f"token_ids should be a list, got: {type(token_list)}"
assert ( assert (
len(token_list) > 0 len(token_list) > 0
), f"token_data should not be empty: {token_list}" ), f"token_ids should not be empty: {token_list}"
assert all( assert all(
isinstance(token, int) for token in token_list isinstance(token, int) for token in token_list
), f"All tokens should be integers: {token_list}" ), f"All tokens should be integers: {token_list}"
logger.info( logger.info(
f"Valid token_data with {len(token_list)} tokens: {token_list[:10]}{'...' if len(token_list) > 10 else ''}" f"Valid token_ids with {len(token_list)} tokens: {token_list[:10]}{'...' if len(token_list) > 10 else ''}"
)
# Validate that no actual generation happened (should only be metadata)
# This proves the early return worked correctly
generation_indicators = [
"choices",
"content",
"delta",
"finish_reason",
]
for indicator in generation_indicators:
assert (
indicator not in full_response.lower()
), f"Found generation indicator '{indicator}' - request should not have been routed to worker"
logger.info(
"No generation content found - early return worked correctly"
) )
return { return {
"worker_instance_id": worker_event[1].strip('"'), "prefill_worker_id": prefill_worker_id,
"decode_worker_id": decode_worker_id,
"token_count": len(token_list), "token_count": len(token_list),
"tokens": token_list, "tokens": token_list,
} }
...@@ -1154,7 +1135,8 @@ def _test_router_query_instance_id( ...@@ -1154,7 +1135,8 @@ def _test_router_query_instance_id(
result = asyncio.run(test_annotation_response()) result = asyncio.run(test_annotation_response())
logger.info("Successfully validated query_instance_id annotation response:") logger.info("Successfully validated query_instance_id annotation response:")
logger.info(f"Worker ID: {result['worker_instance_id']}") logger.info(f"Prefill Worker ID: {result['prefill_worker_id']}")
logger.info(f"Decode Worker ID: {result['decode_worker_id']}")
logger.info(f"Token count: {result['token_count']}") logger.info(f"Token count: {result['token_count']}")
finally: finally:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment