Unverified Commit c26e6fb8 authored by atchernych's avatar atchernych Committed by GitHub
Browse files

feat: Update GAIE recipe (#4761)


Signed-off-by: default avatarAnna Tchernych <atchernych@nvidia.com>
parent 6da23fcf
......@@ -83,7 +83,7 @@ echo " Docker: ${GAIE_DIR}/Dockerfile.epp"
echo "Applying Dynamo patch..."
cd "${GAIE_DIR}"
PATCH_FILE="${DYNAMO_DIR}/deploy/inference-gateway/epp-patches/v0.5.1-2/epp-v0.5.1-dyn2.patch"
PATCH_FILE="${DYNAMO_DIR}/deploy/inference-gateway/epp-patches/v0.8.0/gaie.patch"
if [[ -f "${PATCH_FILE}" ]]; then
if git apply --check "${PATCH_FILE}" 2>/dev/null; then
git apply "${PATCH_FILE}"
......
diff --git a/Dockerfile b/Dockerfile
deleted file mode 100644
index fb73765..0000000
--- a/Dockerfile
+++ /dev/null
@@ -1,33 +0,0 @@
-# Dockerfile has specific requirement to put this ARG at the beginning:
-# https://docs.docker.com/engine/reference/builder/#understand-how-arg-and-from-interact
-ARG BUILDER_IMAGE=golang:1.24
-ARG BASE_IMAGE=gcr.io/distroless/static:nonroot
-
-## Multistage build
-FROM ${BUILDER_IMAGE} AS builder
-ENV CGO_ENABLED=0
-ENV GOOS=linux
-ENV GOARCH=amd64
-ARG COMMIT_SHA=unknown
-ARG BUILD_REF
-
-# Dependencies
-WORKDIR /src
-COPY go.mod go.sum ./
-RUN go mod download
-
-# Sources
-COPY cmd/epp ./cmd/epp
-COPY pkg/epp ./pkg/epp
-COPY internal ./internal
-COPY api ./api
-WORKDIR /src/cmd/epp
-RUN go build -ldflags="-X sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics.CommitSHA=${COMMIT_SHA} -X sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics.BuildRef=${BUILD_REF}" -o /epp
-
-## Multistage deploy
-FROM ${BASE_IMAGE}
-
-WORKDIR /
-COPY --from=builder /epp /epp
-
-ENTRYPOINT ["/epp"]
diff --git a/Makefile b/Makefile
index dee7e99..d3f9ec7 100644
--- a/Makefile
+++ b/Makefile
@@ -170,6 +170,48 @@ verify-all:
##@ Build
+##@ Dynamo EPP with FFI
+
+# Build the Dynamo EPP image with CGO static library support
+.PHONY: dynamo-image-local-build
+dynamo-image-local-build: ## Build the Dynamo EPP image using Docker Buildx for local development.
+ BUILDER=$(shell $(DOCKER_BUILDX_CMD) create --use)
+ $(MAKE) dynamo-image-build PUSH=$(PUSH)
+ $(MAKE) dynamo-image-build LOAD=$(LOAD)
+ $(DOCKER_BUILDX_CMD) rm $$BUILDER
+
+.PHONY: dynamo-image-local-push
+dynamo-image-local-push: PUSH=--push ## Build the Dynamo EPP image for local development and push it to $IMAGE_REPO.
+dynamo-image-local-push: dynamo-image-local-build
+
+.PHONY: dynamo-image-local-load
+dynamo-image-local-load: LOAD=--load ## Build the Dynamo EPP image for local development and load it in the local Docker registry.
+dynamo-image-local-load: dynamo-image-local-build
+
+.PHONY: dynamo-image-build
+dynamo-image-build: ## Build the Dynamo EPP image using Docker Buildx with CGO support.
+ $(IMAGE_BUILD_CMD) -f Dockerfile.dynamo -t $(IMAGE_TAG) \
+ --platform=$(PLATFORMS) \
+ --build-arg BASE_IMAGE=ubuntu:24.04 \
+ --build-arg BUILDER_IMAGE=$(BUILDER_IMAGE) \
+ --build-arg COMMIT_SHA=${GIT_COMMIT_SHA} \
+ --build-arg BUILD_REF=${BUILD_REF} \
+ $(PUSH) \
+ $(LOAD) \
+ $(IMAGE_BUILD_EXTRA_OPTS) ./
+
+.PHONY: dynamo-image-push
+dynamo-image-push: PUSH=--push ## Build the Dynamo EPP image and push it to $IMAGE_REPO.
+dynamo-image-push: dynamo-image-build
+
+.PHONY: dynamo-image-load
+dynamo-image-load: LOAD=--load ## Build the Dynamo EPP image and load it in the local Docker registry.
+dynamo-image-load: dynamo-image-build
+
+.PHONY: dynamo-image-kind
+dynamo-image-kind: dynamo-image-build ## Build the Dynamo EPP image and load it to kind cluster $KIND_CLUSTER ("kind" by default).
+ kind load docker-image $(IMAGE_TAG) --name $(KIND_CLUSTER)
+
# Build the container image
.PHONY: image-local-build
image-local-build: ## Build the EPP image using Docker Buildx for local development.
diff --git a/cmd/epp/main.go b/cmd/epp/main.go
index b5e0617..8592735 100644
--- a/cmd/epp/main.go
+++ b/cmd/epp/main.go
@@ -22,6 +22,11 @@ import (
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/gateway-api-inference-extension/cmd/epp/runner"
+ eppplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
+
+ // Dynamo plugins
+ dynprereq "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid"
+ dynscorer "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/dynamo_kv_scorer"
)
func main() {
@@ -30,6 +35,9 @@ func main() {
// For adding out-of-tree plugins to the plugins registry, use the following:
// plugins.Register(my-out-of-tree-plugin-name, my-out-of-tree-plugin-factory-function)
+ eppplugins.Register("dynamo-inject-workerid", dynprereq.InjectWorkerIDPreRequestFactory)
+ eppplugins.Register("kv-aware-scorer", dynscorer.KVAwareScorerFactory)
+
if err := runner.NewRunner().Run(ctrl.SetupSignalHandler()); err != nil {
os.Exit(1)
}
diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go
index 670d922..0cf04cb 100644
--- a/pkg/epp/requestcontrol/director.go
+++ b/pkg/epp/requestcontrol/director.go
@@ -130,6 +130,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
TargetModel: reqCtx.ResolvedTargetModel,
Prompt: prompt,
Headers: reqCtx.Request.Headers,
+ Annotations: map[string]any{},
}
logger = logger.WithValues("model", reqCtx.Model, "resolvedTargetModel", reqCtx.ResolvedTargetModel, "criticality", requestCriticality)
@@ -253,7 +254,7 @@ func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestC
reqCtx.TargetPod = targetPod
reqCtx.TargetEndpoint = endpoint
- d.runPreRequestPlugins(ctx, reqCtx.SchedulingRequest, result, targetPort)
+ d.runPreRequestPlugins(ctx, reqCtx.SchedulingRequest, result, targetPort, reqCtx.Request.Body)
return reqCtx, nil
}
@@ -319,13 +320,20 @@ func RandomWeightedDraw(logger logr.Logger, model *v1alpha2.InferenceModel, seed
return ""
}
-func (d *Director) runPreRequestPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, schedulingResult *schedulingtypes.SchedulingResult,
+func (d *Director) runPreRequestPlugins(
+ ctx context.Context,
+ request *schedulingtypes.LLMRequest,
+ schedulingResult *schedulingtypes.SchedulingResult,
targetPort int,
+ body map[string]any,
) {
for _, plugin := range d.preRequestPlugins {
log.FromContext(ctx).V(logutil.DEBUG).Info("Running pre-request plugin", "plugin", plugin.TypedName().Type)
before := time.Now()
plugin.PreRequest(ctx, request, schedulingResult, targetPort)
+ if mutator, ok := plugin.(RequestBodyMutator); ok && body != nil {
+ mutator.MutateRequestBody(ctx, request, schedulingResult, targetPort, body)
+ }
metrics.RecordRequestControlPluginProcessingLatency(PreRequestPluginType, plugin.TypedName().Type, time.Since(before))
}
}
diff --git a/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go b/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go
new file mode 100644
index 0000000..cd9a0b5
--- /dev/null
+++ b/pkg/epp/requestcontrol/plugins/dynamo_inject_workerid/plugin.go
@@ -0,0 +1,119 @@
+package dynamo_inject_workerid
+
+import (
+ "context"
+ "encoding/json"
+ "strconv"
+ "strings"
+
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
+ rc "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
+ schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
+)
+
+const (
+ typeString = "dynamo-inject-workerid"
+ pluginName = "dynamo-inject-workerid"
+ WorkerIDHeader = "x-worker-instance-id"
+ tokenDataAnnotationKey = "dynamo/token-data"
+)
+
+var _ plugins.Plugin = (*InjectWorkerIDPreRequest)(nil)
+var _ rc.PreRequest = (*InjectWorkerIDPreRequest)(nil)
+var _ rc.RequestBodyMutator = (*InjectWorkerIDPreRequest)(nil)
+
+type InjectWorkerIDPreRequest struct {
+ typedName plugins.TypedName
+}
+
+func NewInjectWorkerIDPreRequest() *InjectWorkerIDPreRequest {
+ return &InjectWorkerIDPreRequest{
+ typedName: plugins.TypedName{Type: typeString, Name: pluginName},
+ }
+}
+
+func (p *InjectWorkerIDPreRequest) WithName(name string) *InjectWorkerIDPreRequest {
+ p.typedName.Name = name
+ return p
+}
+
+func InjectWorkerIDPreRequestFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
+ return NewInjectWorkerIDPreRequest().WithName(name), nil
+}
+
+func (p *InjectWorkerIDPreRequest) TypedName() plugins.TypedName { return p.typedName }
+
+func (p *InjectWorkerIDPreRequest) PreRequest(
+ _ context.Context,
+ req *schedtypes.LLMRequest,
+ _ *schedtypes.SchedulingResult,
+ _ int,
+) {
+ if req == nil {
+ return
+ }
+ if req.Headers == nil {
+ req.Headers = map[string]string{}
+ }
+ wid := strings.TrimSpace(req.Headers[WorkerIDHeader])
+ if wid == "" {
+ return
+ }
+ req.Headers[WorkerIDHeader] = wid
+}
+
+func (p *InjectWorkerIDPreRequest) MutateRequestBody(
+ _ context.Context,
+ req *schedtypes.LLMRequest,
+ _ *schedtypes.SchedulingResult,
+ _ int,
+ body map[string]any,
+) {
+ if req == nil || body == nil {
+ return
+ }
+ if req.Headers == nil {
+ return
+ }
+
+ wid := strings.TrimSpace(req.Headers[WorkerIDHeader])
+ if wid == "" {
+ return
+ }
+
+ nvext, _ := body["nvext"].(map[string]any)
+ if nvext == nil {
+ nvext = map[string]any{}
+ body["nvext"] = nvext
+ }
+ if widUint, err := strconv.ParseUint(wid, 10, 64); err == nil {
+ nvext["backend_instance_id"] = widUint
+ }
+
+ if tokens, ok := req.Annotations[tokenDataAnnotationKey]; ok {
+ switch v := tokens.(type) {
+ case []int64:
+ if len(v) > 0 {
+ nvext["token_data"] = v
+ }
+ case []any:
+ var out []int64
+ for _, elem := range v {
+ switch t := elem.(type) {
+ case int64:
+ out = append(out, t)
+ case float64:
+ out = append(out, int64(t))
+ }
+ }
+ if len(out) > 0 {
+ nvext["token_data"] = out
+ }
+ case json.RawMessage:
+ var out []int64
+ if err := json.Unmarshal(v, &out); err == nil && len(out) > 0 {
+ nvext["token_data"] = out
+ }
+ }
+ }
+}
diff --git a/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml
new file mode 100644
index 0000000..b689c00
--- /dev/null
+++ b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/epp-config-dynamo.yaml
@@ -0,0 +1,21 @@
+# This is an example for configuring the EPP to use the dynamo token-aware kv router for scoring the pods
+apiVersion: inference.networking.x-k8s.io/v1alpha1
+kind: EndpointPickerConfig
+plugins:
+ # Required: tells EPP which profile to use (even if you only have one)
+ - type: single-profile-handler
+
+ # Picker: chooses the final endpoint after scoring
+ - name: picker
+ type: max-score-picker
+ - name: dyn-pre
+ type: dynamo-inject-workerid
+ parameters: {}
+ - name: dyn-kv
+ type: kv-aware-scorer
+schedulingProfiles:
+ - name: default
+ plugins:
+ - pluginRef: dyn-kv
+ weight: 1
+ - pluginRef: picker
diff --git a/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go
new file mode 100644
index 0000000..bc29c0a
--- /dev/null
+++ b/pkg/epp/scheduling/plugins/dynamo_kv_scorer/plugin.go
@@ -0,0 +1,424 @@
+package dynamo_kv_scorer
+
+/*
+#cgo CPPFLAGS: -I${SRCDIR}/include
+#cgo CXXFLAGS: -std=c++17
+#cgo LDFLAGS: ${SRCDIR}/lib/libdynamo_llm_capi.a -lstdc++ -ldl -lpthread -lm
+
+#include <stdint.h>
+#include <stddef.h>
+#include <stdlib.h> // for free
+#include <stdbool.h>
+
+// enum underlying type is uint32_t; matches cbindgen output
+typedef uint32_t dynamo_llm_result_t;
+enum { DYNAMO_OK = 0, DYNAMO_ERR = 1 };
+
+// opaque handle forward-decl
+struct WorkerSelectionPipeline;
+typedef struct WorkerSelectionPipeline WorkerSelectionPipeline;
+
+// Prototypes (C-compatible)
+dynamo_llm_result_t dynamo_llm_init(const char *namespace_c_str,
+ const char *component_c_str,
+ int64_t worker_id,
+ uint32_t kv_block_size);
+
+dynamo_llm_result_t dynamo_llm_shutdown(void);
+dynamo_llm_result_t dynamo_llm_load_publisher_create(void);
+
+dynamo_llm_result_t dynamo_kv_event_publish_stored(uint64_t event_id,
+ const uint32_t *token_ids,
+ const uintptr_t *num_block_tokens,
+ const uint64_t *block_ids,
+ size_t num_blocks,
+ const uint64_t *parent_hash,
+ uint64_t lora_id);
+
+dynamo_llm_result_t dynamo_kv_event_publish_removed(uint64_t event_id,
+ const uint64_t *block_ids,
+ size_t num_blocks);
+
+dynamo_llm_result_t dynamo_create_worker_selection_pipeline(const char *namespace_c_str,
+ const char *component_c_str,
+ const char *model_name_c_str,
+ bool use_kv_routing,
+ double busy_threshold,
+ double overlap_score_weight,
+ double router_temperature,
+ bool use_kv_events,
+ bool router_replica_sync,
+ WorkerSelectionPipeline **pipeline_out);
+
+dynamo_llm_result_t dynamo_destroy_worker_selection_pipeline(WorkerSelectionPipeline *pipeline);
+
+dynamo_llm_result_t dynamo_query_worker_selection_and_annotate(WorkerSelectionPipeline *pipeline,
+ const char *request_json_c_str,
+ int64_t *worker_instance_id_out,
+ uint32_t **token_ids_out,
+ size_t *token_count_out,
+ char **annotated_request_json_out);
+
+dynamo_llm_result_t dynamo_free_worker_selection_result(uint32_t *token_ids,
+ size_t token_count,
+ char *annotated_request_json);
+*/
+import "C"
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "os"
+ "strings"
+ "sync"
+ "unsafe"
+
+ log "sigs.k8s.io/controller-runtime/pkg/log"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
+ schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
+ logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
+)
+
+const (
+ PluginName = "dynamo-kv-scorer"
+ KVAwareScorerType = "kv-aware-scorer"
+ StateKeyWorkerInstanceID = schedtypes.StateKey("dynamo/worker-instance-id")
+ WorkerIDHeader = "x-worker-instance-id"
+ tokenDataAnnotationKey = "dynamo/token-data"
+)
+
+// --------------------------- config / env ---------------------------
+
+var warmupOnce sync.Once
+var warmupErr error
+
+type stateString string
+type params struct {
+}
+
+func (s stateString) Clone() schedtypes.StateData { return s }
+
+type KVAwareScorer struct {
+ typedName plugins.TypedName
+}
+
+var _ plugins.Plugin = (*KVAwareScorer)(nil)
+var _ framework.Scorer = (*KVAwareScorer)(nil)
+
+func NewKVAwareScorer() *KVAwareScorer {
+ return &KVAwareScorer{
+ typedName: plugins.TypedName{Type: KVAwareScorerType, Name: PluginName},
+ }
+}
+
+func (k *KVAwareScorer) WithName(name string) *KVAwareScorer { k.typedName.Name = name; return k }
+
+func KVAwareScorerFactory(name string, raw json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
+ p := params{}
+ _ = json.Unmarshal(raw, &p)
+
+ s := NewKVAwareScorer().WithName(name)
+
+ // one-time FFI init (runtime + persistent pipeline)
+ warmupOnce.Do(func() {
+ defer func() {
+ if r := recover(); r != nil {
+ warmupErr = fmt.Errorf("Dynamo configuration error: %v", r)
+ }
+ }()
+ warmupErr = initFFI()
+ })
+ if warmupErr != nil {
+ return nil, fmt.Errorf("Dynamo FFI init for the Router failed: %w", warmupErr)
+ }
+
+ return s, nil
+}
+
+func (k *KVAwareScorer) TypedName() plugins.TypedName { return k.typedName }
+
+// --------------------------- FFI integration ---------------------------
+
+var (
+ ffiOnce sync.Once
+ ffiErr error
+
+ ffiNamespace string
+ ffiComponent string
+ ffiModel string
+ ffiOverlapScoreWeight float64
+ ffiRouterTemperature float64
+ ffiKvBlockSize uint32
+ ffiWorkerID int64
+
+ runtimeInitialized bool
+
+ // Boxed pipeline handle (owned on the Rust side, opaque here)
+ pipeline *C.struct_WorkerSelectionPipeline
+ pipelineMutex sync.RWMutex
+)
+
+func loadDynamoConfig() {
+ ffiNamespace = getEnvOrDefault("DYNAMO_NAMESPACE", "vllm-agg")
+ ffiComponent = getEnvOrDefault("DYNAMO_COMPONENT", "backend")
+ ffiModel = getEnvOrDefault("DYNAMO_MODEL", "Qwen/Qwen3-0.6B")
+ ffiWorkerID = getEnvInt64OrDefault("DYNAMO_WORKER_ID", 1)
+
+ ffiOverlapScoreWeight = getEnvFloatOrDefault("DYNAMO_OVERLAP_SCORE_WEIGHT", -1.0)
+ ffiRouterTemperature = getEnvFloatOrDefault("DYNAMO_ROUTER_TEMPERATURE", -1.0)
+
+ kvBlockSizeStr := os.Getenv("DYNAMO_KV_BLOCK_SIZE")
+ if kvBlockSizeStr == "" {
+ panic("DYNAMO_KV_BLOCK_SIZE is required and must match the model card's kv_cache_block_size")
+ }
+ var tmp int64
+ if n, err := fmt.Sscanf(kvBlockSizeStr, "%d", &tmp); err != nil || n != 1 {
+ panic(fmt.Sprintf("DYNAMO_KV_BLOCK_SIZE='%s' is not a valid integer", kvBlockSizeStr))
+ }
+ ffiKvBlockSize = uint32(tmp)
+ if ffiKvBlockSize < 16 || ffiKvBlockSize > 8192 {
+ panic(fmt.Sprintf("DYNAMO_KV_BLOCK_SIZE=%d outside [16,8192]", ffiKvBlockSize))
+ }
+ if (ffiKvBlockSize & (ffiKvBlockSize - 1)) != 0 {
+ panic(fmt.Sprintf("DYNAMO_KV_BLOCK_SIZE=%d must be a power of 2", ffiKvBlockSize))
+ }
+ fmt.Printf("Dynamo KV Scorer: Loaded DYNAMO_KV_BLOCK_SIZE=%d\n", ffiKvBlockSize)
+}
+
+func getEnvOrDefault(key, def string) string {
+ if v := os.Getenv(key); v != "" {
+ return v
+ }
+ return def
+}
+func getEnvInt64OrDefault(key string, def int64) int64 {
+ if v := os.Getenv(key); v != "" {
+ var p int64
+ if n, err := fmt.Sscanf(v, "%d", &p); err == nil && n == 1 {
+ return p
+ }
+ }
+ return def
+}
+func getEnvFloatOrDefault(key string, def float64) float64 {
+ if v := os.Getenv(key); v != "" {
+ var p float64
+ if n, err := fmt.Sscanf(v, "%f", &p); err == nil && n == 1 {
+ return p
+ }
+ }
+ return def
+}
+func getEnvBoolOrDefault(key string, def bool) bool {
+ if v := os.Getenv(key); v != "" {
+ switch strings.ToLower(v) {
+ case "true", "1", "yes", "on":
+ return true
+ case "false", "0", "no", "off":
+ return false
+ }
+ }
+ return def
+}
+
+// initFFI: initialize runtime and create a persistent boxed pipeline.
+func initFFI() error {
+ ffiOnce.Do(func() {
+ loadDynamoConfig()
+
+ ns := C.CString(ffiNamespace)
+ cm := C.CString(ffiComponent)
+ model := C.CString(ffiModel)
+ defer C.free(unsafe.Pointer(ns))
+ defer C.free(unsafe.Pointer(cm))
+ defer C.free(unsafe.Pointer(model))
+
+ // Init Dynamo runtime
+ if rc := C.dynamo_llm_init(ns, cm, C.int64_t(ffiWorkerID), C.uint32_t(ffiKvBlockSize)); rc != C.DYNAMO_OK {
+ ffiErr = fmt.Errorf("dynamo_llm_init failed")
+ return
+ }
+ runtimeInitialized = true
+
+ // Create persistent pipeline
+ pipelineMutex.Lock()
+ defer pipelineMutex.Unlock()
+
+ rc := C.dynamo_create_worker_selection_pipeline(
+ ns,
+ cm,
+ model,
+ C.bool(getEnvBoolOrDefault("DYNAMO_USE_KV_ROUTING", true)),
+ C.double(getEnvFloatOrDefault("DYNAMO_BUSY_THRESHOLD", -1.0)),
+ C.double(ffiOverlapScoreWeight),
+ C.double(ffiRouterTemperature),
+ C.bool(getEnvBoolOrDefault("DYNAMO_USE_KV_EVENTS", true)),
+ C.bool(getEnvBoolOrDefault("DYNAMO_ROUTER_REPLICA_SYNC", true)),
+ &pipeline,
+ )
+ if rc != C.DYNAMO_OK {
+ ffiErr = fmt.Errorf("dynamo_create_worker_selection_pipeline failed")
+ return
+ }
+ })
+ return ffiErr
+}
+
+// --------------------------- scoring ---------------------------
+
+func (k *KVAwareScorer) Score(
+ ctx context.Context,
+ cycle *schedtypes.CycleState,
+ req *schedtypes.LLMRequest,
+ pods []schedtypes.Pod,
+) map[schedtypes.Pod]float64 {
+ logger := log.FromContext(ctx)
+
+ workerID, tokenData, err := k.callDynamoRouter(ctx, req)
+ if err != nil {
+ logger.V(logutil.DEFAULT).Error(err, "Dynamo call failed; proceeding without worker id")
+ } else if workerID != "" {
+ logger.V(logutil.DEFAULT).Info(
+ "Dynamo router selected worker",
+ "workerID", workerID,
+ "tokenDataCount", len(tokenData),
+ "tokenData", tokenData,
+ )
+ cycle.Write(StateKeyWorkerInstanceID, stateString(workerID))
+ if req.Headers == nil {
+ req.Headers = map[string]string{}
+ }
+ req.Headers[WorkerIDHeader] = workerID
+ if len(tokenData) > 0 {
+ if req.Annotations == nil {
+ req.Annotations = map[string]any{}
+ }
+ copied := make([]int64, len(tokenData))
+ copy(copied, tokenData)
+ req.Annotations[tokenDataAnnotationKey] = copied
+ }
+ }
+
+ out := make(map[schedtypes.Pod]float64, len(pods))
+ for _, p := range pods {
+ out[p] = 1.0
+ }
+ return out
+}
+
+// --------------------------- router call (persistent only) ---------------------------
+
+func (k *KVAwareScorer) callDynamoRouter(
+ ctx context.Context,
+ req *schedtypes.LLMRequest,
+) (string, []int64, error) {
+ logger := log.FromContext(ctx)
+
+ if err := initFFI(); err != nil {
+ logger.V(logutil.DEFAULT).Error(err, "FFI init failed")
+ return "", nil, err
+ }
+ if !runtimeInitialized {
+ return "", nil, fmt.Errorf("dynamo runtime not initialized")
+ }
+
+ pipelineMutex.RLock()
+ currentPipeline := pipeline
+ pipelineMutex.RUnlock()
+
+ if currentPipeline == nil {
+ return "", nil, fmt.Errorf("dynamo worker selection pipeline not created")
+ }
+
+ // Build OpenAI-compatible JSON request
+ requestBody := buildOpenAIRequest(req)
+ requestJSON, err := json.Marshal(requestBody)
+ if err != nil {
+ logger.V(logutil.DEFAULT).Error(err, "Failed to marshal OpenAI request")
+ return "", nil, fmt.Errorf("marshal OpenAI request: %w", err)
+ }
+ cRequestJSON := C.CString(string(requestJSON))
+ defer C.free(unsafe.Pointer(cRequestJSON))
+
+ // Output variables
+ var cWorkerID C.int64_t
+ var cTokens *C.uint32_t
+ var cTokenCount C.size_t
+ var cAnnotatedJSON *C.char
+
+ // Call the worker selection pipeline
+ rc := C.dynamo_query_worker_selection_and_annotate(
+ currentPipeline,
+ cRequestJSON,
+ &cWorkerID,
+ &cTokens,
+ &cTokenCount,
+ &cAnnotatedJSON,
+ )
+ if rc != C.DYNAMO_OK {
+ return "", nil, fmt.Errorf("dynamo_query_worker_selection_and_annotate failed")
+ }
+
+ // Copy tokens into Go memory and free C memory
+ count := int(uintptr(cTokenCount))
+ var tokens64 []int64
+ if count > 0 && cTokens != nil {
+ src := unsafe.Slice((*uint32)(unsafe.Pointer(cTokens)), count)
+ tokens64 = make([]int64, count)
+ for i := 0; i < count; i++ {
+ tokens64[i] = int64(src[i])
+ }
+ }
+ C.dynamo_free_worker_selection_result(cTokens, cTokenCount, cAnnotatedJSON)
+
+ workerID := fmt.Sprintf("%d", int64(cWorkerID))
+ logger.V(logutil.DEFAULT).Info("Worker selection completed",
+ "workerID", workerID, "tokenCount", count)
+
+ return workerID, tokens64, nil
+}
+
+func buildOpenAIRequest(req *schedtypes.LLMRequest) map[string]any {
+ requestBody := make(map[string]any)
+ userText := "default prompt"
+ if req != nil && strings.TrimSpace(req.Prompt) != "" {
+ userText = req.Prompt
+ }
+ requestBody["messages"] = []map[string]any{{"role": "user", "content": userText}}
+ if req != nil && strings.TrimSpace(req.TargetModel) != "" {
+ requestBody["model"] = req.TargetModel
+ } else {
+ requestBody["model"] = ffiModel
+ }
+ requestBody["max_tokens"] = 1
+ requestBody["temperature"] = 0.0
+ requestBody["stream"] = true
+ requestBody["nvext"] = map[string]any{
+ "annotations": []string{"query_instance_id"},
+ }
+ return requestBody
+}
+
+// --------------------------- shutdown ---------------------------
+
+func cleanupDynamo() error {
+ pipelineMutex.Lock()
+ defer pipelineMutex.Unlock()
+
+ if pipeline != nil {
+ if rc := C.dynamo_destroy_worker_selection_pipeline(pipeline); rc != C.DYNAMO_OK {
+ fmt.Printf("Warning: dynamo_destroy_worker_selection_pipeline failed\n")
+ }
+ pipeline = nil
+ }
+
+ if runtimeInitialized {
+ if rc := C.dynamo_llm_shutdown(); rc != C.DYNAMO_OK {
+ return fmt.Errorf("dynamo_llm_shutdown failed")
+ }
+ runtimeInitialized = false
+ }
+ return nil
+}
diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go
index 2962117..7da1d43 100644
--- a/pkg/epp/scheduling/types/types.go
+++ b/pkg/epp/scheduling/types/types.go
@@ -33,10 +33,12 @@ type LLMRequest struct {
Prompt string
// Headers is a map of the request headers.
Headers map[string]string
+ // Annotations provides plugin-specific data that should travel alongside the request.
+ Annotations map[string]any
}
func (r *LLMRequest) String() string {
- return fmt.Sprintf("RequestID: %s, TargetModel: %s, PromptLength: %d, Headers: %v", r.RequestId, r.TargetModel, len(r.Prompt), r.Headers)
+ return fmt.Sprintf("RequestID: %s, TargetModel: %s, PromptLength: %d, Headers: %v, Annotations: %v", r.RequestId, r.TargetModel, len(r.Prompt), r.Headers, r.Annotations)
}
type Pod interface {
......@@ -24,6 +24,7 @@ spec:
- group: gateway.networking.k8s.io
kind: Gateway
name: inference-gateway
namespace: kgateway-system
rules:
- backendRefs:
- group: inference.networking.x-k8s.io
......
......@@ -21,7 +21,7 @@ metadata:
spec:
targetPortNumber: 8000
selector:
nvidia.com/dynamo-component: Frontend
nvidia.com/dynamo-component-type: frontend
extensionRef:
failureMode: FailOpen
group: ""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment