Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
15b49818
Unverified
Commit
15b49818
authored
Dec 18, 2025
by
atchernych
Committed by
GitHub
Dec 18, 2025
Browse files
feat: support disag serving in GAIE [DEP-659] (#4756)
Signed-off-by:
Anna Tchernych
<
atchernych@nvidia.com
>
parent
7a3b15e6
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
817 additions
and
290 deletions
+817
-290
deploy/inference-gateway/epp-patches/v0.8.0/gaie.patch
deploy/inference-gateway/epp-patches/v0.8.0/gaie.patch
+99
-25
lib/bindings/c/src/lib.rs
lib/bindings/c/src/lib.rs
+304
-130
lib/llm/src/kv_router.rs
lib/llm/src/kv_router.rs
+126
-18
lib/llm/src/kv_router/prefill_router.rs
lib/llm/src/kv_router/prefill_router.rs
+142
-37
lib/llm/src/preprocessor.rs
lib/llm/src/preprocessor.rs
+4
-1
lib/llm/src/protocols/common/llm_backend.rs
lib/llm/src/protocols/common/llm_backend.rs
+1
-1
lib/llm/src/protocols/common/preprocessor.rs
lib/llm/src/protocols/common/preprocessor.rs
+22
-0
lib/llm/src/protocols/openai/chat_completions/delta.rs
lib/llm/src/protocols/openai/chat_completions/delta.rs
+16
-3
lib/llm/src/protocols/openai/completions/delta.rs
lib/llm/src/protocols/openai/completions/delta.rs
+16
-3
lib/llm/src/protocols/openai/nvext.rs
lib/llm/src/protocols/openai/nvext.rs
+33
-0
tests/router/common.py
tests/router/common.py
+54
-72
No files found.
deploy/inference-gateway/epp-patches/v0.8.0/gaie.patch
View file @
15b49818
...
...
@@ -161,10 +161,10 @@ index 670d922..0cf04cb 100644
}
diff --git a/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go b/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go
new file mode 100644
index 0000000..
cd9a0b5
index 0000000..
1c8f979
--- /dev/null
+++ b/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go
@@ -0,0 +1,11
9
@@
@@ -0,0 +1,1
7
1 @@
+package dynamo_inject_workerid
+
+import (
...
...
@@ -182,6 +182,7 @@ index 0000000..cd9a0b5
+ typeString = "dynamo-inject-workerid"
+ pluginName = "dynamo-inject-workerid"
+ WorkerIDHeader = "x-worker-instance-id"
+ PrefillWorkerIDHeader = "x-prefiller-host-port"
+ tokenDataAnnotationKey = "dynamo/token-data"
+)
+
...
...
@@ -222,11 +223,18 @@ index 0000000..cd9a0b5
+ if req.Headers == nil {
+ req.Headers = map[string]string{}
+ }
+
+ // Handle worker instance ID
+ wid := strings.TrimSpace(req.Headers[WorkerIDHeader])
+ if wid == "" {
+ return
+ }
+ if wid != "" {
+ req.Headers[WorkerIDHeader] = wid
+ }
+
+ // Handle prefill worker ID
+ prefillWid := strings.TrimSpace(req.Headers[PrefillWorkerIDHeader])
+ if prefillWid != "" {
+ req.Headers[PrefillWorkerIDHeader] = prefillWid
+ }
+}
+
+func (p *InjectWorkerIDPreRequest) MutateRequestBody(
...
...
@@ -248,14 +256,28 @@ index 0000000..cd9a0b5
+ return
+ }
+
+ prefillWid := strings.TrimSpace(req.Headers[PrefillWorkerIDHeader])
+
+ nvext, _ := body["nvext"].(map[string]any)
+ if nvext == nil {
+ nvext = map[string]any{}
+ body["nvext"] = nvext
+ }
+
+ if prefillWid != "" && prefillWid != wid {
+ // Disaggregated mode: use prefill_worker_id and decode_worker_id
+ if prefillWidUint, err := strconv.ParseUint(prefillWid, 10, 64); err == nil {
+ nvext["prefill_worker_id"] = prefillWidUint
+ }
+ if widUint, err := strconv.ParseUint(wid, 10, 64); err == nil {
+ nvext["decode_worker_id"] = widUint
+ }
+ } else {
+ // Aggregated mode (empty prefill or prefill == decode): use backend_instance_id
+ if widUint, err := strconv.ParseUint(wid, 10, 64); err == nil {
+ nvext["backend_instance_id"] = widUint
+ }
+ }
+
+ if tokens, ok := req.Annotations[tokenDataAnnotationKey]; ok {
+ switch v := tokens.(type) {
...
...
@@ -283,6 +305,36 @@ index 0000000..cd9a0b5
+ }
+ }
+ }
+
+ // Remove query_instance_id from nvext.annotations if present
+ if annotations, ok := nvext["annotations"]; ok {
+ switch annList := annotations.(type) {
+ case []string:
+ filtered := make([]string, 0, len(annList))
+ for _, ann := range annList {
+ if ann != "query_instance_id" {
+ filtered = append(filtered, ann)
+ }
+ }
+ if len(filtered) == 0 {
+ delete(nvext, "annotations")
+ } else {
+ nvext["annotations"] = filtered
+ }
+ case []any:
+ filtered := make([]any, 0, len(annList))
+ for _, ann := range annList {
+ if str, ok := ann.(string); !ok || str != "query_instance_id" {
+ filtered = append(filtered, ann)
+ }
+ }
+ if len(filtered) == 0 {
+ delete(nvext, "annotations")
+ } else {
+ nvext["annotations"] = filtered
+ }
+ }
+ }
+}
diff --git a/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml
new file mode 100644
...
...
@@ -313,10 +365,10 @@ index 0000000..b689c00
+ - pluginRef: picker
diff --git a/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go
new file mode 100644
index 0000000..
bc29c0a
index 0000000..
75f30e9
--- /dev/null
+++ b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go
@@ -0,0 +1,4
2
4 @@
@@ -0,0 +1,44
6
@@
+package dynamo_kv_scorer
+
+/*
...
...
@@ -367,13 +419,15 @@ index 0000000..bc29c0a
+ double router_temperature,
+ bool use_kv_events,
+ bool router_replica_sync,
+ bool enforce_disagg,
+ WorkerSelectionPipeline **pipeline_out);
+
+dynamo_llm_result_t dynamo_destroy_worker_selection_pipeline(WorkerSelectionPipeline *pipeline);
+
+dynamo_llm_result_t dynamo_query_worker_selection_and_annotate(WorkerSelectionPipeline *pipeline,
+ const char *request_json_c_str,
+ int64_t *worker_instance_id_out,
+ int64_t *decode_worker_id_out,
+ int64_t *prefill_worker_id_out,
+ uint32_t **token_ids_out,
+ size_t *token_count_out,
+ char **annotated_request_json_out);
...
...
@@ -404,7 +458,9 @@ index 0000000..bc29c0a
+ PluginName = "dynamo-kv-scorer"
+ KVAwareScorerType = "kv-aware-scorer"
+ StateKeyWorkerInstanceID = schedtypes.StateKey("dynamo/worker-instance-id")
+ StateKeyPrefillWorkerID = schedtypes.StateKey("dynamo/prefill-worker-id")
+ WorkerIDHeader = "x-worker-instance-id"
+ PrefillWorkerIDHeader = "x-prefiller-host-port"
+ tokenDataAnnotationKey = "dynamo/token-data"
+)
+
...
...
@@ -471,6 +527,7 @@ index 0000000..bc29c0a
+ ffiRouterTemperature float64
+ ffiKvBlockSize uint32
+ ffiWorkerID int64
+ ffiEnforceDisagg bool
+
+ runtimeInitialized bool
+
...
...
@@ -484,6 +541,7 @@ index 0000000..bc29c0a
+ ffiComponent = getEnvOrDefault("DYNAMO_COMPONENT", "backend")
+ ffiModel = getEnvOrDefault("DYNAMO_MODEL", "Qwen/Qwen3-0.6B")
+ ffiWorkerID = getEnvInt64OrDefault("DYNAMO_WORKER_ID", 1)
+ ffiEnforceDisagg = getEnvBoolOrDefault("DYNAMO_ENFORCE_DISAGG", true) // TODO default to false
+
+ ffiOverlapScoreWeight = getEnvFloatOrDefault("DYNAMO_OVERLAP_SCORE_WEIGHT", -1.0)
+ ffiRouterTemperature = getEnvFloatOrDefault("DYNAMO_ROUTER_TEMPERATURE", -1.0)
...
...
@@ -575,6 +633,7 @@ index 0000000..bc29c0a
+ C.double(ffiRouterTemperature),
+ C.bool(getEnvBoolOrDefault("DYNAMO_USE_KV_EVENTS", true)),
+ C.bool(getEnvBoolOrDefault("DYNAMO_ROUTER_REPLICA_SYNC", true)),
+ C.bool(ffiEnforceDisagg),
+ &pipeline,
+ )
+ if rc != C.DYNAMO_OK {
...
...
@@ -595,13 +654,14 @@ index 0000000..bc29c0a
+) map[schedtypes.Pod]float64 {
+ logger := log.FromContext(ctx)
+
+ workerID, tokenData, err := k.callDynamoRouter(ctx, req)
+ workerID,
prefillWorkerID,
tokenData, err := k.callDynamoRouter(ctx, req)
+ if err != nil {
+ logger.V(logutil.DEFAULT).Error(err, "Dynamo call failed; proceeding without worker id")
+ } else if workerID != "" {
+ logger.V(logutil.DEFAULT).Info(
+ "Dynamo router selected worker",
+ "workerID", workerID,
+ "prefillWorkerID", prefillWorkerID,
+ "tokenDataCount", len(tokenData),
+ "tokenData", tokenData,
+ )
...
...
@@ -610,6 +670,13 @@ index 0000000..bc29c0a
+ req.Headers = map[string]string{}
+ }
+ req.Headers[WorkerIDHeader] = workerID
+
+ // Set prefill worker ID if present
+ if prefillWorkerID != "" {
+ cycle.Write(StateKeyPrefillWorkerID, stateString(prefillWorkerID))
+ req.Headers[PrefillWorkerIDHeader] = prefillWorkerID
+ }
+
+ if len(tokenData) > 0 {
+ if req.Annotations == nil {
+ req.Annotations = map[string]any{}
...
...
@@ -632,15 +699,15 @@ index 0000000..bc29c0a
+func (k *KVAwareScorer) callDynamoRouter(
+ ctx context.Context,
+ req *schedtypes.LLMRequest,
+) (
string,
[]int64, error) {
+) (
workerID string, prefillWorkerID string, tokenData
[]int64,
err
error) {
+ logger := log.FromContext(ctx)
+
+ if err := initFFI(); err != nil {
+ logger.V(logutil.DEFAULT).Error(err, "FFI init failed")
+ return "", nil, err
+ return "",
"",
nil, err
+ }
+ if !runtimeInitialized {
+ return "", nil, fmt.Errorf("dynamo runtime not initialized")
+ return "",
"",
nil, fmt.Errorf("dynamo runtime not initialized")
+ }
+
+ pipelineMutex.RLock()
...
...
@@ -648,21 +715,22 @@ index 0000000..bc29c0a
+ pipelineMutex.RUnlock()
+
+ if currentPipeline == nil {
+ return "", nil, fmt.Errorf("dynamo worker selection pipeline not created")
+ return "",
"",
nil, fmt.Errorf("dynamo worker selection pipeline not created")
+ }
+
+ // Build OpenAI-compatible JSON request
+ requestBody := buildOpenAIRequest(req)
+ requestJSON,
e
rr := json.Marshal(requestBody)
+ if
e
rr != nil {
+ logger.V(logutil.DEFAULT).Error(
e
rr, "Failed to marshal OpenAI request")
+ return "", nil, fmt.Errorf("marshal OpenAI request: %w",
e
rr)
+ requestJSON,
jsonE
rr := json.Marshal(requestBody)
+ if
jsonE
rr != nil {
+ logger.V(logutil.DEFAULT).Error(
jsonE
rr, "Failed to marshal OpenAI request")
+ return "",
"",
nil, fmt.Errorf("marshal OpenAI request: %w",
jsonE
rr)
+ }
+ cRequestJSON := C.CString(string(requestJSON))
+ defer C.free(unsafe.Pointer(cRequestJSON))
+
+ // Output variables
+ var cWorkerID C.int64_t
+ var cDecodeWorkerID C.int64_t
+ var cPrefillWorkerID C.int64_t
+ var cTokens *C.uint32_t
+ var cTokenCount C.size_t
+ var cAnnotatedJSON *C.char
...
...
@@ -671,13 +739,14 @@ index 0000000..bc29c0a
+ rc := C.dynamo_query_worker_selection_and_annotate(
+ currentPipeline,
+ cRequestJSON,
+ &cWorkerID,
+ &cDecodeWorkerID,
+ &cPrefillWorkerID,
+ &cTokens,
+ &cTokenCount,
+ &cAnnotatedJSON,
+ )
+ if rc != C.DYNAMO_OK {
+ return "", nil, fmt.Errorf("dynamo_query_worker_selection_and_annotate failed")
+ return "",
"",
nil, fmt.Errorf("dynamo_query_worker_selection_and_annotate failed")
+ }
+
+ // Copy tokens into Go memory and free C memory
...
...
@@ -692,11 +761,16 @@ index 0000000..bc29c0a
+ }
+ C.dynamo_free_worker_selection_result(cTokens, cTokenCount, cAnnotatedJSON)
+
+ workerID := fmt.Sprintf("%d", int64(cWorkerID))
+ workerIDStr := fmt.Sprintf("%d", int64(cDecodeWorkerID))
+ prefillWorkerIDStr := ""
+ // Rust returns -1 for prefill_worker_id when not in disaggregated mode
+ if int64(cPrefillWorkerID) >= 0 {
+ prefillWorkerIDStr = fmt.Sprintf("%d", int64(cPrefillWorkerID))
+ }
+ logger.V(logutil.DEFAULT).Info("Worker selection completed",
+ "workerID", workerID, "tokenCount", count)
+ "workerID", workerID
Str, "prefillWorkerID", prefillWorkerIDStr
, "tokenCount", count)
+
+ return workerID, tokens64, nil
+ return workerID
Str, prefillWorkerIDStr
, tokens64, nil
+}
+
+func buildOpenAIRequest(req *schedtypes.LLMRequest) map[string]any {
...
...
lib/bindings/c/src/lib.rs
View file @
15b49818
...
...
@@ -393,6 +393,10 @@ pub struct WorkerSelectionPipeline {
///
/// # Errors
/// Returns `DynamoLlmResult::ERR` on failure and does not write to `pipeline_out`.
/// # Safety
/// See detailed safety docs above. Additional parameter:
/// - `enforce_disagg`: If true, requests fail when disaggregated serving is unavailable.
/// If false, falls back to aggregated serving.
#[unsafe(no_mangle)]
pub
unsafe
extern
"C"
fn
dynamo_create_worker_selection_pipeline
(
namespace_c_str
:
*
const
c_char
,
...
...
@@ -404,6 +408,7 @@ pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline(
router_temperature
:
f64
,
use_kv_events
:
bool
,
router_replica_sync
:
bool
,
enforce_disagg
:
bool
,
pipeline_out
:
*
mut
*
mut
WorkerSelectionPipeline
,
)
->
DynamoLlmResult
{
if
pipeline_out
.is_null
()
{
...
...
@@ -472,6 +477,7 @@ pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline(
router_mode
,
(
busy_threshold
>=
0.0
)
.then_some
(
busy_threshold
),
kv_router_config
,
enforce_disagg
,
)
.await
};
...
...
@@ -492,7 +498,8 @@ pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline(
}
/// Query worker selection on an existing pipeline and return:
/// - `worker_instance_id_out` (`i64`)
/// - `decode_worker_id_out` (`i64`): The decode worker ID (primary worker)
/// - `prefill_worker_id_out` (`i64`): The prefill worker ID (-1 if not in disaggregated mode)
/// - `token_ids_out` (heap-allocated `*mut u32`; caller must free via
/// `dynamo_free_worker_selection_result`)
/// - `token_count_out` (`usize`)
...
...
@@ -513,10 +520,10 @@ pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline(
/// function returns `DynamoLlmResult::ERR`.
/// - Must remain valid for the duration of this call.
/// - Output pointers:
/// - `worker_i
nstance
_id_out`, `token_ids_out`, `token_count_out`,
/// - `
decode_
worker_i
d_out`, `prefill_worker
_id_out`, `token_ids_out`, `token_count_out`,
/// and `annotated_request_json_out` must each be **non-null** and point to
/// writable memory for their respective types. On success, this function
/// writes to all f
our
outputs exactly once.
/// writes to all f
ive
outputs exactly once.
/// - On **error**, outputs are left unmodified.
/// - Ownership & deallocation:
/// - On success, if there are zero tokens, `*token_ids_out` may be set to `NULL`
...
...
@@ -540,11 +547,18 @@ pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline(
/// Returns `DynamoLlmResult::ERR` if any precondition fails (null/invalid pointers,
/// malformed UTF-8/JSON, pipeline errors, allocation failures, etc.). On error, no
/// output pointer is written.
///
/// # Output values
/// - `decode_worker_id_out`: The decode worker ID (primary worker in aggregated mode)
/// - `prefill_worker_id_out`: The prefill worker ID (only set in disaggregated mode, -1 if not present)
/// - `token_ids_out`, `token_count_out`: Token IDs and count
/// - `annotated_request_json_out`: The annotated request JSON
#[unsafe(no_mangle)]
pub
unsafe
extern
"C"
fn
dynamo_query_worker_selection_and_annotate
(
pipeline
:
*
mut
WorkerSelectionPipeline
,
request_json_c_str
:
*
const
c_char
,
worker_instance_id_out
:
*
mut
i64
,
decode_worker_id_out
:
*
mut
i64
,
prefill_worker_id_out
:
*
mut
i64
,
token_ids_out
:
*
mut
*
mut
u32
,
token_count_out
:
*
mut
usize
,
annotated_request_json_out
:
*
mut
*
mut
c_char
,
...
...
@@ -553,7 +567,8 @@ pub unsafe extern "C" fn dynamo_query_worker_selection_and_annotate(
tracing
::
error!
(
"Pipeline pointer is null"
);
return
DynamoLlmResult
::
ERR
;
}
if
worker_instance_id_out
.is_null
()
if
decode_worker_id_out
.is_null
()
||
prefill_worker_id_out
.is_null
()
||
token_ids_out
.is_null
()
||
token_count_out
.is_null
()
||
annotated_request_json_out
.is_null
()
...
...
@@ -579,7 +594,7 @@ pub unsafe extern "C" fn dynamo_query_worker_selection_and_annotate(
let
pl
=
unsafe
{
&*
pipeline
};
let
fut
=
async
{
query_worker_selection_and_annotate
(
&
pl
.engine
,
request
)
.await
};
let
(
worker_id
,
tokens
,
annotated_req
)
=
match
pl
.wk
.runtime
()
.secondary
()
.block_on
(
fut
)
{
let
(
result
,
annotated_req
)
=
match
pl
.wk
.runtime
()
.secondary
()
.block_on
(
fut
)
{
Ok
(
v
)
=>
v
,
Err
(
e
)
=>
{
tracing
::
error!
(
error
=
?
e
,
"query_worker_selection_and_annotate failed"
);
...
...
@@ -587,10 +602,10 @@ pub unsafe extern "C" fn dynamo_query_worker_selection_and_annotate(
}
};
let
tokens_ptr
=
if
tokens
.is_empty
()
{
let
tokens_ptr
=
if
result
.
tokens
.is_empty
()
{
std
::
ptr
::
null_mut
()
}
else
{
let
len
=
tokens
.len
();
let
len
=
result
.
tokens
.len
();
let
layout
=
std
::
alloc
::
Layout
::
array
::
<
u32
>
(
len
)
.unwrap
();
let
ptr
=
unsafe
{
std
::
alloc
::
alloc
(
layout
)
as
*
mut
u32
};
if
ptr
.is_null
()
{
...
...
@@ -598,7 +613,7 @@ pub unsafe extern "C" fn dynamo_query_worker_selection_and_annotate(
return
DynamoLlmResult
::
ERR
;
}
unsafe
{
std
::
ptr
::
copy_nonoverlapping
(
tokens
.as_ptr
(),
ptr
,
len
);
std
::
ptr
::
copy_nonoverlapping
(
result
.
tokens
.as_ptr
(),
ptr
,
len
);
}
ptr
};
...
...
@@ -606,11 +621,11 @@ pub unsafe extern "C" fn dynamo_query_worker_selection_and_annotate(
let
annotated_json
=
match
serde_json
::
to_string
(
&
annotated_req
)
{
Ok
(
s
)
=>
s
,
Err
(
e
)
=>
{
let
layout
=
std
::
alloc
::
Layout
::
array
::
<
u32
>
(
tokens
.len
())
.unwrap
();
if
!
tokens_ptr
.is_null
()
{
let
layout
=
std
::
alloc
::
Layout
::
array
::
<
u32
>
(
result
.tokens
.len
())
.unwrap
();
unsafe
{
std
::
alloc
::
dealloc
(
tokens_ptr
as
*
mut
u8
,
layout
);
}
if
!
tokens_ptr
.is_null
()
{
tracing
::
error!
(
error
=
?
e
,
"serialize annotated request failed"
);
}
return
DynamoLlmResult
::
ERR
;
...
...
@@ -621,7 +636,7 @@ pub unsafe extern "C" fn dynamo_query_worker_selection_and_annotate(
Err
(
e
)
=>
{
tracing
::
error!
(
error
=
?
e
,
"CString::new for annotated JSON failed"
);
if
!
tokens_ptr
.is_null
()
{
let
layout
=
std
::
alloc
::
Layout
::
array
::
<
u32
>
(
tokens
.len
())
.unwrap
();
let
layout
=
std
::
alloc
::
Layout
::
array
::
<
u32
>
(
result
.
tokens
.len
())
.unwrap
();
unsafe
{
std
::
alloc
::
dealloc
(
tokens_ptr
as
*
mut
u8
,
layout
);
}
...
...
@@ -630,9 +645,10 @@ pub unsafe extern "C" fn dynamo_query_worker_selection_and_annotate(
}
};
unsafe
{
*
worker_instance_id_out
=
worker_id
;
*
decode_worker_id_out
=
result
.decode_worker_id
.unwrap_or
(
0
);
*
prefill_worker_id_out
=
result
.prefill_worker_id
.unwrap_or
(
-
1
);
*
token_ids_out
=
tokens_ptr
;
*
token_count_out
=
tokens
.len
();
*
token_count_out
=
result
.
tokens
.len
();
*
annotated_request_json_out
=
cjson
.into_raw
();
}
DynamoLlmResult
::
OK
...
...
@@ -724,96 +740,77 @@ pub unsafe extern "C" fn dynamo_free_worker_selection_result(
DynamoLlmResult
::
OK
}
/// Result of worker selection extraction
#[derive(Debug,
Clone,
Default)]
pub
struct
WorkerSelectionResult
{
/// Decode worker ID (primary worker for aggregated, decode-only for disaggregated)
pub
decode_worker_id
:
Option
<
i64
>
,
/// Prefill worker ID (only present in disaggregated mode)
pub
prefill_worker_id
:
Option
<
i64
>
,
/// Token IDs from tokenization
pub
tokens
:
Vec
<
u32
>
,
}
/// Helper function to extract worker selection information from the annotation stream
///
/// The response format (from disaggregated_params in nvext):
/// - worker_id: {"prefill_worker_id": 123, "decode_worker_id": 456}
/// - token_ids: [1, 2, 3, ...]
pub
async
fn
extract_worker_selection_from_stream
(
mut
stream
:
Pin
<
Box
<
dyn
AsyncEngineStream
<
Annotated
<
NvCreateChatCompletionStreamResponse
>>>>
,
)
->
anyhow
::
Result
<
(
i64
,
Vec
<
u32
>
)
>
{
)
->
anyhow
::
Result
<
WorkerSelectionResult
>
{
use
dynamo_llm
::
protocols
::
openai
::
nvext
::
WorkerIdInfo
;
use
futures
::
StreamExt
;
let
mut
worker_id
:
i64
=
0
;
let
mut
tokens
:
Vec
<
u32
>
=
Vec
::
new
();
let
mut
result
=
WorkerSelectionResult
::
default
();
while
let
Some
(
response
)
=
stream
.next
()
.await
{
let
Some
(
event
)
=
&
response
.event
else
{
tracing
::
error!
(
"Response has no event field"
);
continue
;
};
match
event
.as_str
()
{
"worker_instance_id"
=>
{
tracing
::
debug!
(
"Found worker_instance_id event"
);
let
Some
(
first_comment
)
=
response
.comment
.as_ref
()
.and_then
(|
v
|
v
.first
())
else
{
tracing
::
debug!
(
"worker_instance_id event without comments"
);
continue
;
};
// Try JSON string first (e.g. `"1732646935200805498"`), then plain integer.
if
let
Ok
(
id_string
)
=
serde_json
::
from_str
::
<
String
>
(
first_comment
)
{
match
id_string
.parse
::
<
i64
>
()
{
Ok
(
parsed_id
)
=>
{
worker_id
=
parsed_id
;
tracing
::
debug!
(
"parsed worker_id from JSON string: {}"
,
worker_id
);
}
Err
(
_
)
=>
{
tracing
::
error!
(
"failed to parse number from JSON string: '{}'"
,
id_string
// Check for data in nvext (worker_id and token_ids are direct fields)
// nvext is a serde_json::Value, so we access it as a JSON object
if
let
Some
(
data
)
=
&
response
.data
&&
let
Some
(
nvext
)
=
&
data
.nvext
{
// Extract worker_id
if
let
Some
(
worker_id_value
)
=
nvext
.get
(
"worker_id"
)
&&
let
Ok
(
worker_info
)
=
serde_json
::
from_value
::
<
WorkerIdInfo
>
(
worker_id_value
.clone
())
{
result
.decode_worker_id
=
worker_info
.decode_worker_id
.map
(|
id
|
id
as
i64
);
result
.prefill_worker_id
=
worker_info
.prefill_worker_id
.map
(|
id
|
id
as
i64
);
tracing
::
debug!
(
decode_worker_id
=
?
result
.decode_worker_id
,
prefill_worker_id
=
?
result
.prefill_worker_id
,
"Parsed worker_id from nvext"
);
}
}
continue
;
}
match
first_comment
.parse
::
<
i64
>
()
{
Ok
(
parsed_id
)
=>
{
worker_id
=
parsed_id
;
tracing
::
debug!
(
"parsed worker_id directly: {}"
,
worker_id
);
}
Err
(
_
)
=>
{
tracing
::
error!
(
"failed to parse worker_id from: '{}'"
,
first_comment
);
}
}
}
"token_data"
=>
{
tracing
::
debug!
(
"Found token_data event"
);
let
Some
(
first_comment
)
=
response
.comment
.as_ref
()
.and_then
(|
v
|
v
.first
())
else
{
tracing
::
debug!
(
"token_data event without comments"
);
continue
;
};
tracing
::
debug!
(
"Token comment: '{}'"
,
first_comment
);
match
serde_json
::
from_str
::
<
Vec
<
u32
>>
(
first_comment
)
{
Ok
(
parsed_tokens
)
=>
{
tokens
=
parsed_tokens
;
tracing
::
debug!
(
"Successfully parsed {} tokens"
,
tokens
.len
());
}
Err
(
e
)
=>
{
tracing
::
error!
(
"Failed to parse tokens from '{}': {}"
,
first_comment
,
e
);
}
}
}
other
=>
{
tracing
::
debug!
(
"Unknown event type: '{}'"
,
other
);
// Extract token_ids
if
let
Some
(
token_ids_value
)
=
nvext
.get
(
"token_ids"
)
&&
let
Ok
(
parsed_tokens
)
=
serde_json
::
from_value
::
<
Vec
<
u32
>>
(
token_ids_value
.clone
())
{
result
.tokens
=
parsed_tokens
;
tracing
::
debug!
(
"Successfully parsed {} tokens from nvext"
,
result
.tokens
.len
()
);
}
}
}
tracing
::
info!
(
"Final worker_id={}, tokens.len()={}"
,
worker_id
,
tokens
.len
()
decode_worker_id
=
?
result
.decode_worker_id
,
prefill_worker_id
=
?
result
.prefill_worker_id
,
token_count
=
result
.tokens
.len
(),
"Worker selection extraction complete"
);
Ok
(
(
worker_id
,
tokens
)
)
Ok
(
result
)
}
/// Utility function to add the "query_instance_id" annotation to an OpenAI request
///
/// This function modifies the request to include the annotation that signals the KV router
/// to return worker selection information (worker_
instance_
id and token_data) instead of
/// to return worker selection information (worker_
f
id and token_data) instead of
/// performing actual inference.
///
/// # Parameters
...
...
@@ -824,28 +821,73 @@ pub async fn extract_worker_selection_from_stream(
pub
fn
add_query_instance_id
(
request
:
&
mut
NvCreateChatCompletionRequest
,
)
->
&
mut
NvCreateChatCompletionRequest
{
add_annotation_unique
(
request
,
"query_instance_id"
)
// Send empty value - router treats empty as aggregated / aggregated worker selection
set_kv_annotation
(
request
,
"query_instance_id"
.to_string
(),
""
)
}
/// Utility function to add worker_instance_id annotation to an OpenAI request
pub
fn
add_worker_instance_id_annotation
(
/// Set worker IDs directly on the NvExt fields for GAIE Stage 2
///
/// For disaggregated mode: sets `prefill_worker_id` and `decode_worker_id`
/// For aggregated mode: sets `backend_instance_id` (when both IDs are the same)
pub
fn
set_worker_ids_for_stage2
(
request
:
&
mut
NvCreateChatCompletionRequest
,
worker_id
:
i64
,
decode_worker_id
:
Option
<
i64
>
,
prefill_worker_id
:
Option
<
i64
>
,
)
->
&
mut
NvCreateChatCompletionRequest
{
set_kv_annotation
(
request
,
"worker_instance_id"
.to_string
(),
worker_id
.to_string
(),
)
let
nvext
=
request
.nvext
.get_or_insert_with
(||
{
NvExt
::
builder
()
.build
()
.expect
(
"NvExt builder should not fail"
)
});
// Check if this is aggregated mode (same worker for both)
let
is_aggregated
=
prefill_worker_id
==
decode_worker_id
;
if
is_aggregated
{
// Aggregated: use backend_instance_id for direct routing
if
let
Some
(
id
)
=
decode_worker_id
{
nvext
.backend_instance_id
=
Some
(
id
as
u64
);
tracing
::
debug!
(
backend_instance_id
=
id
,
"GAIE Stage 2 Aggregated: Setting backend_instance_id"
);
}
}
else
{
// Disaggregated: use separate prefill and decode worker IDs
if
let
Some
(
id
)
=
prefill_worker_id
{
nvext
.prefill_worker_id
=
Some
(
id
as
u64
);
}
if
let
Some
(
id
)
=
decode_worker_id
{
nvext
.decode_worker_id
=
Some
(
id
as
u64
);
}
tracing
::
debug!
(
prefill_worker_id
=
?
prefill_worker_id
,
decode_worker_id
=
?
decode_worker_id
,
"GAIE Stage 2 Disaggregated: Setting prefill and decode worker IDs"
);
}
request
}
///
Utility function to add token_data annotation to an OpenAI request
pub
fn
add
_token_data_
annotation
<
'a
>
(
///
Set token_data directly on the NvExt field for GAIE Stage 2
pub
fn
set
_token_data_
for_stage2
<
'a
>
(
request
:
&
'a
mut
NvCreateChatCompletionRequest
,
tokens
:
&
[
u32
],
)
->
&
'a
mut
NvCreateChatCompletionRequest
{
let
tokens_json
=
serde_json
::
to_string
(
tokens
)
.unwrap_or_default
();
set_kv_annotation
(
request
,
"token_data"
.to_string
(),
tokens_json
)
let
nvext
=
request
.nvext
.get_or_insert_with
(||
{
NvExt
::
builder
()
.build
()
.expect
(
"NvExt builder should not fail"
)
});
nvext
.token_data
=
Some
(
tokens
.to_vec
());
tracing
::
debug!
(
token_count
=
tokens
.len
(),
"GAIE Stage 2: Setting token_data"
);
request
}
/// Ensure `nvext` exists and return a mutable slice of annotations.
...
...
@@ -858,19 +900,6 @@ fn ensure_annotations(request: &mut NvCreateChatCompletionRequest) -> &mut Vec<S
nvext
.annotations
.get_or_insert_with
(
Vec
::
new
)
}
/// Add a plain annotation once.
fn
add_annotation_unique
(
request
:
&
mut
NvCreateChatCompletionRequest
,
annotation
:
impl
Into
<
String
>
,
)
->
&
mut
NvCreateChatCompletionRequest
{
let
ann
=
annotation
.into
();
let
annotations
=
ensure_annotations
(
request
);
if
!
annotations
.iter
()
.any
(|
a
|
a
==
&
ann
)
{
annotations
.push
(
ann
);
}
request
}
/// Set a `key:value` annotation.
fn
set_kv_annotation
(
request
:
&
mut
NvCreateChatCompletionRequest
,
...
...
@@ -885,38 +914,153 @@ fn set_kv_annotation(
request
}
/// Wrapper function that queries worker selection and
annotates the original request
/// Wrapper function that queries worker selection and
prepares the request for GAIE Stage 2
///
/// This function performs the complete flow:
/// 1. Clones the original request and adds "query_instance_id
"
annotation
/// This function performs the complete
GAIE Stage 1
flow:
/// 1. Clones the original request and adds "query_instance_id
:" (empty)
annotation
/// 2. Calls engine.generate() with the modified request
/// 3. Extracts worker_instance_id and tokens from the response stream
/// 4. Adds worker_instance_id and token_data annotations to the original request
/// 5. Returns (worker_id, tokens, annotated_original_request)
/// 3. Extracts worker_id info and tokens from the response stream
/// 4. Sets the appropriate NvExt fields on the original request for Stage 2:
/// - Disaggregated: prefill_worker_id, decode_worker_id, token_data
/// - Aggregated: backend_instance_id, token_data
/// 5. Returns WorkerSelectionResult and the modified request ready for Stage 2
///
/// # Parameters
/// - `engine`: The worker selection pipeline engine
/// - `original_request`: The original OpenAI request to process
///
/// # Returns
/// A tuple containing (
w
orker
_instance_id, tokens
, modified_original_request)
/// where the modified_original_request
has worker_instance_id and token_data annotations added
/// A tuple containing (
W
orker
SelectionResult
, modified_original_request)
/// where the modified_original_request
is ready for GAIE Stage 2 execution
pub
async
fn
query_worker_selection_and_annotate
(
engine
:
&
ServiceEngine
<
SingleIn
<
NvCreateChatCompletionRequest
>
,
ManyOut
<
Annotated
<
NvCreateChatCompletionStreamResponse
>>
,
>
,
mut
original_request
:
NvCreateChatCompletionRequest
,
)
->
anyhow
::
Result
<
(
i64
,
Vec
<
u32
>
,
NvCreateChatCompletionRequest
)
>
{
)
->
anyhow
::
Result
<
(
WorkerSelectionResult
,
NvCreateChatCompletionRequest
)
>
{
// GAIE Stage 1: Query for worker selection
let
mut
query_request
=
original_request
.clone
();
add_query_instance_id
(
&
mut
query_request
);
let
single_in
=
SingleIn
::
new
(
query_request
);
let
response_stream
=
engine
.generate
(
single_in
)
.await
?
;
let
(
worker_id
,
tokens
)
=
extract_worker_selection_from_stream
(
response_stream
)
.await
?
;
add_worker_instance_id_annotation
(
&
mut
original_request
,
worker_id
);
add_token_data_annotation
(
&
mut
original_request
,
&
tokens
);
let
result
=
extract_worker_selection_from_stream
(
response_stream
)
.await
?
;
Ok
((
worker_id
,
tokens
,
original_request
))
// Prepare request for GAIE Stage 2: Set NvExt fields directly
set_worker_ids_for_stage2
(
&
mut
original_request
,
result
.decode_worker_id
,
result
.prefill_worker_id
,
);
set_token_data_for_stage2
(
&
mut
original_request
,
&
result
.tokens
);
Ok
((
result
,
original_request
))
}
/// Spawn a background task to watch for prefill models and activate prefill routers.
/// This is a lightweight watcher that only handles prefill model discovery.
fn
spawn_prefill_watcher
(
drt
:
DistributedRuntime
,
model_manager
:
Arc
<
ModelManager
>
,
target_namespace
:
String
,
)
{
use
dynamo_llm
::
model_card
::
ModelDeploymentCard
;
use
dynamo_runtime
::
discovery
::{
DiscoveryEvent
,
DiscoveryInstance
,
DiscoveryQuery
};
use
dynamo_runtime
::
protocols
::
EndpointId
;
use
futures
::
StreamExt
;
tokio
::
spawn
(
async
move
{
let
discovery
=
drt
.discovery
();
let
mut
stream
=
match
discovery
.list_and_watch
(
DiscoveryQuery
::
AllModels
,
None
)
.await
{
Ok
(
s
)
=>
s
,
Err
(
e
)
=>
{
tracing
::
error!
(
error
=
%
e
,
"Failed to start prefill discovery stream"
);
return
;
}
};
while
let
Some
(
result
)
=
stream
.next
()
.await
{
let
event
=
match
result
{
Ok
(
e
)
=>
e
,
Err
(
e
)
=>
{
tracing
::
error!
(
error
=
%
e
,
"Error in prefill discovery stream"
);
continue
;
}
};
match
event
{
DiscoveryEvent
::
Added
(
instance
)
=>
{
let
(
endpoint_id
,
card
)
=
match
&
instance
{
DiscoveryInstance
::
Model
{
namespace
,
component
,
endpoint
,
..
}
=>
{
// Filter by namespace
if
namespace
!=
&
target_namespace
{
continue
;
}
let
eid
=
EndpointId
{
namespace
:
namespace
.clone
(),
component
:
component
.clone
(),
name
:
endpoint
.clone
(),
};
match
instance
.deserialize_model
::
<
ModelDeploymentCard
>
()
{
Ok
(
card
)
=>
(
eid
,
card
),
Err
(
_
)
=>
continue
,
}
}
_
=>
continue
,
};
// Only handle prefill models
if
!
card
.model_type
.supports_prefill
()
{
continue
;
}
tracing
::
info!
(
model_name
=
card
.name
(),
"Prefill model discovered, activating prefill router"
);
// Get the endpoint and activate the prefill router
if
let
Ok
(
ns
)
=
drt
.namespace
(
&
endpoint_id
.namespace
)
&&
let
Ok
(
comp
)
=
ns
.component
(
&
endpoint_id
.component
)
{
let
endpoint
=
comp
.endpoint
(
&
endpoint_id
.name
);
if
let
Err
(
e
)
=
model_manager
.activate_prefill_router
(
card
.name
(),
endpoint
)
{
tracing
::
warn!
(
model_name
=
card
.name
(),
error
=
%
e
,
"Failed to activate prefill router"
);
}
else
{
tracing
::
info!
(
model_name
=
card
.name
(),
"Prefill router activated successfully"
);
}
}
}
DiscoveryEvent
::
Removed
(
instance_id
)
=>
{
// Log removal for observability
// Note: The PrefillRouter remains active - worker availability
// is handled dynamically by the underlying Client's instance tracking
tracing
::
debug!
(
instance_id
=
instance_id
,
"Prefill worker instance removed from discovery"
);
}
}
}
});
}
/// Create a worker selection pipeline for OpenAI Chat Completion requests
...
...
@@ -931,6 +1075,7 @@ pub async fn query_worker_selection_and_annotate(
/// - `router_mode`: How to route requests (KV, RoundRobin, etc.)
/// - `busy_threshold`: Optional threshold for busy worker detection
/// - `kv_router_config`: Optional KV router configuration (only used when router_mode is KV)
/// - `enforce_disagg`: If true, fail requests when disaggregated serving is unavailable
///
/// # Returns
/// A configured worker selection pipeline ready to use
...
...
@@ -941,12 +1086,15 @@ pub async fn create_worker_selection_pipeline_chat(
router_mode
:
RouterMode
,
busy_threshold
:
Option
<
f64
>
,
kv_router_config
:
Option
<
KvRouterConfig
>
,
enforce_disagg
:
bool
,
)
->
anyhow
::
Result
<
ServiceEngine
<
SingleIn
<
NvCreateChatCompletionRequest
>
,
ManyOut
<
Annotated
<
NvCreateChatCompletionStreamResponse
>>
,
>
,
>
{
use
dynamo_llm
::
kv_router
::
PrefillRouter
;
let
runtime
=
Runtime
::
from_settings
()
?
;
let
dst_config
=
DistributedConfig
::
from_settings
();
let
drt_owned
=
DistributedRuntime
::
new
(
runtime
,
dst_config
)
.await
?
;
...
...
@@ -966,10 +1114,9 @@ pub async fn create_worker_selection_pipeline_chat(
let
router_config
=
dynamo_llm
::
entrypoint
::
RouterConfig
{
router_mode
,
kv_router_config
:
kv_router_config
.unwrap_or_default
(),
// C bindings only support active_decode_blocks_threshold for now (via busy_threshold param)
active_decode_blocks_threshold
:
busy_threshold
,
active_prefill_tokens_threshold
:
None
,
enforce_disagg
:
false
,
enforce_disagg
,
};
let
watcher
=
ModelWatcher
::
new
(
component
.drt
()
.clone
(),
...
...
@@ -999,6 +1146,34 @@ pub async fn create_worker_selection_pipeline_chat(
None
};
// Create prefill chooser for dynamic disaggregation support
// This registers the model and returns a receiver that will be activated
// when a prefill worker is discovered
let
prefill_chooser
=
model_manager
.register_prefill_router
(
model_name
.to_string
())
.map
(|
rx
|
{
// Create prefill-specific config with track_active_blocks disabled
let
mut
prefill_config
=
kv_router_config
.unwrap_or_default
();
prefill_config
.router_track_active_blocks
=
false
;
PrefillRouter
::
new
(
rx
,
model_manager
.clone
(),
router_mode
,
card
.kv_cache_block_size
,
Some
(
prefill_config
),
enforce_disagg
,
)
});
// Start background watcher for prefill model discovery
// This will activate the prefill router when prefill workers join
spawn_prefill_watcher
(
component
.drt
()
.clone
(),
model_manager
.clone
(),
namespace
.to_string
(),
);
// Download model config files from HuggingFace for EPP
// The backend's card has NATS URLs which aren't accessible from EPP
tracing
::
debug!
(
...
...
@@ -1034,7 +1209,6 @@ pub async fn create_worker_selection_pipeline_chat(
// Create worker monitor if busy_threshold is set
// Note: C bindings don't register with ModelManager, so HTTP endpoint won't see this
// C bindings only support active_decode_blocks_threshold for now (active_prefill_tokens_threshold defaults to 1000000 tokens = effectively disabled)
let
worker_monitor
=
busy_threshold
.map
(|
t
|
KvWorkerMonitor
::
new
(
client
.clone
(),
t
,
1000000
));
let
engine
=
build_routed_pipeline
::
<
...
...
@@ -1047,8 +1221,8 @@ pub async fn create_worker_selection_pipeline_chat(
worker_monitor
,
chooser
,
hf_tokenizer
,
None
,
//
prefill_chooser
false
,
//
enforce_disagg
prefill_chooser
,
enforce_disagg
,
)
.await
?
;
...
...
lib/llm/src/kv_router.rs
View file @
15b49818
...
...
@@ -97,6 +97,46 @@ pub fn router_endpoint_id(namespace: String) -> EndpointId {
}
}
/// Specifies the type of worker being queried when using the `query_instance_id` annotation.
/// This tells the router which worker pool to select from and what type of operation is intended.
///
/// Query instance types for worker selection
/// - "prefill" → select a prefill worker (disaggregated serving)
/// - "decode" → select a decode worker (disaggregated serving)
///
/// Note: Empty value ("query_instance_id:") is handled by PrefillRouter for disagg orchestration
#[derive(Debug,
Clone,
Copy,
PartialEq,
Eq,
Serialize,
Deserialize)]
#[serde(rename_all
=
"lowercase"
)]
pub
enum
QueryInstanceType
{
/// Query for a prefill worker (disaggregated serving)
Prefill
,
/// Query for a decode worker (disaggregated serving)
Decode
,
}
impl
std
::
fmt
::
Display
for
QueryInstanceType
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
match
self
{
QueryInstanceType
::
Prefill
=>
write!
(
f
,
"prefill"
),
QueryInstanceType
::
Decode
=>
write!
(
f
,
"decode"
),
}
}
}
impl
std
::
str
::
FromStr
for
QueryInstanceType
{
type
Err
=
String
;
fn
from_str
(
s
:
&
str
)
->
Result
<
Self
,
Self
::
Err
>
{
match
s
.to_lowercase
()
.as_str
()
{
"prefill"
=>
Ok
(
QueryInstanceType
::
Prefill
),
"decode"
=>
Ok
(
QueryInstanceType
::
Decode
),
_
=>
Err
(
format!
(
"Invalid QueryInstanceType: '{s}'. Expected 'prefill' or 'decode'"
)),
}
}
}
/// Creates a DiscoveryQuery for the KV router in the given namespace.
pub
fn
router_discovery_query
(
namespace
:
String
)
->
DiscoveryQuery
{
DiscoveryQuery
::
Endpoint
{
...
...
@@ -731,13 +771,34 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
// Extract context ID for request tracking
let
context_id
=
request
.context
()
.id
()
.to_string
();
// Check if this is a query_instance_id request first
let
query_instance_id
=
request
.has_annotation
(
"query_instance_id"
);
// Check if this is a query_instance_id request and parse its type
// Format: "query_instance_id:type" where type is "prefill", "decode", or "" (empty for aggregated)
// Empty value ("query_instance_id:") means GAIE Aggregated mode - return same worker as both prefill and decode
let
query_instance_annotation
=
request
.get_annotation_value
(
"query_instance_id"
);
let
is_gaie_agg_query
=
query_instance_annotation
.as_ref
()
.is_some_and
(|
s
|
s
.is_empty
());
let
query_instance_type
:
Option
<
QueryInstanceType
>
=
if
let
Some
(
type_str
)
=
&
query_instance_annotation
{
match
type_str
.parse
::
<
QueryInstanceType
>
()
{
Ok
(
t
)
=>
Some
(
t
),
Err
(
_
)
if
type_str
.is_empty
()
=>
{
// Empty value is valid for aggregated mode, not a warning
None
}
Err
(
e
)
=>
{
tracing
::
warn!
(
"Invalid query_instance_id type '{type_str}': {e}"
);
None
}
}
}
else
{
None
};
let
(
instance_id
,
dp_rank
,
overlap_amount
)
=
if
let
Some
(
id
)
=
request
.backend_instance_id
{
// If instance_id is set, use it and compute actual overlap
let
dp_rank
=
request
.dp_rank
.unwrap_or
(
0
);
if
query_instance_
id
{
if
query_instance_
type
.is_some
()
{
tracing
::
debug!
(
"backend_instance_id is set, routing to instance {id} with dp_rank {dp_rank} and ignoring query_instance_id annotation"
);
...
...
@@ -761,33 +822,80 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
(
id
,
dp_rank
,
overlap_blocks
)
}
else
{
// Otherwise, find the best match
// Don't update states if this is a query-only request (any query_instance_id annotation)
let
should_update_states
=
query_instance_annotation
.is_none
();
let
(
best_worker
,
overlap_amount
)
=
self
.chooser
.find_best_match
(
Some
(
&
context_id
),
&
request
.token_ids
,
request
.router_config_override
.as_ref
(),
!
query_instance_id
,
// Don't
update
states
if query_instance_id
should_
update
_
states
,
)
.await
?
;
(
best_worker
.worker_id
,
best_worker
.dp_rank
,
overlap_amount
)
};
//
i
f request has
the annotation "query_instance_id",
//
then the request will not be routed to the worker,
//
and instead the worker_instance_id will be returned
.
//
I
f request has
a query_instance_id annotation, return worker selection info
//
without routing to the actual worker. Returns LLMEngineOutput with disaggregated_params
//
containing worker_id info, same structure as normal execution for uniform extraction
.
let
stream_context
=
request
.context
()
.clone
();
if
query_instance_id
{
let
instance_id_str
=
instance_id
.to_string
();
let
response
=
Annotated
::
from_annotation
(
"worker_instance_id"
,
&
instance_id_str
)
?
;
// Return the tokens in nvext.token_data format
let
response_tokens
=
Annotated
::
from_annotation
(
"token_data"
,
&
request
.token_ids
)
?
;
// Handle query-only requests (GAIE Stage 1)
if
query_instance_type
.is_some
()
||
is_gaie_agg_query
{
let
worker_id_info
=
if
is_gaie_agg_query
{
// GAIE Aggregated mode: same worker serves both prefill and decode
tracing
::
trace!
(
query_type
=
"aggregated"
,
worker_id
=
instance_id
,
"Returning aggregated worker selection (same worker for prefill and decode)"
);
WorkerIdInfo
{
prefill_worker_id
:
Some
(
instance_id
),
decode_worker_id
:
Some
(
instance_id
),
}
}
else
{
match
query_instance_type
.unwrap
()
{
QueryInstanceType
::
Prefill
=>
{
tracing
::
trace!
(
query_type
=
"prefill"
,
prefill_worker_id
=
instance_id
,
"Returning prefill worker selection"
);
WorkerIdInfo
{
prefill_worker_id
:
Some
(
instance_id
),
decode_worker_id
:
None
,
}
}
QueryInstanceType
::
Decode
=>
{
// Get prefill_worker_id from annotation (set by caller after prefill selection)
let
prefill_worker_id
=
request
.get_annotation_value
(
"prefill_worker_id"
)
.and_then
(|
s
|
s
.parse
::
<
u64
>
()
.ok
());
tracing
::
trace!
(
"Tokens requested in the response through the query_instance_id annotation: {:?}"
,
response_tokens
query_type
=
"decode"
,
prefill_worker_id
=
?
prefill_worker_id
,
decode_worker_id
=
instance_id
,
"Returning decode worker selection"
);
let
stream
=
stream
::
iter
(
vec!
[
response
,
response_tokens
]);
WorkerIdInfo
{
prefill_worker_id
,
decode_worker_id
:
Some
(
instance_id
),
}
}
}
};
// Return as LLMEngineOutput with disaggregated_params (same structure as normal execution)
let
output
=
LLMEngineOutput
{
disaggregated_params
:
Some
(
json!
({
"worker_id"
:
worker_id_info
,
"token_ids"
:
request
.token_ids
})),
..
Default
::
default
()
};
let
response
=
Annotated
::
from_data
(
output
);
let
stream
=
stream
::
iter
(
vec!
[
response
]);
return
Ok
(
ResponseStream
::
new
(
Box
::
pin
(
stream
),
stream_context
));
}
let
(
mut
backend_input
,
context
)
=
request
.into_parts
();
...
...
lib/llm/src/kv_router/prefill_router.rs
View file @
15b49818
...
...
@@ -20,9 +20,10 @@ use dynamo_runtime::{
use
crate
::{
discovery
::
ModelManager
,
kv_router
::{
KvPushRouter
,
KvRouterConfig
,
RouterConfigOverride
},
kv_router
::{
KvPushRouter
,
KvRouterConfig
,
QueryInstanceType
,
RouterConfigOverride
},
protocols
::
common
::
llm_backend
::{
LLMEngineOutput
,
PreprocessedRequest
},
protocols
::
common
::
preprocessor
::{
BootstrapInfo
,
PrefillResult
},
protocols
::
openai
::
nvext
::
WorkerIdInfo
,
};
/// Errors that can occur during prefill routing
...
...
@@ -67,6 +68,11 @@ impl InnerPrefillRouter {
/// PrefillRouter is a forward-only operator that sits between Migration and the decode router.
/// It optionally calls a prefill worker before routing to decode, extracting disaggregated_params
/// from the prefill response and injecting them into the decode request.
///
/// Supports regular Dynamo and GAIE integrated mode via query_instance_id state machine:
/// - GAIE Stage 1: query_instance_id transitions "" -> "prefill" -> "decode", returns only worker IDs
/// - GAIE Stage 2: target_prefill_worker_id/target_decode_worker_id are set, full execution with specified workers
/// - Non-GAIE: like GAIE Stage 2 but the worker ids have to be determined.
pub
struct
PrefillRouter
{
prefill_router
:
OnceLock
<
InnerPrefillRouter
>
,
cancel_token
:
CancellationToken
,
...
...
@@ -196,10 +202,13 @@ impl PrefillRouter {
rand
::
rng
()
.random
()
}
/// Query best worker upfront, build bootstrap_info, and spawn prefill in background
/// Build bootstrap_info for disaggregated serving
/// If preselected_worker is provided (GAIE Stage 2), use it directly.
/// Otherwise, query for the best worker.
async
fn
build_bootstrap_info
(
&
self
,
req
:
&
PreprocessedRequest
,
preselected_worker
:
Option
<
u64
>
,
)
->
Option
<
(
u64
,
u32
,
BootstrapInfo
)
>
{
let
prefill_router
=
self
.prefill_router
.get
()
?
;
...
...
@@ -209,14 +218,24 @@ impl PrefillRouter {
InnerPrefillRouter
::
SimpleRouter
(
_
)
=>
return
None
,
};
// Query best worker without routing
let
(
worker_id
,
dp_rank
)
=
match
kv_router
// Use pre-selected worker (GAIE Stage 2) or query for best worker
let
(
worker_id
,
dp_rank
)
=
if
let
Some
(
id
)
=
preselected_worker
{
let
dp_rank
=
req
.dp_rank
.unwrap_or
(
0
);
tracing
::
debug!
(
worker_id
=
id
,
dp_rank
=
dp_rank
,
"Using pre-selected prefill worker for bootstrap"
);
(
id
,
dp_rank
)
}
else
{
match
kv_router
.chooser
.find_best_match
(
None
,
&
req
.token_ids
,
None
,
false
)
.await
{
Ok
((
worker
,
_
overlap
))
=>
(
worker
.worker_id
,
worker
.dp_rank
),
Err
(
_
)
=>
return
None
,
}
};
// Look up bootstrap endpoint from discovery
...
...
@@ -343,6 +362,56 @@ impl PrefillRouter {
}
}
/// GAIE helper functions for preparing prefill requests
impl
PrefillRouter
{
/// Prepare prefill request for GAIE flows
/// - Stage 1: Sets query_instance_id:prefill annotation
/// - Stage 2: Sets backend_instance_id to target prefill worker
fn
prepare_prefill_for_gaie
(
prefill_req
:
&
mut
PreprocessedRequest
,
is_gaie_stage1
:
bool
)
{
if
is_gaie_stage1
{
// GAIE Stage 1: Set query_instance_id to "prefill" for prefill worker selection
prefill_req
.annotations
.retain
(|
a
|
!
a
.starts_with
(
"query_instance_id"
));
prefill_req
.annotations
.push
(
format!
(
"query_instance_id:{}"
,
QueryInstanceType
::
Prefill
));
}
else
if
let
Some
(
prefill_worker_id
)
=
prefill_req
.target_prefill_worker_id
{
// GAIE Stage 2: Route to pre-selected prefill worker from the stage 1
tracing
::
debug!
(
target_prefill_worker_id
=
prefill_worker_id
,
"GAIE Stage 2: Routing prefill to pre-selected worker"
);
prefill_req
.backend_instance_id
=
Some
(
prefill_worker_id
);
}
}
/// Prepare decode request for GAIE Stage 1
/// Extracts prefill_worker_id from prefill result and sets decode annotations
fn
prepare_decode_for_gaie_stage1
(
decode_req
:
&
mut
PreprocessedRequest
,
prefill_result
:
&
PrefillResult
,
)
{
let
prefill_worker_id
=
prefill_result
.disaggregated_params
.get
(
"worker_id"
)
.and_then
(|
v
|
serde_json
::
from_value
::
<
WorkerIdInfo
>
(
v
.clone
())
.ok
())
.and_then
(|
info
|
info
.prefill_worker_id
);
if
let
Some
(
worker_id
)
=
prefill_worker_id
{
decode_req
.annotations
.retain
(|
a
|
!
a
.starts_with
(
"query_instance_id"
));
decode_req
.annotations
.push
(
format!
(
"query_instance_id:{}"
,
QueryInstanceType
::
Decode
));
decode_req
.annotations
.push
(
format!
(
"prefill_worker_id:{worker_id}"
));
}
}
}
impl
Drop
for
PrefillRouter
{
fn
drop
(
&
mut
self
)
{
tracing
::
debug!
(
"Dropping PrefillRouter, cancelling background activation task"
);
...
...
@@ -369,6 +438,12 @@ impl
let
request_id
=
context
.id
()
.to_string
();
let
engine_ctx
=
context
.context
();
// GAIE Stage 1: the presence of the empty query_instance_id signals query-only mode
// State machine: "" -> "prefill" -> "decode" (disagg) OR "" -> aggregated worker (agg fallback)
let
is_gaie_stage1
=
req
.get_annotation_value
(
"query_instance_id"
)
.is_some_and
(|
s
|
s
.is_empty
());
// Save original max_tokens for decode
let
original_max_tokens
=
req
.stop_conditions.max_tokens
;
...
...
@@ -376,9 +451,16 @@ impl
let
mut
prefill_req
=
req
.clone
();
prefill_req
.stop_conditions.max_tokens
=
Some
(
1
);
// Try build_bootstrap_info optimization
let
prefill_result
=
if
let
Some
((
worker_id
,
dp_rank
,
bootstrap_info
))
=
self
.build_bootstrap_info
(
&
prefill_req
)
.await
// Prepare prefill request for GAIE flows (Stage 1 or Stage 2)
Self
::
prepare_prefill_for_gaie
(
&
mut
prefill_req
,
is_gaie_stage1
);
// Try build_bootstrap_info optimization (skip for GAIE Stage 1 which needs query-only flow)
// For GAIE Stage 2, use target_prefill_worker_id if provided
let
preselected_worker
=
prefill_req
.target_prefill_worker_id
;
let
prefill_result
=
if
!
is_gaie_stage1
{
if
let
Some
((
worker_id
,
dp_rank
,
bootstrap_info
))
=
self
.build_bootstrap_info
(
&
prefill_req
,
preselected_worker
)
.await
{
let
bootstrap_room
=
bootstrap_info
.bootstrap_room
;
...
...
@@ -408,6 +490,15 @@ impl
let
prefill_context
=
Context
::
with_id
(
prefill_req
,
request_id
.clone
());
engine_ctx
.link_child
(
prefill_context
.context
());
self
.call_prefill
(
prefill_context
)
.await
.map
(|(
result
,
worker_id
)|
(
Some
(
result
),
worker_id
,
None
))
}
}
else
{
// GAIE Stage 1: Use original path (no bootstrap optimization)
let
prefill_context
=
Context
::
with_id
(
prefill_req
,
request_id
.clone
());
engine_ctx
.link_child
(
prefill_context
.context
());
self
.call_prefill
(
prefill_context
)
.await
.map
(|(
result
,
worker_id
)|
(
Some
(
result
),
worker_id
,
None
))
...
...
@@ -429,8 +520,13 @@ impl
let
mut
decode_req
=
req
;
// Update request with prefill result if available (only in original path)
if
let
Some
(
prefill_result
)
=
maybe_prefill_result
{
// Update request with prefill result
if
is_gaie_stage1
{
if
let
Some
(
ref
prefill_result
)
=
maybe_prefill_result
{
Self
::
prepare_decode_for_gaie_stage1
(
&
mut
decode_req
,
prefill_result
);
}
}
else
if
let
Some
(
prefill_result
)
=
maybe_prefill_result
{
// Normal or GAIE Stage 2: Set prefill_result for decode
decode_req
.prefill_result
=
Some
(
prefill_result
);
}
...
...
@@ -449,6 +545,15 @@ impl
..
existing_override
.unwrap_or_default
()
});
// GAIE Stage 2: Route to pre-selected decode worker if specified
if
let
Some
(
decode_worker_id
)
=
decode_req
.target_decode_worker_id
{
decode_req
.backend_instance_id
=
Some
(
decode_worker_id
);
tracing
::
debug!
(
decode_worker_id
=
decode_worker_id
,
"GAIE Stage 2: Routing decode to pre-selected worker"
);
}
// Map the modified request through with preserved context
let
decode_request
=
context
.map
(|
_
|
decode_req
);
next
.generate
(
decode_request
)
.await
...
...
lib/llm/src/preprocessor.rs
View file @
15b49818
...
...
@@ -238,10 +238,13 @@ impl OpenAIPreprocessor {
builder
.annotations
(
request
.annotations
()
.unwrap_or_default
());
builder
.mdc_sum
(
Some
(
self
.mdcsum
.clone
()));
builder
.estimated_prefix_hit_num_blocks
(
None
);
// Extract backend_instance_id
and
extra_fields from nvext if present
// Extract backend_instance_id
,
extra_fields
, and worker IDs
from nvext if present
if
let
Some
(
nvext
)
=
request
.nvext
()
{
builder
.backend_instance_id
(
nvext
.backend_instance_id
);
builder
.extra_fields
(
nvext
.extra_fields
.clone
());
// GAIE Stage 2: Extract targeted worker IDs for disaggregated serving
builder
.target_prefill_worker_id
(
nvext
.prefill_worker_id
);
builder
.target_decode_worker_id
(
nvext
.decode_worker_id
);
}
Ok
(
builder
)
...
...
lib/llm/src/protocols/common/llm_backend.rs
View file @
15b49818
...
...
@@ -74,7 +74,7 @@ pub struct BackendOutput {
///
/// This is the minimal raw output from the LLM engine. The Backend may then apply multiple
/// levels of post-processing before the BackendOutput is returns
#[derive(Serialize,
Deserialize,
Debug,
Clone,
PartialEq)]
#[derive(Serialize,
Deserialize,
Debug,
Clone,
PartialEq
,
Default
)]
pub
struct
LLMEngineOutput
{
// new token_ids
pub
token_ids
:
Vec
<
TokenIdType
>
,
...
...
lib/llm/src/protocols/common/preprocessor.rs
View file @
15b49818
...
...
@@ -118,12 +118,34 @@ pub struct PreprocessedRequest {
#[builder(default)]
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
extra_fields
:
Option
<
Vec
<
String
>>
,
/// Targeted prefill worker ID for disaggregated serving (GAIE Stage 2)
/// When set, the prefill request will be routed to this specific worker.
#[builder(default)]
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
target_prefill_worker_id
:
Option
<
u64
>
,
/// Targeted decode worker ID for disaggregated serving (GAIE Stage 2)
/// When set, the decode request will be routed to this specific worker.
#[builder(default)]
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
target_decode_worker_id
:
Option
<
u64
>
,
}
impl
PreprocessedRequest
{
pub
fn
has_annotation
(
&
self
,
annotation
:
&
str
)
->
bool
{
self
.annotations
.contains
(
&
annotation
.to_string
())
}
/// Get the value of an annotation in the format "key:value"
/// Returns None if the annotation is not found or has no value
pub
fn
get_annotation_value
(
&
self
,
key
:
&
str
)
->
Option
<
String
>
{
let
prefix
=
format!
(
"{}:"
,
key
);
self
.annotations
.iter
()
.find
(|
a
|
a
.starts_with
(
&
prefix
))
.map
(|
a
|
a
[
prefix
.len
()
..
]
.to_string
())
}
}
impl
PreprocessedRequest
{
...
...
lib/llm/src/protocols/openai/chat_completions/delta.rs
View file @
15b49818
...
...
@@ -400,13 +400,19 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
tracker
.record_first_token
();
}
// Extract worker_id from disaggregated_params
// Extract worker_id
and token_ids
from disaggregated_params
let
worker_id_info
=
delta
.disaggregated_params
.as_ref
()
.and_then
(|
params
|
params
.get
(
"worker_id"
))
.and_then
(|
v
|
serde_json
::
from_value
::
<
WorkerIdInfo
>
(
v
.clone
())
.ok
());
let
token_ids
=
delta
.disaggregated_params
.as_ref
()
.and_then
(|
params
|
params
.get
(
"token_ids"
))
.and_then
(|
v
|
serde_json
::
from_value
::
<
Vec
<
u32
>>
(
v
.clone
())
.ok
());
// Get timing info if this is the final response (has finish_reason)
let
timing_info
:
Option
<
TimingInfo
>
=
if
finish_reason
.is_some
()
{
self
.timing_tracker
.as_ref
()
.map
(|
tracker
|
{
...
...
@@ -417,11 +423,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
None
};
// Inject nvext if we have worker_id or timing
if
worker_id_info
.is_some
()
||
timing_info
.is_some
()
{
// Inject nvext if we have worker_id
, token_ids,
or timing
if
worker_id_info
.is_some
()
||
token_ids
.is_some
()
||
timing_info
.is_some
()
{
let
nvext_response
=
NvExtResponse
{
worker_id
:
worker_id_info
.clone
(),
timing
:
timing_info
,
token_ids
:
token_ids
.clone
(),
};
if
let
Ok
(
nvext_json
)
=
serde_json
::
to_value
(
&
nvext_response
)
{
...
...
@@ -433,6 +440,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
info
.decode_worker_id
);
}
if
let
Some
(
ref
tokens
)
=
token_ids
{
tracing
::
debug!
(
"Injected token_ids into chat completion nvext: {} tokens"
,
tokens
.len
()
);
}
}
}
...
...
lib/llm/src/protocols/openai/completions/delta.rs
View file @
15b49818
...
...
@@ -295,13 +295,19 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
tracker
.record_first_token
();
}
// Extract worker_id from disaggregated_params
// Extract worker_id
and token_ids
from disaggregated_params
let
worker_id_info
=
delta
.disaggregated_params
.as_ref
()
.and_then
(|
params
|
params
.get
(
"worker_id"
))
.and_then
(|
v
|
serde_json
::
from_value
::
<
WorkerIdInfo
>
(
v
.clone
())
.ok
());
let
token_ids
=
delta
.disaggregated_params
.as_ref
()
.and_then
(|
params
|
params
.get
(
"token_ids"
))
.and_then
(|
v
|
serde_json
::
from_value
::
<
Vec
<
u32
>>
(
v
.clone
())
.ok
());
// Get timing info if this is the final response (has finish_reason)
let
timing_info
:
Option
<
TimingInfo
>
=
if
finish_reason
.is_some
()
{
self
.timing_tracker
.as_ref
()
.map
(|
tracker
|
{
...
...
@@ -312,11 +318,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
None
};
// Inject nvext if we have worker_id or timing
if
worker_id_info
.is_some
()
||
timing_info
.is_some
()
{
// Inject nvext if we have worker_id
, token_ids,
or timing
if
worker_id_info
.is_some
()
||
token_ids
.is_some
()
||
timing_info
.is_some
()
{
let
nvext_response
=
NvExtResponse
{
worker_id
:
worker_id_info
.clone
(),
timing
:
timing_info
,
token_ids
:
token_ids
.clone
(),
};
if
let
Ok
(
nvext_json
)
=
serde_json
::
to_value
(
&
nvext_response
)
{
...
...
@@ -328,6 +335,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
info
.decode_worker_id
);
}
if
let
Some
(
ref
tokens
)
=
token_ids
{
tracing
::
debug!
(
"Injected token_ids into completions nvext: {} tokens"
,
tokens
.len
()
);
}
}
}
...
...
lib/llm/src/protocols/openai/nvext.rs
View file @
15b49818
...
...
@@ -35,6 +35,11 @@ pub struct NvExtResponse {
/// Populated when client requests `extra_fields: ["timing"]`
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
timing
:
Option
<
TimingInfo
>
,
/// Token IDs for GAIE Stage 1 query-only mode
/// Contains the tokenized prompt for reuse in Stage 2
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
token_ids
:
Option
<
Vec
<
u32
>>
,
}
/// NVIDIA LLM extensions to the OpenAI API
...
...
@@ -87,6 +92,18 @@ pub struct NvExt {
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(strip_option))]
pub
extra_fields
:
Option
<
Vec
<
String
>>
,
/// Targeted prefill worker ID for disaggregated serving (GAIE Stage 2)
/// When set, the request will be routed to this specific prefill worker.
#[builder(default,
setter(strip_option))]
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
prefill_worker_id
:
Option
<
u64
>
,
/// Targeted decode worker ID for disaggregated serving (GAIE Stage 2)
/// When set, the request will be routed to this specific decode worker.
#[builder(default,
setter(strip_option))]
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
decode_worker_id
:
Option
<
u64
>
,
}
impl
Default
for
NvExt
{
...
...
@@ -133,6 +150,8 @@ mod tests {
assert_eq!
(
nv_ext
.token_data
,
None
);
assert_eq!
(
nv_ext
.max_thinking_tokens
,
None
);
assert_eq!
(
nv_ext
.extra_fields
,
None
);
assert_eq!
(
nv_ext
.prefill_worker_id
,
None
);
assert_eq!
(
nv_ext
.decode_worker_id
,
None
);
}
// Test valid builder configurations
...
...
@@ -157,4 +176,18 @@ mod tests {
// Validate the built struct
assert
!
(
nv_ext
.validate
()
.is_ok
());
}
// Test GAIE Stage 2 disaggregated worker IDs
#[test]
fn
test_nv_ext_disagg_worker_ids
()
{
let
nv_ext
=
NvExt
::
builder
()
.prefill_worker_id
(
100
)
.decode_worker_id
(
200
)
.build
()
.unwrap
();
assert_eq!
(
nv_ext
.prefill_worker_id
,
Some
(
100
));
assert_eq!
(
nv_ext
.decode_worker_id
,
Some
(
200
));
assert
!
(
nv_ext
.validate
()
.is_ok
());
}
}
tests/router/common.py
View file @
15b49818
...
...
@@ -1026,9 +1026,10 @@ def _test_router_query_instance_id(
asyncio
.
run
(
send_request_with_retry
(
url
,
test_payload
))
# Test payload with query_instance_id annotation
# Format: "query_instance_id:" (colon with empty value) for GAIE aggregated mode
annotated_payload
=
{
**
test_payload
,
"nvext"
:
{
"annotations"
:
[
"query_instance_id"
]},
"nvext"
:
{
"annotations"
:
[
"query_instance_id
:
"
]},
}
async
def
test_annotation_response
():
...
...
@@ -1053,100 +1054,80 @@ def _test_router_query_instance_id(
f
"Full SSE response (
{
len
(
full_response
)
}
bytes):
\n
{
full_response
}
"
)
# Parse and validate the response structure
events
=
[]
# Parse the SSE response to extract the first chunk with nvext data
# New format: nvext contains worker_id and token_ids
sse_parts
=
full_response
.
split
(
"
\n\n
"
)
worker_id_info
=
None
token_list
=
None
for
part
in
sse_parts
:
part
=
part
.
strip
()
if
not
part
:
if
not
part
or
not
part
.
startswith
(
"data:"
)
:
continue
if
part
.
startswith
(
"event:"
):
lines
=
part
.
split
(
"
\n
"
)
event_line
=
next
(
(
line
for
line
in
lines
if
line
.
startswith
(
"event:"
)),
None
,
)
data_line
=
next
(
(
line
for
line
in
lines
if
line
.
startswith
(
"data:"
)
or
line
.
startswith
(
":"
)
),
None
,
)
if
event_line
and
data_line
:
event_type
=
event_line
.
split
(
":"
,
1
)[
1
].
strip
()
if
data_line
.
startswith
(
"data:"
):
data_value
=
data_line
.
split
(
":"
,
1
)[
1
].
strip
()
else
:
data_value
=
data_line
.
split
(
":"
,
1
)[
1
].
strip
()
events
.
append
((
event_type
,
data_value
))
elif
part
.
startswith
(
"data:"
):
data_value
=
part
.
split
(
":"
,
1
)[
1
].
strip
()
data_str
=
part
.
split
(
"data:"
,
1
)[
1
].
strip
()
if
data_str
==
"[DONE]"
:
continue
logger
.
info
(
f
"Parsed events:
{
events
}
"
)
try
:
chunk
=
json
.
loads
(
data_str
)
logger
.
info
(
f
"Parsed chunk:
{
json
.
dumps
(
chunk
,
indent
=
2
)
}
"
)
# Validate worker_instance_id event
worker_event
=
next
(
(
e
for
e
in
events
if
e
[
0
]
==
"worker_instance_id"
),
None
# Extract nvext data containing worker_id and token_ids
nvext
=
chunk
.
get
(
"nvext"
,
{})
if
nvext
:
if
"worker_id"
in
nvext
:
worker_id_info
=
nvext
[
"worker_id"
]
logger
.
info
(
f
"Found worker_id info:
{
worker_id_info
}
"
)
assert
(
worker_event
is
not
None
),
f
"Missing worker_instance_id event in:
{
events
}
"
# Validate token_data event
token_event
=
next
(
(
e
for
e
in
events
if
e
[
0
]
==
"token_data"
),
None
if
"token_ids"
in
nvext
:
token_list
=
nvext
[
"token_ids"
]
logger
.
info
(
f
"Found token_ids:
{
len
(
token_list
)
}
tokens"
)
except
json
.
JSONDecodeError
:
continue
# Validate worker_id info
assert
(
token_event
is
not
None
),
f
"Missing
token_data event in:
{
events
}
"
worker_id_info
is
not
None
),
f
"Missing
worker_id in nvext. Response:
{
full_response
}
"
token_data_str
=
token_event
[
1
].
strip
(
'"'
)
try
:
token_list
=
json
.
loads
(
token_data_str
)
except
json
.
JSONDecodeError
as
e
:
raise
AssertionError
(
f
"token_data is not valid JSON:
{
token_data_str
}
, error:
{
e
}
"
)
# For aggregated mode, both prefill and decode should be the same
prefill_worker_id
=
worker_id_info
.
get
(
"prefill_worker_id"
)
decode_worker_id
=
worker_id_info
.
get
(
"decode_worker_id"
)
assert
(
prefill_worker_id
is
not
None
),
f
"Missing prefill_worker_id in worker_id:
{
worker_id_info
}
"
assert
(
decode_worker_id
is
not
None
),
f
"Missing decode_worker_id in worker_id:
{
worker_id_info
}
"
assert
(
prefill_worker_id
==
decode_worker_id
),
f
"For aggregated mode, prefill and decode worker should be same:
{
worker_id_info
}
"
# Validate token_ids
assert
(
token_list
is
not
None
),
f
"Missing token_ids in nvext. Response:
{
full_response
}
"
assert
isinstance
(
token_list
,
list
),
f
"token_
data
should be a list, got:
{
type
(
token_list
)
}
"
),
f
"token_
ids
should be a list, got:
{
type
(
token_list
)
}
"
assert
(
len
(
token_list
)
>
0
),
f
"token_
data
should not be empty:
{
token_list
}
"
),
f
"token_
ids
should not be empty:
{
token_list
}
"
assert
all
(
isinstance
(
token
,
int
)
for
token
in
token_list
),
f
"All tokens should be integers:
{
token_list
}
"
logger
.
info
(
f
"Valid token_data with
{
len
(
token_list
)
}
tokens:
{
token_list
[:
10
]
}{
'...'
if
len
(
token_list
)
>
10
else
''
}
"
)
# Validate that no actual generation happened (should only be metadata)
# This proves the early return worked correctly
generation_indicators
=
[
"choices"
,
"content"
,
"delta"
,
"finish_reason"
,
]
for
indicator
in
generation_indicators
:
assert
(
indicator
not
in
full_response
.
lower
()
),
f
"Found generation indicator '
{
indicator
}
' - request should not have been routed to worker"
logger
.
info
(
"No generation content found - early return worked correctly"
f
"Valid token_ids with
{
len
(
token_list
)
}
tokens:
{
token_list
[:
10
]
}{
'...'
if
len
(
token_list
)
>
10
else
''
}
"
)
return
{
"worker_instance_id"
:
worker_event
[
1
].
strip
(
'"'
),
"prefill_worker_id"
:
prefill_worker_id
,
"decode_worker_id"
:
decode_worker_id
,
"token_count"
:
len
(
token_list
),
"tokens"
:
token_list
,
}
...
...
@@ -1154,7 +1135,8 @@ def _test_router_query_instance_id(
result
=
asyncio
.
run
(
test_annotation_response
())
logger
.
info
(
"Successfully validated query_instance_id annotation response:"
)
logger
.
info
(
f
"Worker ID:
{
result
[
'worker_instance_id'
]
}
"
)
logger
.
info
(
f
"Prefill Worker ID:
{
result
[
'prefill_worker_id'
]
}
"
)
logger
.
info
(
f
"Decode Worker ID:
{
result
[
'decode_worker_id'
]
}
"
)
logger
.
info
(
f
"Token count:
{
result
[
'token_count'
]
}
"
)
finally
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment