Unverified Commit 01f77f2c authored by atchernych's avatar atchernych Committed by GitHub
Browse files

feat: Book-keeping bindings [DEP-689] (#5036)


Signed-off-by: default avatarAnna Tchernych <atchernych@nvidia.com>
parent 06e6ff6d
...@@ -91,31 +91,58 @@ index dee7e99..d3f9ec7 100644 ...@@ -91,31 +91,58 @@ index dee7e99..d3f9ec7 100644
.PHONY: image-local-build .PHONY: image-local-build
image-local-build: ## Build the EPP image using Docker Buildx for local development. image-local-build: ## Build the EPP image using Docker Buildx for local development.
diff --git a/cmd/epp/main.go b/cmd/epp/main.go diff --git a/cmd/epp/main.go b/cmd/epp/main.go
index b5e0617..8592735 100644 index b5e0617..b5c0312 100644
--- a/cmd/epp/main.go --- a/cmd/epp/main.go
+++ b/cmd/epp/main.go +++ b/cmd/epp/main.go
@@ -22,6 +22,11 @@ import ( @@ -22,6 +22,12 @@ import (
ctrl "sigs.k8s.io/controller-runtime" ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/gateway-api-inference-extension/cmd/epp/runner" "sigs.k8s.io/gateway-api-inference-extension/cmd/epp/runner"
+ eppplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + eppplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
+ +
+ // Dynamo plugins + // Dynamo plugins
+ dyncleanup "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol/plugins/dynamo_cleanup"
+ dynprereq "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid" + dynprereq "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid"
+ dynscorer "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/dynamo_kv_scorer" + dynscorer "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/dynamo_kv_scorer"
) )
func main() { func main() {
@@ -30,6 +35,9 @@ func main() { @@ -30,6 +36,10 @@ func main() {
// For adding out-of-tree plugins to the plugins registry, use the following: // For adding out-of-tree plugins to the plugins registry, use the following:
// plugins.Register(my-out-of-tree-plugin-name, my-out-of-tree-plugin-factory-function) // plugins.Register(my-out-of-tree-plugin-name, my-out-of-tree-plugin-factory-function)
+ eppplugins.Register("dynamo-inject-workerid", dynprereq.InjectWorkerIDPreRequestFactory) + eppplugins.Register("dynamo-inject-workerid", dynprereq.InjectWorkerIDPreRequestFactory)
+ eppplugins.Register("kv-aware-scorer", dynscorer.KVAwareScorerFactory) + eppplugins.Register("kv-aware-scorer", dynscorer.KVAwareScorerFactory)
+ eppplugins.Register("dynamo-cleanup", dyncleanup.DynamoCleanupPluginFactory)
+ +
if err := runner.NewRunner().Run(ctrl.SetupSignalHandler()); err != nil { if err := runner.NewRunner().Run(ctrl.SetupSignalHandler()); err != nil {
os.Exit(1) os.Exit(1)
} }
diff --git a/pkg/epp/requestcontrol/body_mutator.go b/pkg/epp/requestcontrol/body_mutator.go
new file mode 100644
index 0000000..de87445
--- /dev/null
+++ b/pkg/epp/requestcontrol/body_mutator.go
@@ -0,0 +1,19 @@
+package requestcontrol
+
+import (
+ "context"
+
+ schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
+)
+
+// RequestBodyMutator allows pre-request plugins to mutate the outbound request body.
+// Implementations are invoked after the standard PreRequest hook completes.
+type RequestBodyMutator interface {
+ MutateRequestBody(
+ ctx context.Context,
+ request *schedtypes.LLMRequest,
+ schedulingResult *schedtypes.SchedulingResult,
+ targetPort int,
+ body map[string]any,
+ )
+}
diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go
index 670d922..0cf04cb 100644 index 670d922..0cf04cb 100644
--- a/pkg/epp/requestcontrol/director.go --- a/pkg/epp/requestcontrol/director.go
...@@ -159,6 +186,98 @@ index 670d922..0cf04cb 100644 ...@@ -159,6 +186,98 @@ index 670d922..0cf04cb 100644
metrics.RecordRequestControlPluginProcessingLatency(PreRequestPluginType, plugin.TypedName().Type, time.Since(before)) metrics.RecordRequestControlPluginProcessingLatency(PreRequestPluginType, plugin.TypedName().Type, time.Since(before))
} }
} }
diff --git a/pkg/epp/requestcontrol/plugins/dynamo_cleanup/plugin.go b/pkg/epp/requestcontrol/plugins/dynamo_cleanup/plugin.go
new file mode 100644
index 0000000..a389372
--- /dev/null
+++ b/pkg/epp/requestcontrol/plugins/dynamo_cleanup/plugin.go
@@ -0,0 +1,86 @@
+package dynamo_cleanup
+
+import (
+ "context"
+ "encoding/json"
+
+ log "sigs.k8s.io/controller-runtime/pkg/log"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
+ rc "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
+ schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
+ logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
+
+ dynamo "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/dynamo_kv_scorer"
+)
+
+const (
+ PluginName = "dynamo-cleanup"
+ PluginType = "dynamo-cleanup"
+)
+
+// DynamoCleanupPlugin is a PostResponse plugin that cleans up router state
+// when a request completes. It calls dynamo_router_free_request to release
+// the bookkeeping resources associated with the request.
+type DynamoCleanupPlugin struct {
+ typedName plugins.TypedName
+}
+
+var _ plugins.Plugin = (*DynamoCleanupPlugin)(nil)
+var _ rc.PostResponse = (*DynamoCleanupPlugin)(nil)
+
+// NewDynamoCleanupPlugin creates a new DynamoCleanupPlugin instance.
+func NewDynamoCleanupPlugin() *DynamoCleanupPlugin {
+ return &DynamoCleanupPlugin{
+ typedName: plugins.TypedName{Type: PluginType, Name: PluginName},
+ }
+}
+
+// WithName sets a custom name for the plugin.
+func (p *DynamoCleanupPlugin) WithName(name string) *DynamoCleanupPlugin {
+ p.typedName.Name = name
+ return p
+}
+
+// DynamoCleanupPluginFactory creates a DynamoCleanupPlugin from configuration.
+func DynamoCleanupPluginFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
+ return NewDynamoCleanupPlugin().WithName(name), nil
+}
+
+// TypedName returns the plugin's type and name.
+func (p *DynamoCleanupPlugin) TypedName() plugins.TypedName {
+ return p.typedName
+}
+
+// PostResponse is called after a response is received from the model server.
+// It cleans up the router bookkeeping state for the completed request.
+func (p *DynamoCleanupPlugin) PostResponse(
+ ctx context.Context,
+ request *schedtypes.LLMRequest,
+ response *rc.Response,
+ targetPod *backend.Pod,
+) {
+ logger := log.FromContext(ctx)
+
+ if request == nil {
+ logger.V(logutil.DEBUG).Info("DynamoCleanupPlugin: request is nil, skipping cleanup")
+ return
+ }
+
+ requestID := request.RequestId
+ if requestID == "" {
+ logger.V(logutil.DEBUG).Info("DynamoCleanupPlugin: no request ID, skipping cleanup")
+ return
+ }
+
+ // Call the dynamo router to free the request bookkeeping
+ if err := dynamo.CallFreeRequest(requestID); err != nil {
+ logger.V(logutil.DEFAULT).Error(err, "DynamoCleanupPlugin: failed to free request",
+ "requestID", requestID)
+ return
+ }
+
+ logger.V(logutil.VERBOSE).Info("DynamoCleanupPlugin: freed request from router",
+ "requestID", requestID)
+}
+
diff --git a/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go b/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go diff --git a/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go b/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go
new file mode 100644 new file mode 100644
index 0000000..1c8f979 index 0000000..1c8f979
...@@ -338,10 +457,10 @@ index 0000000..1c8f979 ...@@ -338,10 +457,10 @@ index 0000000..1c8f979
+} +}
diff --git a/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml diff --git a/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml
new file mode 100644 new file mode 100644
index 0000000..b689c00 index 0000000..e94b72b
--- /dev/null --- /dev/null
+++ b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml +++ b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml
@@ -0,0 +1,21 @@ @@ -0,0 +1,24 @@
+# This is an example for configuring the EPP to use the dynamo token-aware kv router for scoring the pods +# This is an example for configuring the EPP to use the dynamo token-aware kv router for scoring the pods
+apiVersion: inference.networking.x-k8s.io/v1alpha1 +apiVersion: inference.networking.x-k8s.io/v1alpha1
+kind: EndpointPickerConfig +kind: EndpointPickerConfig
...@@ -357,6 +476,9 @@ index 0000000..b689c00 ...@@ -357,6 +476,9 @@ index 0000000..b689c00
+ parameters: {} + parameters: {}
+ - name: dyn-kv + - name: dyn-kv
+ type: kv-aware-scorer + type: kv-aware-scorer
+ # Cleanup: frees router bookkeeping when request completes
+ - name: dyn-cleanup
+ type: dynamo-cleanup
+schedulingProfiles: +schedulingProfiles:
+ - name: default + - name: default
+ plugins: + plugins:
...@@ -365,10 +487,10 @@ index 0000000..b689c00 ...@@ -365,10 +487,10 @@ index 0000000..b689c00
+ - pluginRef: picker + - pluginRef: picker
diff --git a/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go diff --git a/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go
new file mode 100644 new file mode 100644
index 0000000..6ee6634 index 0000000..31af16e
--- /dev/null --- /dev/null
+++ b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go +++ b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go
@@ -0,0 +1,446 @@ @@ -0,0 +1,587 @@
+package dynamo_kv_scorer +package dynamo_kv_scorer
+ +
+/* +/*
...@@ -435,6 +557,20 @@ index 0000000..6ee6634 ...@@ -435,6 +557,20 @@ index 0000000..6ee6634
+dynamo_llm_result_t dynamo_free_worker_selection_result(uint32_t *token_ids, +dynamo_llm_result_t dynamo_free_worker_selection_result(uint32_t *token_ids,
+ size_t token_count, + size_t token_count,
+ char *annotated_request_json); + char *annotated_request_json);
+
+// Router bookkeeping functions for GAIE integration
+dynamo_llm_result_t dynamo_router_add_request(WorkerSelectionPipeline *pipeline,
+ const char *request_id_c_str,
+ const uint32_t *token_ids,
+ size_t token_count,
+ uint64_t worker_id,
+ uint32_t dp_rank);
+
+dynamo_llm_result_t dynamo_router_mark_prefill_complete(WorkerSelectionPipeline *pipeline,
+ const char *request_id_c_str);
+
+dynamo_llm_result_t dynamo_router_free_request(WorkerSelectionPipeline *pipeline,
+ const char *request_id_c_str);
+*/ +*/
+import "C" +import "C"
+ +
...@@ -459,6 +595,7 @@ index 0000000..6ee6634 ...@@ -459,6 +595,7 @@ index 0000000..6ee6634
+ KVAwareScorerType = "kv-aware-scorer" + KVAwareScorerType = "kv-aware-scorer"
+ StateKeyWorkerInstanceID = schedtypes.StateKey("dynamo/worker-instance-id") + StateKeyWorkerInstanceID = schedtypes.StateKey("dynamo/worker-instance-id")
+ StateKeyPrefillWorkerID = schedtypes.StateKey("dynamo/prefill-worker-id") + StateKeyPrefillWorkerID = schedtypes.StateKey("dynamo/prefill-worker-id")
+ StateKeyRequestID = schedtypes.StateKey("dynamo/request-id")
+ WorkerIDHeader = "x-worker-instance-id" + WorkerIDHeader = "x-worker-instance-id"
+ PrefillWorkerIDHeader = "x-prefiller-host-port" + PrefillWorkerIDHeader = "x-prefiller-host-port"
+ tokenDataAnnotationKey = "dynamo/token-data" + tokenDataAnnotationKey = "dynamo/token-data"
...@@ -685,6 +822,19 @@ index 0000000..6ee6634 ...@@ -685,6 +822,19 @@ index 0000000..6ee6634
+ copy(copied, tokenData) + copy(copied, tokenData)
+ req.Annotations[tokenDataAnnotationKey] = copied + req.Annotations[tokenDataAnnotationKey] = copied
+ } + }
+
+ // GAIE Stage 1: Register request with router bookkeeping
+ // The request ID comes from Envoy's request ID header
+ requestID := req.RequestId
+ if requestID != "" {
+ cycle.Write(StateKeyRequestID, stateString(requestID))
+ if addErr := k.callAddRequest(ctx, requestID, tokenData, workerID, prefillWorkerID); addErr != nil {
+ logger.V(logutil.DEFAULT).Error(addErr, "Failed to add request to router bookkeeping",
+ "requestID", requestID)
+ }
+ } else {
+ logger.V(logutil.VERBOSE).Info("No request ID available, skipping router bookkeeping")
+ }
+ } + }
+ +
+ out := make(map[schedtypes.Pod]float64, len(pods)) + out := make(map[schedtypes.Pod]float64, len(pods))
...@@ -794,6 +944,119 @@ index 0000000..6ee6634 ...@@ -794,6 +944,119 @@ index 0000000..6ee6634
+ return requestBody + return requestBody
+} +}
+ +
+// --------------------------- router bookkeeping ---------------------------
+
+// callAddRequest registers a request with the router's bookkeeping.
+// This should be called after worker selection to track active requests.
+func (k *KVAwareScorer) callAddRequest(
+ ctx context.Context,
+ requestID string,
+ tokenData []int64,
+ workerID string,
+ prefillWorkerID string,
+) error {
+ logger := log.FromContext(ctx)
+
+ if !runtimeInitialized {
+ return fmt.Errorf("dynamo runtime not initialized")
+ }
+
+ pipelineMutex.RLock()
+ currentPipeline := pipeline
+ pipelineMutex.RUnlock()
+
+ if currentPipeline == nil {
+ return fmt.Errorf("dynamo worker selection pipeline not created")
+ }
+
+ // Parse worker ID (use decode worker for bookkeeping in disagg mode)
+ var workerIDUint uint64
+ if _, err := fmt.Sscanf(workerID, "%d", &workerIDUint); err != nil {
+ return fmt.Errorf("invalid worker ID: %s", workerID)
+ }
+
+ // Convert token data from int64 to uint32
+ tokens := make([]uint32, len(tokenData))
+ for i, t := range tokenData {
+ tokens[i] = uint32(t)
+ }
+
+ cRequestID := C.CString(requestID)
+ defer C.free(unsafe.Pointer(cRequestID))
+
+ var cTokens *C.uint32_t
+ if len(tokens) > 0 {
+ cTokens = (*C.uint32_t)(unsafe.Pointer(&tokens[0]))
+ }
+
+ rc := C.dynamo_router_add_request(
+ currentPipeline,
+ cRequestID,
+ cTokens,
+ C.size_t(len(tokens)),
+ C.uint64_t(workerIDUint),
+ C.uint32_t(0), // dp_rank = 0 for now
+ )
+
+ if rc != C.DYNAMO_OK {
+ return fmt.Errorf("dynamo_router_add_request failed")
+ }
+
+ logger.V(logutil.VERBOSE).Info("Added request to router bookkeeping",
+ "requestID", requestID, "workerID", workerID, "tokenCount", len(tokens))
+ return nil
+}
+
+// CallMarkPrefillComplete marks prefill as completed for a request.
+// Exported for use by response handlers.
+func CallMarkPrefillComplete(requestID string) error {
+ if !runtimeInitialized {
+ return fmt.Errorf("dynamo runtime not initialized")
+ }
+
+ pipelineMutex.RLock()
+ currentPipeline := pipeline
+ pipelineMutex.RUnlock()
+
+ if currentPipeline == nil {
+ return fmt.Errorf("dynamo worker selection pipeline not created")
+ }
+
+ cRequestID := C.CString(requestID)
+ defer C.free(unsafe.Pointer(cRequestID))
+
+ rc := C.dynamo_router_mark_prefill_complete(currentPipeline, cRequestID)
+ if rc != C.DYNAMO_OK {
+ return fmt.Errorf("dynamo_router_mark_prefill_complete failed")
+ }
+ return nil
+}
+
+// CallFreeRequest cleans up router state for a completed/cancelled request.
+// Exported for use by response handlers.
+func CallFreeRequest(requestID string) error {
+ if !runtimeInitialized {
+ return fmt.Errorf("dynamo runtime not initialized")
+ }
+
+ pipelineMutex.RLock()
+ currentPipeline := pipeline
+ pipelineMutex.RUnlock()
+
+ if currentPipeline == nil {
+ return fmt.Errorf("dynamo worker selection pipeline not created")
+ }
+
+ cRequestID := C.CString(requestID)
+ defer C.free(unsafe.Pointer(cRequestID))
+
+ rc := C.dynamo_router_free_request(currentPipeline, cRequestID)
+ if rc != C.DYNAMO_OK {
+ return fmt.Errorf("dynamo_router_free_request failed")
+ }
+ return nil
+}
+
+// --------------------------- shutdown --------------------------- +// --------------------------- shutdown ---------------------------
+ +
+func cleanupDynamo() error { +func cleanupDynamo() error {
......
...@@ -379,6 +379,8 @@ pub struct WorkerSelectionPipeline { ...@@ -379,6 +379,8 @@ pub struct WorkerSelectionPipeline {
SingleIn<NvCreateChatCompletionRequest>, SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>, >,
/// KV router for bookkeeping operations (only present when router_mode is KV)
kv_router: Option<Arc<dynamo_llm::kv_router::KvRouter>>,
} }
/// Create a worker-selection pipeline ("generate" endpoint). /// Create a worker-selection pipeline ("generate" endpoint).
...@@ -485,7 +487,7 @@ pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline( ...@@ -485,7 +487,7 @@ pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline(
.await .await
}; };
let engine = match wk.runtime().secondary().block_on(make_engine()) { let (engine, kv_router) = match wk.runtime().secondary().block_on(make_engine()) {
Ok(p) => p, Ok(p) => p,
Err(e) => { Err(e) => {
tracing::error!(error = ?e, "create_worker_selection_pipeline_chat failed"); tracing::error!(error = ?e, "create_worker_selection_pipeline_chat failed");
...@@ -493,7 +495,11 @@ pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline( ...@@ -493,7 +495,11 @@ pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline(
} }
}; };
let handle = Box::new(WorkerSelectionPipeline { wk, engine }); let handle = Box::new(WorkerSelectionPipeline {
wk,
engine,
kv_router,
});
unsafe { unsafe {
*pipeline_out = Box::into_raw(handle); *pipeline_out = Box::into_raw(handle);
} }
...@@ -743,6 +749,248 @@ pub unsafe extern "C" fn dynamo_free_worker_selection_result( ...@@ -743,6 +749,248 @@ pub unsafe extern "C" fn dynamo_free_worker_selection_result(
DynamoLlmResult::OK DynamoLlmResult::OK
} }
/// Default timeout for GAIE bookkeeping operations (30 seconds)
const GAIE_BOOKKEEPING_TIMEOUT_SECS: u64 = 30;
/// Helper to validate pipeline pointer and extract request_id from C string.
/// Returns `Err(DynamoLlmResult::ERR)` on validation failure, `Ok((pipeline_ref, request_id))` on success.
unsafe fn validate_pipeline_and_request_id(
pipeline: *mut WorkerSelectionPipeline,
request_id_c_str: *const c_char,
operation: &str,
) -> Result<(&'static WorkerSelectionPipeline, String), DynamoLlmResult> {
if pipeline.is_null() {
tracing::error!("[GAIE] {} failed: pipeline pointer is null", operation);
return Err(DynamoLlmResult::ERR);
}
let request_id = match unsafe { CStr::from_ptr(request_id_c_str) }.to_str() {
Ok(s) => s.to_owned(),
Err(e) => {
tracing::error!(error = ?e, "[GAIE] {} failed: bad request_id", operation);
return Err(DynamoLlmResult::ERR);
}
};
// SAFETY: Caller guarantees pipeline is valid for the duration of the call
let pl: &'static WorkerSelectionPipeline = unsafe { &*pipeline };
Ok((pl, request_id))
}
/// Helper to run an async bookkeeping operation with timeout.
/// Returns `OK` on success or timeout, `ERR` only on validation failures (handled by caller).
fn run_bookkeeping_with_timeout<F, Fut>(
pl: &WorkerSelectionPipeline,
operation: &'static str,
request_id: &str,
f: F,
) -> DynamoLlmResult
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = ()>,
{
use std::time::Duration;
let timeout_duration = Duration::from_secs(GAIE_BOOKKEEPING_TIMEOUT_SECS);
let fut = f();
let result = pl
.wk
.runtime()
.secondary()
.block_on(async { tokio::time::timeout(timeout_duration, fut).await });
match result {
Ok(()) => DynamoLlmResult::OK,
Err(_elapsed) => {
tracing::warn!(
request_id = %request_id,
timeout_secs = GAIE_BOOKKEEPING_TIMEOUT_SECS,
"[GAIE] {} timed out",
operation
);
// Return OK to avoid blocking the caller - the operation may still complete
DynamoLlmResult::OK
}
}
}
/// Router bookkeeping functions for GAIE integration
/// Add a request to the router's bookkeeping after worker selection.
/// Call this from GAIE Stage 1 after `dynamo_query_worker_selection_and_annotate`.
///
/// This function computes the overlap_blocks internally by querying the indexer,
/// so the caller doesn't need to provide it.
///
/// # Safety
/// - `pipeline` must be a valid, non-null pointer from `dynamo_create_worker_selection_pipeline`
/// - `request_id_c_str` must be a valid NUL-terminated UTF-8 C string
/// - `token_ids` must point to at least `token_count` valid u32 values
/// - Must not be called concurrently on the same pipeline without synchronization
#[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_router_add_request(
pipeline: *mut WorkerSelectionPipeline,
request_id_c_str: *const c_char,
token_ids: *const u32,
token_count: usize,
worker_id: u64,
dp_rank: u32,
) -> DynamoLlmResult {
let (pl, request_id) = match unsafe {
validate_pipeline_and_request_id(pipeline, request_id_c_str, "add_request")
} {
Ok(v) => v,
Err(e) => return e,
};
let Some(ref kv_router) = pl.kv_router else {
tracing::debug!(
"[GAIE] KV router not available (router_mode is not KV), skipping add_request (no-op)"
);
return DynamoLlmResult::OK;
};
// Log after kv_router check to reduce noise
tracing::debug!(
request_id = %request_id,
worker_id = worker_id,
dp_rank = dp_rank,
token_count = token_count,
"[GAIE] dynamo_router_add_request processing"
);
let tokens: Vec<u32> = if token_count > 0 && !token_ids.is_null() {
unsafe { std::slice::from_raw_parts(token_ids, token_count) }.to_vec()
} else {
Vec::new()
};
let kv_router = kv_router.clone();
let request_id_clone = request_id.clone();
run_bookkeeping_with_timeout(pl, "add_request", &request_id, || async move {
let worker = dynamo_llm::kv_router::protocols::WorkerWithDpRank::new(worker_id, dp_rank);
// Compute overlap_blocks using the public method
let overlap_blocks = match kv_router.get_overlap_blocks(&tokens, worker).await {
Ok(overlap) => overlap,
Err(e) => {
tracing::warn!(error = ?e, "Failed to compute overlap, using 0");
0
}
};
kv_router
.add_request(request_id_clone.clone(), &tokens, overlap_blocks, worker)
.await;
tracing::debug!(
request_id = %request_id_clone,
worker_id = worker_id,
dp_rank = dp_rank,
overlap_blocks = overlap_blocks,
token_count = tokens.len(),
"[GAIE] dynamo_router_add_request completed - request registered in router bookkeeping"
);
})
}
/// Mark prefill as completed for a request.
/// Call this from GAIE hook when the first token is generated.
///
/// # Safety
/// - `pipeline` must be a valid, non-null pointer from `dynamo_create_worker_selection_pipeline`
/// - `request_id_c_str` must be a valid NUL-terminated UTF-8 C string
#[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_router_mark_prefill_complete(
pipeline: *mut WorkerSelectionPipeline,
request_id_c_str: *const c_char,
) -> DynamoLlmResult {
let (pl, request_id) = match unsafe {
validate_pipeline_and_request_id(pipeline, request_id_c_str, "mark_prefill_complete")
} {
Ok(v) => v,
Err(e) => return e,
};
let Some(ref kv_router) = pl.kv_router else {
tracing::debug!(
"[GAIE] KV router not available (router_mode is not KV), skipping mark_prefill_complete (no-op)"
);
return DynamoLlmResult::OK;
};
// Log after kv_router check to reduce noise
tracing::debug!(
request_id = %request_id,
"[GAIE] dynamo_router_mark_prefill_complete processing"
);
let kv_router = kv_router.clone();
let request_id_clone = request_id.clone();
run_bookkeeping_with_timeout(pl, "mark_prefill_complete", &request_id, || async move {
if let Err(e) = kv_router.mark_prefill_completed(&request_id_clone).await {
tracing::warn!(
"Failed to mark prefill completed for {}: {}",
request_id_clone,
e
);
} else {
tracing::debug!(
request_id = %request_id_clone,
"[GAIE] dynamo_router_mark_prefill_complete completed - prefill tokens released"
);
}
})
}
/// Free a request from the router's bookkeeping.
/// Call this from GAIE hook when the stream is closed (completed or cancelled).
///
/// # Safety
/// - `pipeline` must be a valid, non-null pointer from `dynamo_create_worker_selection_pipeline`
/// - `request_id_c_str` must be a valid NUL-terminated UTF-8 C string
#[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_router_free_request(
pipeline: *mut WorkerSelectionPipeline,
request_id_c_str: *const c_char,
) -> DynamoLlmResult {
let (pl, request_id) = match unsafe {
validate_pipeline_and_request_id(pipeline, request_id_c_str, "free_request")
} {
Ok(v) => v,
Err(e) => return e,
};
let Some(ref kv_router) = pl.kv_router else {
tracing::debug!(
"[GAIE] KV router not available (router_mode is not KV), skipping free_request (no-op)"
);
return DynamoLlmResult::OK;
};
// Log after kv_router check to reduce noise
tracing::debug!(
request_id = %request_id,
"[GAIE] dynamo_router_free_request processing"
);
let kv_router = kv_router.clone();
let request_id_clone = request_id.clone();
run_bookkeeping_with_timeout(pl, "free_request", &request_id, || async move {
if let Err(e) = kv_router.free(&request_id_clone).await {
tracing::warn!("Failed to free request {}: {}", request_id_clone, e);
} else {
tracing::debug!(
request_id = %request_id_clone,
"[GAIE] dynamo_router_free_request completed - request removed from bookkeeping"
);
}
})
}
/// Result of worker selection extraction /// Result of worker selection extraction
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct WorkerSelectionResult { pub struct WorkerSelectionResult {
...@@ -1081,7 +1329,7 @@ fn spawn_prefill_watcher( ...@@ -1081,7 +1329,7 @@ fn spawn_prefill_watcher(
/// - `enforce_disagg`: If true, fail requests when disaggregated serving is unavailable /// - `enforce_disagg`: If true, fail requests when disaggregated serving is unavailable
/// ///
/// # Returns /// # Returns
/// A configured worker selection pipeline ready to use /// A tuple of (engine, kv_router) where kv_router is Some when router_mode is KV
pub async fn create_worker_selection_pipeline_chat( pub async fn create_worker_selection_pipeline_chat(
namespace: &str, namespace: &str,
component_name: &str, component_name: &str,
...@@ -1090,12 +1338,13 @@ pub async fn create_worker_selection_pipeline_chat( ...@@ -1090,12 +1338,13 @@ pub async fn create_worker_selection_pipeline_chat(
busy_threshold: Option<f64>, busy_threshold: Option<f64>,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
enforce_disagg: bool, enforce_disagg: bool,
) -> anyhow::Result< ) -> anyhow::Result<(
ServiceEngine< ServiceEngine<
SingleIn<NvCreateChatCompletionRequest>, SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>, >,
> { Option<Arc<dynamo_llm::kv_router::KvRouter>>,
)> {
use dynamo_llm::kv_router::PrefillRouter; use dynamo_llm::kv_router::PrefillRouter;
let runtime = Runtime::from_settings()?; let runtime = Runtime::from_settings()?;
...@@ -1217,6 +1466,9 @@ pub async fn create_worker_selection_pipeline_chat( ...@@ -1217,6 +1466,9 @@ pub async fn create_worker_selection_pipeline_chat(
// Note: C bindings don't register with ModelManager, so HTTP endpoint won't see this // Note: C bindings don't register with ModelManager, so HTTP endpoint won't see this
let worker_monitor = busy_threshold.map(|t| KvWorkerMonitor::new(client.clone(), t, 1000000)); let worker_monitor = busy_threshold.map(|t| KvWorkerMonitor::new(client.clone(), t, 1000000));
// Clone chooser before passing to build_routed_pipeline (which takes ownership)
let kv_router = chooser.clone();
let engine = build_routed_pipeline::< let engine = build_routed_pipeline::<
NvCreateChatCompletionRequest, NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse, NvCreateChatCompletionStreamResponse,
...@@ -1233,5 +1485,5 @@ pub async fn create_worker_selection_pipeline_chat( ...@@ -1233,5 +1485,5 @@ pub async fn create_worker_selection_pipeline_chat(
) )
.await?; .await?;
Ok(engine) Ok((engine, kv_router))
} }
...@@ -561,6 +561,18 @@ impl KvRouter { ...@@ -561,6 +561,18 @@ impl KvRouter {
self.block_size self.block_size
} }
/// Compute the overlap blocks for a given token sequence and worker.
/// This queries the indexer to find how many blocks are already cached.
pub async fn get_overlap_blocks(
&self,
tokens: &[u32],
worker: WorkerWithDpRank,
) -> Result<u32, KvRouterError> {
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
let overlap_scores = self.indexer.find_matches(block_hashes).await?;
Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0))
}
/// Get the disaggregated endpoint for a worker, if available. /// Get the disaggregated endpoint for a worker, if available.
/// Used to look up bootstrap host/port for prefill workers. /// Used to look up bootstrap host/port for prefill workers.
pub async fn get_disaggregated_endpoint( pub async fn get_disaggregated_endpoint(
...@@ -691,6 +703,13 @@ pub struct KvPushRouter { ...@@ -691,6 +703,13 @@ pub struct KvPushRouter {
pub chooser: Arc<KvRouter>, pub chooser: Arc<KvRouter>,
} }
/// Result of worker selection containing instance ID, dp_rank, and overlap amount.
struct WorkerSelection {
instance_id: u64,
dp_rank: u32,
overlap_amount: u32,
}
impl KvPushRouter { impl KvPushRouter {
pub fn new( pub fn new(
inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>, inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
...@@ -698,6 +717,91 @@ impl KvPushRouter { ...@@ -698,6 +717,91 @@ impl KvPushRouter {
) -> Self { ) -> Self {
KvPushRouter { inner, chooser } KvPushRouter { inner, chooser }
} }
/// Select a worker for the request, either using a preselected worker or finding the best match.
///
/// When `is_query_only` is false and `handle_local_updates` is true, this also registers
/// the request with the scheduler via `add_request`.
async fn select_worker(
&self,
context_id: &str,
request: &PreprocessedRequest,
phase: RequestPhase,
is_query_only: bool,
handle_local_updates: bool,
) -> Result<WorkerSelection, Error> {
let routing = request.routing.as_ref();
// Get pre-selected worker based on phase, with backend_instance_id as fallback
let Some(id) = (match phase {
RequestPhase::Prefill => {
routing.and_then(|r| r.prefill_worker_id.or(r.backend_instance_id))
}
RequestPhase::Decode => {
routing.and_then(|r| r.decode_worker_id.or(r.backend_instance_id))
}
RequestPhase::Aggregated => routing.and_then(|r| r.backend_instance_id),
}) else {
// No preselected worker - find the best match
// Don't update states if this is a query-only request
let (best_worker, overlap_amount) = self
.chooser
.find_best_match(
Some(context_id),
&request.token_ids,
request.router_config_override.as_ref(),
!is_query_only,
)
.await?;
return Ok(WorkerSelection {
instance_id: best_worker.worker_id,
dp_rank: best_worker.dp_rank,
overlap_amount,
});
};
// Route to pre-selected or explicitly specified worker
let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0);
tracing::debug!(
worker_id = id,
dp_rank = dp_rank,
?phase,
"Routing to specified worker"
);
// Compute actual overlap blocks by querying the indexer
let worker = WorkerWithDpRank::new(id, dp_rank);
let overlap_blocks = self
.chooser
.get_overlap_blocks(&request.token_ids, worker)
.await?;
// Perform add_request if this router handles local updates
if !is_query_only && handle_local_updates {
self.chooser
.add_request(
context_id.to_string(),
&request.token_ids,
overlap_blocks,
worker,
)
.await;
} else {
tracing::debug!(
request_id = %context_id,
worker_id = id,
dp_rank = dp_rank,
"Skipping add_request - query or handled externally"
);
}
Ok(WorkerSelection {
instance_id: id,
dp_rank,
overlap_amount: overlap_blocks,
})
}
} }
#[async_trait] #[async_trait]
...@@ -733,6 +837,18 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -733,6 +837,18 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
// Simple query-only detection: presence of query_instance_id annotation means query-only mode // Simple query-only detection: presence of query_instance_id annotation means query-only mode
let is_query_only = request.get_annotation_value("query_instance_id").is_some(); let is_query_only = request.get_annotation_value("query_instance_id").is_some();
// Determine if this router should handle local state updates (add_request, free, etc.)
// When routing hints are present, the external caller handles state tracking
// via separate API calls, so we skip local updates here.
let routing = request.routing.as_ref();
let handle_local_updates = routing
.map(|r| {
// No routing hints = we handle updates locally
r.backend_instance_id.is_none()
&& (r.prefill_worker_id.is_none() || r.decode_worker_id.is_none())
})
.unwrap_or(true);
// Get phase from tracker (defaults to Aggregated if no tracker or phase not set) // Get phase from tracker (defaults to Aggregated if no tracker or phase not set)
let phase = request let phase = request
.tracker .tracker
...@@ -740,61 +856,21 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -740,61 +856,21 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
.map(|t| t.phase()) .map(|t| t.phase())
.unwrap_or(RequestPhase::Aggregated); .unwrap_or(RequestPhase::Aggregated);
// Get pre-selected worker based on phase, with backend_instance_id as fallback
let routing = request.routing.as_ref();
let preselected = match phase {
RequestPhase::Prefill => {
routing.and_then(|r| r.prefill_worker_id.or(r.backend_instance_id))
}
RequestPhase::Decode => {
routing.and_then(|r| r.decode_worker_id.or(r.backend_instance_id))
}
RequestPhase::Aggregated => routing.and_then(|r| r.backend_instance_id),
};
let block_size = self.chooser.block_size() as usize; let block_size = self.chooser.block_size() as usize;
let (instance_id, dp_rank, overlap_amount) = if let Some(id) = preselected { let selection = self
// Route to pre-selected or explicitly specified worker .select_worker(
let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0); &context_id,
tracing::debug!( &request,
worker_id = id, phase,
dp_rank = dp_rank, is_query_only,
?phase, handle_local_updates,
"Routing to specified worker" )
); .await?;
let WorkerSelection {
// Compute actual overlap blocks by querying the indexer instance_id,
let block_hashes = dp_rank,
compute_block_hash_for_seq(&request.token_ids, self.chooser.block_size(), None); overlap_amount,
let overlap_scores = self.chooser.indexer.find_matches(block_hashes).await?; } = selection;
let worker = WorkerWithDpRank::new(id, dp_rank);
let overlap_blocks = overlap_scores.scores.get(&worker).copied().unwrap_or(0);
if !is_query_only {
self.chooser
.add_request(
context_id.clone(),
&request.token_ids,
overlap_blocks,
worker,
)
.await;
}
(id, dp_rank, overlap_blocks)
} else {
// Find the best worker match
// Don't update states if this is a query-only request
let (best_worker, overlap_amount) = self
.chooser
.find_best_match(
Some(&context_id),
&request.token_ids,
request.router_config_override.as_ref(),
!is_query_only,
)
.await?;
(best_worker.worker_id, best_worker.dp_rank, overlap_amount)
};
// Record metrics in tracker: KV hit rate and worker ID based on phase // Record metrics in tracker: KV hit rate and worker ID based on phase
if let Some(ref tracker) = request.tracker { if let Some(ref tracker) = request.tracker {
...@@ -834,11 +910,14 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -834,11 +910,14 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
backend_input.routing_mut().dp_rank = Some(dp_rank); backend_input.routing_mut().dp_rank = Some(dp_rank);
let updated_request = context.map(|_| backend_input); let updated_request = context.map(|_| backend_input);
let chooser = self.chooser.clone();
let mut response_stream = self.inner.direct(updated_request, instance_id).await?; let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
let stream_context = response_stream.context(); let stream_context = response_stream.context();
let chooser = self.chooser.clone();
let context_for_monitoring = stream_context.clone(); let context_for_monitoring = stream_context.clone();
// TODO: When handle_local_updates=false, consider moving mark_prefill_completed
// to an external caller (e.g., sidecar) if they support a first-token hook.
// Currently mark_prefill_completed is called here for all flows.
let wrapped_stream = Box::pin(async_stream::stream! { let wrapped_stream = Box::pin(async_stream::stream! {
let mut prefill_marked = false; let mut prefill_marked = false;
...@@ -868,6 +947,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -868,6 +947,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
} }
} }
// Always call free() - it's idempotent and safe even if already freed or never added
if let Err(e) = chooser.free(&context_id).await { if let Err(e) = chooser.free(&context_id).await {
tracing::warn!("Failed to free request {context_id}: {e}"); tracing::warn!("Failed to free request {context_id}: {e}");
} }
......
...@@ -36,9 +36,8 @@ data: ...@@ -36,9 +36,8 @@ data:
parameters: {} parameters: {}
- name: dyn-kv - name: dyn-kv
type: kv-aware-scorer type: kv-aware-scorer
parameters: - name: dyn-cleanup
frontendURL: http://127.0.0.1:8000/v1/chat/completions type: dynamo-cleanup
timeoutMS: 10000
schedulingProfiles: schedulingProfiles:
- name: default - name: default
plugins: plugins:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment