Unverified Commit 52271760 authored by atchernych's avatar atchernych Committed by GitHub
Browse files

feat: Decomposed pipeline for EPP integration [DEP-730] (#5446)


Signed-off-by: default avatarAnna Tchernych <atchernych@nvidia.com>
parent 16a28058
...@@ -167,7 +167,7 @@ class FrontendArgGroup(ArgGroup): ...@@ -167,7 +167,7 @@ class FrontendArgGroup(ArgGroup):
env_var="DYN_ROUTER_MODE", env_var="DYN_ROUTER_MODE",
default="round-robin", default="round-robin",
help="How to route the request.", help="How to route the request.",
choices=["round-robin", "random", "kv"], choices=["round-robin", "random", "kv", "direct"],
) )
add_argument( add_argument(
g, g,
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
# - OpenAI HTTP server. # - OpenAI HTTP server.
# - Auto-discovery: Watches etcd for engine/worker registration (via `register_llm`). # - Auto-discovery: Watches etcd for engine/worker registration (via `register_llm`).
# - Pre-processor: Prompt templating and tokenization. # - Pre-processor: Prompt templating and tokenization.
# - Router, defaulting to round-robin. Use --router-mode to switch (round-robin, random, kv). # - Router, defaulting to round-robin. Use --router-mode to switch (round-robin, random, kv, direct).
# #
# Pass `--interactive` or `-i` for text chat instead of HTTP server. # Pass `--interactive` or `-i` for text chat instead of HTTP server.
# #
...@@ -197,6 +197,9 @@ async def async_main(): ...@@ -197,6 +197,9 @@ async def async_main():
elif config.router_mode == "random": elif config.router_mode == "random":
router_mode = RouterMode.Random router_mode = RouterMode.Random
kv_router_config = None kv_router_config = None
elif config.router_mode == "direct":
router_mode = RouterMode.Direct
kv_router_config = None
else: else:
router_mode = RouterMode.RoundRobin router_mode = RouterMode.RoundRobin
kv_router_config = None kv_router_config = None
......
...@@ -63,26 +63,6 @@ kubectl get gateway inference-gateway ...@@ -63,26 +63,6 @@ kubectl get gateway inference-gateway
### 3. Setup secrets ### ### 3. Setup secrets ###
Follow the steps in [model deployment](../../examples/backends/vllm/deploy/README.md) to deploy `Qwen/Qwen3-0.6B` model in aggregate mode using [agg.yaml](../../examples/backends/vllm/deploy/agg.yaml) in `my-model` kubernetes namespace.
Make sure to enable kv-routing by adding the env var in the FrontEnd.
```bash
mainContainer:
image: ...
env:
- name: DYN_ROUTER_MODE
value: "kv"
```
Sample commands to deploy model:
```bash
cd <dynamo-source-root>
cd examples/backends/vllm/deploy
kubectl apply -f agg.yaml -n my-model
```
Take a note of or change the DYNAMO_IMAGE in the model deployment file.
Do not forget docker registry secret if needed. Do not forget docker registry secret if needed.
```bash ```bash
...@@ -93,7 +73,7 @@ kubectl create secret docker-registry docker-imagepullsecret \ ...@@ -93,7 +73,7 @@ kubectl create secret docker-registry docker-imagepullsecret \
--namespace=$NAMESPACE --namespace=$NAMESPACE
``` ```
Do not forget to include the HuggingFace token if required. Do not forget to include the HuggingFace token.
```bash ```bash
export HF_TOKEN=your_hf_token export HF_TOKEN=your_hf_token
...@@ -139,13 +119,34 @@ make info # Check image tag ...@@ -139,13 +119,34 @@ make info # Check image tag
We recommend deploying Inference Gateway's Endpoint Picker as a Dynamo operator's managed component. Alternatively, We recommend deploying Inference Gateway's Endpoint Picker as a Dynamo operator's managed component. Alternatively,
you could deploy it as a standalone pod you could deploy it as a standalone pod
#### 5.a. Deploy as a DGD component #### 5.a. Deploy as a DGD component (recommended)
We provide an example for llama-3-70b vLLM below.
```bash ```bash
kubectl apply -f operator-managed/examples/agg.yaml -n ${NAMESPACE} # Deploy PVC, first Update `storageClassName` in recipes/llama-3-70b/model-cache/model-cache.yaml to match your cluster before deploying
kubectl apply -f operator-managed/examples/http-route.yaml -n ${NAMESPACE} kubectl apply -f recipes/llama-3-70b/model-cache/model-cache.yaml
kubectl apply -f recipes/llama-3-70b/model-cache/model-download.yaml
# Deploy your model
kubectl apply -f recipes/llama-3-70b/vllm/agg/gaie/deploy.yaml -n ${NAMESPACE}
# Deploy the GAIE http-route CR.
kubectl apply -f recipes/llama-3-70b/vllm/agg/gaie/http-route.yaml -n ${NAMESPACE}
``` ```
- When using GAIE the FrontEnd does not choose the workers. The routing is determined in the EPP.
- You must enable the flag in the FrontEnd cli as below.
```bash
command:
- python3
args:
- -m
- dynamo.frontend
- --router-mode
- direct
```
- The pre-selected worker (decode and prefill in case of the disaggregated serving) are passed in the request headers.
- The flag assures the routing respects this selection.
**Startup Probe Timeout:** The EPP has a default startup probe timeout of 30 minutes (10s × 180 failures). **Startup Probe Timeout:** The EPP has a default startup probe timeout of 30 minutes (10s × 180 failures).
If your model takes longer to load, increase the `failureThreshold` in the EPP's `startupProbe`. For example, If your model takes longer to load, increase the `failureThreshold` in the EPP's `startupProbe`. For example,
to allow 60 minutes for startup: to allow 60 minutes for startup:
...@@ -166,6 +167,18 @@ If you installed it into a different namespace, you need to adjust the HttpRoute ...@@ -166,6 +167,18 @@ If you installed it into a different namespace, you need to adjust the HttpRoute
##### 5.b.1 Deploy Your Model ### ##### 5.b.1 Deploy Your Model ###
We provide an example for Qwen vLLM below.
Before deploying you must enable the `--direct-route` flag in the FrontEnd cli in your Dynamo Graph.
```bash
command:
- python3
args:
- -m
- dynamo.frontend
- --router-mode
- direct
```
Follow the steps in [model deployment](../../examples/backends/vllm/deploy/README.md) to deploy `Qwen/Qwen3-0.6B` model in aggregate mode using [agg.yaml](../../examples/backends/vllm/deploy/agg.yaml) in `my-model` kubernetes namespace. Follow the steps in [model deployment](../../examples/backends/vllm/deploy/README.md) to deploy `Qwen/Qwen3-0.6B` model in aggregate mode using [agg.yaml](../../examples/backends/vllm/deploy/agg.yaml) in `my-model` kubernetes namespace.
Sample commands to deploy model: Sample commands to deploy model:
...@@ -176,10 +189,6 @@ cd examples/backends/vllm/deploy ...@@ -176,10 +189,6 @@ cd examples/backends/vllm/deploy
kubectl apply -f agg.yaml -n my-model kubectl apply -f agg.yaml -n my-model
``` ```
Take a note of or change the DYNAMO_IMAGE in the model deployment file.
Do not forget docker registry secret if needed.
##### 5.b.2 Install Dynamo GIE helm chart ### ##### 5.b.2 Install Dynamo GIE helm chart ###
```bash ```bash
...@@ -214,14 +223,14 @@ You can configure the plugin by setting environment variables in the EPP compone ...@@ -214,14 +223,14 @@ You can configure the plugin by setting environment variables in the EPP compone
Common Vars for Routing Configuration: Common Vars for Routing Configuration:
- Set `DYN_BUSY_THRESHOLD` to configure the upper bound on how "full" a worker can be (often derived from kv_active_blocks or other load metrics) before the router skips it. If the selected worker exceeds this value, routing falls back to the next best candidate. By default the value is negative meaning this is not enabled. - Set `DYN_BUSY_THRESHOLD` to configure the upper bound on how "full" a worker can be (often derived from kv_active_blocks or other load metrics) before the router skips it. If the selected worker exceeds this value, routing falls back to the next best candidate. By default the value is negative meaning this is not enabled.
- Set `DYN_ENFORCE_DISAGG=true` if you want to enforce every request being served in the disaggregated manner. By default it is false meaning if the the prefill worker is not available the request will be served in the aggregated manner. - Set `DYN_ENFORCE_DISAGG=true` if you want to enforce every request being served in the disaggregated manner. By default it is false meaning if the the prefill worker is not available the request will be served in the aggregated manner.
- By default the Dynamo plugin uses KV routing. You can expose `DYN_USE_KV_ROUTING=false` in your [values.yaml](standalone/helm/dynamo-gaie/values.yaml) if you prefer to route in the round-robin fashion. - Set `DYN_OVERLAP_SCORE_WEIGHT` to weigh how heavily the score uses token overlap (predicted KV cache hits) versus other factors (load, historical hit rate). Higher weight biases toward reusing workers with similar cached prefixes. (default: 1)
- If using kv-routing: - Set `DYN_ROUTER_TEMPERATURE` to soften or sharpen the selection curve when combining scores. Low temperature makes the router pick the top candidate deterministically; higher temperature lets lower-scoring workers through more often (exploration).
- Overwrite the `DYN_KV_BLOCK_SIZE` in your [values-dynamo-epp.yaml](./values-dynamo-epp.yaml) to match your model's block size.The `DYN_KV_BLOCK_SIZE` env var is ***MANDATORY*** to prevent silent KV routing failures. - Set `DYN_USE_KV_EVENTS=false` if you want to disable the workers sending KV events while using kv-routing (default: true)
- Set `DYNAMO_OVERLAP_SCORE_WEIGHT` to weigh how heavily the score uses token overlap (predicted KV cache hits) versus other factors (load, historical hit rate). Higher weight biases toward reusing workers with similar cached prefixes. - `DYN_ROUTER_TEMPERATURE` — Temperature for worker sampling via softmax (default: 0.0)
- Set `DYNAMO_ROUTER_TEMPERATURE` to soften or sharpen the selection curve when combining scores. Low temperature makes the router pick the top candidate deterministically; higher temperature lets lower-scoring workers through more often (exploration). - `DYN_ROUTER_REPLICA_SYNC` — Enable replica synchronization (default: false)
- Set `DYNAMO_USE_KV_EVENTS=false` if you want to disable the workers sending KV events while using kv-routing - `DYN_ROUTER_TRACK_ACTIVE_BLOCKS` — Track active blocks (default: true)
- See the [Router Guide](../../docs/pages/components/router/router-guide.md) for details. - `DYN_ROUTER_TRACK_OUTPUT_BLOCKS` — Track output blocks during generation (default: false)
- See the [KV cache routing design](../../docs/pages/design-docs/router-design.md) for details.
Stand-Alone installation only: Stand-Alone installation only:
- Overwrite the `DYN_NAMESPACE` env var if needed to match your model's dynamo namespace. - Overwrite the `DYN_NAMESPACE` env var if needed to match your model's dynamo namespace.
...@@ -272,7 +281,7 @@ b. use port-forward to expose the gateway to the host ...@@ -272,7 +281,7 @@ b. use port-forward to expose the gateway to the host
```bash ```bash
# in first terminal # in first terminal
kubectl port-forward svc/inference-gateway 8000:80 -n my-model kubectl port-forward svc/inference-gateway 8000:80 -n kgateway-system
# in second terminal where you want to send inference requests # in second terminal where you want to send inference requests
GATEWAY_URL=http://localhost:8000 GATEWAY_URL=http://localhost:8000
...@@ -359,6 +368,14 @@ Sample inference output: ...@@ -359,6 +368,14 @@ Sample inference output:
} }
``` ```
***If you have more than one HttpRoute running on the cluster***
Add the host to your HttpRoute.yaml and add the header `curl -H "Host: llama3-70b-agg.example.com" ...` to every request.
```bash
spec:
hostnames:
- llama3-70b-agg.example.com
```
### 8. Deleting the installation ### ### 8. Deleting the installation ###
If you need to uninstall run: If you need to uninstall run:
...@@ -407,4 +424,4 @@ The plugins set HTTP headers that are forwarded to the backend workers. ...@@ -407,4 +424,4 @@ The plugins set HTTP headers that are forwarded to the backend workers.
| Header | Description | Set By | | Header | Description | Set By |
|--------|-------------|--------| |--------|-------------|--------|
| `x-worker-instance-id` | Primary worker ID (decode worker in disagg mode) | kv-aware-scorer | | `x-worker-instance-id` | Primary worker ID (decode worker in disagg mode) | kv-aware-scorer |
| `x-prefill-instance-id` | Prefill worker ID (disaggregated mode only) | kv-aware-scorer | | `x-prefill-instance-id` | Prefill worker ID (disaggregated mode only) | kv-aware-scorer |
\ No newline at end of file
...@@ -26,74 +26,57 @@ package dynamo_kv_scorer ...@@ -26,74 +26,57 @@ package dynamo_kv_scorer
#include <stdlib.h> // for free #include <stdlib.h> // for free
#include <stdbool.h> #include <stdbool.h>
// enum underlying type is uint32_t; matches cbindgen output // Query router result codes (matches QueryRouterResult in Rust)
typedef uint32_t dynamo_llm_result_t; typedef uint32_t query_router_result_t;
enum { DYNAMO_OK = 0, DYNAMO_ERR = 1 }; enum {
QUERY_ROUTER_OK = 0,
// opaque handle forward-decl QUERY_ROUTER_ERR_INVALID_HANDLE = 1,
struct WorkerSelectionPipeline; QUERY_ROUTER_ERR_INVALID_PARAM = 2,
typedef struct WorkerSelectionPipeline WorkerSelectionPipeline; QUERY_ROUTER_ERR_INIT_FAILED = 3,
QUERY_ROUTER_ERR_QUERY_FAILED = 4,
// Prototypes (C-compatible) QUERY_ROUTER_ERR_DISAGG_ENFORCED = 5,
dynamo_llm_result_t dynamo_llm_init(const char *namespace_c_str, QUERY_ROUTER_ERR_TIMEOUT = 6,
const char *component_c_str, };
int64_t worker_id,
uint32_t kv_block_size); // opaque handle forward-decl for Router bindings
struct RouterHandles;
dynamo_llm_result_t dynamo_llm_shutdown(void); typedef struct RouterHandles RouterHandles;
dynamo_llm_result_t dynamo_llm_load_publisher_create(void);
// Routing result from route_chat_request
dynamo_llm_result_t dynamo_kv_event_publish_stored(uint64_t event_id, typedef struct {
const uint32_t *token_ids, bool is_disaggregated;
const uintptr_t *num_block_tokens, uint64_t prefill_worker_id;
const uint64_t *block_ids, uint64_t decode_worker_id;
size_t num_blocks, uint32_t *token_ids;
const uint64_t *parent_hash, size_t token_count;
uint64_t lora_id); } CRoutingResult;
dynamo_llm_result_t dynamo_kv_event_publish_removed(uint64_t event_id, // Router bindings API (replaces Pipeline API)
const uint64_t *block_ids, query_router_result_t create_routers(const char *namespace_c_str,
size_t num_blocks); const char *component_c_str,
bool enforce_disagg,
dynamo_llm_result_t dynamo_create_worker_selection_pipeline(const char *namespace_c_str, RouterHandles **out_handle);
const char *component_c_str,
const char *model_name_c_str, query_router_result_t route_request(RouterHandles *handle,
bool use_kv_routing, const char *request_json,
double busy_threshold, CRoutingResult *out_result);
double overlap_score_weight,
double router_temperature, query_router_result_t add_request(RouterHandles *handle,
bool use_kv_events, const char *request_id,
bool router_replica_sync, const uint32_t *token_ids,
bool enforce_disagg, size_t token_count,
WorkerSelectionPipeline **pipeline_out); uint64_t worker_id,
uint32_t dp_rank);
dynamo_llm_result_t dynamo_destroy_worker_selection_pipeline(WorkerSelectionPipeline *pipeline);
query_router_result_t mark_prefill_complete(RouterHandles *handle,
dynamo_llm_result_t dynamo_query_worker_selection_and_annotate(WorkerSelectionPipeline *pipeline, const char *request_id);
const char *request_json_c_str,
int64_t *decode_worker_id_out, query_router_result_t free_request(RouterHandles *handle,
int64_t *prefill_worker_id_out, const char *request_id);
uint32_t **token_ids_out,
size_t *token_count_out, void free_routing_result(CRoutingResult *result);
char **annotated_request_json_out);
void destroy(RouterHandles *handle);
dynamo_llm_result_t dynamo_free_worker_selection_result(uint32_t *token_ids,
size_t token_count,
char *annotated_request_json);
// Router bookkeeping functions for GAIE integration
dynamo_llm_result_t dynamo_router_add_request(WorkerSelectionPipeline *pipeline,
const char *request_id_c_str,
const uint32_t *token_ids,
size_t token_count,
uint64_t worker_id,
uint32_t dp_rank);
dynamo_llm_result_t dynamo_router_mark_prefill_complete(WorkerSelectionPipeline *pipeline,
const char *request_id_c_str);
dynamo_llm_result_t dynamo_router_free_request(WorkerSelectionPipeline *pipeline,
const char *request_id_c_str);
*/ */
import "C" import "C"
...@@ -121,9 +104,6 @@ const ( ...@@ -121,9 +104,6 @@ const (
WorkerIDHeader = "x-worker-instance-id" WorkerIDHeader = "x-worker-instance-id"
PrefillWorkerIDHeader = "x-prefill-instance-id" PrefillWorkerIDHeader = "x-prefill-instance-id"
RoutingModeHeader = "x-dynamo-routing-mode" RoutingModeHeader = "x-dynamo-routing-mode"
// EnableLocalUpdatesHeader controls router bookkeeping in the Dynamo frontend.
// Set to "false" for GAIE Stage 2 so the EPP handles bookkeeping via C FFI.
EnableLocalUpdatesHeader = "x-enable-local-updates"
// stateKey is the key used to store routing state in PluginState // stateKey is the key used to store routing state in PluginState
stateKey = "dynamo-routing-state" stateKey = "dynamo-routing-state"
...@@ -141,9 +121,8 @@ type params struct{} ...@@ -141,9 +121,8 @@ type params struct{}
type DynamoRoutingState struct { type DynamoRoutingState struct {
WorkerID string WorkerID string
PrefillWorkerID string PrefillWorkerID string
// TokenData holds the token IDs from the router. // TokenData holds the token IDs from the router, needed for add_request bookkeeping.
// Currently unused but stored for future implementation where tokens // These tokens are used to compute overlap blocks and track active blocks accurately.
// may be passed to the worker via request body instead of headers.
TokenData []int64 TokenData []int64
} }
...@@ -214,48 +193,24 @@ var ( ...@@ -214,48 +193,24 @@ var (
ffiOnce sync.Once ffiOnce sync.Once
ffiErr error ffiErr error
ffiNamespace string ffiNamespace string
ffiComponent string ffiComponent string
ffiModel string ffiEnforceDisagg bool
ffiOverlapScoreWeight float64
ffiRouterTemperature float64
ffiKvBlockSize uint32
ffiWorkerID int64
ffiEnforceDisagg bool
runtimeInitialized bool routerInitialized bool
// Boxed pipeline handle (owned on the Rust side, opaque here) // Router handles (owned on the Rust side, opaque here)
pipeline *C.struct_WorkerSelectionPipeline routerHandles *C.struct_RouterHandles
pipelineMutex sync.RWMutex routerHandlesMutex sync.RWMutex
) )
func loadDynamoConfig() { func loadDynamoConfig() {
ffiNamespace = getEnvOrDefault("DYN_NAMESPACE", "vllm-agg") ffiNamespace = getEnvOrDefault("DYN_NAMESPACE", "vllm-agg")
ffiComponent = "backend" // The pipeline uses backend not DYN_COMPONENT which is epp ffiComponent = "backend" // This is not the same as DYN_COMPONENT=epp (in this case)
ffiModel = getEnvOrDefault("DYN_MODEL", "Qwen/Qwen3-0.6B")
ffiWorkerID = getEnvInt64OrDefault("DYNAMO_WORKER_ID", 1)
ffiEnforceDisagg = getEnvBoolOrDefault("DYN_ENFORCE_DISAGG", false) ffiEnforceDisagg = getEnvBoolOrDefault("DYN_ENFORCE_DISAGG", false)
// Note: model name and kv_cache_block_size are now auto-discovered from the model card
ffiOverlapScoreWeight = getEnvFloatOrDefault("DYN_OVERLAP_SCORE_WEIGHT", -1.0) fmt.Printf("Dynamo KV Scorer: namespace=%s, component=%s, enforce_disagg=%v\n",
ffiRouterTemperature = getEnvFloatOrDefault("DYN_ROUTER_TEMPERATURE", -1.0) ffiNamespace, ffiComponent, ffiEnforceDisagg)
kvBlockSizeStr := os.Getenv("DYN_KV_BLOCK_SIZE")
if kvBlockSizeStr == "" {
panic("DYN_KV_BLOCK_SIZE is required and must match the model card's kv_cache_block_size")
}
var tmp int64
if n, err := fmt.Sscanf(kvBlockSizeStr, "%d", &tmp); err != nil || n != 1 {
panic(fmt.Sprintf("DYN_KV_BLOCK_SIZE='%s' is not a valid integer", kvBlockSizeStr))
}
ffiKvBlockSize = uint32(tmp)
if ffiKvBlockSize < 16 || ffiKvBlockSize > 8192 {
panic(fmt.Sprintf("DYN_KV_BLOCK_SIZE=%d outside [16,8192]", ffiKvBlockSize))
}
if (ffiKvBlockSize & (ffiKvBlockSize - 1)) != 0 {
panic(fmt.Sprintf("DYN_KV_BLOCK_SIZE=%d must be a power of 2", ffiKvBlockSize))
}
fmt.Printf("Dynamo KV Scorer: Loaded DYN_KV_BLOCK_SIZE=%d\n", ffiKvBlockSize)
} }
func getEnvOrDefault(key, def string) string { func getEnvOrDefault(key, def string) string {
...@@ -265,26 +220,6 @@ func getEnvOrDefault(key, def string) string { ...@@ -265,26 +220,6 @@ func getEnvOrDefault(key, def string) string {
return def return def
} }
func getEnvInt64OrDefault(key string, def int64) int64 {
if v := os.Getenv(key); v != "" {
var p int64
if n, err := fmt.Sscanf(v, "%d", &p); err == nil && n == 1 {
return p
}
}
return def
}
func getEnvFloatOrDefault(key string, def float64) float64 {
if v := os.Getenv(key); v != "" {
var p float64
if n, err := fmt.Sscanf(v, "%f", &p); err == nil && n == 1 {
return p
}
}
return def
}
func getEnvBoolOrDefault(key string, def bool) bool { func getEnvBoolOrDefault(key string, def bool) bool {
if v := os.Getenv(key); v != "" { if v := os.Getenv(key); v != "" {
switch strings.ToLower(v) { switch strings.ToLower(v) {
...@@ -297,46 +232,31 @@ func getEnvBoolOrDefault(key string, def bool) bool { ...@@ -297,46 +232,31 @@ func getEnvBoolOrDefault(key string, def bool) bool {
return def return def
} }
// initFFI: initialize runtime and create a persistent boxed pipeline. // initFFI: initialize router handles using the new Router bindings.
func initFFI() error { func initFFI() error {
ffiOnce.Do(func() { ffiOnce.Do(func() {
loadDynamoConfig() loadDynamoConfig()
ns := C.CString(ffiNamespace) ns := C.CString(ffiNamespace)
cm := C.CString(ffiComponent) cm := C.CString(ffiComponent)
model := C.CString(ffiModel)
defer C.free(unsafe.Pointer(ns)) defer C.free(unsafe.Pointer(ns))
defer C.free(unsafe.Pointer(cm)) defer C.free(unsafe.Pointer(cm))
defer C.free(unsafe.Pointer(model))
// Init Dynamo runtime
if rc := C.dynamo_llm_init(ns, cm, C.int64_t(ffiWorkerID), C.uint32_t(ffiKvBlockSize)); rc != C.DYNAMO_OK {
ffiErr = fmt.Errorf("dynamo_llm_init failed")
return
}
runtimeInitialized = true
// Create persistent pipeline // Create router handles
pipelineMutex.Lock() routerHandlesMutex.Lock()
defer pipelineMutex.Unlock() defer routerHandlesMutex.Unlock()
rc := C.dynamo_create_worker_selection_pipeline( rc := C.create_routers(
ns, ns,
cm, cm,
model,
C.bool(getEnvBoolOrDefault("DYN_USE_KV_ROUTING", true)),
C.double(getEnvFloatOrDefault("DYN_BUSY_THRESHOLD", -1.0)),
C.double(ffiOverlapScoreWeight),
C.double(ffiRouterTemperature),
C.bool(getEnvBoolOrDefault("DYN_USE_KV_EVENTS", true)),
C.bool(getEnvBoolOrDefault("DYNAMO_ROUTER_REPLICA_SYNC", false)), // no need as long as we call the Router Book keeping operations from the EPP.
C.bool(ffiEnforceDisagg), C.bool(ffiEnforceDisagg),
&pipeline, &routerHandles,
) )
if rc != C.DYNAMO_OK { if rc != C.QUERY_ROUTER_OK {
ffiErr = fmt.Errorf("dynamo_create_worker_selection_pipeline failed") ffiErr = fmt.Errorf("create_routers failed with code %d", rc)
return return
} }
routerInitialized = true
}) })
return ffiErr return ffiErr
} }
...@@ -362,16 +282,12 @@ func (k *KVAwareScorer) Score( ...@@ -362,16 +282,12 @@ func (k *KVAwareScorer) Score(
"tokenDataCount", len(tokenData), "tokenDataCount", len(tokenData),
) )
// Store in request headers for the Lua filter at the gateway // Store in request headers
if req.Headers == nil { if req.Headers == nil {
req.Headers = map[string]string{} req.Headers = map[string]string{}
} }
req.Headers[WorkerIDHeader] = workerID req.Headers[WorkerIDHeader] = workerID
// Disable local updates in the Dynamo frontend router.
// EPP handles bookkeeping via C FFI (add_request, mark_prefill_complete, free_request).
req.Headers[EnableLocalUpdatesHeader] = "false"
// Set routing mode and prefill worker ID based on disaggregated vs aggregated // Set routing mode and prefill worker ID based on disaggregated vs aggregated
if prefillWorkerID != "" && prefillWorkerID != workerID { if prefillWorkerID != "" && prefillWorkerID != workerID {
// Disaggregated mode: separate prefill and decode workers // Disaggregated mode: separate prefill and decode workers
...@@ -383,15 +299,13 @@ func (k *KVAwareScorer) Score( ...@@ -383,15 +299,13 @@ func (k *KVAwareScorer) Score(
} }
// Store routing state for PreRequest to register with router bookkeeping. // Store routing state for PreRequest to register with router bookkeeping.
// This is the correct place to store state - PreRequest is called AFTER // PreRequest is called AFTER scheduling is finalized, ensuring we only
// scheduling is finalized, ensuring we only register committed requests. // register committed requests (avoiding phantom bookkeeping entries).
if req.RequestId != "" { if req.RequestId != "" {
routingState := &DynamoRoutingState{ routingState := &DynamoRoutingState{
WorkerID: workerID, WorkerID: workerID,
PrefillWorkerID: prefillWorkerID, PrefillWorkerID: prefillWorkerID,
// TokenData is stored for future use. Currently not passed to workers TokenData: tokenData,
// via headers (too large). May be passed via request body in future.
TokenData: tokenData,
} }
k.pluginState.Write(req.RequestId, plugins.StateKey(stateKey), routingState) k.pluginState.Write(req.RequestId, plugins.StateKey(stateKey), routingState)
} }
...@@ -405,8 +319,8 @@ func (k *KVAwareScorer) Score( ...@@ -405,8 +319,8 @@ func (k *KVAwareScorer) Score(
} }
// PreRequest is called after scheduling is finalized and before the request is sent to the worker. // PreRequest is called after scheduling is finalized and before the request is sent to the worker.
// This is the correct place to register the request with the Dynamo router's bookkeeping, // This registers the request with the Dynamo router's bookkeeping (add_request), passing the
// as we know the request WILL be dispatched (avoiding phantom bookkeeping entries). // token data obtained during Score(). This ensures only committed requests are tracked.
func (k *KVAwareScorer) PreRequest( func (k *KVAwareScorer) PreRequest(
ctx context.Context, ctx context.Context,
request *schedtypes.LLMRequest, request *schedtypes.LLMRequest,
...@@ -432,8 +346,16 @@ func (k *KVAwareScorer) PreRequest( ...@@ -432,8 +346,16 @@ func (k *KVAwareScorer) PreRequest(
return return
} }
// Parse worker ID
var workerIDUint uint64
if _, parseErr := fmt.Sscanf(state.WorkerID, "%d", &workerIDUint); parseErr != nil {
logger.V(logutil.DEFAULT).Error(parseErr, "PreRequest: invalid worker ID",
"requestID", request.RequestId, "workerID", state.WorkerID)
return
}
// Register request with router bookkeeping now that scheduling is committed // Register request with router bookkeeping now that scheduling is committed
if addErr := k.callAddRequest(ctx, request.RequestId, state.TokenData, state.WorkerID, state.PrefillWorkerID); addErr != nil { if addErr := CallAddRequest(request.RequestId, state.TokenData, workerIDUint, 0); addErr != nil {
logger.V(logutil.DEFAULT).Error(addErr, "PreRequest: failed to add request to router bookkeeping", logger.V(logutil.DEFAULT).Error(addErr, "PreRequest: failed to add request to router bookkeeping",
"requestID", request.RequestId) "requestID", request.RequestId)
return return
...@@ -476,7 +398,7 @@ func (k *KVAwareScorer) ResponseStreaming( ...@@ -476,7 +398,7 @@ func (k *KVAwareScorer) ResponseStreaming(
// ResponseComplete is called after the complete response is sent to the client. // ResponseComplete is called after the complete response is sent to the client.
// It cleans up the router bookkeeping state for the completed request by calling // It cleans up the router bookkeeping state for the completed request by calling
// dynamo_router_free_request to release resources associated with the request. // free_request to release resources associated with the request.
func (k *KVAwareScorer) ResponseComplete( func (k *KVAwareScorer) ResponseComplete(
ctx context.Context, ctx context.Context,
request *schedtypes.LLMRequest, request *schedtypes.LLMRequest,
...@@ -510,7 +432,7 @@ func (k *KVAwareScorer) ResponseComplete( ...@@ -510,7 +432,7 @@ func (k *KVAwareScorer) ResponseComplete(
"requestID", requestID) "requestID", requestID)
} }
// --------------------------- router call (persistent only) --------------------------- // --------------------------- router call ---------------------------
func (k *KVAwareScorer) callDynamoRouter( func (k *KVAwareScorer) callDynamoRouter(
ctx context.Context, ctx context.Context,
...@@ -522,20 +444,25 @@ func (k *KVAwareScorer) callDynamoRouter( ...@@ -522,20 +444,25 @@ func (k *KVAwareScorer) callDynamoRouter(
logger.V(logutil.DEFAULT).Error(err, "FFI init failed") logger.V(logutil.DEFAULT).Error(err, "FFI init failed")
return "", "", nil, err return "", "", nil, err
} }
if !runtimeInitialized { if !routerInitialized {
return "", "", nil, fmt.Errorf("dynamo runtime not initialized") return "", "", nil, fmt.Errorf("dynamo router not initialized")
} }
pipelineMutex.RLock() routerHandlesMutex.RLock()
currentPipeline := pipeline router := routerHandles
pipelineMutex.RUnlock() routerHandlesMutex.RUnlock()
if currentPipeline == nil { if router == nil {
return "", "", nil, fmt.Errorf("dynamo worker selection pipeline not created") return "", "", nil, fmt.Errorf("dynamo router handles not created")
} }
// Build OpenAI-compatible JSON request from the new LLMRequest structure // Build OpenAI-compatible JSON request from the GAIE LLMRequest structure
requestBody := buildOpenAIRequest(req) requestBody, err := buildOpenAIRequest(req)
if err != nil {
logger.V(logutil.DEFAULT).Info("Invalid/empty request body for router; refusing to route",
"err", err.Error())
return "", "", nil, err
}
requestJSON, jsonErr := json.Marshal(requestBody) requestJSON, jsonErr := json.Marshal(requestBody)
if jsonErr != nil { if jsonErr != nil {
logger.V(logutil.DEFAULT).Error(jsonErr, "Failed to marshal OpenAI request") logger.V(logutil.DEFAULT).Error(jsonErr, "Failed to marshal OpenAI request")
...@@ -544,115 +471,104 @@ func (k *KVAwareScorer) callDynamoRouter( ...@@ -544,115 +471,104 @@ func (k *KVAwareScorer) callDynamoRouter(
cRequestJSON := C.CString(string(requestJSON)) cRequestJSON := C.CString(string(requestJSON))
defer C.free(unsafe.Pointer(cRequestJSON)) defer C.free(unsafe.Pointer(cRequestJSON))
// Output variables var result C.CRoutingResult
var cDecodeWorkerID C.int64_t rc := C.route_request(router, cRequestJSON, &result)
var cPrefillWorkerID C.int64_t if rc != C.QUERY_ROUTER_OK {
var cTokens *C.uint32_t return "", "", nil, fmt.Errorf("route_request failed with code %d", rc)
var cTokenCount C.size_t
var cAnnotatedJSON *C.char
// Call the worker selection pipeline
rc := C.dynamo_query_worker_selection_and_annotate(
currentPipeline,
cRequestJSON,
&cDecodeWorkerID,
&cPrefillWorkerID,
&cTokens,
&cTokenCount,
&cAnnotatedJSON,
)
if rc != C.DYNAMO_OK {
return "", "", nil, fmt.Errorf("dynamo_query_worker_selection_and_annotate failed")
} }
// Copy tokens into Go memory and free C memory // Copy token IDs into Go memory before freeing the Rust-allocated result.
count := int(uintptr(cTokenCount)) // These tokens are needed for add_request bookkeeping (overlap + active block tracking).
count := int(result.token_count)
var tokens64 []int64 var tokens64 []int64
if count > 0 && cTokens != nil { if count > 0 && result.token_ids != nil {
src := unsafe.Slice((*uint32)(unsafe.Pointer(cTokens)), count) src := unsafe.Slice((*uint32)(unsafe.Pointer(result.token_ids)), count)
tokens64 = make([]int64, count) tokens64 = make([]int64, count)
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
tokens64[i] = int64(src[i]) tokens64[i] = int64(src[i])
} }
} }
C.dynamo_free_worker_selection_result(cTokens, cTokenCount, cAnnotatedJSON)
workerIDStr := fmt.Sprintf("%d", int64(cDecodeWorkerID)) // Copy scalar result fields before freeing the struct
isDisaggregated := result.is_disaggregated
decodeWorkerID := uint64(result.decode_worker_id)
prefillWorkerIDVal := uint64(result.prefill_worker_id)
// Free the Rust-allocated routing result (including token_ids)
C.free_routing_result(&result)
workerIDStr := fmt.Sprintf("%d", decodeWorkerID)
prefillWorkerIDStr := "" prefillWorkerIDStr := ""
// Rust returns -1 for prefill_worker_id when not in disaggregated mode if isDisaggregated {
if int64(cPrefillWorkerID) >= 0 { prefillWorkerIDStr = fmt.Sprintf("%d", prefillWorkerIDVal)
prefillWorkerIDStr = fmt.Sprintf("%d", int64(cPrefillWorkerID))
} }
logger.V(logutil.DEFAULT).Info("Worker selection completed", logger.V(logutil.DEFAULT).Info("Worker selection completed",
"workerID", workerIDStr, "prefillWorkerID", prefillWorkerIDStr, "tokenCount", count) "workerID", workerIDStr, "prefillWorkerID", prefillWorkerIDStr,
"isDisaggregated", isDisaggregated, "tokenCount", count)
return workerIDStr, prefillWorkerIDStr, tokens64, nil return workerIDStr, prefillWorkerIDStr, tokens64, nil
} }
// buildOpenAIRequest constructs an OpenAI-compatible request from the new LLMRequest structure // buildOpenAIRequest constructs an OpenAI-compatible request from the GAIE LLMRequest structure.
func buildOpenAIRequest(req *schedtypes.LLMRequest) map[string]any { // Preserves message roles for correct chat template application and tokenization.
func buildOpenAIRequest(req *schedtypes.LLMRequest) (map[string]any, error) {
requestBody := make(map[string]any) requestBody := make(map[string]any)
// Extract prompt from the new Body structure // Preserve the original message structure for correct chat template application
userText := "default prompt" if req == nil || req.Body == nil {
if req != nil && req.Body != nil { return nil, fmt.Errorf("missing request body")
if req.Body.ChatCompletions != nil && len(req.Body.ChatCompletions.Messages) > 0 { }
// Extract text from chat completions messages
var sb strings.Builder if req.Body.ChatCompletions != nil && len(req.Body.ChatCompletions.Messages) > 0 {
for _, msg := range req.Body.ChatCompletions.Messages { messages := make([]map[string]any, 0, len(req.Body.ChatCompletions.Messages))
sb.WriteString(msg.Content.PlainText()) anyNonEmpty := false
sb.WriteString(" ") for _, msg := range req.Body.ChatCompletions.Messages {
content := msg.Content.PlainText()
if strings.TrimSpace(content) != "" {
anyNonEmpty = true
} }
userText = strings.TrimSpace(sb.String()) messages = append(messages, map[string]any{
} else if req.Body.Completions != nil && req.Body.Completions.Prompt != "" { "role": msg.Role,
userText = req.Body.Completions.Prompt "content": content,
})
}
if !anyNonEmpty {
return nil, fmt.Errorf("empty chat messages")
}
requestBody["messages"] = messages
} else if req.Body.Completions != nil && strings.TrimSpace(req.Body.Completions.Prompt) != "" {
// Legacy completions format - wrap as single user message
requestBody["messages"] = []map[string]any{
{"role": "user", "content": req.Body.Completions.Prompt},
} }
} else {
return nil, fmt.Errorf("no messages or prompt provided")
} }
requestBody["messages"] = []map[string]any{{"role": "user", "content": userText}} // Model field is required by OpenAI spec but not used by the router's tokenizer
// (tokenizer is determined by the discovered model card)
if req != nil && strings.TrimSpace(req.TargetModel) != "" { if req != nil && strings.TrimSpace(req.TargetModel) != "" {
requestBody["model"] = req.TargetModel requestBody["model"] = req.TargetModel
} else { } else {
requestBody["model"] = ffiModel requestBody["model"] = "default"
}
requestBody["max_tokens"] = 1
requestBody["temperature"] = 0.0
requestBody["stream"] = true
requestBody["nvext"] = map[string]any{
"annotations": []string{"query_instance_id"},
} }
return requestBody return requestBody, nil
} }
// --------------------------- router bookkeeping --------------------------- // --------------------------- router bookkeeping ---------------------------
// callAddRequest registers a request with the router's bookkeeping. // CallAddRequest registers a request with the router's bookkeeping.
// This should be called after worker selection to track active requests. func CallAddRequest(requestID string, tokenData []int64, workerID uint64, dpRank uint32) error {
func (k *KVAwareScorer) callAddRequest( if !routerInitialized {
ctx context.Context, return fmt.Errorf("dynamo router not initialized")
requestID string,
tokenData []int64,
workerID string,
prefillWorkerID string,
) error {
logger := log.FromContext(ctx)
if !runtimeInitialized {
return fmt.Errorf("dynamo runtime not initialized")
} }
pipelineMutex.RLock() routerHandlesMutex.RLock()
currentPipeline := pipeline router := routerHandles
pipelineMutex.RUnlock() routerHandlesMutex.RUnlock()
if currentPipeline == nil {
return fmt.Errorf("dynamo worker selection pipeline not created")
}
// Parse worker ID (use decode worker for bookkeeping in disagg mode) if router == nil {
var workerIDUint uint64 return fmt.Errorf("dynamo router handles not created")
if _, err := fmt.Sscanf(workerID, "%d", &workerIDUint); err != nil {
return fmt.Errorf("invalid worker ID: %s", workerID)
} }
// Convert token data from int64 to uint32 // Convert token data from int64 to uint32
...@@ -669,69 +585,66 @@ func (k *KVAwareScorer) callAddRequest( ...@@ -669,69 +585,66 @@ func (k *KVAwareScorer) callAddRequest(
cTokens = (*C.uint32_t)(unsafe.Pointer(&tokens[0])) cTokens = (*C.uint32_t)(unsafe.Pointer(&tokens[0]))
} }
rc := C.dynamo_router_add_request( rc := C.add_request(
currentPipeline, router,
cRequestID, cRequestID,
cTokens, cTokens,
C.size_t(len(tokens)), C.size_t(len(tokens)),
C.uint64_t(workerIDUint), C.uint64_t(workerID),
C.uint32_t(0), // dp_rank = 0 for now C.uint32_t(dpRank),
) )
if rc != C.DYNAMO_OK { if rc != C.QUERY_ROUTER_OK {
return fmt.Errorf("dynamo_router_add_request failed") return fmt.Errorf("add_request failed with code %d", rc)
} }
logger.V(logutil.VERBOSE).Info("Added request to router bookkeeping",
"requestID", requestID, "workerID", workerID, "tokenCount", len(tokens))
return nil return nil
} }
// CallMarkPrefillComplete marks prefill as completed for a request. // CallMarkPrefillComplete marks prefill as completed for a request.
// Exported for use by response handlers. // Exported for use by response handlers.
func CallMarkPrefillComplete(requestID string) error { func CallMarkPrefillComplete(requestID string) error {
if !runtimeInitialized { if !routerInitialized {
return fmt.Errorf("dynamo runtime not initialized") return fmt.Errorf("dynamo router not initialized")
} }
pipelineMutex.RLock() routerHandlesMutex.RLock()
currentPipeline := pipeline router := routerHandles
pipelineMutex.RUnlock() routerHandlesMutex.RUnlock()
if currentPipeline == nil { if router == nil {
return fmt.Errorf("dynamo worker selection pipeline not created") return fmt.Errorf("dynamo router handles not created")
} }
cRequestID := C.CString(requestID) cRequestID := C.CString(requestID)
defer C.free(unsafe.Pointer(cRequestID)) defer C.free(unsafe.Pointer(cRequestID))
rc := C.dynamo_router_mark_prefill_complete(currentPipeline, cRequestID) rc := C.mark_prefill_complete(router, cRequestID)
if rc != C.DYNAMO_OK { if rc != C.QUERY_ROUTER_OK {
return fmt.Errorf("dynamo_router_mark_prefill_complete failed") return fmt.Errorf("mark_prefill_complete failed with code %d", rc)
} }
return nil return nil
} }
// callFreeRequestInternal cleans up router state for a completed/cancelled request. // callFreeRequestInternal cleans up router state for a completed/cancelled request.
func callFreeRequestInternal(requestID string) error { func callFreeRequestInternal(requestID string) error {
if !runtimeInitialized { if !routerInitialized {
return fmt.Errorf("dynamo runtime not initialized") return fmt.Errorf("dynamo router not initialized")
} }
pipelineMutex.RLock() routerHandlesMutex.RLock()
currentPipeline := pipeline router := routerHandles
pipelineMutex.RUnlock() routerHandlesMutex.RUnlock()
if currentPipeline == nil { if router == nil {
return fmt.Errorf("dynamo worker selection pipeline not created") return fmt.Errorf("dynamo router handles not created")
} }
cRequestID := C.CString(requestID) cRequestID := C.CString(requestID)
defer C.free(unsafe.Pointer(cRequestID)) defer C.free(unsafe.Pointer(cRequestID))
rc := C.dynamo_router_free_request(currentPipeline, cRequestID) rc := C.free_request(router, cRequestID)
if rc != C.DYNAMO_OK { if rc != C.QUERY_ROUTER_OK {
return fmt.Errorf("dynamo_router_free_request failed") return fmt.Errorf("free_request failed with code %d", rc)
} }
return nil return nil
} }
...@@ -739,21 +652,14 @@ func callFreeRequestInternal(requestID string) error { ...@@ -739,21 +652,14 @@ func callFreeRequestInternal(requestID string) error {
// --------------------------- shutdown --------------------------- // --------------------------- shutdown ---------------------------
func cleanupDynamo() error { func cleanupDynamo() error {
pipelineMutex.Lock() routerHandlesMutex.Lock()
defer pipelineMutex.Unlock() defer routerHandlesMutex.Unlock()
if pipeline != nil { if routerHandles != nil {
if rc := C.dynamo_destroy_worker_selection_pipeline(pipeline); rc != C.DYNAMO_OK { C.destroy(routerHandles)
fmt.Printf("Warning: dynamo_destroy_worker_selection_pipeline failed\n") routerHandles = nil
}
pipeline = nil
} }
if runtimeInitialized { routerInitialized = false
if rc := C.dynamo_llm_shutdown(); rc != C.DYNAMO_OK {
return fmt.Errorf("dynamo_llm_shutdown failed")
}
runtimeInitialized = false
}
return nil return nil
} }
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
{{- $useDynamo := default false .Values.epp.useDynamo -}} {{- $useDynamo := default false .Values.epp.useDynamo -}}
{{- $resolvedDynNs := (include "dynamo-gaie.dynamoNamespace" .) | trim -}} {{- $resolvedDynNs := (include "dynamo-gaie.dynamoNamespace" .) | trim -}}
{{- $ns := ternary (required "set dynamoGraphDeploymentName when epp.useDynamo=true" $resolvedDynNs) "" $useDynamo -}} {{- $ns := ternary (required "set dynamoGraphDeploymentName when epp.useDynamo=true" $resolvedDynNs) "" $useDynamo -}}
{{- $kv := default "16" .Values.epp.dynamo.kvBlockSize -}}
{{- $useEtcd := default false .Values.epp.dynamo.useEtcd -}} {{- $useEtcd := default false .Values.epp.dynamo.useEtcd -}}
{{- $eppImage := required "extension.image is required - set via --set-string extension.image=$EPP_IMAGE or in values file" .Values.extension.image }} {{- $eppImage := required "extension.image is required - set via --set-string extension.image=$EPP_IMAGE or in values file" .Values.extension.image }}
...@@ -113,8 +112,6 @@ spec: ...@@ -113,8 +112,6 @@ spec:
value: "nats://{{ $platformName }}-nats.{{ $platformNs }}:4222" value: "nats://{{ $platformName }}-nats.{{ $platformNs }}:4222"
- name: DYN_NAMESPACE - name: DYN_NAMESPACE
value: "{{ $ns }}" value: "{{ $ns }}"
- name: DYN_KV_BLOCK_SIZE
value: "{{ $kv }}"
- name: USE_STREAMING - name: USE_STREAMING
value: "true" value: "true"
# HuggingFace token for downloading model config files # HuggingFace token for downloading model config files
......
...@@ -72,7 +72,6 @@ epp: ...@@ -72,7 +72,6 @@ epp:
# Dynamo-specific settings (only used when useDynamo: true) # Dynamo-specific settings (only used when useDynamo: true)
configFile: "/etc/epp/epp-config-dynamo.yaml" configFile: "/etc/epp/epp-config-dynamo.yaml"
dynamo: dynamo:
kvBlockSize: "16"
# Use ETCD for discovery instead of Kubernetes (default: false) # Use ETCD for discovery instead of Kubernetes (default: false)
# Set to true via --set epp.dynamo.useEtcd=true to enable ETCD discovery # Set to true via --set epp.dynamo.useEtcd=true to enable ETCD discovery
useEtcd: false useEtcd: false
......
...@@ -87,10 +87,6 @@ func (e *EPPDefaults) GetBaseContainer(context ComponentContext) (corev1.Contain ...@@ -87,10 +87,6 @@ func (e *EPPDefaults) GetBaseContainer(context ComponentContext) (corev1.Contain
// EPP-specific environment variables // EPP-specific environment variables
container.Env = append(container.Env, []corev1.EnvVar{ container.Env = append(container.Env, []corev1.EnvVar{
{
Name: "DYN_KV_BLOCK_SIZE",
Value: "16",
},
{ {
Name: "USE_STREAMING", Name: "USE_STREAMING",
Value: "true", Value: "true",
......
...@@ -6,15 +6,24 @@ use libc::c_char; ...@@ -6,15 +6,24 @@ use libc::c_char;
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use std::borrow::Cow; use std::borrow::Cow;
use std::ffi::CStr; use std::ffi::CStr;
use std::ptr;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
use dynamo_llm::{ use dynamo_llm::kv_router::{protocols::*, publisher::KvEventPublisher};
discovery::{KvWorkerMonitor, ModelWatcher}, use dynamo_llm::preprocessor::OpenAIPreprocessor;
kv_router::{protocols::*, publisher::KvEventPublisher},
};
use dynamo_runtime::discovery::DiscoveryQuery; use dynamo_runtime::discovery::DiscoveryQuery;
use dynamo_runtime::{DistributedRuntime, Worker}; use dynamo_runtime::{DistributedRuntime, Worker};
use dynamo_runtime::Runtime;
use dynamo_llm::discovery::{ModelManager, WORKER_TYPE_DECODE};
use dynamo_llm::kv_router::KvRouterConfig;
use dynamo_llm::kv_router::protocols::WorkerWithDpRank;
use dynamo_llm::kv_router::{KvRouter, PrefillRouter, RouterConfigOverride};
use dynamo_runtime::pipeline::RouterMode;
static WK: OnceCell<Worker> = OnceCell::new(); static WK: OnceCell<Worker> = OnceCell::new();
static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new(); static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new();
// [FIXME] shouldn't the publisher be instance passing between API calls? // [FIXME] shouldn't the publisher be instance passing between API calls?
...@@ -354,1156 +363,866 @@ pub extern "C" fn dynamo_kv_event_publish_removed( ...@@ -354,1156 +363,866 @@ pub extern "C" fn dynamo_kv_event_publish_removed(
} }
} }
// Need to setup etcd and nats to run these tests
// #[cfg(test)]
// mod tests {
// use super::*;
// use std::ffi::CString;
// #[test]
// fn test_dynamo_llm_init() {
// // Create C-compatible strings
// let namespace = CString::new("test_namespace").unwrap();
// let component = CString::new("test_component").unwrap();
// // Call the init function
// let result = unsafe {
// dynamo_llm_init(
// namespace.as_ptr(),
// component.as_ptr(),
// 1, // worker_id
// 32, // kv_block_size
// )
// };
// assert_eq!(result as u32, DynamoLlmResult::OK as u32);
// assert!(WK.get().is_some());
// let shutdown_result = dynamo_llm_shutdown();
// assert_eq!(shutdown_result as u32, DynamoLlmResult::OK as u32);
// }
// }
/* ------------------------------------------------------------------------ /* ------------------------------------------------------------------------
* Worker selection pipeline * Router Bindings for GAIE EPP
* ------------------------------------------------------------------------ */ * ------------------------------------------------------------------------ */
use std::pin::Pin;
const GENERATE_ENDPOINT: &str = "generate";
use anyhow::Context;
use dynamo_runtime::{Runtime, traits::DistributedRuntimeProvider};
use dynamo_llm::discovery::ModelManager; // Default timeout for bookkeeping operations
use dynamo_llm::entrypoint::build_routed_pipeline; const BOOKKEEPING_TIMEOUT_SEC: u64 = 5;
use dynamo_llm::http::service::metrics::Metrics; /// Complete routing result for a chat completion request (C-compatible)
use dynamo_llm::kv_router::KvRouterConfig; #[repr(C)]
use dynamo_llm::model_card::ModelDeploymentCard; pub struct CRoutingResult {
use dynamo_llm::protocols::openai::nvext::NvExt; /// Whether disaggregated mode is active
use dynamo_llm::types::{ pub is_disaggregated: bool,
Annotated, /// Prefill worker ID (only valid if is_disaggregated is true)
openai::chat_completions::{ pub prefill_worker_id: u64,
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, /// Decode worker ID
}, pub decode_worker_id: u64,
}; /// Token IDs (needed for add_request callback)
use dynamo_runtime::{ pub token_ids: *mut u32,
engine::AsyncEngineStream, /// Number of tokens in the request
pipeline::{ManyOut, RouterMode, ServiceEngine, SingleIn}, pub token_count: usize,
};
/// Opaque handle exposed to C — it owns its own Worker/runtime and engine.
pub struct WorkerSelectionPipeline {
wk: Worker,
engine: ServiceEngine<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>,
/// KV router for bookkeeping operations (only present when router_mode is KV)
kv_router: Option<Arc<dynamo_llm::kv_router::KvRouter>>,
} }
/// Create a worker-selection pipeline ("generate" endpoint). impl Default for CRoutingResult {
/// fn default() -> Self {
/// # Safety Self {
/// - `namespace_c_str`, `component_c_str`, and `model_name_c_str` must be **non-null** pointers to is_disaggregated: false,
/// **NUL-terminated** C strings that contain **valid UTF-8**. They must remain valid for the prefill_worker_id: 0,
/// duration of this call. decode_worker_id: 0,
/// - `pipeline_out` must be **non-null** and point to writable memory for a `*mut WorkerSelectionPipeline`. token_ids: ptr::null_mut(),
/// On success this function writes exactly once to `*pipeline_out`. The caller becomes the owner of token_count: 0,
/// that pointer and **must** later free it by calling `dynamo_destroy_worker_selection_pipeline`.
/// - Must be called **after** a successful `dynamo_llm_init()`; otherwise behavior is undefined.
/// - This function is not signal-safe and must not be called from a signal handler.
/// - This function may block internally; do not call it from contexts that forbid blocking.
///
/// # Errors
/// Returns `DynamoLlmResult::ERR` on failure and does not write to `pipeline_out`.
/// # Safety
/// See detailed safety docs above. Additional parameter:
/// - `enforce_disagg`: If true, requests fail when disaggregated serving is unavailable.
/// If false, falls back to aggregated serving.
#[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline(
namespace_c_str: *const c_char,
component_c_str: *const c_char,
model_name_c_str: *const c_char,
use_kv_routing: bool,
busy_threshold: f64,
overlap_score_weight: f64,
router_temperature: f64,
use_kv_events: bool,
router_replica_sync: bool,
enforce_disagg: bool,
pipeline_out: *mut *mut WorkerSelectionPipeline,
) -> DynamoLlmResult {
if pipeline_out.is_null() {
tracing::error!("pipeline_out pointer is null");
return DynamoLlmResult::ERR;
}
let wk = match WK.get() {
Some(w) => w.clone(),
None => {
tracing::error!("Worker not initialized. Call dynamo_llm_init first.");
return DynamoLlmResult::ERR;
}
};
let namespace = match unsafe { CStr::from_ptr(namespace_c_str) }.to_str() {
Ok(s) => s.to_owned(),
Err(e) => {
tracing::error!(error = ?e, "bad namespace");
return DynamoLlmResult::ERR;
} }
};
let component_cow = unsafe { cstr_or_default(component_c_str, "backend") };
if let Cow::Borrowed("backend") = &component_cow {
tracing::info!("defaulting to \"backend\" for component");
} }
let component: String = component_cow.into_owned(); }
let model = match unsafe { CStr::from_ptr(model_name_c_str) }.to_str() { /// Container holding routers and preprocessor for query routing
Ok(s) => s.to_owned(), pub struct RouterHandles {
Err(e) => { prefill_router: Arc<PrefillRouter>,
tracing::error!(error = ?e, "bad model"); decode_router: Arc<KvRouter>,
return DynamoLlmResult::ERR; #[allow(dead_code)]
} model_manager: Arc<ModelManager>,
}; #[allow(dead_code)]
namespace: String,
/// Cached runtime for executing async operations (avoids creating new runtime per call)
runtime: Runtime,
/// Preprocessor for tokenization and template application (fetched via discovery)
preprocessor: Option<Arc<OpenAIPreprocessor>>,
}
let make_engine = || async { impl RouterHandles {
let router_mode = if use_kv_routing { /// Query optimal prefill worker for a request.
RouterMode::KV /// Returns worker_id on success.
} else { async fn query_prefill_worker(
RouterMode::RoundRobin &self,
}; tokens: &[u32],
update_states: bool,
lora_name: Option<String>,
priority_jump: f64,
) -> Result<u64, QueryRouterResult> {
self.prefill_router
.query_prefill_worker(tokens, update_states, lora_name, priority_jump)
.await
.map(|(worker_id, _dp_rank)| worker_id)
.map_err(|e| {
tracing::error!(error = ?e, "Prefill query failed");
QueryRouterResult::ErrQueryFailed
})
}
let kv_router_config = if use_kv_routing { /// Query optimal decode worker for a request.
Some(KvRouterConfig { /// For disaggregated mode, set `is_disaggregated` to true to use overlap_score_weight=0
overlap_score_weight, /// (since KV cache is being transferred from prefill, not reused).
router_temperature, ///
use_kv_events, /// Note: The C bindings are query-only and must not mutate router state during worker
router_replica_sync, /// selection. State updates require a `context_id` (request id) and are managed via the
..KvRouterConfig::default() /// explicit bookkeeping APIs (`add_request`, `mark_prefill_complete`, `free_request`).
/// Returns (worker, overlap_blocks) on success.
async fn query_decode_worker(
&self,
tokens: &[u32],
is_disaggregated: bool,
) -> Result<(WorkerWithDpRank, u32), QueryRouterResult> {
// For decode phase in disaggregated mode, use overlap_score_weight=0
// This matches prefill_router.rs
let config_override = if is_disaggregated {
Some(RouterConfigOverride {
overlap_score_weight: Some(0.0),
..Default::default()
}) })
} else { } else {
None None
}; };
create_worker_selection_pipeline_chat( self.decode_router
&namespace, .find_best_match(None, tokens, config_override.as_ref(), false, None, 0.0)
&component, .await
&model, .map_err(|e| {
router_mode, tracing::error!(error = ?e, "Decode query failed");
(busy_threshold >= 0.0).then_some(busy_threshold), QueryRouterResult::ErrQueryFailed
kv_router_config, })
enforce_disagg, }
) }
.await
};
let (engine, kv_router) = match wk.runtime().secondary().block_on(make_engine()) { /// Opaque handle for the router pair
Ok(p) => p, pub type RouterHandlesPtr = *mut RouterHandles;
Err(e) => {
tracing::error!(error = ?e, "create_worker_selection_pipeline_chat failed");
return DynamoLlmResult::ERR;
}
};
let handle = Box::new(WorkerSelectionPipeline { /// Result codes for query router C FFI
wk, #[repr(u32)]
engine, pub enum QueryRouterResult {
kv_router, Ok = 0,
}); ErrInvalidHandle = 1,
unsafe { ErrInvalidParam = 2,
*pipeline_out = Box::into_raw(handle); ErrInitFailed = 3,
ErrQueryFailed = 4,
ErrDisaggEnforced = 5,
ErrTimeout = 6,
}
/// Build a `KvRouterConfig` from defaults, overridden by optional `DYN_*` environment variables.
///
/// Supported env vars (all optional — unset or empty values are ignored):
/// - `DYN_OVERLAP_SCORE_WEIGHT` — Weight for overlap score in worker selection (default: 1.0)
/// - `DYN_ROUTER_TEMPERATURE` — Temperature for worker sampling via softmax (default: 0.0)
/// - `DYN_USE_KV_EVENTS` — Use KV events for cache tracking (default: true)
/// - `DYN_ROUTER_REPLICA_SYNC` — Enable replica synchronization (default: false)
/// - `DYN_ROUTER_TRACK_ACTIVE_BLOCKS` — Track active blocks (default: true)
/// - `DYN_ROUTER_TRACK_OUTPUT_BLOCKS` — Track output blocks during generation (default: false)
fn kv_router_config_from_env() -> KvRouterConfig {
let mut cfg = KvRouterConfig::default();
fn env_f64(key: &str) -> Option<f64> {
std::env::var(key).ok().and_then(|v| v.parse().ok())
} }
DynamoLlmResult::OK fn env_bool(key: &str) -> Option<bool> {
std::env::var(key)
.ok()
.and_then(|v| match v.to_lowercase().as_str() {
"true" | "1" | "yes" | "on" => Some(true),
"false" | "0" | "no" | "off" => Some(false),
_ => None,
})
}
if let Some(v) = env_f64("DYN_OVERLAP_SCORE_WEIGHT") {
cfg.overlap_score_weight = v;
}
if let Some(v) = env_f64("DYN_ROUTER_TEMPERATURE") {
cfg.router_temperature = v;
}
if let Some(v) = env_bool("DYN_USE_KV_EVENTS") {
cfg.use_kv_events = v;
}
if let Some(v) = env_bool("DYN_ROUTER_REPLICA_SYNC") {
cfg.router_replica_sync = v;
}
if let Some(v) = env_bool("DYN_ROUTER_TRACK_ACTIVE_BLOCKS") {
cfg.router_track_active_blocks = v;
}
if let Some(v) = env_bool("DYN_ROUTER_TRACK_OUTPUT_BLOCKS") {
cfg.router_track_output_blocks = v;
}
tracing::info!(
overlap_score_weight = cfg.overlap_score_weight,
router_temperature = cfg.router_temperature,
use_kv_events = cfg.use_kv_events,
router_replica_sync = cfg.router_replica_sync,
router_track_active_blocks = cfg.router_track_active_blocks,
router_track_output_blocks = cfg.router_track_output_blocks,
"KvRouterConfig initialized (DYN_* env overrides applied)"
);
cfg
} }
/// Query worker selection on an existing pipeline and return: /// Create router handles for query-only routing
/// - `decode_worker_id_out` (`i64`): The decode worker ID (primary worker)
/// - `prefill_worker_id_out` (`i64`): The prefill worker ID (-1 if not in disaggregated mode)
/// - `token_ids_out` (heap-allocated `*mut u32`; caller must free via
/// `dynamo_free_worker_selection_result`)
/// - `token_count_out` (`usize`)
/// - `annotated_request_json_out` (`*mut c_char` to a NUL-terminated C string;
/// caller frees via the same free function)
/// ///
/// # Safety /// This function waits for at least one decode worker to be discovered before returning.
/// - `pipeline` /// It auto-detects disaggregated mode by checking if prefill workers are present.
/// - Must be a **non-null** pointer previously returned by /// The KV cache block size is automatically fetched from the model card via discovery.
/// `dynamo_create_worker_selection_pipeline` and not yet passed to
/// `dynamo_destroy_worker_selection_pipeline`.
/// - Must remain valid for the entire duration of this call.
/// - **Do not** call this function concurrently on the same `pipeline` pointer
/// from multiple threads unless the surrounding code guarantees synchronization.
/// - `request_json_c_str`
/// - Must be a **non-null**, **NUL-terminated** C string containing **valid UTF-8**.
/// - The JSON must represent a valid `NvCreateChatCompletionRequest`; otherwise this
/// function returns `DynamoLlmResult::ERR`.
/// - Must remain valid for the duration of this call.
/// - Output pointers:
/// - `decode_worker_id_out`, `prefill_worker_id_out`, `token_ids_out`, `token_count_out`,
/// and `annotated_request_json_out` must each be **non-null** and point to
/// writable memory for their respective types. On success, this function
/// writes to all five outputs exactly once.
/// - On **error**, outputs are left unmodified.
/// - Ownership & deallocation:
/// - On success, if there are zero tokens, `*token_ids_out` may be set to `NULL`
/// and `*token_count_out` set to `0`.
/// - If non-null, the buffer written to `*token_ids_out` is allocated with the
/// Rust global allocator and **must** be freed by calling
/// `dynamo_free_worker_selection_result` with the same `token_count_out` value.
/// - The pointer written to `*annotated_request_json_out` is a `CString` allocated
/// by Rust and **must** be freed by calling `dynamo_free_worker_selection_result`.
/// - **Do not** free these with `free(3)` or any other allocator; doing so is
/// undefined behavior.
/// - Blocking & context:
/// - This function may **block** internally while it performs async work; do not
/// call it from contexts that forbid blocking (e.g., signal handlers).
/// - Process/ABI assumptions:
/// - The caller and callee must run in the same process and use the same Rust
/// global allocator for the paired allocation/free described above.
/// - This function is not signal-safe.
/// ///
/// # Errors /// # Arguments
/// Returns `DynamoLlmResult::ERR` if any precondition fails (null/invalid pointers, /// - `namespace`: Namespace for the model
/// malformed UTF-8/JSON, pipeline errors, allocation failures, etc.). On error, no /// - `component`: Component name (defaults to "backend" if NULL or empty)
/// output pointer is written. /// - `enforce_disagg`: If true, disaggregated mode is required (fails if no prefill workers found)
/// - `out_handle`: Output handle
/// ///
/// # Output values /// # Safety
/// - `decode_worker_id_out`: The decode worker ID (primary worker in aggregated mode) /// - All string parameters must be valid null-terminated C strings
/// - `prefill_worker_id_out`: The prefill worker ID (only set in disaggregated mode, -1 if not present) /// - The returned handle must be freed with `destroy`
/// - `token_ids_out`, `token_count_out`: Token IDs and count
/// - `annotated_request_json_out`: The annotated request JSON
#[unsafe(no_mangle)] #[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_query_worker_selection_and_annotate( pub unsafe extern "C" fn create_routers(
pipeline: *mut WorkerSelectionPipeline, namespace: *const c_char,
request_json_c_str: *const c_char, component: *const c_char,
decode_worker_id_out: *mut i64, enforce_disagg: bool,
prefill_worker_id_out: *mut i64, out_handle: *mut RouterHandlesPtr,
token_ids_out: *mut *mut u32, ) -> QueryRouterResult {
token_count_out: *mut usize, if namespace.is_null() || out_handle.is_null() {
annotated_request_json_out: *mut *mut c_char, return QueryRouterResult::ErrInvalidParam;
) -> DynamoLlmResult {
if pipeline.is_null() {
tracing::error!("Pipeline pointer is null");
return DynamoLlmResult::ERR;
}
if decode_worker_id_out.is_null()
|| prefill_worker_id_out.is_null()
|| token_ids_out.is_null()
|| token_count_out.is_null()
|| annotated_request_json_out.is_null()
{
tracing::error!("One or more output pointers are null");
return DynamoLlmResult::ERR;
} }
let req_str = match unsafe { CStr::from_ptr(request_json_c_str) }.to_str() { let namespace_str = match unsafe { CStr::from_ptr(namespace) }.to_str() {
Ok(s) => s, Ok(s) => s.to_owned(),
Err(e) => { Err(_) => return QueryRouterResult::ErrInvalidParam,
tracing::error!(error = ?e, "bad request json");
return DynamoLlmResult::ERR;
}
};
let request: NvCreateChatCompletionRequest = match serde_json::from_str(req_str) {
Ok(r) => r,
Err(e) => {
tracing::error!(error = ?e, "parse request failed");
return DynamoLlmResult::ERR;
}
};
let pl = unsafe { &*pipeline };
let fut = async { query_worker_selection_and_annotate(&pl.engine, request).await };
let (result, annotated_req) = match pl.wk.runtime().secondary().block_on(fut) {
Ok(v) => v,
Err(e) => {
tracing::error!(error = ?e, "query_worker_selection_and_annotate failed");
return DynamoLlmResult::ERR;
}
}; };
let tokens_ptr = if result.tokens.is_empty() { let component_str = if component.is_null() {
std::ptr::null_mut() "backend".to_string()
} else { } else {
let len = result.tokens.len(); match unsafe { CStr::from_ptr(component) }.to_str() {
let layout = std::alloc::Layout::array::<u32>(len).unwrap(); Ok(s) if !s.is_empty() => s.to_owned(),
let ptr = unsafe { std::alloc::alloc(layout) as *mut u32 }; _ => "backend".to_string(),
if ptr.is_null() {
tracing::error!("alloc tokens failed");
return DynamoLlmResult::ERR;
}
unsafe {
std::ptr::copy_nonoverlapping(result.tokens.as_ptr(), ptr, len);
} }
ptr
}; };
let annotated_json = match serde_json::to_string(&annotated_req) { // Create the runtime once - it will be stored in RouterHandles and reused
Ok(s) => s, let runtime = match Runtime::from_settings() {
Ok(rt) => rt,
Err(e) => { Err(e) => {
if !tokens_ptr.is_null() { tracing::error!(error = ?e, "Failed to create runtime");
let layout = std::alloc::Layout::array::<u32>(result.tokens.len()).unwrap(); return QueryRouterResult::ErrInitFailed;
unsafe {
std::alloc::dealloc(tokens_ptr as *mut u8, layout);
}
tracing::error!(error = ?e, "serialize annotated request failed");
}
return DynamoLlmResult::ERR;
} }
}; };
let cjson = match std::ffi::CString::new(annotated_json) {
Ok(c) => c, // Clone for use inside the async block (the original will be moved into handles)
Err(e) => { let runtime_for_async = runtime.clone();
tracing::error!(error = ?e, "CString::new for annotated JSON failed");
if !tokens_ptr.is_null() { let result = runtime_for_async.secondary().block_on(async {
let layout = std::alloc::Layout::array::<u32>(result.tokens.len()).unwrap(); let drt = match DistributedRuntime::from_settings(runtime_for_async.clone()).await {
unsafe { Ok(drt) => drt,
std::alloc::dealloc(tokens_ptr as *mut u8, layout); Err(e) => {
} tracing::error!(error = ?e, "Failed to create distributed runtime");
return Err(QueryRouterResult::ErrInitFailed);
} }
return DynamoLlmResult::ERR; };
// Wait for at least one worker to be discovered before proceeding
// This ensures the decode router can be created successfully
let instance_count = wait_for_discovery_sync(&drt).await;
if instance_count == 0 {
tracing::error!(
"Discovery sync failed: no worker instances found. Is the backend running?"
);
return Err(QueryRouterResult::ErrInitFailed);
} }
}; tracing::info!(
unsafe { "Discovery sync complete, {} worker(s) found",
*decode_worker_id_out = result.decode_worker_id.unwrap_or(0); instance_count
*prefill_worker_id_out = result.prefill_worker_id.unwrap_or(-1); );
*token_ids_out = tokens_ptr;
*token_count_out = result.tokens.len();
*annotated_request_json_out = cjson.into_raw();
}
DynamoLlmResult::OK
}
/// Destroy a previously created pipeline. let kv_router_config = kv_router_config_from_env();
///
/// # Safety
/// - `pipeline`
/// - **Must** be a non-null pointer that was **originally returned by**
/// `dynamo_create_worker_selection_pipeline` (i.e., obtained via
/// `Box::into_raw` on a `WorkerSelectionPipeline`).
/// - **Must not** have been passed to this function (or otherwise freed)
/// before. Passing the same pointer twice is a **double free** and is
/// undefined behavior.
/// - **Must not** be used by any other thread while this function runs.
/// Ensure no concurrent calls are in flight that read or write through
/// this handle (e.g., `dynamo_query_worker_selection_and_annotate`).
/// - After a successful call, the pointer is **invalid** and must not be
/// dereferenced or used again in any way.
/// - Allocator/ABI
/// - The caller and callee must be in the same process and share the same
/// allocator; this function reclaims the allocation that was created by
/// Rust for the handle.
/// - Lifetime/FFI
/// - Do not call from contexts that forbid blocking or running destructors
/// (e.g., signal handlers).
///
/// # Errors
/// - Returns `DynamoLlmResult::ERR` if `pipeline` is null.
/// - On `OK`, ownership of `pipeline` is taken and the underlying resources
/// are dropped; using the pointer after return is undefined behavior.
#[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_destroy_worker_selection_pipeline(
pipeline: *mut WorkerSelectionPipeline,
) -> DynamoLlmResult {
if pipeline.is_null() {
tracing::error!("Pipeline pointer is null");
return DynamoLlmResult::ERR;
}
let _boxed: Box<WorkerSelectionPipeline> = unsafe { Box::from_raw(pipeline) };
DynamoLlmResult::OK
}
/// Free buffers allocated by `dynamo_query_worker_selection_and_annotate`. // Get component and endpoint
/// let component_handle = match drt.namespace(&namespace_str) {
/// # Safety Ok(ns) => match ns.component(&component_str) {
/// - `token_ids` and `annotated_request_json` **must come from this library**: Ok(c) => c,
/// - `token_ids` must be the exact pointer previously returned by Err(e) => {
/// `dynamo_query_worker_selection_and_annotate` for the tokens buffer, tracing::error!(error = ?e, "Failed to get component");
/// allocated with Rust’s global allocator in this process. return Err(QueryRouterResult::ErrInitFailed);
/// - `annotated_request_json` must be the exact pointer previously returned by }
/// `CString::into_raw` inside `dynamo_query_worker_selection_and_annotate`.
/// - **Call at most once** per pointer. Passing the same pointer again is a
/// double-free and is undefined behavior.
/// - Pointer/length invariants:
/// - If `token_ids` is non-null, `token_count` **must** be the exact length
/// originally returned. Mismatched lengths cause invalid deallocation.
/// - If `token_ids` is null, `token_count` should be `0`.
/// - Passing a non-null `token_ids` with `token_count == 0` will leak in this
/// implementation (we only dealloc when `token_count > 0`).
/// - After return, the pointers are **invalid** and must not be used again.
/// - The caller and callee must be in the same process and share the same
/// allocator/ABI (these deallocations use Rust’s global allocator).
/// - Ensure no other threads are concurrently reading/writing these buffers when
/// freeing them.
/// - Do not call from contexts that forbid running destructors (e.g., signal handlers).
///
/// Returns `DynamoLlmResult::OK` on success.
#[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_free_worker_selection_result(
token_ids: *mut u32,
token_count: usize,
annotated_request_json: *mut c_char,
) -> DynamoLlmResult {
if token_count > 0 {
match std::alloc::Layout::array::<u32>(token_count) {
Ok(layout) if !token_ids.is_null() => unsafe {
std::alloc::dealloc(token_ids as *mut u8, layout);
}, },
_ => {} Err(e) => {
} tracing::error!(error = ?e, "Failed to get namespace");
} return Err(QueryRouterResult::ErrInitFailed);
if !annotated_request_json.is_null() { }
unsafe { };
drop(std::ffi::CString::from_raw(annotated_request_json)); let endpoint = component_handle.endpoint("generate");
}
}
DynamoLlmResult::OK
}
/// Default timeout for GAIE bookkeeping operations (30 seconds) let model_manager = Arc::new(ModelManager::new());
const GAIE_BOOKKEEPING_TIMEOUT_SECS: u64 = 30;
/// Helper to validate pipeline pointer and extract request_id from C string.
/// Returns `Err(DynamoLlmResult::ERR)` on validation failure, `Ok((pipeline_ref, request_id))` on success.
unsafe fn validate_pipeline_and_request_id(
pipeline: *mut WorkerSelectionPipeline,
request_id_c_str: *const c_char,
operation: &str,
) -> Result<(&'static WorkerSelectionPipeline, String), DynamoLlmResult> {
if pipeline.is_null() {
tracing::error!("[GAIE] {} failed: pipeline pointer is null", operation);
return Err(DynamoLlmResult::ERR);
}
let request_id = match unsafe { CStr::from_ptr(request_id_c_str) }.to_str() { // Fetch model card via discovery and create preprocessor + get block_size
Ok(s) => s.to_owned(), let (preprocessor, block_size, model_name) =
Err(e) => { match fetch_preprocessor_from_discovery(&drt, &namespace_str).await {
tracing::error!(error = ?e, "[GAIE] {} failed: bad request_id", operation); Ok((prep, bs, name)) => {
return Err(DynamoLlmResult::ERR); tracing::info!(
} kv_cache_block_size = bs,
}; "Preprocessor created from discovery"
);
(Some(prep), bs, name)
}
Err(e) => {
tracing::error!(
error = %e,
"Failed to fetch model card from discovery - cannot determine block_size"
);
return Err(QueryRouterResult::ErrInitFailed);
}
};
// SAFETY: Caller guarantees pipeline is valid for the duration of the call // Create decode router
let pl: &'static WorkerSelectionPipeline = unsafe { &*pipeline }; let decode_router = match model_manager
Ok((pl, request_id)) .kv_chooser_for(
} &endpoint,
block_size,
Some(kv_router_config),
WORKER_TYPE_DECODE,
)
.await
{
Ok(r) => r,
Err(e) => {
tracing::error!(error = ?e, "Failed to create decode router");
return Err(QueryRouterResult::ErrInitFailed);
}
};
/// Helper to run an async bookkeeping operation with timeout. // Create PrefillRouter based on one-time discovery of prefill workers
/// Returns `OK` on success or timeout, `ERR` only on validation failures (handled by caller). // Auto-detects disaggregated mode by checking if prefill workers are present
fn run_bookkeeping_with_timeout<F, Fut>( // The prefill workers have to be created before the epp is created.
pl: &WorkerSelectionPipeline, // Given that we wait first for the decode worker to show up it is reasonable to assume the prefill will be up as well.
operation: &'static str, let prefill_router = match find_prefill_endpoint(&drt, &namespace_str).await {
request_id: &str, Some(prefill_endpoint) => {
f: F, tracing::info!("Prefill worker found, running in disaggregated mode");
) -> DynamoLlmResult let mut prefill_config = kv_router_config;
where prefill_config.router_track_active_blocks = false;
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = ()>, // Create immediately-resolved channel to activate router
{ let (tx, rx) = tokio::sync::oneshot::channel();
use std::time::Duration; let _ = tx.send(prefill_endpoint);
let timeout_duration = Duration::from_secs(GAIE_BOOKKEEPING_TIMEOUT_SECS); PrefillRouter::new(
let fut = f(); rx,
model_manager.clone(),
let result = pl RouterMode::KV,
.wk block_size,
.runtime() Some(prefill_config),
.secondary() enforce_disagg,
.block_on(async { tokio::time::timeout(timeout_duration, fut).await }); model_name.clone(),
)
}
None if enforce_disagg => {
tracing::error!("Prefill workers required (enforce_disagg=true) but none found");
return Err(QueryRouterResult::ErrDisaggEnforced);
}
None => {
tracing::info!("No prefill workers found, running in aggregated mode");
PrefillRouter::disabled(model_manager.clone(), RouterMode::KV, enforce_disagg)
}
};
Ok((
prefill_router,
decode_router,
model_manager,
namespace_str,
preprocessor,
))
});
match result { match result {
Ok(()) => DynamoLlmResult::OK, Ok((prefill_router, decode_router, model_manager, namespace_str, preprocessor)) => {
Err(_elapsed) => { let handles = RouterHandles {
tracing::warn!( prefill_router,
request_id = %request_id, decode_router,
timeout_secs = GAIE_BOOKKEEPING_TIMEOUT_SECS, model_manager,
"[GAIE] {} timed out", namespace: namespace_str,
operation runtime, // Store the runtime for reuse
); preprocessor,
// Return OK to avoid blocking the caller - the operation may still complete };
DynamoLlmResult::OK unsafe { *out_handle = Box::into_raw(Box::new(handles)) };
QueryRouterResult::Ok
} }
Err(code) => code,
} }
} }
/// Router bookkeeping functions for GAIE integration
/// Add a request to the router's bookkeeping after worker selection. /// Add a request to the router's bookkeeping after worker selection.
/// Call this from GAIE Stage 1 after `dynamo_query_worker_selection_and_annotate`.
/// ///
/// This function computes the overlap_blocks internally by querying the indexer, /// Register the request with the KvRouter's scheduler for tracking active blocks
/// so the caller doesn't need to provide it. /// and managing prefill/decode lifecycle. Call this after `query_decode` returns
/// worker IDs and before sending the request to the worker.
/// ///
/// # Safety /// # Safety
/// - `pipeline` must be a valid, non-null pointer from `dynamo_create_worker_selection_pipeline` /// - `handle` must be a valid RouterHandles handle
/// - `request_id_c_str` must be a valid NUL-terminated UTF-8 C string /// - `request_id` must be a valid null-terminated C string
/// - `token_ids` must point to at least `token_count` valid u32 values /// - `token_ids` must point to at least `token_count` valid u32 values
/// - Must not be called concurrently on the same pipeline without synchronization
#[unsafe(no_mangle)] #[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_router_add_request( pub unsafe extern "C" fn add_request(
pipeline: *mut WorkerSelectionPipeline, handle: RouterHandlesPtr,
request_id_c_str: *const c_char, request_id: *const c_char,
token_ids: *const u32, token_ids: *const u32,
token_count: usize, token_count: usize,
worker_id: u64, worker_id: u64,
dp_rank: u32, dp_rank: u32,
) -> DynamoLlmResult { ) -> QueryRouterResult {
let (pl, request_id) = match unsafe { if handle.is_null() || request_id.is_null() {
validate_pipeline_and_request_id(pipeline, request_id_c_str, "add_request") return QueryRouterResult::ErrInvalidParam;
} { }
Ok(v) => v,
Err(e) => return e,
};
let Some(ref kv_router) = pl.kv_router else { let handles = unsafe { &*handle };
tracing::debug!( let request_id_str = match unsafe { CStr::from_ptr(request_id) }.to_str() {
"[GAIE] KV router not available (router_mode is not KV), skipping add_request (no-op)" Ok(s) => s.to_owned(),
); Err(_) => return QueryRouterResult::ErrInvalidParam,
return DynamoLlmResult::OK;
}; };
// Log after kv_router check to reduce noise
tracing::debug!(
request_id = %request_id,
worker_id = worker_id,
dp_rank = dp_rank,
token_count = token_count,
"[GAIE] dynamo_router_add_request processing"
);
let tokens: Vec<u32> = if token_count > 0 && !token_ids.is_null() { let tokens: Vec<u32> = if token_count > 0 && !token_ids.is_null() {
unsafe { std::slice::from_raw_parts(token_ids, token_count) }.to_vec() unsafe { std::slice::from_raw_parts(token_ids, token_count) }.to_vec()
} else { } else {
Vec::new() Vec::new()
}; };
let kv_router = kv_router.clone(); let decode_router = handles.decode_router.clone();
let request_id_clone = request_id.clone();
run_bookkeeping_with_timeout(pl, "add_request", &request_id, || async move { let result = handles.runtime.secondary().block_on(async {
let worker = dynamo_llm::kv_router::protocols::WorkerWithDpRank::new(worker_id, dp_rank); let timeout_duration = Duration::from_secs(BOOKKEEPING_TIMEOUT_SEC);
// Compute overlap_blocks using the public method tokio::time::timeout(timeout_duration, async {
let overlap_blocks = match kv_router.get_overlap_blocks(&tokens, worker).await { let worker = WorkerWithDpRank::new(worker_id, dp_rank);
Ok(overlap) => overlap,
Err(e) => {
tracing::warn!(error = ?e, "Failed to compute overlap, using 0");
0
}
};
kv_router // Compute overlap_blocks using the public method
.add_request( let overlap_blocks = match decode_router.get_overlap_blocks(&tokens, worker).await {
request_id_clone.clone(), Ok(overlap) => overlap,
&tokens, Err(e) => {
overlap_blocks, tracing::warn!(error = ?e, "Failed to compute overlap, using 0");
None, 0
worker, }
None, // lora_name not exposed in C API yet };
None, // router_config_override not exposed in C API yet
) decode_router
.await; .add_request(
request_id_str.clone(),
tracing::debug!( &tokens,
request_id = %request_id_clone, overlap_blocks,
worker_id = worker_id, None,
dp_rank = dp_rank, worker,
overlap_blocks = overlap_blocks, None, // lora_name
token_count = tokens.len(), None, // router_config_override
"[GAIE] dynamo_router_add_request completed - request registered in router bookkeeping" )
); .await;
})
tracing::debug!(
request_id = %request_id_str,
worker_id = worker_id,
dp_rank = dp_rank,
overlap_blocks = overlap_blocks,
token_count = tokens.len(),
"add_request completed"
);
})
.await
});
match result {
Ok(()) => QueryRouterResult::Ok,
Err(_elapsed) => {
tracing::warn!(
request_id = %request_id_str,
timeout_secs = BOOKKEEPING_TIMEOUT_SEC,
"add_request timed out"
);
QueryRouterResult::ErrTimeout
}
}
} }
/// Mark prefill as completed for a request. /// Mark prefill as completed for a request.
/// Call this from the EPP extension point when the first token is generated. ///
/// Call when the first token is generated to release prefill tokens from decode worker's load
/// ///
/// # Safety /// # Safety
/// - `pipeline` must be a valid, non-null pointer from `dynamo_create_worker_selection_pipeline` /// - `handle` must be a valid RouterHandles handle
/// - `request_id_c_str` must be a valid NUL-terminated UTF-8 C string /// - `request_id` must be a valid null-terminated C string
#[unsafe(no_mangle)] #[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_router_mark_prefill_complete( pub unsafe extern "C" fn mark_prefill_complete(
pipeline: *mut WorkerSelectionPipeline, handle: RouterHandlesPtr,
request_id_c_str: *const c_char, request_id: *const c_char,
) -> DynamoLlmResult { ) -> QueryRouterResult {
let (pl, request_id) = match unsafe { if handle.is_null() || request_id.is_null() {
validate_pipeline_and_request_id(pipeline, request_id_c_str, "mark_prefill_complete") return QueryRouterResult::ErrInvalidParam;
} { }
Ok(v) => v,
Err(e) => return e,
};
let Some(ref kv_router) = pl.kv_router else { let handles = unsafe { &*handle };
tracing::debug!( let request_id_str = match unsafe { CStr::from_ptr(request_id) }.to_str() {
"[GAIE] KV router not available (router_mode is not KV), skipping mark_prefill_complete (no-op)" Ok(s) => s.to_owned(),
); Err(_) => return QueryRouterResult::ErrInvalidParam,
return DynamoLlmResult::OK;
}; };
// Log after kv_router check to reduce noise let decode_router = handles.decode_router.clone();
tracing::debug!(
request_id = %request_id,
"[GAIE] dynamo_router_mark_prefill_complete processing"
);
let kv_router = kv_router.clone(); let result = handles.runtime.secondary().block_on(async {
let request_id_clone = request_id.clone(); let timeout_duration = Duration::from_secs(BOOKKEEPING_TIMEOUT_SEC);
tokio::time::timeout(timeout_duration, async {
if let Err(e) = decode_router.mark_prefill_completed(&request_id_str).await {
tracing::warn!(
request_id = %request_id_str,
error = %e,
"Failed to mark prefill complete"
);
} else {
tracing::debug!(
request_id = %request_id_str,
"mark_prefill_complete completed"
);
}
})
.await
});
run_bookkeeping_with_timeout(pl, "mark_prefill_complete", &request_id, || async move { match result {
if let Err(e) = kv_router.mark_prefill_completed(&request_id_clone).await { Ok(()) => QueryRouterResult::Ok,
Err(_elapsed) => {
tracing::warn!( tracing::warn!(
"Failed to mark prefill completed for {}: {}", request_id = %request_id_str,
request_id_clone, timeout_secs = BOOKKEEPING_TIMEOUT_SEC,
e "mark_prefill_complete timed out"
);
} else {
tracing::debug!(
request_id = %request_id_clone,
"[GAIE] dynamo_router_mark_prefill_complete completed - prefill tokens released"
); );
QueryRouterResult::ErrTimeout
} }
}) }
} }
/// Free a request from the router's bookkeeping. /// Free a request from the router's bookkeeping.
/// Call this from GAIE hook when the stream is closed (completed or cancelled). ///
/// Call this when the stream is closed (completed or cancelled) to release all resources.
/// ///
/// # Safety /// # Safety
/// - `pipeline` must be a valid, non-null pointer from `dynamo_create_worker_selection_pipeline` /// - `handle` must be a valid RouterHandles handle
/// - `request_id_c_str` must be a valid NUL-terminated UTF-8 C string /// - `request_id` must be a valid null-terminated C string
#[unsafe(no_mangle)] #[unsafe(no_mangle)]
pub unsafe extern "C" fn dynamo_router_free_request( pub unsafe extern "C" fn free_request(
pipeline: *mut WorkerSelectionPipeline, handle: RouterHandlesPtr,
request_id_c_str: *const c_char, request_id: *const c_char,
) -> DynamoLlmResult { ) -> QueryRouterResult {
let (pl, request_id) = match unsafe { if handle.is_null() || request_id.is_null() {
validate_pipeline_and_request_id(pipeline, request_id_c_str, "free_request") return QueryRouterResult::ErrInvalidParam;
} { }
Ok(v) => v,
Err(e) => return e,
};
let Some(ref kv_router) = pl.kv_router else { let handles = unsafe { &*handle };
tracing::debug!( let request_id_str = match unsafe { CStr::from_ptr(request_id) }.to_str() {
"[GAIE] KV router not available (router_mode is not KV), skipping free_request (no-op)" Ok(s) => s.to_owned(),
); Err(_) => return QueryRouterResult::ErrInvalidParam,
return DynamoLlmResult::OK;
}; };
// Log after kv_router check to reduce noise let decode_router = handles.decode_router.clone();
tracing::debug!(
request_id = %request_id,
"[GAIE] dynamo_router_free_request processing"
);
let kv_router = kv_router.clone(); let result = handles.runtime.secondary().block_on(async {
let request_id_clone = request_id.clone(); let timeout_duration = Duration::from_secs(BOOKKEEPING_TIMEOUT_SEC);
run_bookkeeping_with_timeout(pl, "free_request", &request_id, || async move { tokio::time::timeout(timeout_duration, async {
if let Err(e) = kv_router.free(&request_id_clone).await { if let Err(e) = decode_router.free(&request_id_str).await {
tracing::warn!("Failed to free request {}: {}", request_id_clone, e); tracing::warn!(
} else { request_id = %request_id_str,
tracing::debug!( error = %e,
request_id = %request_id_clone, "Failed to free request"
"[GAIE] dynamo_router_free_request completed - request removed from bookkeeping"
);
}
})
}
/// Result of worker selection extraction
#[derive(Debug, Clone, Default)]
pub struct WorkerSelectionResult {
/// Decode worker ID (primary worker for aggregated, decode-only for disaggregated)
pub decode_worker_id: Option<i64>,
/// Prefill worker ID (only present in disaggregated mode)
pub prefill_worker_id: Option<i64>,
/// Token IDs from tokenization
pub tokens: Vec<u32>,
}
/// Helper function to extract worker selection information from the annotation stream
///
/// The response format (from disaggregated_params in nvext):
/// - worker_id: {"prefill_worker_id": 123, "decode_worker_id": 456}
/// - token_ids: [1, 2, 3, ...]
pub async fn extract_worker_selection_from_stream(
mut stream: Pin<Box<dyn AsyncEngineStream<Annotated<NvCreateChatCompletionStreamResponse>>>>,
) -> anyhow::Result<WorkerSelectionResult> {
use dynamo_llm::protocols::openai::nvext::WorkerIdInfo;
use futures::StreamExt;
let mut result = WorkerSelectionResult::default();
while let Some(response) = stream.next().await {
// Check for data in nvext (worker_id and token_ids are direct fields)
// nvext is a serde_json::Value, so we access it as a JSON object
if let Some(data) = &response.data
&& let Some(nvext) = &data.nvext
{
// Extract worker_id
if let Some(worker_id_value) = nvext.get("worker_id")
&& let Ok(worker_info) =
serde_json::from_value::<WorkerIdInfo>(worker_id_value.clone())
{
result.decode_worker_id = worker_info.decode_worker_id.map(|id| id as i64);
result.prefill_worker_id = worker_info.prefill_worker_id.map(|id| id as i64);
tracing::debug!(
decode_worker_id = ?result.decode_worker_id,
prefill_worker_id = ?result.prefill_worker_id,
"Parsed worker_id from nvext"
); );
} } else {
// Extract token_ids
if let Some(token_ids_value) = nvext.get("token_ids")
&& let Ok(parsed_tokens) =
serde_json::from_value::<Vec<u32>>(token_ids_value.clone())
{
result.tokens = parsed_tokens;
tracing::debug!( tracing::debug!(
"Successfully parsed {} tokens from nvext", request_id = %request_id_str,
result.tokens.len() "free_request completed"
); );
} }
})
.await
});
match result {
Ok(()) => QueryRouterResult::Ok,
Err(_elapsed) => {
tracing::warn!(
request_id = %request_id_str,
timeout_secs = BOOKKEEPING_TIMEOUT_SEC,
"free_request timed out"
);
QueryRouterResult::ErrTimeout
} }
} }
}
tracing::info!( /// Destroy router handles
decode_worker_id = ?result.decode_worker_id, ///
prefill_worker_id = ?result.prefill_worker_id, /// # Safety
token_count = result.tokens.len(), /// - `handle` must be a valid RouterHandles handle or null
"Worker selection extraction complete" /// - After this call, `handle` must not be used
); #[unsafe(no_mangle)]
Ok(result) pub unsafe extern "C" fn destroy(handle: RouterHandlesPtr) {
if !handle.is_null() {
drop(unsafe { Box::from_raw(handle) });
}
} }
/// Utility function to add the "query_instance_id" annotation to an OpenAI request /// Route a chat completion request in a single call.
/// ///
/// This function modifies the request to include the annotation that signals the KV router /// This is the main function for EPP to route a `/v1/chat/completions` request.
/// to return worker selection information (worker_fid and token_data) instead of /// It combines tokenization and worker selection in one call:
/// performing actual inference. /// 1. Applies the chat template to the request JSON
/// 2. Tokenizes the formatted prompt
/// 3. Queries the prefill router (if disaggregated mode)
/// 4. Queries the decode router
/// 5. Returns worker IDs and token_ids
/// ///
/// # Parameters /// After this call, EPP should:
/// - `request`: Mutable reference to the OpenAI chat completion request /// - Call `add_request()` to register the request for bookkeeping
/// - Set worker ID headers and forward to backend
/// - Call `mark_prefill_complete()` on first token
/// - Call `free_request()` when the stream ends
/// - Call `free_routing_result()` to free the result
/// ///
/// # Returns /// # Safety
/// The same request with the "query_instance_id" annotation added /// - `handle` must be a valid RouterHandles handle
pub fn add_query_instance_id( /// - `request_json` must be a valid null-terminated C string containing JSON
request: &mut NvCreateChatCompletionRequest, /// - `out_result` must be a valid pointer
) -> &mut NvCreateChatCompletionRequest { #[unsafe(no_mangle)]
// Send empty value - router treats empty as aggregated / aggregated worker selection pub unsafe extern "C" fn route_request(
set_kv_annotation(request, "query_instance_id".to_string(), "") handle: RouterHandlesPtr,
} request_json: *const c_char,
out_result: *mut CRoutingResult,
) -> QueryRouterResult {
if handle.is_null() || request_json.is_null() || out_result.is_null() {
return QueryRouterResult::ErrInvalidParam;
}
// Note: set_worker_ids_for_stage2 and set_token_data_for_stage2 have been removed. let handles = unsafe { &*handle };
// The EPP now handles routing configuration via HTTP headers:
// - `x-worker-instance-id`: decode worker ID // Get preprocessor
// - `x-prefill-instance-id`: prefill worker ID (disaggregated mode only) let preprocessor = match &handles.preprocessor {
// - `x-enable-local-updates`: set to "false" to disable router bookkeeping Some(p) => p,
// None => {
// Body modifications are NOT sent to the inference engine (only headers are forwarded), tracing::error!("Preprocessor not available");
// so these functions were ineffective. return QueryRouterResult::ErrInitFailed;
}
/// Ensure `nvext` exists and return a mutable slice of annotations. };
fn ensure_annotations(request: &mut NvCreateChatCompletionRequest) -> &mut Vec<String> {
let nvext = request.nvext.get_or_insert_with(|| { let json_str = match unsafe { CStr::from_ptr(request_json) }.to_str() {
NvExt::builder() Ok(s) => s,
.build() Err(_) => return QueryRouterResult::ErrInvalidParam,
.expect("NvExt builder should not fail") };
// Parse JSON
let request: dynamo_llm::types::openai::chat_completions::NvCreateChatCompletionRequest =
match serde_json::from_str(json_str) {
Ok(req) => req,
Err(e) => {
tracing::error!(error = ?e, "Failed to parse request JSON");
return QueryRouterResult::ErrInvalidParam;
}
};
// Apply chat template
let formatted_prompt = match preprocessor.apply_template(&request) {
Ok(Some(prompt)) => prompt,
Ok(None) => String::new(),
Err(e) => {
tracing::error!(error = ?e, "Failed to apply chat template");
return QueryRouterResult::ErrQueryFailed;
}
};
// Tokenize
let encoding = match preprocessor.tokenize(&formatted_prompt) {
Ok(enc) => enc,
Err(e) => {
tracing::error!(error = ?e, "Failed to tokenize");
return QueryRouterResult::ErrQueryFailed;
}
};
let tokens = encoding.token_ids();
let token_count = tokens.len();
let is_disaggregated = handles.prefill_router.is_activated();
// Query workers
let result = handles.runtime.secondary().block_on(async {
let prefill_worker_id = if is_disaggregated {
handles
.query_prefill_worker(tokens, false, None, 0.0)
.await?
} else {
0
};
let (decode_worker, _overlap_blocks) = handles
.query_decode_worker(tokens, is_disaggregated)
.await?;
tracing::info!(
is_disaggregated = is_disaggregated,
prefill_worker_id = prefill_worker_id,
decode_worker_id = decode_worker.worker_id,
decode_dp_rank = decode_worker.dp_rank,
token_count = token_count,
"Routed chat request"
);
Ok((prefill_worker_id, decode_worker))
}); });
nvext.annotations.get_or_insert_with(Vec::new)
}
/// Set a `key:value` annotation. match result {
fn set_kv_annotation( Ok((prefill_worker_id, decode_worker)) => {
request: &mut NvCreateChatCompletionRequest, // Allocate and copy token IDs for caller (needed for add_request bookkeeping)
key: String, // <- owned, only one borrowed param remains let token_vec: Vec<u32> = tokens.to_vec();
value: impl Into<String>, let mut tokens_boxed = token_vec.into_boxed_slice();
) -> &mut NvCreateChatCompletionRequest { let token_ptr = tokens_boxed.as_mut_ptr();
let prefix = format!("{}:", key); std::mem::forget(tokens_boxed);
let kv = format!("{}{}", prefix, value.into());
let annotations = ensure_annotations(request); unsafe {
annotations.retain(|a| !a.starts_with(&prefix)); *out_result = CRoutingResult {
annotations.push(kv); is_disaggregated,
request prefill_worker_id,
decode_worker_id: decode_worker.worker_id,
token_ids: token_ptr,
token_count,
};
}
QueryRouterResult::Ok
}
Err(code) => code,
}
} }
/// Wrapper function that queries worker selection for GAIE Stage 1 /// Free a routing result.
///
/// This function performs the complete GAIE Stage 1 flow:
/// 1. Clones the original request and adds "query_instance_id:" (empty) annotation
/// 2. Calls engine.generate() with the modified request
/// 3. Extracts worker_id info and tokens from the response stream
/// 4. Returns WorkerSelectionResult and the original request
///
/// Note: The EPP (caller) is responsible for setting HTTP headers for Stage 2:
/// - `x-worker-instance-id`: decode worker ID
/// - `x-prefill-instance-id`: prefill worker ID (disaggregated mode only)
/// - `x-enable-local-updates`: "false" to disable router bookkeeping
///
/// Body modifications are NOT forwarded to the inference engine, so this function
/// does not modify the request body.
///
/// # Parameters
/// - `engine`: The worker selection pipeline engine
/// - `original_request`: The original OpenAI request to process
/// ///
/// # Returns /// # Safety
/// A tuple containing (WorkerSelectionResult, original_request) /// - `result` must be a valid pointer to a CRoutingResult previously returned by route functions
pub async fn query_worker_selection_and_annotate( #[unsafe(no_mangle)]
engine: &ServiceEngine< pub unsafe extern "C" fn free_routing_result(result: *mut CRoutingResult) {
SingleIn<NvCreateChatCompletionRequest>, if result.is_null() {
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, return;
>, }
original_request: NvCreateChatCompletionRequest,
) -> anyhow::Result<(WorkerSelectionResult, NvCreateChatCompletionRequest)> { let res = unsafe { &mut *result };
// GAIE Stage 1: Query for worker selection
let mut query_request = original_request.clone(); // Free token IDs
add_query_instance_id(&mut query_request); if !res.token_ids.is_null() && res.token_count > 0 {
let single_in = SingleIn::new(query_request); drop(unsafe {
let response_stream = engine.generate(single_in).await?; Box::from_raw(std::slice::from_raw_parts_mut(
let result = extract_worker_selection_from_stream(response_stream).await?; res.token_ids,
res.token_count,
// Return the original request unchanged. ))
// The EPP sets routing headers (worker IDs, enable_local_updates) which the });
// Dynamo frontend reads via apply_header_routing_overrides(). res.token_ids = ptr::null_mut();
Ok((result, original_request)) res.token_count = 0;
}
} }
/// Spawn a background task to watch for prefill models and activate prefill routers. /// Fetch model card via discovery and create preprocessor.
/// This is a lightweight watcher that only handles prefill model discovery. ///
fn spawn_prefill_watcher( /// This function:
drt: DistributedRuntime, /// 1. Lists all models via discovery
model_manager: Arc<ModelManager>, /// 2. Finds the first model in the target namespace (decode workers only)
target_namespace: String, /// 3. Downloads the model config (tokenizer files) if needed
) { /// 4. Creates an OpenAIPreprocessor from the model card
/// 5. Returns the preprocessor, the kv_cache_block_size, and model_name from the model card
async fn fetch_preprocessor_from_discovery(
drt: &DistributedRuntime,
target_namespace: &str,
) -> anyhow::Result<(Arc<OpenAIPreprocessor>, u32, String)> {
use dynamo_llm::model_card::ModelDeploymentCard; use dynamo_llm::model_card::ModelDeploymentCard;
use dynamo_runtime::discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryQuery}; use dynamo_runtime::discovery::DiscoveryInstance;
use dynamo_runtime::protocols::EndpointId;
use futures::StreamExt;
tokio::spawn(async move {
let discovery = drt.discovery();
let mut stream = match discovery
.list_and_watch(DiscoveryQuery::AllModels, None)
.await
{
Ok(s) => s,
Err(e) => {
tracing::error!(error = %e, "Failed to start prefill discovery stream");
return;
}
};
while let Some(result) = stream.next().await { let discovery = drt.discovery();
let event = match result {
Ok(e) => e,
Err(e) => {
tracing::error!(error = %e, "Error in prefill discovery stream");
continue;
}
};
match event { // List all models
DiscoveryEvent::Added(instance) => { let instances = discovery.list(DiscoveryQuery::AllModels).await?;
let (endpoint_id, card) = match &instance {
DiscoveryInstance::Model {
namespace,
component,
endpoint,
..
} => {
// Filter by namespace
if namespace != &target_namespace {
continue;
}
let eid = EndpointId {
namespace: namespace.clone(),
component: component.clone(),
name: endpoint.clone(),
};
match instance.deserialize_model::<ModelDeploymentCard>() {
Ok(card) => (eid, card),
Err(_) => continue,
}
}
_ => continue,
};
// Only handle prefill models
if !card.model_type.supports_prefill() {
continue;
}
tracing::info!( // Find first model card in the target namespace (decode workers only)
model_name = card.name(), let mut model_card: Option<ModelDeploymentCard> = None;
"Prefill model discovered, activating prefill router"
);
// Get the endpoint and activate the prefill router for instance in instances {
if let Ok(ns) = drt.namespace(&endpoint_id.namespace) if let DiscoveryInstance::Model { namespace, .. } = &instance {
&& let Ok(comp) = ns.component(&endpoint_id.component) // Filter by namespace
if namespace != target_namespace {
continue;
}
match instance.deserialize_model::<ModelDeploymentCard>() {
Ok(card) => {
// Skip prefill-only workers, we want decode workers for routing
if card.model_type.supports_prefill()
&& !card.model_type.supports_chat()
&& !card.model_type.supports_completions()
{ {
let endpoint = comp.endpoint(&endpoint_id.name); continue;
if let Err(e) = model_manager.activate_prefill_router(card.name(), endpoint)
{
tracing::warn!(
model_name = card.name(),
error = %e,
"Failed to activate prefill router"
);
} else {
tracing::info!(
model_name = card.name(),
"Prefill router activated successfully"
);
}
} }
model_card = Some(card);
break;
} }
DiscoveryEvent::Removed(id) => { Err(e) => {
// Log removal for observability tracing::debug!(error = %e, "Failed to deserialize model card, skipping");
// Note: The PrefillRouter remains active - worker availability continue;
// is handled dynamically by the underlying Client's instance tracking
tracing::debug!(
instance_id = id.instance_id(),
"Prefill worker instance removed from discovery"
);
} }
} }
} }
});
}
/// Create a worker selection pipeline for OpenAI Chat Completion requests
///
/// This is a concrete implementation that works specifically with NvCreateChatCompletionRequest
/// and is designed for use with C bindings. Uses the "generate" endpoint by default.
///
/// # Parameters
/// - `namespace`: namespace name
/// - `component_name`: component name
/// - `model_name`: Name/slug of the model to load
/// - `router_mode`: How to route requests (KV, RoundRobin, etc.)
/// - `busy_threshold`: Optional threshold for busy worker detection
/// - `kv_router_config`: Optional KV router configuration (only used when router_mode is KV)
/// - `enforce_disagg`: If true, fail requests when disaggregated serving is unavailable
///
/// # Returns
/// A tuple of (engine, kv_router) where kv_router is Some when router_mode is KV
pub async fn create_worker_selection_pipeline_chat(
namespace: &str,
component_name: &str,
model_name: &str,
router_mode: RouterMode,
busy_threshold: Option<f64>,
kv_router_config: Option<KvRouterConfig>,
enforce_disagg: bool,
) -> anyhow::Result<(
ServiceEngine<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>,
Option<Arc<dynamo_llm::kv_router::KvRouter>>,
)> {
use dynamo_llm::discovery::WORKER_TYPE_DECODE;
use dynamo_llm::kv_router::PrefillRouter;
// Use the global DRT singleton - initialize if not already done
// Check if already initialized (by dynamo_llm_init) to avoid redundant sync wait
let needs_sync = DRT.get().is_none();
let distributed_runtime = DRT
.get_or_try_init(async {
tracing::debug!("Initializing DistributedRuntime singleton (standalone mode)");
DistributedRuntime::from_settings(Runtime::from_settings()?).await
})
.await
.map_err(|e| anyhow::anyhow!("Failed to initialize DistributedRuntime: {}", e))?;
// Only wait for discovery sync if we just initialized the DRT
// (dynamo_llm_init already does this when it initializes)
// Note: This waits indefinitely - the K8s StartupProbe is the timeout mechanism.
if needs_sync {
wait_for_discovery_sync(distributed_runtime).await;
} }
let component = distributed_runtime let mut card = model_card.ok_or_else(|| {
.namespace(namespace)? anyhow::anyhow!(
.component(component_name)?; "No model found in namespace '{}' via discovery",
let endpoint = component.endpoint(GENERATE_ENDPOINT); target_namespace
let client = endpoint.client().await?; )
})?;
// Discover the model card by searching all instances with this model name
tracing::debug!("Looking for model: {}", model_name); let kv_cache_block_size = card.kv_cache_block_size;
tracing::debug!("Namespace: {}", namespace); let model_name = card.name().to_string();
tracing::info!(
let model_manager = Arc::new(ModelManager::new()); model_name = model_name,
let router_config = dynamo_llm::entrypoint::RouterConfig { kv_cache_block_size = kv_cache_block_size,
router_mode, "Found model card via discovery"
kv_router_config: kv_router_config.unwrap_or_default(),
load_threshold_config: dynamo_llm::discovery::LoadThresholdConfig {
active_decode_blocks_threshold: busy_threshold,
active_prefill_tokens_threshold: None,
active_prefill_tokens_threshold_frac: None,
},
enforce_disagg,
};
// Create metrics for migration tracking (not exposed via /metrics in C bindings)
let metrics = Arc::new(Metrics::new());
let watcher = ModelWatcher::new(
component.drt().clone(),
model_manager.clone(),
router_config,
0, // migration_limit - default to 0 for C bindings
None,
metrics.clone(),
); );
let cards = watcher
.cards_for_model(model_name, Some(namespace), false)
.await
.with_context(|| format!("Failed to discover model: {}", model_name))?;
tracing::debug!("Found {} cards for model {}", cards.len(), model_name); // Download config (tokenizer files) if not local
card.download_config().await?;
let card = cards.into_iter().next().ok_or_else(|| { // Create preprocessor
tracing::error!("No ModelDeploymentCard found for model: {}", model_name); let preprocessor = OpenAIPreprocessor::new(card)?;
anyhow::anyhow!("ModelDeploymentCard not found for model: {}", model_name) Ok((preprocessor, kv_cache_block_size, model_name))
})?; }
let chooser = if router_mode == RouterMode::KV { /// Find a prefill endpoint from already-discovered instances (one-time filter).
Some( /// Returns the endpoint if a prefill worker is found in the target namespace.
model_manager async fn find_prefill_endpoint(
.kv_chooser_for( drt: &DistributedRuntime,
&endpoint, target_namespace: &str,
card.kv_cache_block_size, ) -> Option<dynamo_runtime::component::Endpoint> {
kv_router_config, use dynamo_llm::model_card::ModelDeploymentCard;
WORKER_TYPE_DECODE, use dynamo_runtime::discovery::DiscoveryInstance;
)
.await?, let discovery = drt.discovery();
) let instances = match discovery.list(DiscoveryQuery::AllModels).await {
} else { Ok(instances) => instances,
None Err(e) => {
tracing::warn!(error = %e, "Failed to list instances for prefill discovery");
return None;
}
}; };
// Create prefill chooser for dynamic disaggregation support for instance in instances {
// This registers the model and returns a receiver that will be activated if let DiscoveryInstance::Model {
// when a prefill worker is discovered namespace,
let prefill_chooser = model_manager component,
.register_prefill_router(model_name.to_string()) endpoint,
.map(|rx| { ..
// Create prefill-specific config with track_active_blocks disabled } = &instance
let mut prefill_config = kv_router_config.unwrap_or_default(); {
prefill_config.router_track_active_blocks = false; // Filter by namespace
if namespace != target_namespace {
PrefillRouter::new( continue;
rx, }
model_manager.clone(),
router_mode,
card.kv_cache_block_size,
Some(prefill_config),
enforce_disagg,
model_name.to_string(),
)
});
// Start background watcher for prefill model discovery let card = match instance.deserialize_model::<ModelDeploymentCard>() {
// This will activate the prefill router when prefill workers join Ok(card) => card,
spawn_prefill_watcher( Err(_) => continue,
component.drt().clone(), };
model_manager.clone(),
namespace.to_string(),
);
// Download model config files from HuggingFace for EPP // Only handle prefill models
// The backend's card has NATS URLs which aren't accessible from EPP if !card.model_type.supports_prefill() {
tracing::debug!( continue;
"Downloading model config files for EPP: {}", }
card.display_name
);
let local_path = dynamo_llm::hub::from_hf(&card.display_name, true) tracing::info!(
.await model_name = card.name(),
.with_context(|| { "Prefill worker found in discovered instances"
format!( );
"Failed to download model config files for: {}",
card.display_name // Build and return the endpoint
) if let Ok(ns) = drt.namespace(namespace)
})?; && let Ok(comp) = ns.component(component)
{
// Load a fresh card from local files, then copy runtime config from original card return Some(comp.endpoint(endpoint));
tracing::debug!("Loading ModelDeploymentCard from local path..."); }
let mut card_with_local_files = ModelDeploymentCard::load_from_disk(&local_path, None) }
.with_context(|| format!("Failed to load card from disk: {:?}", local_path))?; }
// Copy runtime settings from the backend's card
tracing::debug!("Copying runtime config from backend card...");
card_with_local_files.runtime_config = card.runtime_config.clone();
card_with_local_files.kv_cache_block_size = card.kv_cache_block_size;
card_with_local_files.context_length = card.context_length;
// Load the tokenizer from the downloaded files
tracing::debug!("Loading tokenizer from local files...");
let hf_tokenizer = card_with_local_files
.tokenizer_hf()
.with_context(|| format!("Failed to load tokenizer for: {}", card.display_name))?;
// Create worker monitor if busy_threshold is set
// Note: C bindings don't register with ModelManager, so HTTP endpoint won't see this
let worker_monitor = busy_threshold.map(|t| {
KvWorkerMonitor::new(
client.clone(),
dynamo_llm::discovery::LoadThresholdConfig {
active_decode_blocks_threshold: Some(t),
active_prefill_tokens_threshold: None,
active_prefill_tokens_threshold_frac: None,
},
)
});
// Clone chooser before passing to build_routed_pipeline (which takes ownership) None
let kv_router = chooser.clone();
let engine = build_routed_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(
&card_with_local_files,
&client,
model_manager.clone(),
router_mode,
worker_monitor,
chooser,
hf_tokenizer,
prefill_chooser,
enforce_disagg,
0, // migration_limit - default to 0 for C bindings
metrics,
)
.await?;
Ok((engine, kv_router))
} }
...@@ -1635,6 +1635,7 @@ dependencies = [ ...@@ -1635,6 +1635,7 @@ dependencies = [
"ahash", "ahash",
"aho-corasick", "aho-corasick",
"akin", "akin",
"aligned-vec",
"anyhow", "anyhow",
"async-nats", "async-nats",
"async-stream", "async-stream",
...@@ -1651,6 +1652,7 @@ dependencies = [ ...@@ -1651,6 +1652,7 @@ dependencies = [
"bytes", "bytes",
"candle-core", "candle-core",
"chrono", "chrono",
"cudarc",
"dashmap 5.5.3", "dashmap 5.5.3",
"derive-getters", "derive-getters",
"derive_builder", "derive_builder",
...@@ -1685,6 +1687,8 @@ dependencies = [ ...@@ -1685,6 +1687,8 @@ dependencies = [
"ndarray", "ndarray",
"ndarray-interp", "ndarray-interp",
"ndarray-npy", "ndarray-npy",
"nix 0.26.4",
"nixl-sys",
"object_store", "object_store",
"offset-allocator", "offset-allocator",
"oneshot", "oneshot",
...@@ -3962,6 +3966,15 @@ version = "0.3.3" ...@@ -3962,6 +3966,15 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b" checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b"
[[package]]
name = "memoffset"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "memoffset" name = "memoffset"
version = "0.9.1" version = "0.9.1"
...@@ -4255,6 +4268,19 @@ dependencies = [ ...@@ -4255,6 +4268,19 @@ dependencies = [
"thiserror 1.0.69", "thiserror 1.0.69",
] ]
[[package]]
name = "nix"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b"
dependencies = [
"bitflags 1.3.2",
"cfg-if 1.0.4",
"libc",
"memoffset 0.7.1",
"pin-utils",
]
[[package]] [[package]]
name = "nix" name = "nix"
version = "0.29.0" version = "0.29.0"
...@@ -5501,7 +5527,7 @@ dependencies = [ ...@@ -5501,7 +5527,7 @@ dependencies = [
"cfg-if 1.0.4", "cfg-if 1.0.4",
"indoc", "indoc",
"libc", "libc",
"memoffset", "memoffset 0.9.1",
"once_cell", "once_cell",
"portable-atomic", "portable-atomic",
"pyo3-build-config", "pyo3-build-config",
......
...@@ -46,6 +46,9 @@ pub enum RouterMode { ...@@ -46,6 +46,9 @@ pub enum RouterMode {
RoundRobin, RoundRobin,
Random, Random,
KV, KV,
/// Direct routing - reads worker ID from each request's routing hints.
/// Used when an external orchestrator (e.g., EPP) handles worker selection.
Direct,
} }
impl From<RouterMode> for RsRouterMode { impl From<RouterMode> for RsRouterMode {
...@@ -54,6 +57,7 @@ impl From<RouterMode> for RsRouterMode { ...@@ -54,6 +57,7 @@ impl From<RouterMode> for RsRouterMode {
RouterMode::RoundRobin => Self::RoundRobin, RouterMode::RoundRobin => Self::RoundRobin,
RouterMode::Random => Self::Random, RouterMode::Random => Self::Random,
RouterMode::KV => Self::KV, RouterMode::KV => Self::KV,
RouterMode::Direct => Self::Direct,
} }
} }
} }
......
...@@ -950,6 +950,7 @@ class RouterMode: ...@@ -950,6 +950,7 @@ class RouterMode:
RoundRobin: "RouterMode" RoundRobin: "RouterMode"
Random: "RouterMode" Random: "RouterMode"
KV: "RouterMode" KV: "RouterMode"
Direct: "RouterMode"
... ...
class RouterConfig: class RouterConfig:
...@@ -968,7 +969,7 @@ class RouterConfig: ...@@ -968,7 +969,7 @@ class RouterConfig:
Create a RouterConfig. Create a RouterConfig.
Args: Args:
mode: The router mode (RoundRobin, Random, or KV) mode: The router mode (RoundRobin, Random, KV, or Direct)
config: Optional KV router configuration (used when mode is KV) config: Optional KV router configuration (used when mode is KV)
active_decode_blocks_threshold: Threshold percentage (0.0-1.0) for decode blocks busy detection active_decode_blocks_threshold: Threshold percentage (0.0-1.0) for decode blocks busy detection
active_prefill_tokens_threshold: Literal token count threshold for prefill busy detection active_prefill_tokens_threshold: Literal token count threshold for prefill busy detection
......
...@@ -9,7 +9,7 @@ use crate::{ ...@@ -9,7 +9,7 @@ use crate::{
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
entrypoint::{EngineConfig, RouterConfig}, entrypoint::{EngineConfig, RouterConfig},
http::service::metrics::Metrics, http::service::metrics::Metrics,
kv_router::{KvPushRouter, KvRouter, PrefillRouter}, kv_router::{DirectRoutingRouter, KvPushRouter, KvRouter, PrefillRouter},
migration::Migration, migration::Migration,
model_card::ModelDeploymentCard, model_card::ModelDeploymentCard,
preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter}, preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter},
...@@ -274,10 +274,10 @@ where ...@@ -274,10 +274,10 @@ where
.await?; .await?;
let service_backend = match router_mode { let service_backend = match router_mode {
RouterMode::Random | RouterMode::RoundRobin | RouterMode::Direct(_) => { RouterMode::Direct => {
// Non-KV routing: use PushRouter directly. ServiceBackend::from_engine(Arc::new(DirectRoutingRouter::new(router)))
// Note: Per-worker metrics (active_prefill_tokens, active_decode_blocks) are only }
// available in KV routing mode where the router has actual bookkeeping. RouterMode::Random | RouterMode::RoundRobin => {
ServiceBackend::from_engine(Arc::new(router)) ServiceBackend::from_engine(Arc::new(router))
} }
RouterMode::KV => { RouterMode::KV => {
......
...@@ -41,7 +41,7 @@ pub mod worker_query; ...@@ -41,7 +41,7 @@ pub mod worker_query;
pub use config::{KvRouterConfig, RouterConfigOverride}; pub use config::{KvRouterConfig, RouterConfigOverride};
pub use prefill_router::PrefillRouter; pub use prefill_router::PrefillRouter;
pub use push_router::KvPushRouter; pub use push_router::{DirectRoutingRouter, KvPushRouter};
use crate::{ use crate::{
discovery::RuntimeConfigWatch, discovery::RuntimeConfigWatch,
......
...@@ -81,14 +81,6 @@ impl InnerPrefillRouter { ...@@ -81,14 +81,6 @@ impl InnerPrefillRouter {
InnerPrefillRouter::KvRouter(_) => None, InnerPrefillRouter::KvRouter(_) => None,
} }
} }
/// Peek next worker without incrementing state (for non-KV modes only)
fn peek_next_worker(&self) -> Option<u64> {
match self {
InnerPrefillRouter::SimpleRouter(router) => router.peek_next_worker(),
InnerPrefillRouter::KvRouter(_) => None,
}
}
} }
/// PrefillRouter is a forward-only operator that sits between Migration and the decode router. /// PrefillRouter is a forward-only operator that sits between Migration and the decode router.
...@@ -273,7 +265,7 @@ impl PrefillRouter { ...@@ -273,7 +265,7 @@ impl PrefillRouter {
preselected_worker: Option<u64>, preselected_worker: Option<u64>,
) -> Option<(u64, u32, BootstrapInfo)> { ) -> Option<(u64, u32, BootstrapInfo)> {
let endpoint_id = self.endpoint_id.get()?; let endpoint_id = self.endpoint_id.get()?;
let prefill_router = self.prefill_router.get()?; let _prefill_router = self.prefill_router.get()?;
// Worker selection // Worker selection
let (worker_id, dp_rank) = if let Some(id) = preselected_worker { let (worker_id, dp_rank) = if let Some(id) = preselected_worker {
...@@ -284,12 +276,8 @@ impl PrefillRouter { ...@@ -284,12 +276,8 @@ impl PrefillRouter {
"Using pre-selected prefill worker for bootstrap" "Using pre-selected prefill worker for bootstrap"
); );
(id, dp_rank) (id, dp_rank)
} else if self.router_mode.is_kv_routing() { } else {
// KV mode: use find_best_match // Use shared worker selection logic (update_states=false for peek behavior)
let kv_router = match prefill_router {
InnerPrefillRouter::KvRouter(r) => r,
_ => return None,
};
// Extract LORA name and priority jump from routing hints // Extract LORA name and priority jump from routing hints
let lora_name = req.routing.as_ref().and_then(|r| r.lora_name.clone()); let lora_name = req.routing.as_ref().and_then(|r| r.lora_name.clone());
let priority_jump = req let priority_jump = req
...@@ -297,24 +285,14 @@ impl PrefillRouter { ...@@ -297,24 +285,14 @@ impl PrefillRouter {
.as_ref() .as_ref()
.and_then(|r| r.priority_jump) .and_then(|r| r.priority_jump)
.unwrap_or(0.0); .unwrap_or(0.0);
match async { match self
kv_router .query_prefill_worker(&req.token_ids, false, lora_name, priority_jump)
.chooser .instrument(tracing::info_span!("query_prefill_worker"))
.find_best_match(None, &req.token_ids, None, false, lora_name, priority_jump) .await
.await
}
.instrument(tracing::info_span!("kv_find_best_match"))
.await
{ {
Ok((worker, _overlap)) => (worker.worker_id, worker.dp_rank), Ok((worker_id, dp_rank)) => (worker_id, dp_rank),
Err(_) => return None, Err(_) => return None,
} }
} else {
// Non-KV mode: use PushRouter's stateful selection
// We use peek_next_worker instead of select_next_worker to avoid double-incrementing the counter
// if we fall back to the original path.
let worker_id = prefill_router.peek_next_worker()?;
(worker_id, 0)
}; };
// Get bootstrap info from ModelManager (works for ANY mode) // Get bootstrap info from ModelManager (works for ANY mode)
...@@ -489,6 +467,55 @@ impl PrefillRouter { ...@@ -489,6 +467,55 @@ impl PrefillRouter {
// No phase permit needed - we wait for completion before changing phase // No phase permit needed - we wait for completion before changing phase
Self::execute_prefill(self.prefill_router.get().cloned(), request, None, None).await Self::execute_prefill(self.prefill_router.get().cloned(), request, None, None).await
} }
/// Query the best prefill worker without executing a request.
/// Returns (worker_id, dp_rank).
///
/// This is the shared worker selection logic used by both `build_bootstrap_info`
/// and `query_route`.
pub async fn query_prefill_worker(
&self,
token_ids: &[u32],
update_states: bool,
lora_name: Option<String>,
priority_jump: f64,
) -> Result<(u64, u32)> {
let prefill_router = self
.prefill_router
.get()
.ok_or_else(|| anyhow::anyhow!(PrefillError::NotActivated))?;
match prefill_router {
InnerPrefillRouter::KvRouter(r) => {
let (worker, _overlap) = r
.chooser
.find_best_match(
None,
token_ids,
None,
update_states,
lora_name,
priority_jump,
)
.await?;
Ok((worker.worker_id, worker.dp_rank))
}
InnerPrefillRouter::SimpleRouter(r) => {
let worker_id = if update_states {
r.select_next_worker()
} else {
r.peek_next_worker()
}
.ok_or_else(|| anyhow::anyhow!("No workers available for prefill"))?;
Ok((worker_id, 0))
}
}
}
/// Check if disaggregated mode is currently active (prefill router activated)
pub fn is_activated(&self) -> bool {
self.prefill_router.get().is_some()
}
} }
impl Drop for PrefillRouter { impl Drop for PrefillRouter {
......
...@@ -48,7 +48,6 @@ struct WorkerSelection { ...@@ -48,7 +48,6 @@ struct WorkerSelection {
struct RequestGuard { struct RequestGuard {
chooser: Arc<KvRouter>, chooser: Arc<KvRouter>,
context_id: String, context_id: String,
handle_local_updates: bool,
tracker: Option<Arc<RequestTracker>>, tracker: Option<Arc<RequestTracker>>,
request_metrics: Arc<RouterRequestMetrics>, request_metrics: Arc<RouterRequestMetrics>,
cumulative_osl: usize, cumulative_osl: usize,
...@@ -59,9 +58,7 @@ struct RequestGuard { ...@@ -59,9 +58,7 @@ struct RequestGuard {
impl RequestGuard { impl RequestGuard {
async fn finish(&mut self) { async fn finish(&mut self) {
self.record_metrics(); self.record_metrics();
if self.handle_local_updates if let Err(e) = self.chooser.free(&self.context_id).await {
&& let Err(e) = self.chooser.free(&self.context_id).await
{
tracing::warn!("Failed to free request {}: {e}", self.context_id); tracing::warn!("Failed to free request {}: {e}", self.context_id);
} }
self.freed = true; self.freed = true;
...@@ -86,7 +83,7 @@ impl RequestGuard { ...@@ -86,7 +83,7 @@ impl RequestGuard {
impl Drop for RequestGuard { impl Drop for RequestGuard {
fn drop(&mut self) { fn drop(&mut self) {
self.record_metrics(); self.record_metrics();
if !self.freed && self.handle_local_updates { if !self.freed {
let chooser = self.chooser.clone(); let chooser = self.chooser.clone();
let context_id = self.context_id.clone(); let context_id = self.context_id.clone();
let Ok(handle) = tokio::runtime::Handle::try_current() else { let Ok(handle) = tokio::runtime::Handle::try_current() else {
...@@ -112,15 +109,13 @@ impl KvPushRouter { ...@@ -112,15 +109,13 @@ impl KvPushRouter {
/// Select a worker for the request, either using a preselected worker or finding the best match. /// Select a worker for the request, either using a preselected worker or finding the best match.
/// ///
/// When `is_query_only` is false and `handle_local_updates` is true, this also registers /// When `is_query_only` is false, this also registers the request with the scheduler via `add_request`.
/// the request with the scheduler via `add_request`.
async fn select_worker( async fn select_worker(
&self, &self,
context_id: &str, context_id: &str,
request: &PreprocessedRequest, request: &PreprocessedRequest,
phase: RequestPhase, phase: RequestPhase,
is_query_only: bool, is_query_only: bool,
handle_local_updates: bool,
) -> Result<WorkerSelection, Error> { ) -> Result<WorkerSelection, Error> {
let routing = request.routing.as_ref(); let routing = request.routing.as_ref();
let lora_name = routing.and_then(|r| r.lora_name.clone()); let lora_name = routing.and_then(|r| r.lora_name.clone());
...@@ -172,7 +167,7 @@ impl KvPushRouter { ...@@ -172,7 +167,7 @@ impl KvPushRouter {
.get_overlap_blocks(&request.token_ids, worker) .get_overlap_blocks(&request.token_ids, worker)
.await?; .await?;
if !is_query_only && handle_local_updates { if !is_query_only {
self.chooser self.chooser
.add_request( .add_request(
context_id.to_string(), context_id.to_string(),
...@@ -234,15 +229,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -234,15 +229,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
// Simple query-only detection: presence of query_instance_id annotation means query-only mode // Simple query-only detection: presence of query_instance_id annotation means query-only mode
let is_query_only = request.get_annotation_value("query_instance_id").is_some(); let is_query_only = request.get_annotation_value("query_instance_id").is_some();
// Determine if this router should handle local state updates (add_request, free, etc.)
// Default is true (router handles bookkeeping). Set to false for GAIE Stage 2 where
// an external orchestrator (e.g., EPP sidecar) handles bookkeeping via C FFI.
let handle_local_updates = request
.routing
.as_ref()
.and_then(|r| r.enable_local_updates)
.unwrap_or(true);
// Get phase from tracker (defaults to Aggregated if no tracker or phase not set) // Get phase from tracker (defaults to Aggregated if no tracker or phase not set)
let phase = request let phase = request
.tracker .tracker
...@@ -252,13 +238,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -252,13 +238,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let block_size = self.chooser.block_size() as usize; let block_size = self.chooser.block_size() as usize;
let selection = self let selection = self
.select_worker( .select_worker(&context_id, &request, phase, is_query_only)
&context_id,
&request,
phase,
is_query_only,
handle_local_updates,
)
.await?; .await?;
let WorkerSelection { let WorkerSelection {
instance_id, instance_id,
...@@ -335,8 +315,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -335,8 +315,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
.routing .routing
.as_ref() .as_ref()
.and_then(|r| r.expected_output_tokens); .and_then(|r| r.expected_output_tokens);
let track_output_blocks = let track_output_blocks = self.chooser.kv_router_config().router_track_output_blocks;
self.chooser.kv_router_config().router_track_output_blocks && handle_local_updates;
let tracker = request.tracker.clone(); let tracker = request.tracker.clone();
let (mut backend_input, context) = request.into_parts(); let (mut backend_input, context) = request.into_parts();
...@@ -360,7 +339,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -360,7 +339,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let mut guard = RequestGuard { let mut guard = RequestGuard {
chooser: chooser.clone(), chooser: chooser.clone(),
context_id: context_id.clone(), context_id: context_id.clone(),
handle_local_updates,
tracker: tracker.clone(), tracker: tracker.clone(),
request_metrics: request_metrics.clone(), request_metrics: request_metrics.clone(),
cumulative_osl: 0, cumulative_osl: 0,
...@@ -385,7 +363,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -385,7 +363,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
break; break;
}; };
if handle_local_updates && !prefill_marked { if !prefill_marked {
// Only mark prefill completed when we receive actual tokens,
// not empty bootstrap info (token_ids: []) from disaggregated prefill
let has_tokens = item.data.as_ref() let has_tokens = item.data.as_ref()
.map(|d| !d.token_ids.is_empty()) .map(|d| !d.token_ids.is_empty())
.unwrap_or(false); .unwrap_or(false);
...@@ -451,3 +431,48 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -451,3 +431,48 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
Ok(ResponseStream::new(wrapped_stream, stream_context)) Ok(ResponseStream::new(wrapped_stream, stream_context))
} }
} }
/// A direct routing wrapper for `RouterMode::Direct`.
///
/// This wraps a `PushRouter` and reads worker IDs from each request's routing hints,
/// then routes directly to the specified worker. Used when an external router
/// (e.g., EPP) handles worker selection.
pub struct DirectRoutingRouter {
inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
}
impl DirectRoutingRouter {
pub fn new(inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>) -> Self {
DirectRoutingRouter { inner }
}
/// Extract worker ID from request routing hints.
/// Returns an error if no worker ID is found (required in direct routing mode).
fn get_worker_id(request: &PreprocessedRequest) -> Result<u64, Error> {
let routing = request.routing.as_ref();
let worker_id = routing.and_then(|r| r.decode_worker_id.or(r.backend_instance_id));
worker_id.ok_or_else(|| {
anyhow::anyhow!(
"Worker ID required (--direct-route) but none found in request. \
Expected decode_worker_id or backend_instance_id to be set by external router (e.g., EPP)."
)
})
}
}
#[async_trait]
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
for DirectRoutingRouter
{
async fn generate(
&self,
request: SingleIn<PreprocessedRequest>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
let worker_id = Self::get_worker_id(&request)?;
tracing::debug!(worker_id = worker_id, "Direct routing to specified worker");
self.inner.direct(request, worker_id).await
}
}
...@@ -280,7 +280,6 @@ impl OpenAIPreprocessor { ...@@ -280,7 +280,6 @@ impl OpenAIPreprocessor {
prefill_worker_id: nvext.prefill_worker_id, prefill_worker_id: nvext.prefill_worker_id,
decode_worker_id: nvext.decode_worker_id, decode_worker_id: nvext.decode_worker_id,
dp_rank: None, // dp_rank is set later in the pipeline dp_rank: None, // dp_rank is set later in the pipeline
enable_local_updates: nvext.enable_local_updates,
expected_output_tokens: hints.and_then(|h| h.osl), expected_output_tokens: hints.and_then(|h| h.osl),
priority_jump: hints.and_then(|h| h.latency_sensitivity), priority_jump: hints.and_then(|h| h.latency_sensitivity),
lora_name, lora_name,
......
...@@ -34,14 +34,6 @@ pub struct RoutingHints { ...@@ -34,14 +34,6 @@ pub struct RoutingHints {
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub dp_rank: Option<u32>, pub dp_rank: Option<u32>,
/// Controls whether the router should manage local bookkeeping (add_request,
/// mark_prefill_completed, free) for this request.
///
/// - `None` or `Some(true)`: Router handles bookkeeping locally (default behavior)
/// - `Some(false)`: External caller (e.g., GAIE sidecar) handles bookkeeping via C FFI
#[serde(default, skip_serializing_if = "Option::is_none")]
pub enable_local_updates: Option<bool>,
/// Expected number of output tokens for this request. /// Expected number of output tokens for this request.
/// Used as a hint for routing decisions to estimate resource requirements. /// Used as a hint for routing decisions to estimate resource requirements.
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
......
...@@ -11,16 +11,12 @@ pub use crate::protocols::common::timing::TimingInfo; ...@@ -11,16 +11,12 @@ pub use crate::protocols::common::timing::TimingInfo;
pub const HEADER_WORKER_INSTANCE_ID: &str = "x-worker-instance-id"; pub const HEADER_WORKER_INSTANCE_ID: &str = "x-worker-instance-id";
pub const HEADER_PREFILL_INSTANCE_ID: &str = "x-prefill-instance-id"; pub const HEADER_PREFILL_INSTANCE_ID: &str = "x-prefill-instance-id";
/// Header to disable local bookkeeping updates (for GAIE Stage 2)
/// When set to "false", the router skips add_request, mark_prefill_completed, and free calls.
pub const HEADER_ENABLE_LOCAL_UPDATES: &str = "x-enable-local-updates";
/// Apply routing overrides from HTTP headers to nvext. /// Apply routing overrides from HTTP headers to nvext.
/// ///
/// Header mappings: /// Header mappings:
/// - `x-worker-instance-id` -> `backend_instance_id` and `decode_worker_id` /// - `x-worker-instance-id` -> `backend_instance_id` and `decode_worker_id`
/// - `x-prefill-instance-id` -> `prefill_worker_id` /// - `x-prefill-instance-id` -> `prefill_worker_id`
/// - `x-enable-local-updates` -> `enable_local_updates` (set to false to disable router bookkeeping)
/// ///
/// Headers take priority over existing nvext values when present. /// Headers take priority over existing nvext values when present.
/// If no headers are present, returns the original nvext unchanged. /// If no headers are present, returns the original nvext unchanged.
...@@ -35,17 +31,7 @@ pub fn apply_header_routing_overrides(nvext: Option<NvExt>, headers: &HeaderMap) ...@@ -35,17 +31,7 @@ pub fn apply_header_routing_overrides(nvext: Option<NvExt>, headers: &HeaderMap)
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok()); .and_then(|s| s.parse::<u64>().ok());
// Parse enable_local_updates header: "true" or "false" if worker_id.is_none() && prefill_id.is_none() {
let enable_local_updates = headers
.get(HEADER_ENABLE_LOCAL_UPDATES)
.and_then(|v| v.to_str().ok())
.and_then(|s| match s.to_lowercase().as_str() {
"true" | "1" => Some(true),
"false" | "0" => Some(false),
_ => None,
});
if worker_id.is_none() && prefill_id.is_none() && enable_local_updates.is_none() {
return nvext; return nvext;
} }
...@@ -57,9 +43,6 @@ pub fn apply_header_routing_overrides(nvext: Option<NvExt>, headers: &HeaderMap) ...@@ -57,9 +43,6 @@ pub fn apply_header_routing_overrides(nvext: Option<NvExt>, headers: &HeaderMap)
if let Some(id) = prefill_id { if let Some(id) = prefill_id {
ext.prefill_worker_id = Some(id); ext.prefill_worker_id = Some(id);
} }
if let Some(enabled) = enable_local_updates {
ext.enable_local_updates = Some(enabled);
}
Some(ext) Some(ext)
} }
...@@ -169,17 +152,6 @@ pub struct NvExt { ...@@ -169,17 +152,6 @@ pub struct NvExt {
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub decode_worker_id: Option<u64>, pub decode_worker_id: Option<u64>,
/// Controls whether the router should manage local bookkeeping (add_request,
/// mark_prefill_completed, free) for this request.
///
/// - `None` or `true`: Router handles bookkeeping locally (default behavior)
/// - `false`: External caller (e.g., GAIE sidecar) handles bookkeeping via C FFI
///
/// Set to `false` for GAIE Stage 2 when the EPP/sidecar manages request lifecycle.
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub enable_local_updates: Option<bool>,
/// Agent-provided hints for request handling. /// Agent-provided hints for request handling.
#[builder(default, setter(strip_option))] #[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
...@@ -187,7 +159,7 @@ pub struct NvExt { ...@@ -187,7 +159,7 @@ pub struct NvExt {
} }
/// Hints from the agent/caller about request characteristics. /// Hints from the agent/caller about request characteristics.
#[derive(ToSchema, Serialize, Deserialize, Builder, Debug, Clone, Default)] #[derive(ToSchema, Serialize, Deserialize, Builder, Debug, Clone, Default, PartialEq)]
pub struct AgentHints { pub struct AgentHints {
/// Latency sensitivity in seconds for queue ordering. /// Latency sensitivity in seconds for queue ordering.
/// Higher values cause the request to be scheduled sooner when the router queue is enabled. /// Higher values cause the request to be scheduled sooner when the router queue is enabled.
...@@ -249,7 +221,7 @@ mod tests { ...@@ -249,7 +221,7 @@ mod tests {
assert_eq!(nv_ext.extra_fields, None); assert_eq!(nv_ext.extra_fields, None);
assert_eq!(nv_ext.prefill_worker_id, None); assert_eq!(nv_ext.prefill_worker_id, None);
assert_eq!(nv_ext.decode_worker_id, None); assert_eq!(nv_ext.decode_worker_id, None);
assert_eq!(nv_ext.enable_local_updates, None); assert_eq!(nv_ext.agent_hints, None);
} }
// Test valid builder configurations // Test valid builder configurations
......
...@@ -83,15 +83,18 @@ pub enum RouterMode { ...@@ -83,15 +83,18 @@ pub enum RouterMode {
#[default] #[default]
RoundRobin, RoundRobin,
Random, Random,
Direct(u64),
// Marker value, KV routing itself is in dynamo-llm
KV, KV,
Direct,
} }
impl RouterMode { impl RouterMode {
pub fn is_kv_routing(&self) -> bool { pub fn is_kv_routing(&self) -> bool {
*self == RouterMode::KV *self == RouterMode::KV
} }
pub fn is_direct_routing(&self) -> bool {
*self == RouterMode::Direct
}
} }
async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> { async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
...@@ -415,14 +418,17 @@ where ...@@ -415,14 +418,17 @@ where
U: Data + for<'de> Deserialize<'de> + MaybeError, U: Data + for<'de> Deserialize<'de> + MaybeError,
{ {
async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> { async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
//InstanceSource::Static => self.r#static(request).await,
match self.router_mode { match self.router_mode {
RouterMode::Random => self.random(request).await, RouterMode::Random => self.random(request).await,
RouterMode::RoundRobin => self.round_robin(request).await, RouterMode::RoundRobin => self.round_robin(request).await,
RouterMode::Direct(instance_id) => self.direct(request, instance_id).await,
RouterMode::KV => { RouterMode::KV => {
anyhow::bail!("KV routing should not call generate on PushRouter"); anyhow::bail!("KV routing should not call generate on PushRouter");
} }
RouterMode::Direct => {
anyhow::bail!(
"Direct routing should not call generate on PushRouter directly; use DirectRoutingRouter wrapper"
);
}
} }
} }
} }
...@@ -16,12 +16,7 @@ spec: ...@@ -16,12 +16,7 @@ spec:
replicas: 1 replicas: 1
extraPodSpec: extraPodSpec:
mainContainer: mainContainer:
image: nvcr.io/nvidia/ai-dynamo/frontend:0.8.0 image: nvcr.io/nvidia/ai-dynamo/frontend:my-tag
env:
- name: DYN_KV_BLOCK_SIZE
value: "128"
- name: DYN_MODEL
value: "RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic"
eppConfig: eppConfig:
# This configuration uses Dynamo's KV-aware scorer for intelligent routing # This configuration uses Dynamo's KV-aware scorer for intelligent routing
config: config:
...@@ -49,8 +44,15 @@ spec: ...@@ -49,8 +44,15 @@ spec:
mountPoint: /opt/models mountPoint: /opt/models
extraPodSpec: extraPodSpec:
mainContainer: mainContainer:
image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:0.8.0 image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:my-tag
workingDir: /workspace/examples/backends/vllm workingDir: /workspace/examples/backends/vllm
command:
- python3
args:
- -m
- dynamo.frontend
- --router-mode
- direct
envs: envs:
- name: HF_HOME - name: HF_HOME
value: /opt/models value: /opt/models
...@@ -79,7 +81,7 @@ spec: ...@@ -79,7 +81,7 @@ spec:
command: command:
- /bin/sh - /bin/sh
- -c - -c
image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:0.8.0 image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:my-tag
workingDir: /workspace/examples/backends/vllm workingDir: /workspace/examples/backends/vllm
replicas: 1 replicas: 1
resources: resources:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment