".github/vscode:/vscode.git/clone" did not exist on "15b86408a89d5b998409e7fbe7850e937cc837da"
Unverified Commit 15b49818 authored by atchernych's avatar atchernych Committed by GitHub
Browse files

feat: support disag serving in GAIE [DEP-659] (#4756)


Signed-off-by: default avatarAnna Tchernych <atchernych@nvidia.com>
parent 7a3b15e6
......@@ -161,10 +161,10 @@ index 670d922..0cf04cb 100644
}
diff --git a/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go b/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go
new file mode 100644
index 0000000..cd9a0b5
index 0000000..1c8f979
--- /dev/null
+++ b/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go
@@ -0,0 +1,119 @@
@@ -0,0 +1,171 @@
+package dynamo_inject_workerid
+
+import (
......@@ -182,6 +182,7 @@ index 0000000..cd9a0b5
+ typeString = "dynamo-inject-workerid"
+ pluginName = "dynamo-inject-workerid"
+ WorkerIDHeader = "x-worker-instance-id"
+ PrefillWorkerIDHeader = "x-prefiller-host-port"
+ tokenDataAnnotationKey = "dynamo/token-data"
+)
+
......@@ -222,11 +223,18 @@ index 0000000..cd9a0b5
+ if req.Headers == nil {
+ req.Headers = map[string]string{}
+ }
+
+ // Handle worker instance ID
+ wid := strings.TrimSpace(req.Headers[WorkerIDHeader])
+ if wid == "" {
+ return
+ }
+ if wid != "" {
+ req.Headers[WorkerIDHeader] = wid
+ }
+
+ // Handle prefill worker ID
+ prefillWid := strings.TrimSpace(req.Headers[PrefillWorkerIDHeader])
+ if prefillWid != "" {
+ req.Headers[PrefillWorkerIDHeader] = prefillWid
+ }
+}
+
+func (p *InjectWorkerIDPreRequest) MutateRequestBody(
......@@ -248,14 +256,28 @@ index 0000000..cd9a0b5
+ return
+ }
+
+ prefillWid := strings.TrimSpace(req.Headers[PrefillWorkerIDHeader])
+
+ nvext, _ := body["nvext"].(map[string]any)
+ if nvext == nil {
+ nvext = map[string]any{}
+ body["nvext"] = nvext
+ }
+
+ if prefillWid != "" && prefillWid != wid {
+ // Disaggregated mode: use prefill_worker_id and decode_worker_id
+ if prefillWidUint, err := strconv.ParseUint(prefillWid, 10, 64); err == nil {
+ nvext["prefill_worker_id"] = prefillWidUint
+ }
+ if widUint, err := strconv.ParseUint(wid, 10, 64); err == nil {
+ nvext["decode_worker_id"] = widUint
+ }
+ } else {
+ // Aggregated mode (empty prefill or prefill == decode): use backend_instance_id
+ if widUint, err := strconv.ParseUint(wid, 10, 64); err == nil {
+ nvext["backend_instance_id"] = widUint
+ }
+ }
+
+ if tokens, ok := req.Annotations[tokenDataAnnotationKey]; ok {
+ switch v := tokens.(type) {
......@@ -283,6 +305,36 @@ index 0000000..cd9a0b5
+ }
+ }
+ }
+
+ // Remove query_instance_id from nvext.annotations if present
+ if annotations, ok := nvext["annotations"]; ok {
+ switch annList := annotations.(type) {
+ case []string:
+ filtered := make([]string, 0, len(annList))
+ for _, ann := range annList {
+ if ann != "query_instance_id" {
+ filtered = append(filtered, ann)
+ }
+ }
+ if len(filtered) == 0 {
+ delete(nvext, "annotations")
+ } else {
+ nvext["annotations"] = filtered
+ }
+ case []any:
+ filtered := make([]any, 0, len(annList))
+ for _, ann := range annList {
+ if str, ok := ann.(string); !ok || str != "query_instance_id" {
+ filtered = append(filtered, ann)
+ }
+ }
+ if len(filtered) == 0 {
+ delete(nvext, "annotations")
+ } else {
+ nvext["annotations"] = filtered
+ }
+ }
+ }
+}
diff --git a/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml
new file mode 100644
......@@ -313,10 +365,10 @@ index 0000000..b689c00
+ - pluginRef: picker
diff --git a/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go
new file mode 100644
index 0000000..bc29c0a
index 0000000..75f30e9
--- /dev/null
+++ b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go
@@ -0,0 +1,424 @@
@@ -0,0 +1,446 @@
+package dynamo_kv_scorer
+
+/*
......@@ -367,13 +419,15 @@ index 0000000..bc29c0a
+ double router_temperature,
+ bool use_kv_events,
+ bool router_replica_sync,
+ bool enforce_disagg,
+ WorkerSelectionPipeline **pipeline_out);
+
+dynamo_llm_result_t dynamo_destroy_worker_selection_pipeline(WorkerSelectionPipeline *pipeline);
+
+dynamo_llm_result_t dynamo_query_worker_selection_and_annotate(WorkerSelectionPipeline *pipeline,
+ const char *request_json_c_str,
+ int64_t *worker_instance_id_out,
+ int64_t *decode_worker_id_out,
+ int64_t *prefill_worker_id_out,
+ uint32_t **token_ids_out,
+ size_t *token_count_out,
+ char **annotated_request_json_out);
......@@ -404,7 +458,9 @@ index 0000000..bc29c0a
+ PluginName = "dynamo-kv-scorer"
+ KVAwareScorerType = "kv-aware-scorer"
+ StateKeyWorkerInstanceID = schedtypes.StateKey("dynamo/worker-instance-id")
+ StateKeyPrefillWorkerID = schedtypes.StateKey("dynamo/prefill-worker-id")
+ WorkerIDHeader = "x-worker-instance-id"
+ PrefillWorkerIDHeader = "x-prefiller-host-port"
+ tokenDataAnnotationKey = "dynamo/token-data"
+)
+
......@@ -471,6 +527,7 @@ index 0000000..bc29c0a
+ ffiRouterTemperature float64
+ ffiKvBlockSize uint32
+ ffiWorkerID int64
+ ffiEnforceDisagg bool
+
+ runtimeInitialized bool
+
......@@ -484,6 +541,7 @@ index 0000000..bc29c0a
+ ffiComponent = getEnvOrDefault("DYNAMO_COMPONENT", "backend")
+ ffiModel = getEnvOrDefault("DYNAMO_MODEL", "Qwen/Qwen3-0.6B")
+ ffiWorkerID = getEnvInt64OrDefault("DYNAMO_WORKER_ID", 1)
+ ffiEnforceDisagg = getEnvBoolOrDefault("DYNAMO_ENFORCE_DISAGG", true) // TODO default to false
+
+ ffiOverlapScoreWeight = getEnvFloatOrDefault("DYNAMO_OVERLAP_SCORE_WEIGHT", -1.0)
+ ffiRouterTemperature = getEnvFloatOrDefault("DYNAMO_ROUTER_TEMPERATURE", -1.0)
......@@ -575,6 +633,7 @@ index 0000000..bc29c0a
+ C.double(ffiRouterTemperature),
+ C.bool(getEnvBoolOrDefault("DYNAMO_USE_KV_EVENTS", true)),
+ C.bool(getEnvBoolOrDefault("DYNAMO_ROUTER_REPLICA_SYNC", true)),
+ C.bool(ffiEnforceDisagg),
+ &pipeline,
+ )
+ if rc != C.DYNAMO_OK {
......@@ -595,13 +654,14 @@ index 0000000..bc29c0a
+) map[schedtypes.Pod]float64 {
+ logger := log.FromContext(ctx)
+
+ workerID, tokenData, err := k.callDynamoRouter(ctx, req)
+ workerID, prefillWorkerID, tokenData, err := k.callDynamoRouter(ctx, req)
+ if err != nil {
+ logger.V(logutil.DEFAULT).Error(err, "Dynamo call failed; proceeding without worker id")
+ } else if workerID != "" {
+ logger.V(logutil.DEFAULT).Info(
+ "Dynamo router selected worker",
+ "workerID", workerID,
+ "prefillWorkerID", prefillWorkerID,
+ "tokenDataCount", len(tokenData),
+ "tokenData", tokenData,
+ )
......@@ -610,6 +670,13 @@ index 0000000..bc29c0a
+ req.Headers = map[string]string{}
+ }
+ req.Headers[WorkerIDHeader] = workerID
+
+ // Set prefill worker ID if present
+ if prefillWorkerID != "" {
+ cycle.Write(StateKeyPrefillWorkerID, stateString(prefillWorkerID))
+ req.Headers[PrefillWorkerIDHeader] = prefillWorkerID
+ }
+
+ if len(tokenData) > 0 {
+ if req.Annotations == nil {
+ req.Annotations = map[string]any{}
......@@ -632,15 +699,15 @@ index 0000000..bc29c0a
+func (k *KVAwareScorer) callDynamoRouter(
+ ctx context.Context,
+ req *schedtypes.LLMRequest,
+) (string, []int64, error) {
+) (workerID string, prefillWorkerID string, tokenData []int64, err error) {
+ logger := log.FromContext(ctx)
+
+ if err := initFFI(); err != nil {
+ logger.V(logutil.DEFAULT).Error(err, "FFI init failed")
+ return "", nil, err
+ return "", "", nil, err
+ }
+ if !runtimeInitialized {
+ return "", nil, fmt.Errorf("dynamo runtime not initialized")
+ return "", "", nil, fmt.Errorf("dynamo runtime not initialized")
+ }
+
+ pipelineMutex.RLock()
......@@ -648,21 +715,22 @@ index 0000000..bc29c0a
+ pipelineMutex.RUnlock()
+
+ if currentPipeline == nil {
+ return "", nil, fmt.Errorf("dynamo worker selection pipeline not created")
+ return "", "", nil, fmt.Errorf("dynamo worker selection pipeline not created")
+ }
+
+ // Build OpenAI-compatible JSON request
+ requestBody := buildOpenAIRequest(req)
+ requestJSON, err := json.Marshal(requestBody)
+ if err != nil {
+ logger.V(logutil.DEFAULT).Error(err, "Failed to marshal OpenAI request")
+ return "", nil, fmt.Errorf("marshal OpenAI request: %w", err)
+ requestJSON, jsonErr := json.Marshal(requestBody)
+ if jsonErr != nil {
+ logger.V(logutil.DEFAULT).Error(jsonErr, "Failed to marshal OpenAI request")
+ return "", "", nil, fmt.Errorf("marshal OpenAI request: %w", jsonErr)
+ }
+ cRequestJSON := C.CString(string(requestJSON))
+ defer C.free(unsafe.Pointer(cRequestJSON))
+
+ // Output variables
+ var cWorkerID C.int64_t
+ var cDecodeWorkerID C.int64_t
+ var cPrefillWorkerID C.int64_t
+ var cTokens *C.uint32_t
+ var cTokenCount C.size_t
+ var cAnnotatedJSON *C.char
......@@ -671,13 +739,14 @@ index 0000000..bc29c0a
+ rc := C.dynamo_query_worker_selection_and_annotate(
+ currentPipeline,
+ cRequestJSON,
+ &cWorkerID,
+ &cDecodeWorkerID,
+ &cPrefillWorkerID,
+ &cTokens,
+ &cTokenCount,
+ &cAnnotatedJSON,
+ )
+ if rc != C.DYNAMO_OK {
+ return "", nil, fmt.Errorf("dynamo_query_worker_selection_and_annotate failed")
+ return "", "", nil, fmt.Errorf("dynamo_query_worker_selection_and_annotate failed")
+ }
+
+ // Copy tokens into Go memory and free C memory
......@@ -692,11 +761,16 @@ index 0000000..bc29c0a
+ }
+ C.dynamo_free_worker_selection_result(cTokens, cTokenCount, cAnnotatedJSON)
+
+ workerID := fmt.Sprintf("%d", int64(cWorkerID))
+ workerIDStr := fmt.Sprintf("%d", int64(cDecodeWorkerID))
+ prefillWorkerIDStr := ""
+ // Rust returns -1 for prefill_worker_id when not in disaggregated mode
+ if int64(cPrefillWorkerID) >= 0 {
+ prefillWorkerIDStr = fmt.Sprintf("%d", int64(cPrefillWorkerID))
+ }
+ logger.V(logutil.DEFAULT).Info("Worker selection completed",
+ "workerID", workerID, "tokenCount", count)
+ "workerID", workerIDStr, "prefillWorkerID", prefillWorkerIDStr, "tokenCount", count)
+
+ return workerID, tokens64, nil
+ return workerIDStr, prefillWorkerIDStr, tokens64, nil
+}
+
+func buildOpenAIRequest(req *schedtypes.LLMRequest) map[string]any {
......
......@@ -393,6 +393,10 @@ pub struct WorkerSelectionPipeline {
///
/// # Errors
/// Returns `DynamoLlmResult::ERR` on failure and does not write to `pipeline_out`.
/// # Safety
/// See detailed safety docs above. Additional parameter:
/// - `enforce_disagg`: If true, requests fail when disaggregated serving is unavailable.
/// If false, falls back to aggregated serving.
#[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline(
namespace_c_str: *const c_char,
......@@ -404,6 +408,7 @@ pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline(
router_temperature: f64,
use_kv_events: bool,
router_replica_sync: bool,
enforce_disagg: bool,
pipeline_out: *mut *mut WorkerSelectionPipeline,
) -> DynamoLlmResult {
if pipeline_out.is_null() {
......@@ -472,6 +477,7 @@ pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline(
router_mode,
(busy_threshold >= 0.0).then_some(busy_threshold),
kv_router_config,
enforce_disagg,
)
.await
};
......@@ -492,7 +498,8 @@ pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline(
}
/// Query worker selection on an existing pipeline and return:
/// - `worker_instance_id_out` (`i64`)
/// - `decode_worker_id_out` (`i64`): The decode worker ID (primary worker)
/// - `prefill_worker_id_out` (`i64`): The prefill worker ID (-1 if not in disaggregated mode)
/// - `token_ids_out` (heap-allocated `*mut u32`; caller must free via
/// `dynamo_free_worker_selection_result`)
/// - `token_count_out` (`usize`)
......@@ -513,10 +520,10 @@ pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline(
/// function returns `DynamoLlmResult::ERR`.
/// - Must remain valid for the duration of this call.
/// - Output pointers:
/// - `worker_instance_id_out`, `token_ids_out`, `token_count_out`,
/// - `decode_worker_id_out`, `prefill_worker_id_out`, `token_ids_out`, `token_count_out`,
/// and `annotated_request_json_out` must each be **non-null** and point to
/// writable memory for their respective types. On success, this function
/// writes to all four outputs exactly once.
/// writes to all five outputs exactly once.
/// - On **error**, outputs are left unmodified.
/// - Ownership & deallocation:
/// - On success, if there are zero tokens, `*token_ids_out` may be set to `NULL`
......@@ -540,11 +547,18 @@ pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline(
/// Returns `DynamoLlmResult::ERR` if any precondition fails (null/invalid pointers,
/// malformed UTF-8/JSON, pipeline errors, allocation failures, etc.). On error, no
/// output pointer is written.
///
/// # Output values
/// - `decode_worker_id_out`: The decode worker ID (primary worker in aggregated mode)
/// - `prefill_worker_id_out`: The prefill worker ID (only set in disaggregated mode, -1 if not present)
/// - `token_ids_out`, `token_count_out`: Token IDs and count
/// - `annotated_request_json_out`: The annotated request JSON
#[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_query_worker_selection_and_annotate(
pipeline: *mut WorkerSelectionPipeline,
request_json_c_str: *const c_char,
worker_instance_id_out: *mut i64,
decode_worker_id_out: *mut i64,
prefill_worker_id_out: *mut i64,
token_ids_out: *mut *mut u32,
token_count_out: *mut usize,
annotated_request_json_out: *mut *mut c_char,
......@@ -553,7 +567,8 @@ pub unsafe extern "C" fn dynamo_query_worker_selection_and_annotate(
tracing::error!("Pipeline pointer is null");
return DynamoLlmResult::ERR;
}
if worker_instance_id_out.is_null()
if decode_worker_id_out.is_null()
|| prefill_worker_id_out.is_null()
|| token_ids_out.is_null()
|| token_count_out.is_null()
|| annotated_request_json_out.is_null()
......@@ -579,7 +594,7 @@ pub unsafe extern "C" fn dynamo_query_worker_selection_and_annotate(
let pl = unsafe { &*pipeline };
let fut = async { query_worker_selection_and_annotate(&pl.engine, request).await };
let (worker_id, tokens, annotated_req) = match pl.wk.runtime().secondary().block_on(fut) {
let (result, annotated_req) = match pl.wk.runtime().secondary().block_on(fut) {
Ok(v) => v,
Err(e) => {
tracing::error!(error = ?e, "query_worker_selection_and_annotate failed");
......@@ -587,10 +602,10 @@ pub unsafe extern "C" fn dynamo_query_worker_selection_and_annotate(
}
};
let tokens_ptr = if tokens.is_empty() {
let tokens_ptr = if result.tokens.is_empty() {
std::ptr::null_mut()
} else {
let len = tokens.len();
let len = result.tokens.len();
let layout = std::alloc::Layout::array::<u32>(len).unwrap();
let ptr = unsafe { std::alloc::alloc(layout) as *mut u32 };
if ptr.is_null() {
......@@ -598,7 +613,7 @@ pub unsafe extern "C" fn dynamo_query_worker_selection_and_annotate(
return DynamoLlmResult::ERR;
}
unsafe {
std::ptr::copy_nonoverlapping(tokens.as_ptr(), ptr, len);
std::ptr::copy_nonoverlapping(result.tokens.as_ptr(), ptr, len);
}
ptr
};
......@@ -606,11 +621,11 @@ pub unsafe extern "C" fn dynamo_query_worker_selection_and_annotate(
let annotated_json = match serde_json::to_string(&annotated_req) {
Ok(s) => s,
Err(e) => {
let layout = std::alloc::Layout::array::<u32>(tokens.len()).unwrap();
if !tokens_ptr.is_null() {
let layout = std::alloc::Layout::array::<u32>(result.tokens.len()).unwrap();
unsafe {
std::alloc::dealloc(tokens_ptr as *mut u8, layout);
}
if !tokens_ptr.is_null() {
tracing::error!(error = ?e, "serialize annotated request failed");
}
return DynamoLlmResult::ERR;
......@@ -621,7 +636,7 @@ pub unsafe extern "C" fn dynamo_query_worker_selection_and_annotate(
Err(e) => {
tracing::error!(error = ?e, "CString::new for annotated JSON failed");
if !tokens_ptr.is_null() {
let layout = std::alloc::Layout::array::<u32>(tokens.len()).unwrap();
let layout = std::alloc::Layout::array::<u32>(result.tokens.len()).unwrap();
unsafe {
std::alloc::dealloc(tokens_ptr as *mut u8, layout);
}
......@@ -630,9 +645,10 @@ pub unsafe extern "C" fn dynamo_query_worker_selection_and_annotate(
}
};
unsafe {
*worker_instance_id_out = worker_id;
*decode_worker_id_out = result.decode_worker_id.unwrap_or(0);
*prefill_worker_id_out = result.prefill_worker_id.unwrap_or(-1);
*token_ids_out = tokens_ptr;
*token_count_out = tokens.len();
*token_count_out = result.tokens.len();
*annotated_request_json_out = cjson.into_raw();
}
DynamoLlmResult::OK
......@@ -724,96 +740,77 @@ pub unsafe extern "C" fn dynamo_free_worker_selection_result(
DynamoLlmResult::OK
}
/// Result of worker selection extraction
#[derive(Debug, Clone, Default)]
pub struct WorkerSelectionResult {
/// Decode worker ID (primary worker for aggregated, decode-only for disaggregated)
pub decode_worker_id: Option<i64>,
/// Prefill worker ID (only present in disaggregated mode)
pub prefill_worker_id: Option<i64>,
/// Token IDs from tokenization
pub tokens: Vec<u32>,
}
/// Helper function to extract worker selection information from the annotation stream
///
/// The response format (from disaggregated_params in nvext):
/// - worker_id: {"prefill_worker_id": 123, "decode_worker_id": 456}
/// - token_ids: [1, 2, 3, ...]
pub async fn extract_worker_selection_from_stream(
mut stream: Pin<Box<dyn AsyncEngineStream<Annotated<NvCreateChatCompletionStreamResponse>>>>,
) -> anyhow::Result<(i64, Vec<u32>)> {
) -> anyhow::Result<WorkerSelectionResult> {
use dynamo_llm::protocols::openai::nvext::WorkerIdInfo;
use futures::StreamExt;
let mut worker_id: i64 = 0;
let mut tokens: Vec<u32> = Vec::new();
let mut result = WorkerSelectionResult::default();
while let Some(response) = stream.next().await {
let Some(event) = &response.event else {
tracing::error!("Response has no event field");
continue;
};
match event.as_str() {
"worker_instance_id" => {
tracing::debug!("Found worker_instance_id event");
let Some(first_comment) = response.comment.as_ref().and_then(|v| v.first()) else {
tracing::debug!("worker_instance_id event without comments");
continue;
};
// Try JSON string first (e.g. `"1732646935200805498"`), then plain integer.
if let Ok(id_string) = serde_json::from_str::<String>(first_comment) {
match id_string.parse::<i64>() {
Ok(parsed_id) => {
worker_id = parsed_id;
tracing::debug!("parsed worker_id from JSON string: {}", worker_id);
}
Err(_) => {
tracing::error!(
"failed to parse number from JSON string: '{}'",
id_string
// Check for data in nvext (worker_id and token_ids are direct fields)
// nvext is a serde_json::Value, so we access it as a JSON object
if let Some(data) = &response.data
&& let Some(nvext) = &data.nvext
{
// Extract worker_id
if let Some(worker_id_value) = nvext.get("worker_id")
&& let Ok(worker_info) =
serde_json::from_value::<WorkerIdInfo>(worker_id_value.clone())
{
result.decode_worker_id = worker_info.decode_worker_id.map(|id| id as i64);
result.prefill_worker_id = worker_info.prefill_worker_id.map(|id| id as i64);
tracing::debug!(
decode_worker_id = ?result.decode_worker_id,
prefill_worker_id = ?result.prefill_worker_id,
"Parsed worker_id from nvext"
);
}
}
continue;
}
match first_comment.parse::<i64>() {
Ok(parsed_id) => {
worker_id = parsed_id;
tracing::debug!("parsed worker_id directly: {}", worker_id);
}
Err(_) => {
tracing::error!("failed to parse worker_id from: '{}'", first_comment);
}
}
}
"token_data" => {
tracing::debug!("Found token_data event");
let Some(first_comment) = response.comment.as_ref().and_then(|v| v.first()) else {
tracing::debug!("token_data event without comments");
continue;
};
tracing::debug!("Token comment: '{}'", first_comment);
match serde_json::from_str::<Vec<u32>>(first_comment) {
Ok(parsed_tokens) => {
tokens = parsed_tokens;
tracing::debug!("Successfully parsed {} tokens", tokens.len());
}
Err(e) => {
tracing::error!("Failed to parse tokens from '{}': {}", first_comment, e);
}
}
}
other => {
tracing::debug!("Unknown event type: '{}'", other);
// Extract token_ids
if let Some(token_ids_value) = nvext.get("token_ids")
&& let Ok(parsed_tokens) =
serde_json::from_value::<Vec<u32>>(token_ids_value.clone())
{
result.tokens = parsed_tokens;
tracing::debug!(
"Successfully parsed {} tokens from nvext",
result.tokens.len()
);
}
}
}
tracing::info!(
"Final worker_id={}, tokens.len()={}",
worker_id,
tokens.len()
decode_worker_id = ?result.decode_worker_id,
prefill_worker_id = ?result.prefill_worker_id,
token_count = result.tokens.len(),
"Worker selection extraction complete"
);
Ok((worker_id, tokens))
Ok(result)
}
/// Utility function to add the "query_instance_id" annotation to an OpenAI request
///
/// This function modifies the request to include the annotation that signals the KV router
/// to return worker selection information (worker_instance_id and token_data) instead of
/// to return worker selection information (worker_fid and token_data) instead of
/// performing actual inference.
///
/// # Parameters
......@@ -824,28 +821,73 @@ pub async fn extract_worker_selection_from_stream(
pub fn add_query_instance_id(
request: &mut NvCreateChatCompletionRequest,
) -> &mut NvCreateChatCompletionRequest {
add_annotation_unique(request, "query_instance_id")
// Send empty value - router treats empty as aggregated / aggregated worker selection
set_kv_annotation(request, "query_instance_id".to_string(), "")
}
/// Utility function to add worker_instance_id annotation to an OpenAI request
pub fn add_worker_instance_id_annotation(
/// Set worker IDs directly on the NvExt fields for GAIE Stage 2
///
/// For disaggregated mode: sets `prefill_worker_id` and `decode_worker_id`
/// For aggregated mode: sets `backend_instance_id` (when both IDs are the same)
pub fn set_worker_ids_for_stage2(
request: &mut NvCreateChatCompletionRequest,
worker_id: i64,
decode_worker_id: Option<i64>,
prefill_worker_id: Option<i64>,
) -> &mut NvCreateChatCompletionRequest {
set_kv_annotation(
request,
"worker_instance_id".to_string(),
worker_id.to_string(),
)
let nvext = request.nvext.get_or_insert_with(|| {
NvExt::builder()
.build()
.expect("NvExt builder should not fail")
});
// Check if this is aggregated mode (same worker for both)
let is_aggregated = prefill_worker_id == decode_worker_id;
if is_aggregated {
// Aggregated: use backend_instance_id for direct routing
if let Some(id) = decode_worker_id {
nvext.backend_instance_id = Some(id as u64);
tracing::debug!(
backend_instance_id = id,
"GAIE Stage 2 Aggregated: Setting backend_instance_id"
);
}
} else {
// Disaggregated: use separate prefill and decode worker IDs
if let Some(id) = prefill_worker_id {
nvext.prefill_worker_id = Some(id as u64);
}
if let Some(id) = decode_worker_id {
nvext.decode_worker_id = Some(id as u64);
}
tracing::debug!(
prefill_worker_id = ?prefill_worker_id,
decode_worker_id = ?decode_worker_id,
"GAIE Stage 2 Disaggregated: Setting prefill and decode worker IDs"
);
}
request
}
/// Utility function to add token_data annotation to an OpenAI request
pub fn add_token_data_annotation<'a>(
/// Set token_data directly on the NvExt field for GAIE Stage 2
pub fn set_token_data_for_stage2<'a>(
request: &'a mut NvCreateChatCompletionRequest,
tokens: &[u32],
) -> &'a mut NvCreateChatCompletionRequest {
let tokens_json = serde_json::to_string(tokens).unwrap_or_default();
set_kv_annotation(request, "token_data".to_string(), tokens_json)
let nvext = request.nvext.get_or_insert_with(|| {
NvExt::builder()
.build()
.expect("NvExt builder should not fail")
});
nvext.token_data = Some(tokens.to_vec());
tracing::debug!(
token_count = tokens.len(),
"GAIE Stage 2: Setting token_data"
);
request
}
/// Ensure `nvext` exists and return a mutable slice of annotations.
......@@ -858,19 +900,6 @@ fn ensure_annotations(request: &mut NvCreateChatCompletionRequest) -> &mut Vec<S
nvext.annotations.get_or_insert_with(Vec::new)
}
/// Add a plain annotation once.
fn add_annotation_unique(
request: &mut NvCreateChatCompletionRequest,
annotation: impl Into<String>,
) -> &mut NvCreateChatCompletionRequest {
let ann = annotation.into();
let annotations = ensure_annotations(request);
if !annotations.iter().any(|a| a == &ann) {
annotations.push(ann);
}
request
}
/// Set a `key:value` annotation.
fn set_kv_annotation(
request: &mut NvCreateChatCompletionRequest,
......@@ -885,38 +914,153 @@ fn set_kv_annotation(
request
}
/// Wrapper function that queries worker selection and annotates the original request
/// Wrapper function that queries worker selection and prepares the request for GAIE Stage 2
///
/// This function performs the complete flow:
/// 1. Clones the original request and adds "query_instance_id" annotation
/// This function performs the complete GAIE Stage 1 flow:
/// 1. Clones the original request and adds "query_instance_id:" (empty) annotation
/// 2. Calls engine.generate() with the modified request
/// 3. Extracts worker_instance_id and tokens from the response stream
/// 4. Adds worker_instance_id and token_data annotations to the original request
/// 5. Returns (worker_id, tokens, annotated_original_request)
/// 3. Extracts worker_id info and tokens from the response stream
/// 4. Sets the appropriate NvExt fields on the original request for Stage 2:
/// - Disaggregated: prefill_worker_id, decode_worker_id, token_data
/// - Aggregated: backend_instance_id, token_data
/// 5. Returns WorkerSelectionResult and the modified request ready for Stage 2
///
/// # Parameters
/// - `engine`: The worker selection pipeline engine
/// - `original_request`: The original OpenAI request to process
///
/// # Returns
/// A tuple containing (worker_instance_id, tokens, modified_original_request)
/// where the modified_original_request has worker_instance_id and token_data annotations added
/// A tuple containing (WorkerSelectionResult, modified_original_request)
/// where the modified_original_request is ready for GAIE Stage 2 execution
pub async fn query_worker_selection_and_annotate(
engine: &ServiceEngine<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>,
mut original_request: NvCreateChatCompletionRequest,
) -> anyhow::Result<(i64, Vec<u32>, NvCreateChatCompletionRequest)> {
) -> anyhow::Result<(WorkerSelectionResult, NvCreateChatCompletionRequest)> {
// GAIE Stage 1: Query for worker selection
let mut query_request = original_request.clone();
add_query_instance_id(&mut query_request);
let single_in = SingleIn::new(query_request);
let response_stream = engine.generate(single_in).await?;
let (worker_id, tokens) = extract_worker_selection_from_stream(response_stream).await?;
add_worker_instance_id_annotation(&mut original_request, worker_id);
add_token_data_annotation(&mut original_request, &tokens);
let result = extract_worker_selection_from_stream(response_stream).await?;
Ok((worker_id, tokens, original_request))
// Prepare request for GAIE Stage 2: Set NvExt fields directly
set_worker_ids_for_stage2(
&mut original_request,
result.decode_worker_id,
result.prefill_worker_id,
);
set_token_data_for_stage2(&mut original_request, &result.tokens);
Ok((result, original_request))
}
/// Spawn a background task to watch for prefill models and activate prefill routers.
/// This is a lightweight watcher that only handles prefill model discovery.
fn spawn_prefill_watcher(
drt: DistributedRuntime,
model_manager: Arc<ModelManager>,
target_namespace: String,
) {
use dynamo_llm::model_card::ModelDeploymentCard;
use dynamo_runtime::discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryQuery};
use dynamo_runtime::protocols::EndpointId;
use futures::StreamExt;
tokio::spawn(async move {
let discovery = drt.discovery();
let mut stream = match discovery
.list_and_watch(DiscoveryQuery::AllModels, None)
.await
{
Ok(s) => s,
Err(e) => {
tracing::error!(error = %e, "Failed to start prefill discovery stream");
return;
}
};
while let Some(result) = stream.next().await {
let event = match result {
Ok(e) => e,
Err(e) => {
tracing::error!(error = %e, "Error in prefill discovery stream");
continue;
}
};
match event {
DiscoveryEvent::Added(instance) => {
let (endpoint_id, card) = match &instance {
DiscoveryInstance::Model {
namespace,
component,
endpoint,
..
} => {
// Filter by namespace
if namespace != &target_namespace {
continue;
}
let eid = EndpointId {
namespace: namespace.clone(),
component: component.clone(),
name: endpoint.clone(),
};
match instance.deserialize_model::<ModelDeploymentCard>() {
Ok(card) => (eid, card),
Err(_) => continue,
}
}
_ => continue,
};
// Only handle prefill models
if !card.model_type.supports_prefill() {
continue;
}
tracing::info!(
model_name = card.name(),
"Prefill model discovered, activating prefill router"
);
// Get the endpoint and activate the prefill router
if let Ok(ns) = drt.namespace(&endpoint_id.namespace)
&& let Ok(comp) = ns.component(&endpoint_id.component)
{
let endpoint = comp.endpoint(&endpoint_id.name);
if let Err(e) = model_manager.activate_prefill_router(card.name(), endpoint)
{
tracing::warn!(
model_name = card.name(),
error = %e,
"Failed to activate prefill router"
);
} else {
tracing::info!(
model_name = card.name(),
"Prefill router activated successfully"
);
}
}
}
DiscoveryEvent::Removed(instance_id) => {
// Log removal for observability
// Note: The PrefillRouter remains active - worker availability
// is handled dynamically by the underlying Client's instance tracking
tracing::debug!(
instance_id = instance_id,
"Prefill worker instance removed from discovery"
);
}
}
}
});
}
/// Create a worker selection pipeline for OpenAI Chat Completion requests
......@@ -931,6 +1075,7 @@ pub async fn query_worker_selection_and_annotate(
/// - `router_mode`: How to route requests (KV, RoundRobin, etc.)
/// - `busy_threshold`: Optional threshold for busy worker detection
/// - `kv_router_config`: Optional KV router configuration (only used when router_mode is KV)
/// - `enforce_disagg`: If true, fail requests when disaggregated serving is unavailable
///
/// # Returns
/// A configured worker selection pipeline ready to use
......@@ -941,12 +1086,15 @@ pub async fn create_worker_selection_pipeline_chat(
router_mode: RouterMode,
busy_threshold: Option<f64>,
kv_router_config: Option<KvRouterConfig>,
enforce_disagg: bool,
) -> anyhow::Result<
ServiceEngine<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>,
> {
use dynamo_llm::kv_router::PrefillRouter;
let runtime = Runtime::from_settings()?;
let dst_config = DistributedConfig::from_settings();
let drt_owned = DistributedRuntime::new(runtime, dst_config).await?;
......@@ -966,10 +1114,9 @@ pub async fn create_worker_selection_pipeline_chat(
let router_config = dynamo_llm::entrypoint::RouterConfig {
router_mode,
kv_router_config: kv_router_config.unwrap_or_default(),
// C bindings only support active_decode_blocks_threshold for now (via busy_threshold param)
active_decode_blocks_threshold: busy_threshold,
active_prefill_tokens_threshold: None,
enforce_disagg: false,
enforce_disagg,
};
let watcher = ModelWatcher::new(
component.drt().clone(),
......@@ -999,6 +1146,34 @@ pub async fn create_worker_selection_pipeline_chat(
None
};
// Create prefill chooser for dynamic disaggregation support
// This registers the model and returns a receiver that will be activated
// when a prefill worker is discovered
let prefill_chooser = model_manager
.register_prefill_router(model_name.to_string())
.map(|rx| {
// Create prefill-specific config with track_active_blocks disabled
let mut prefill_config = kv_router_config.unwrap_or_default();
prefill_config.router_track_active_blocks = false;
PrefillRouter::new(
rx,
model_manager.clone(),
router_mode,
card.kv_cache_block_size,
Some(prefill_config),
enforce_disagg,
)
});
// Start background watcher for prefill model discovery
// This will activate the prefill router when prefill workers join
spawn_prefill_watcher(
component.drt().clone(),
model_manager.clone(),
namespace.to_string(),
);
// Download model config files from HuggingFace for EPP
// The backend's card has NATS URLs which aren't accessible from EPP
tracing::debug!(
......@@ -1034,7 +1209,6 @@ pub async fn create_worker_selection_pipeline_chat(
// Create worker monitor if busy_threshold is set
// Note: C bindings don't register with ModelManager, so HTTP endpoint won't see this
// C bindings only support active_decode_blocks_threshold for now (active_prefill_tokens_threshold defaults to 1000000 tokens = effectively disabled)
let worker_monitor = busy_threshold.map(|t| KvWorkerMonitor::new(client.clone(), t, 1000000));
let engine = build_routed_pipeline::<
......@@ -1047,8 +1221,8 @@ pub async fn create_worker_selection_pipeline_chat(
worker_monitor,
chooser,
hf_tokenizer,
None, // prefill_chooser
false, // enforce_disagg
prefill_chooser,
enforce_disagg,
)
.await?;
......
......@@ -97,6 +97,46 @@ pub fn router_endpoint_id(namespace: String) -> EndpointId {
}
}
/// Specifies the type of worker being queried when using the `query_instance_id` annotation.
/// This tells the router which worker pool to select from and what type of operation is intended.
///
/// Query instance types for worker selection
/// - "prefill" → select a prefill worker (disaggregated serving)
/// - "decode" → select a decode worker (disaggregated serving)
///
/// Note: Empty value ("query_instance_id:") is handled by PrefillRouter for disagg orchestration
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum QueryInstanceType {
/// Query for a prefill worker (disaggregated serving)
Prefill,
/// Query for a decode worker (disaggregated serving)
Decode,
}
impl std::fmt::Display for QueryInstanceType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
QueryInstanceType::Prefill => write!(f, "prefill"),
QueryInstanceType::Decode => write!(f, "decode"),
}
}
}
impl std::str::FromStr for QueryInstanceType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"prefill" => Ok(QueryInstanceType::Prefill),
"decode" => Ok(QueryInstanceType::Decode),
_ => Err(format!(
"Invalid QueryInstanceType: '{s}'. Expected 'prefill' or 'decode'"
)),
}
}
}
/// Creates a DiscoveryQuery for the KV router in the given namespace.
pub fn router_discovery_query(namespace: String) -> DiscoveryQuery {
DiscoveryQuery::Endpoint {
......@@ -731,13 +771,34 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
// Extract context ID for request tracking
let context_id = request.context().id().to_string();
// Check if this is a query_instance_id request first
let query_instance_id = request.has_annotation("query_instance_id");
// Check if this is a query_instance_id request and parse its type
// Format: "query_instance_id:type" where type is "prefill", "decode", or "" (empty for aggregated)
// Empty value ("query_instance_id:") means GAIE Aggregated mode - return same worker as both prefill and decode
let query_instance_annotation = request.get_annotation_value("query_instance_id");
let is_gaie_agg_query = query_instance_annotation
.as_ref()
.is_some_and(|s| s.is_empty());
let query_instance_type: Option<QueryInstanceType> =
if let Some(type_str) = &query_instance_annotation {
match type_str.parse::<QueryInstanceType>() {
Ok(t) => Some(t),
Err(_) if type_str.is_empty() => {
// Empty value is valid for aggregated mode, not a warning
None
}
Err(e) => {
tracing::warn!("Invalid query_instance_id type '{type_str}': {e}");
None
}
}
} else {
None
};
let (instance_id, dp_rank, overlap_amount) = if let Some(id) = request.backend_instance_id {
// If instance_id is set, use it and compute actual overlap
let dp_rank = request.dp_rank.unwrap_or(0);
if query_instance_id {
if query_instance_type.is_some() {
tracing::debug!(
"backend_instance_id is set, routing to instance {id} with dp_rank {dp_rank} and ignoring query_instance_id annotation"
);
......@@ -761,33 +822,80 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
(id, dp_rank, overlap_blocks)
} else {
// Otherwise, find the best match
// Don't update states if this is a query-only request (any query_instance_id annotation)
let should_update_states = query_instance_annotation.is_none();
let (best_worker, overlap_amount) = self
.chooser
.find_best_match(
Some(&context_id),
&request.token_ids,
request.router_config_override.as_ref(),
!query_instance_id, // Don't update states if query_instance_id
should_update_states,
)
.await?;
(best_worker.worker_id, best_worker.dp_rank, overlap_amount)
};
// if request has the annotation "query_instance_id",
// then the request will not be routed to the worker,
// and instead the worker_instance_id will be returned.
// If request has a query_instance_id annotation, return worker selection info
// without routing to the actual worker. Returns LLMEngineOutput with disaggregated_params
// containing worker_id info, same structure as normal execution for uniform extraction.
let stream_context = request.context().clone();
if query_instance_id {
let instance_id_str = instance_id.to_string();
let response = Annotated::from_annotation("worker_instance_id", &instance_id_str)?;
// Return the tokens in nvext.token_data format
let response_tokens = Annotated::from_annotation("token_data", &request.token_ids)?;
// Handle query-only requests (GAIE Stage 1)
if query_instance_type.is_some() || is_gaie_agg_query {
let worker_id_info = if is_gaie_agg_query {
// GAIE Aggregated mode: same worker serves both prefill and decode
tracing::trace!(
query_type = "aggregated",
worker_id = instance_id,
"Returning aggregated worker selection (same worker for prefill and decode)"
);
WorkerIdInfo {
prefill_worker_id: Some(instance_id),
decode_worker_id: Some(instance_id),
}
} else {
match query_instance_type.unwrap() {
QueryInstanceType::Prefill => {
tracing::trace!(
query_type = "prefill",
prefill_worker_id = instance_id,
"Returning prefill worker selection"
);
WorkerIdInfo {
prefill_worker_id: Some(instance_id),
decode_worker_id: None,
}
}
QueryInstanceType::Decode => {
// Get prefill_worker_id from annotation (set by caller after prefill selection)
let prefill_worker_id = request
.get_annotation_value("prefill_worker_id")
.and_then(|s| s.parse::<u64>().ok());
tracing::trace!(
"Tokens requested in the response through the query_instance_id annotation: {:?}",
response_tokens
query_type = "decode",
prefill_worker_id = ?prefill_worker_id,
decode_worker_id = instance_id,
"Returning decode worker selection"
);
let stream = stream::iter(vec![response, response_tokens]);
WorkerIdInfo {
prefill_worker_id,
decode_worker_id: Some(instance_id),
}
}
}
};
// Return as LLMEngineOutput with disaggregated_params (same structure as normal execution)
let output = LLMEngineOutput {
disaggregated_params: Some(json!({
"worker_id": worker_id_info,
"token_ids": request.token_ids
})),
..Default::default()
};
let response = Annotated::from_data(output);
let stream = stream::iter(vec![response]);
return Ok(ResponseStream::new(Box::pin(stream), stream_context));
}
let (mut backend_input, context) = request.into_parts();
......
......@@ -20,9 +20,10 @@ use dynamo_runtime::{
use crate::{
discovery::ModelManager,
kv_router::{KvPushRouter, KvRouterConfig, RouterConfigOverride},
kv_router::{KvPushRouter, KvRouterConfig, QueryInstanceType, RouterConfigOverride},
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
protocols::common::preprocessor::{BootstrapInfo, PrefillResult},
protocols::openai::nvext::WorkerIdInfo,
};
/// Errors that can occur during prefill routing
......@@ -67,6 +68,11 @@ impl InnerPrefillRouter {
/// PrefillRouter is a forward-only operator that sits between Migration and the decode router.
/// It optionally calls a prefill worker before routing to decode, extracting disaggregated_params
/// from the prefill response and injecting them into the decode request.
///
/// Supports regular Dynamo and GAIE integrated mode via query_instance_id state machine:
/// - GAIE Stage 1: query_instance_id transitions "" -> "prefill" -> "decode", returns only worker IDs
/// - GAIE Stage 2: target_prefill_worker_id/target_decode_worker_id are set, full execution with specified workers
/// - Non-GAIE: like GAIE Stage 2 but the worker ids have to be determined.
pub struct PrefillRouter {
prefill_router: OnceLock<InnerPrefillRouter>,
cancel_token: CancellationToken,
......@@ -196,10 +202,13 @@ impl PrefillRouter {
rand::rng().random()
}
/// Query best worker upfront, build bootstrap_info, and spawn prefill in background
/// Build bootstrap_info for disaggregated serving
/// If preselected_worker is provided (GAIE Stage 2), use it directly.
/// Otherwise, query for the best worker.
async fn build_bootstrap_info(
&self,
req: &PreprocessedRequest,
preselected_worker: Option<u64>,
) -> Option<(u64, u32, BootstrapInfo)> {
let prefill_router = self.prefill_router.get()?;
......@@ -209,14 +218,24 @@ impl PrefillRouter {
InnerPrefillRouter::SimpleRouter(_) => return None,
};
// Query best worker without routing
let (worker_id, dp_rank) = match kv_router
// Use pre-selected worker (GAIE Stage 2) or query for best worker
let (worker_id, dp_rank) = if let Some(id) = preselected_worker {
let dp_rank = req.dp_rank.unwrap_or(0);
tracing::debug!(
worker_id = id,
dp_rank = dp_rank,
"Using pre-selected prefill worker for bootstrap"
);
(id, dp_rank)
} else {
match kv_router
.chooser
.find_best_match(None, &req.token_ids, None, false)
.await
{
Ok((worker, _overlap)) => (worker.worker_id, worker.dp_rank),
Err(_) => return None,
}
};
// Look up bootstrap endpoint from discovery
......@@ -343,6 +362,56 @@ impl PrefillRouter {
}
}
/// GAIE helper functions for preparing prefill requests
impl PrefillRouter {
/// Prepare prefill request for GAIE flows
/// - Stage 1: Sets query_instance_id:prefill annotation
/// - Stage 2: Sets backend_instance_id to target prefill worker
fn prepare_prefill_for_gaie(prefill_req: &mut PreprocessedRequest, is_gaie_stage1: bool) {
if is_gaie_stage1 {
// GAIE Stage 1: Set query_instance_id to "prefill" for prefill worker selection
prefill_req
.annotations
.retain(|a| !a.starts_with("query_instance_id"));
prefill_req
.annotations
.push(format!("query_instance_id:{}", QueryInstanceType::Prefill));
} else if let Some(prefill_worker_id) = prefill_req.target_prefill_worker_id {
// GAIE Stage 2: Route to pre-selected prefill worker from the stage 1
tracing::debug!(
target_prefill_worker_id = prefill_worker_id,
"GAIE Stage 2: Routing prefill to pre-selected worker"
);
prefill_req.backend_instance_id = Some(prefill_worker_id);
}
}
/// Prepare decode request for GAIE Stage 1
/// Extracts prefill_worker_id from prefill result and sets decode annotations
fn prepare_decode_for_gaie_stage1(
decode_req: &mut PreprocessedRequest,
prefill_result: &PrefillResult,
) {
let prefill_worker_id = prefill_result
.disaggregated_params
.get("worker_id")
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok())
.and_then(|info| info.prefill_worker_id);
if let Some(worker_id) = prefill_worker_id {
decode_req
.annotations
.retain(|a| !a.starts_with("query_instance_id"));
decode_req
.annotations
.push(format!("query_instance_id:{}", QueryInstanceType::Decode));
decode_req
.annotations
.push(format!("prefill_worker_id:{worker_id}"));
}
}
}
impl Drop for PrefillRouter {
fn drop(&mut self) {
tracing::debug!("Dropping PrefillRouter, cancelling background activation task");
......@@ -369,6 +438,12 @@ impl
let request_id = context.id().to_string();
let engine_ctx = context.context();
// GAIE Stage 1: the presence of the empty query_instance_id signals query-only mode
// State machine: "" -> "prefill" -> "decode" (disagg) OR "" -> aggregated worker (agg fallback)
let is_gaie_stage1 = req
.get_annotation_value("query_instance_id")
.is_some_and(|s| s.is_empty());
// Save original max_tokens for decode
let original_max_tokens = req.stop_conditions.max_tokens;
......@@ -376,9 +451,16 @@ impl
let mut prefill_req = req.clone();
prefill_req.stop_conditions.max_tokens = Some(1);
// Try build_bootstrap_info optimization
let prefill_result = if let Some((worker_id, dp_rank, bootstrap_info)) =
self.build_bootstrap_info(&prefill_req).await
// Prepare prefill request for GAIE flows (Stage 1 or Stage 2)
Self::prepare_prefill_for_gaie(&mut prefill_req, is_gaie_stage1);
// Try build_bootstrap_info optimization (skip for GAIE Stage 1 which needs query-only flow)
// For GAIE Stage 2, use target_prefill_worker_id if provided
let preselected_worker = prefill_req.target_prefill_worker_id;
let prefill_result = if !is_gaie_stage1 {
if let Some((worker_id, dp_rank, bootstrap_info)) = self
.build_bootstrap_info(&prefill_req, preselected_worker)
.await
{
let bootstrap_room = bootstrap_info.bootstrap_room;
......@@ -408,6 +490,15 @@ impl
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
self.call_prefill(prefill_context)
.await
.map(|(result, worker_id)| (Some(result), worker_id, None))
}
} else {
// GAIE Stage 1: Use original path (no bootstrap optimization)
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
self.call_prefill(prefill_context)
.await
.map(|(result, worker_id)| (Some(result), worker_id, None))
......@@ -429,8 +520,13 @@ impl
let mut decode_req = req;
// Update request with prefill result if available (only in original path)
if let Some(prefill_result) = maybe_prefill_result {
// Update request with prefill result
if is_gaie_stage1 {
if let Some(ref prefill_result) = maybe_prefill_result {
Self::prepare_decode_for_gaie_stage1(&mut decode_req, prefill_result);
}
} else if let Some(prefill_result) = maybe_prefill_result {
// Normal or GAIE Stage 2: Set prefill_result for decode
decode_req.prefill_result = Some(prefill_result);
}
......@@ -449,6 +545,15 @@ impl
..existing_override.unwrap_or_default()
});
// GAIE Stage 2: Route to pre-selected decode worker if specified
if let Some(decode_worker_id) = decode_req.target_decode_worker_id {
decode_req.backend_instance_id = Some(decode_worker_id);
tracing::debug!(
decode_worker_id = decode_worker_id,
"GAIE Stage 2: Routing decode to pre-selected worker"
);
}
// Map the modified request through with preserved context
let decode_request = context.map(|_| decode_req);
next.generate(decode_request).await
......
......@@ -238,10 +238,13 @@ impl OpenAIPreprocessor {
builder.annotations(request.annotations().unwrap_or_default());
builder.mdc_sum(Some(self.mdcsum.clone()));
builder.estimated_prefix_hit_num_blocks(None);
// Extract backend_instance_id and extra_fields from nvext if present
// Extract backend_instance_id, extra_fields, and worker IDs from nvext if present
if let Some(nvext) = request.nvext() {
builder.backend_instance_id(nvext.backend_instance_id);
builder.extra_fields(nvext.extra_fields.clone());
// GAIE Stage 2: Extract targeted worker IDs for disaggregated serving
builder.target_prefill_worker_id(nvext.prefill_worker_id);
builder.target_decode_worker_id(nvext.decode_worker_id);
}
Ok(builder)
......
......@@ -74,7 +74,7 @@ pub struct BackendOutput {
///
/// This is the minimal raw output from the LLM engine. The Backend may then apply multiple
/// levels of post-processing before the BackendOutput is returns
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct LLMEngineOutput {
// new token_ids
pub token_ids: Vec<TokenIdType>,
......
......@@ -118,12 +118,34 @@ pub struct PreprocessedRequest {
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub extra_fields: Option<Vec<String>>,
/// Targeted prefill worker ID for disaggregated serving (GAIE Stage 2)
/// When set, the prefill request will be routed to this specific worker.
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub target_prefill_worker_id: Option<u64>,
/// Targeted decode worker ID for disaggregated serving (GAIE Stage 2)
/// When set, the decode request will be routed to this specific worker.
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub target_decode_worker_id: Option<u64>,
}
impl PreprocessedRequest {
pub fn has_annotation(&self, annotation: &str) -> bool {
self.annotations.contains(&annotation.to_string())
}
/// Get the value of an annotation in the format "key:value"
/// Returns None if the annotation is not found or has no value
pub fn get_annotation_value(&self, key: &str) -> Option<String> {
let prefix = format!("{}:", key);
self.annotations
.iter()
.find(|a| a.starts_with(&prefix))
.map(|a| a[prefix.len()..].to_string())
}
}
impl PreprocessedRequest {
......
......@@ -400,13 +400,19 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
tracker.record_first_token();
}
// Extract worker_id from disaggregated_params
// Extract worker_id and token_ids from disaggregated_params
let worker_id_info = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("worker_id"))
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok());
let token_ids = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("token_ids"))
.and_then(|v| serde_json::from_value::<Vec<u32>>(v.clone()).ok());
// Get timing info if this is the final response (has finish_reason)
let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
self.timing_tracker.as_ref().map(|tracker| {
......@@ -417,11 +423,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
None
};
// Inject nvext if we have worker_id or timing
if worker_id_info.is_some() || timing_info.is_some() {
// Inject nvext if we have worker_id, token_ids, or timing
if worker_id_info.is_some() || token_ids.is_some() || timing_info.is_some() {
let nvext_response = NvExtResponse {
worker_id: worker_id_info.clone(),
timing: timing_info,
token_ids: token_ids.clone(),
};
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
......@@ -433,6 +440,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
info.decode_worker_id
);
}
if let Some(ref tokens) = token_ids {
tracing::debug!(
"Injected token_ids into chat completion nvext: {} tokens",
tokens.len()
);
}
}
}
......
......@@ -295,13 +295,19 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
tracker.record_first_token();
}
// Extract worker_id from disaggregated_params
// Extract worker_id and token_ids from disaggregated_params
let worker_id_info = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("worker_id"))
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok());
let token_ids = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("token_ids"))
.and_then(|v| serde_json::from_value::<Vec<u32>>(v.clone()).ok());
// Get timing info if this is the final response (has finish_reason)
let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
self.timing_tracker.as_ref().map(|tracker| {
......@@ -312,11 +318,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
None
};
// Inject nvext if we have worker_id or timing
if worker_id_info.is_some() || timing_info.is_some() {
// Inject nvext if we have worker_id, token_ids, or timing
if worker_id_info.is_some() || token_ids.is_some() || timing_info.is_some() {
let nvext_response = NvExtResponse {
worker_id: worker_id_info.clone(),
timing: timing_info,
token_ids: token_ids.clone(),
};
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
......@@ -328,6 +335,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
info.decode_worker_id
);
}
if let Some(ref tokens) = token_ids {
tracing::debug!(
"Injected token_ids into completions nvext: {} tokens",
tokens.len()
);
}
}
}
......
......@@ -35,6 +35,11 @@ pub struct NvExtResponse {
/// Populated when client requests `extra_fields: ["timing"]`
#[serde(skip_serializing_if = "Option::is_none")]
pub timing: Option<TimingInfo>,
/// Token IDs for GAIE Stage 1 query-only mode
/// Contains the tokenized prompt for reuse in Stage 2
#[serde(skip_serializing_if = "Option::is_none")]
pub token_ids: Option<Vec<u32>>,
}
/// NVIDIA LLM extensions to the OpenAI API
......@@ -87,6 +92,18 @@ pub struct NvExt {
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub extra_fields: Option<Vec<String>>,
/// Targeted prefill worker ID for disaggregated serving (GAIE Stage 2)
/// When set, the request will be routed to this specific prefill worker.
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prefill_worker_id: Option<u64>,
/// Targeted decode worker ID for disaggregated serving (GAIE Stage 2)
/// When set, the request will be routed to this specific decode worker.
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub decode_worker_id: Option<u64>,
}
impl Default for NvExt {
......@@ -133,6 +150,8 @@ mod tests {
assert_eq!(nv_ext.token_data, None);
assert_eq!(nv_ext.max_thinking_tokens, None);
assert_eq!(nv_ext.extra_fields, None);
assert_eq!(nv_ext.prefill_worker_id, None);
assert_eq!(nv_ext.decode_worker_id, None);
}
// Test valid builder configurations
......@@ -157,4 +176,18 @@ mod tests {
// Validate the built struct
assert!(nv_ext.validate().is_ok());
}
// Test GAIE Stage 2 disaggregated worker IDs
#[test]
fn test_nv_ext_disagg_worker_ids() {
let nv_ext = NvExt::builder()
.prefill_worker_id(100)
.decode_worker_id(200)
.build()
.unwrap();
assert_eq!(nv_ext.prefill_worker_id, Some(100));
assert_eq!(nv_ext.decode_worker_id, Some(200));
assert!(nv_ext.validate().is_ok());
}
}
......@@ -1026,9 +1026,10 @@ def _test_router_query_instance_id(
asyncio.run(send_request_with_retry(url, test_payload))
# Test payload with query_instance_id annotation
# Format: "query_instance_id:" (colon with empty value) for GAIE aggregated mode
annotated_payload = {
**test_payload,
"nvext": {"annotations": ["query_instance_id"]},
"nvext": {"annotations": ["query_instance_id:"]},
}
async def test_annotation_response():
......@@ -1053,100 +1054,80 @@ def _test_router_query_instance_id(
f"Full SSE response ({len(full_response)} bytes):\n{full_response}"
)
# Parse and validate the response structure
events = []
# Parse the SSE response to extract the first chunk with nvext data
# New format: nvext contains worker_id and token_ids
sse_parts = full_response.split("\n\n")
worker_id_info = None
token_list = None
for part in sse_parts:
part = part.strip()
if not part:
if not part or not part.startswith("data:"):
continue
if part.startswith("event:"):
lines = part.split("\n")
event_line = next(
(line for line in lines if line.startswith("event:")),
None,
)
data_line = next(
(
line
for line in lines
if line.startswith("data:") or line.startswith(":")
),
None,
)
if event_line and data_line:
event_type = event_line.split(":", 1)[1].strip()
if data_line.startswith("data:"):
data_value = data_line.split(":", 1)[1].strip()
else:
data_value = data_line.split(":", 1)[1].strip()
events.append((event_type, data_value))
elif part.startswith("data:"):
data_value = part.split(":", 1)[1].strip()
data_str = part.split("data:", 1)[1].strip()
if data_str == "[DONE]":
continue
logger.info(f"Parsed events: {events}")
try:
chunk = json.loads(data_str)
logger.info(f"Parsed chunk: {json.dumps(chunk, indent=2)}")
# Validate worker_instance_id event
worker_event = next(
(e for e in events if e[0] == "worker_instance_id"), None
# Extract nvext data containing worker_id and token_ids
nvext = chunk.get("nvext", {})
if nvext:
if "worker_id" in nvext:
worker_id_info = nvext["worker_id"]
logger.info(
f"Found worker_id info: {worker_id_info}"
)
assert (
worker_event is not None
), f"Missing worker_instance_id event in: {events}"
# Validate token_data event
token_event = next(
(e for e in events if e[0] == "token_data"), None
if "token_ids" in nvext:
token_list = nvext["token_ids"]
logger.info(
f"Found token_ids: {len(token_list)} tokens"
)
except json.JSONDecodeError:
continue
# Validate worker_id info
assert (
token_event is not None
), f"Missing token_data event in: {events}"
worker_id_info is not None
), f"Missing worker_id in nvext. Response: {full_response}"
token_data_str = token_event[1].strip('"')
try:
token_list = json.loads(token_data_str)
except json.JSONDecodeError as e:
raise AssertionError(
f"token_data is not valid JSON: {token_data_str}, error: {e}"
)
# For aggregated mode, both prefill and decode should be the same
prefill_worker_id = worker_id_info.get("prefill_worker_id")
decode_worker_id = worker_id_info.get("decode_worker_id")
assert (
prefill_worker_id is not None
), f"Missing prefill_worker_id in worker_id: {worker_id_info}"
assert (
decode_worker_id is not None
), f"Missing decode_worker_id in worker_id: {worker_id_info}"
assert (
prefill_worker_id == decode_worker_id
), f"For aggregated mode, prefill and decode worker should be same: {worker_id_info}"
# Validate token_ids
assert (
token_list is not None
), f"Missing token_ids in nvext. Response: {full_response}"
assert isinstance(
token_list, list
), f"token_data should be a list, got: {type(token_list)}"
), f"token_ids should be a list, got: {type(token_list)}"
assert (
len(token_list) > 0
), f"token_data should not be empty: {token_list}"
), f"token_ids should not be empty: {token_list}"
assert all(
isinstance(token, int) for token in token_list
), f"All tokens should be integers: {token_list}"
logger.info(
f"Valid token_data with {len(token_list)} tokens: {token_list[:10]}{'...' if len(token_list) > 10 else ''}"
)
# Validate that no actual generation happened (should only be metadata)
# This proves the early return worked correctly
generation_indicators = [
"choices",
"content",
"delta",
"finish_reason",
]
for indicator in generation_indicators:
assert (
indicator not in full_response.lower()
), f"Found generation indicator '{indicator}' - request should not have been routed to worker"
logger.info(
"No generation content found - early return worked correctly"
f"Valid token_ids with {len(token_list)} tokens: {token_list[:10]}{'...' if len(token_list) > 10 else ''}"
)
return {
"worker_instance_id": worker_event[1].strip('"'),
"prefill_worker_id": prefill_worker_id,
"decode_worker_id": decode_worker_id,
"token_count": len(token_list),
"tokens": token_list,
}
......@@ -1154,7 +1135,8 @@ def _test_router_query_instance_id(
result = asyncio.run(test_annotation_response())
logger.info("Successfully validated query_instance_id annotation response:")
logger.info(f"Worker ID: {result['worker_instance_id']}")
logger.info(f"Prefill Worker ID: {result['prefill_worker_id']}")
logger.info(f"Decode Worker ID: {result['decode_worker_id']}")
logger.info(f"Token count: {result['token_count']}")
finally:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment