Unverified Commit a2a2753d authored by atchernych's avatar atchernych Committed by GitHub
Browse files

fix: Remove double tokenization in EPP Integration fixes [DYN-2076] (#8093)


Signed-off-by: default avatarAnna Tchernych <atchernych@nvidia.com>
parent 4d5db80a
......@@ -16,7 +16,7 @@ dynamo:
cuda13.0:
base_image: nvcr.io/nvidia/cuda-dl-base
base_image_tag: 25.11-cuda13.0-devel-ubuntu24.04
epp_image: us-central1-docker.pkg.dev/k8s-staging-images/gateway-api-inference-extension/epp:v0.5.1
epp_image: us-central1-docker.pkg.dev/k8s-staging-images/gateway-api-inference-extension/epp:v1.5.0-rc.2
frontend_image: nvcr.io/nvidia/base/ubuntu:noble-20250619
planner_build_image: python
planner_build_image_tag: 3.12-slim
......
......@@ -31,7 +31,7 @@
# -t dynamo/dynamo-epp:dev .
ARG RUST_IMAGE=rust:1.93.1
ARG BUILDER_IMAGE=golang:1.24
ARG BUILDER_IMAGE=golang:1.25
ARG BASE_IMAGE=ubuntu:24.04
# =============================================================================
......
......@@ -30,7 +30,7 @@ EXTRA_BUILD_ARGS ?=
DOCKER_BUILDX_CMD ?= docker buildx
IMAGE_BUILD_CMD ?= $(DOCKER_BUILDX_CMD) build
RUST_IMAGE ?= $(DOCKER_PROXY)rust:1.93.1
BUILDER_IMAGE ?= $(DOCKER_PROXY)golang:1.24
BUILDER_IMAGE ?= $(DOCKER_PROXY)golang:1.25
BASE_IMAGE ?= $(DOCKER_PROXY)ubuntu:24.04
# Container tool
......
......@@ -38,7 +38,7 @@ import (
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/gateway-api-inference-extension/cmd/epp/runner"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
plugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
// Dynamo plugins
"github.com/nvidia/dynamo/deploy/inference-gateway/pkg/plugins/disagg"
......
module github.com/nvidia/dynamo/deploy/inference-gateway
go 1.24.0
go 1.25.0
require (
sigs.k8s.io/controller-runtime v0.22.4
sigs.k8s.io/gateway-api-inference-extension v1.2.1
github.com/go-logr/logr v1.4.3
sigs.k8s.io/controller-runtime v0.23.3
sigs.k8s.io/gateway-api-inference-extension v0.0.0-20260416154104-206d40dca75d
)
require (
cel.dev/expr v0.25.1 // indirect
github.com/antlr4-go/antlr/v4 v4.13.0 // indirect
github.com/antlr4-go/antlr/v4 v4.13.1 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/blang/semver/v4 v4.0.0 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
......@@ -18,31 +19,39 @@ require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dennwc/varint v1.0.0 // indirect
github.com/emicklei/go-restful/v3 v3.13.0 // indirect
github.com/envoyproxy/go-control-plane/envoy v1.36.0 // indirect
github.com/envoyproxy/go-control-plane/envoy v1.37.0 // indirect
github.com/envoyproxy/protoc-gen-validate v1.3.0 // indirect
github.com/evanphx/json-patch/v5 v5.9.11 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/fxamacker/cbor/v2 v2.9.0 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-logr/zapr v1.3.0 // indirect
github.com/go-openapi/jsonpointer v0.21.2 // indirect
github.com/go-openapi/jsonreference v0.21.0 // indirect
github.com/go-openapi/swag v0.23.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/go-openapi/jsonpointer v0.22.4 // indirect
github.com/go-openapi/jsonreference v0.21.4 // indirect
github.com/go-openapi/swag v0.25.4 // indirect
github.com/go-openapi/swag/cmdutils v0.25.4 // indirect
github.com/go-openapi/swag/conv v0.25.4 // indirect
github.com/go-openapi/swag/fileutils v0.25.4 // indirect
github.com/go-openapi/swag/jsonname v0.25.4 // indirect
github.com/go-openapi/swag/jsonutils v0.25.4 // indirect
github.com/go-openapi/swag/loading v0.25.4 // indirect
github.com/go-openapi/swag/mangling v0.25.4 // indirect
github.com/go-openapi/swag/netutils v0.25.4 // indirect
github.com/go-openapi/swag/stringutils v0.25.4 // indirect
github.com/go-openapi/swag/typeutils v0.25.4 // indirect
github.com/go-openapi/swag/yamlutils v0.25.4 // indirect
github.com/google/btree v1.1.3 // indirect
github.com/google/cel-go v0.26.0 // indirect
github.com/google/cel-go v0.28.0 // indirect
github.com/google/gnostic-models v0.7.0 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/grafana/regexp v0.0.0-20250905093917-f7b3be9d1853 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/jellydator/ttlcache/v3 v3.4.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/mailru/easyjson v0.9.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
......@@ -50,60 +59,58 @@ require (
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/prometheus/client_golang v1.23.2 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.67.4 // indirect
github.com/prometheus/procfs v0.17.0 // indirect
github.com/prometheus/prometheus v0.308.1 // indirect
github.com/spf13/cobra v1.9.1 // indirect
github.com/prometheus/common v0.67.5 // indirect
github.com/prometheus/procfs v0.19.2 // indirect
github.com/prometheus/prometheus v0.310.0 // indirect
github.com/spf13/cobra v1.10.2 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/stoewer/go-strcase v1.3.0 // indirect
github.com/x448/float16 v0.8.4 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect
go.opentelemetry.io/otel v1.40.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.39.0 // indirect
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.39.0 // indirect
go.opentelemetry.io/otel/metric v1.40.0 // indirect
go.opentelemetry.io/otel/sdk v1.40.0 // indirect
go.opentelemetry.io/otel/trace v1.40.0 // indirect
go.opentelemetry.io/proto/otlp v1.9.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0 // indirect
go.opentelemetry.io/otel v1.43.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0 // indirect
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.43.0 // indirect
go.opentelemetry.io/otel/metric v1.43.0 // indirect
go.opentelemetry.io/otel/sdk v1.43.0 // indirect
go.opentelemetry.io/otel/trace v1.43.0 // indirect
go.opentelemetry.io/proto/otlp v1.10.0 // indirect
go.uber.org/atomic v1.11.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.1 // indirect
go.yaml.in/yaml/v2 v2.4.3 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/exp v0.0.0-20250808145144-a408d31f581a // indirect
golang.org/x/net v0.48.0 // indirect
golang.org/x/oauth2 v0.34.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.40.0 // indirect
golang.org/x/term v0.38.0 // indirect
golang.org/x/text v0.32.0 // indirect
golang.org/x/time v0.13.0 // indirect
golang.org/x/exp v0.0.0-20260112195511-716be5621a96 // indirect
golang.org/x/net v0.52.0 // indirect
golang.org/x/oauth2 v0.35.0 // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/term v0.41.0 // indirect
golang.org/x/text v0.35.0 // indirect
golang.org/x/time v0.15.0 // indirect
gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect
google.golang.org/grpc v1.79.3 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
google.golang.org/grpc v1.80.0 // indirect
google.golang.org/protobuf v1.36.11 // indirect
gopkg.in/evanphx/json-patch.v4 v4.13.0 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
k8s.io/api v0.34.3 // indirect
k8s.io/apiextensions-apiserver v0.34.3 // indirect
k8s.io/apimachinery v0.34.3 // indirect
k8s.io/apiserver v0.34.3 // indirect
k8s.io/client-go v0.34.3 // indirect
k8s.io/component-base v0.34.3 // indirect
k8s.io/api v0.35.3 // indirect
k8s.io/apiextensions-apiserver v0.35.3 // indirect
k8s.io/apimachinery v0.35.3 // indirect
k8s.io/apiserver v0.35.3 // indirect
k8s.io/client-go v0.35.3 // indirect
k8s.io/component-base v0.35.3 // indirect
k8s.io/klog/v2 v2.130.1 // indirect
k8s.io/kube-openapi v0.0.0-20250814151709-d7b6acb124c3 // indirect
k8s.io/utils v0.0.0-20250820121507-0af2bda4dd1d // indirect
k8s.io/kube-openapi v0.0.0-20260127142750-a19766b6e2d4 // indirect
k8s.io/utils v0.0.0-20260108192941-914a6e750570 // indirect
sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.31.2 // indirect
sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 // indirect
sigs.k8s.io/randfill v1.0.0 // indirect
sigs.k8s.io/structured-merge-diff/v6 v6.3.1 // indirect
sigs.k8s.io/structured-merge-diff/v6 v6.3.2 // indirect
sigs.k8s.io/yaml v1.6.0 // indirect
)
// NOTE: For local development, uncomment the replace directive below.
// For Docker builds, keep it commented out to use the published v1.2.1 release.
// replace sigs.k8s.io/gateway-api-inference-extension => ../../../gaie_latest/gateway-api-inference-extension
// For Docker builds, keep it commented out to use the published release.
// replace sigs.k8s.io/gateway-api-inference-extension => ../../../gaie/gateway-api-inference-extension
This diff is collapsed.
......@@ -24,12 +24,11 @@ import (
"sync"
log "sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
rc "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/observability/logging"
fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
plugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
rc "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol"
schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
dynscorer "github.com/nvidia/dynamo/deploy/inference-gateway/pkg/plugins/dynamo_kv_scorer"
)
......@@ -44,16 +43,14 @@ const (
PrefillDpRankHeader = "x-prefill-dp-rank"
RoutingModeHeader = "x-dynamo-routing-mode"
// decodeStateKey is the key used to store routing state in PluginState
decodeStateKey = "dynamo-decode-routing-state"
)
// compile-time type assertions
var _ framework.Scorer = &DynDecodeScorer{}
var _ schedtypes.Scorer = &DynDecodeScorer{}
var _ plugins.Plugin = &DynDecodeScorer{}
var _ rc.PreRequest = &DynDecodeScorer{}
var _ rc.ResponseStreaming = &DynDecodeScorer{}
var _ rc.ResponseComplete = &DynDecodeScorer{}
var _ rc.ResponseBodyProcessor = &DynDecodeScorer{}
// DecodeRoutingState holds routing information passed from Score() to PreRequest().
type DecodeRoutingState struct {
......@@ -92,7 +89,6 @@ func DynDecodeScorerFactory(name string, rawParameters json.RawMessage, handle p
}
}
// Initialize the shared FFI (idempotent)
if err := dynscorer.InitFFI(); err != nil {
return nil, fmt.Errorf("Dynamo FFI init for decode scorer failed: %w", err)
}
......@@ -111,15 +107,6 @@ func NewDynDecodeScorer(ctx context.Context, enforceDisagg bool) *DynDecodeScore
}
// DynDecodeScorer is a scorer plugin for the decode scheduling profile.
//
// When Score() is called, it:
// 1. Reads PrefillEnabledState from CycleState (written by DisaggProfileHandler).
// 2. Calls the Dynamo FFI decode router with is_disaggregated flag.
// 3. Sets routing headers on the request.
// 4. Stores routing state for PreRequest to register with router bookkeeping.
//
// It also implements PreRequest, ResponseStreaming, and ResponseComplete lifecycle hooks
// for router bookkeeping (add_request, mark_prefill_complete, free_request).
type DynDecodeScorer struct {
typedName plugins.TypedName
pluginState *plugins.PluginState
......@@ -138,8 +125,13 @@ func (s *DynDecodeScorer) WithName(name string) *DynDecodeScorer {
return s
}
// Score scores pods for decode suitability.
func (s *DynDecodeScorer) Score(ctx context.Context, cycleState *schedtypes.CycleState, req *schedtypes.LLMRequest, pods []schedtypes.Pod) map[schedtypes.Pod]float64 {
// Category returns the scorer category.
func (s *DynDecodeScorer) Category() schedtypes.ScorerCategory {
return schedtypes.Affinity
}
// Score scores endpoints for decode suitability.
func (s *DynDecodeScorer) Score(ctx context.Context, cycleState *schedtypes.CycleState, req *schedtypes.InferenceRequest, endpoints []schedtypes.Endpoint) map[schedtypes.Endpoint]float64 {
logger := log.FromContext(ctx)
isDisaggregated := readPrefillEnabled(cycleState)
......@@ -147,29 +139,28 @@ func (s *DynDecodeScorer) Score(ctx context.Context, cycleState *schedtypes.Cycl
requestJSON, err := buildRequestJSON(req)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "DynDecodeScorer: failed to build request")
return uniformScores(pods, 1.0)
return uniformScores(endpoints, 1.0)
}
podsJSON := serializePods(pods)
logger.V(logutil.DEFAULT).Info("DynDecodeScorer: pods received for scoring",
"podCount", len(pods),
"podsJSON", string(podsJSON))
endpointsJSON := serializeEndpoints(endpoints)
logger.V(logutil.DEFAULT).Info("DynDecodeScorer: endpoints received for scoring",
"endpointCount", len(endpoints),
"endpointsJSON", string(endpointsJSON))
result, err := dynscorer.CallRouteDecodeRequest(requestJSON, podsJSON, isDisaggregated)
result, err := dynscorer.CallRouteDecodeRequest(requestJSON, endpointsJSON, isDisaggregated)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "DynDecodeScorer: FFI decode routing failed")
return uniformScores(pods, 1.0)
return uniformScores(endpoints, 1.0)
}
workerIDStr := fmt.Sprintf("%d", result.WorkerID)
dpRankStr := strconv.FormatUint(uint64(result.DpRank), 10)
logger.V(logutil.DEFAULT).Info("DynDecodeScorer: decode worker selected",
logger.V(logutil.DEFAULT).Info("[EPP-SCORER] FFI returned tokens from C bindings tokenization",
"decodeWorkerID", workerIDStr,
"decodeDpRank", result.DpRank,
"isDisaggregated", isDisaggregated,
"tokenCount", len(result.TokenData))
// Set routing headers
if req.Headers == nil {
req.Headers = map[string]string{}
}
......@@ -202,14 +193,15 @@ func (s *DynDecodeScorer) Score(ctx context.Context, cycleState *schedtypes.Cycl
s.pluginState.Write(req.RequestId, plugins.StateKey(decodeStateKey), routingState)
}
// Score: all decode pods get 1.0 since the router's internal selection is authoritative
// and the worker ID is communicated via headers.
return uniformScores(pods, 1.0)
// Inject pre-computed tokens into the request body so the frontend
// sidecar can skip redundant tokenization.
setTokenizedPrompt(req, result.TokenData, logger)
return uniformScores(endpoints, 1.0)
}
// PreRequest is called after scheduling is finalized and before the request is sent to the worker.
// This registers the request with the Dynamo router's bookkeeping.
func (s *DynDecodeScorer) PreRequest(ctx context.Context, request *schedtypes.LLMRequest, _ *schedtypes.SchedulingResult) {
// PreRequest registers the request with the Dynamo router's bookkeeping.
func (s *DynDecodeScorer) PreRequest(ctx context.Context, request *schedtypes.InferenceRequest, _ *schedtypes.SchedulingResult) {
logger := log.FromContext(ctx)
if request == nil || request.RequestId == "" {
......@@ -248,42 +240,37 @@ func (s *DynDecodeScorer) PreRequest(ctx context.Context, request *schedtypes.LL
"tokenCount", len(state.TokenData))
}
// ResponseStreaming is called for each chunk of a streaming response.
// On the first token, it marks prefill as complete in the Dynamo router's bookkeeping.
func (s *DynDecodeScorer) ResponseStreaming(ctx context.Context, request *schedtypes.LLMRequest, _ *rc.Response, _ *backend.Pod) {
// ResponseBody handles streaming chunks and end-of-stream cleanup.
// On the first token it marks prefill as complete; on EndOfStream it frees the request.
func (s *DynDecodeScorer) ResponseBody(ctx context.Context, request *schedtypes.InferenceRequest, response *rc.Response, _ *fwkdl.EndpointMetadata) {
if request == nil || request.RequestId == "" {
return
}
logger := log.FromContext(ctx)
// Mark prefill complete on first token
if _, alreadySeen := s.firstTokenSeen.LoadOrStore(request.RequestId, true); !alreadySeen {
logger := log.FromContext(ctx)
if err := dynscorer.CallMarkPrefillComplete(request.RequestId); err != nil {
logger.V(logutil.DEFAULT).Error(err, "DynDecodeScorer ResponseStreaming: failed to mark prefill complete",
logger.V(logutil.DEFAULT).Error(err, "DynDecodeScorer ResponseBody: failed to mark prefill complete",
"requestID", request.RequestId)
} else {
logger.V(logutil.VERBOSE).Info("DynDecodeScorer ResponseBody: marked prefill complete",
"requestID", request.RequestId)
return
}
logger.V(logutil.VERBOSE).Info("DynDecodeScorer ResponseStreaming: marked prefill complete",
"requestID", request.RequestId)
}
}
// ResponseComplete is called after the complete response is sent to the client.
// It cleans up the router bookkeeping state for the completed request.
func (s *DynDecodeScorer) ResponseComplete(ctx context.Context, request *schedtypes.LLMRequest, _ *rc.Response, _ *backend.Pod) {
logger := log.FromContext(ctx)
if request == nil || request.RequestId == "" {
return
}
s.firstTokenSeen.Delete(request.RequestId)
// Free request on end of stream — must always run regardless of
// earlier errors to avoid leaking router bookkeeping state.
if response != nil && response.EndOfStream {
s.firstTokenSeen.Delete(request.RequestId)
if err := dynscorer.CallFreeRequest(request.RequestId); err != nil {
logger.V(logutil.DEFAULT).Error(err, "DynDecodeScorer ResponseComplete: failed to free request",
"requestID", request.RequestId)
return
if err := dynscorer.CallFreeRequest(request.RequestId); err != nil {
logger.V(logutil.DEFAULT).Error(err, "DynDecodeScorer ResponseBody: failed to free request",
"requestID", request.RequestId)
} else {
logger.V(logutil.VERBOSE).Info("DynDecodeScorer ResponseBody: freed request",
"requestID", request.RequestId)
}
}
logger.V(logutil.VERBOSE).Info("DynDecodeScorer ResponseComplete: freed request",
"requestID", request.RequestId)
}
......@@ -23,10 +23,9 @@ import (
"strconv"
log "sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/observability/logging"
plugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
dynscorer "github.com/nvidia/dynamo/deploy/inference-gateway/pkg/plugins/dynamo_kv_scorer"
)
......@@ -37,7 +36,7 @@ const (
)
// compile-time type assertion
var _ framework.Scorer = &DynPrefillScorer{}
var _ schedtypes.Scorer = &DynPrefillScorer{}
// DynPrefillScorerConfig holds the configuration for the DynPrefillScorer plugin.
type DynPrefillScorerConfig struct{}
......@@ -51,7 +50,6 @@ func DynPrefillScorerFactory(name string, rawParameters json.RawMessage, _ plugi
}
}
// Initialize the shared FFI (idempotent)
if err := dynscorer.InitFFI(); err != nil {
return nil, fmt.Errorf("Dynamo FFI init for prefill scorer failed: %w", err)
}
......@@ -67,12 +65,6 @@ func NewDynPrefillScorer() *DynPrefillScorer {
}
// DynPrefillScorer is a scorer plugin for the prefill scheduling profile.
//
// When Score() is called, it:
// 1. Reads PrefillEnabledState from CycleState (written by DisaggProfileHandler).
// 2. If prefill is NOT enabled, returns zero scores.
// 3. If prefill IS enabled, calls the Dynamo FFI prefill router to select the best prefill worker.
// 4. Assigns score 1.0 to all pods (the router's selection is authoritative, communicated via headers).
type DynPrefillScorer struct {
typedName plugins.TypedName
}
......@@ -88,35 +80,36 @@ func (s *DynPrefillScorer) WithName(name string) *DynPrefillScorer {
return s
}
// Score scores pods for prefill suitability.
func (s *DynPrefillScorer) Score(ctx context.Context, cycleState *schedtypes.CycleState, req *schedtypes.LLMRequest, pods []schedtypes.Pod) map[schedtypes.Pod]float64 {
// Category returns the scorer category.
func (s *DynPrefillScorer) Category() schedtypes.ScorerCategory {
return schedtypes.Affinity
}
// Score scores endpoints for prefill suitability.
func (s *DynPrefillScorer) Score(ctx context.Context, cycleState *schedtypes.CycleState, req *schedtypes.InferenceRequest, endpoints []schedtypes.Endpoint) map[schedtypes.Endpoint]float64 {
logger := log.FromContext(ctx)
if !readPrefillEnabled(cycleState) {
logger.V(logutil.VERBOSE).Info("DynPrefillScorer: prefill not enabled, returning zero scores")
return uniformScores(pods, 0)
return uniformScores(endpoints, 0)
}
requestJSON, err := buildRequestJSON(req)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "DynPrefillScorer: failed to build request")
return uniformScores(pods, 0)
return uniformScores(endpoints, 0)
}
podsJSON := serializePods(pods)
logger.V(logutil.DEFAULT).Info("DynPrefillScorer: pods received for scoring",
"podCount", len(pods),
"podsJSON", string(podsJSON))
endpointsJSON := serializeEndpoints(endpoints)
logger.V(logutil.DEFAULT).Info("DynPrefillScorer: endpoints received for scoring",
"endpointCount", len(endpoints),
"endpointsJSON", string(endpointsJSON))
result, err := dynscorer.CallRoutePrefillRequest(requestJSON, podsJSON)
result, err := dynscorer.CallRoutePrefillRequest(requestJSON, endpointsJSON)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "DynPrefillScorer: FFI prefill routing failed")
// Overwrite PrefillEnabled to false so the decode scorer falls back
// to aggregated routing. Without this, the prefill profile "succeeds"
// (picker picks a pod) but the prefill header is not set, causing
// the sidecar to reject the request in direct routing mode.
cycleState.Write(PrefillEnabledStateKey, &PrefillEnabledState{Enabled: false})
return uniformScores(pods, 0)
return uniformScores(endpoints, 0)
}
prefillWorkerID := strconv.FormatUint(result.WorkerID, 10)
......@@ -125,10 +118,6 @@ func (s *DynPrefillScorer) Score(ctx context.Context, cycleState *schedtypes.Cyc
"prefillDpRank", result.DpRank,
"tokenCount", len(result.TokenData))
// Set the prefill worker ID and DP rank headers directly on the request.
// The request object is shared across all profile runs in the scheduling
// cycle, so the decode scorer (which runs in the next profile) will see it.
// This is more reliable than CycleState which may be scoped per profile.
if req.Headers == nil {
req.Headers = map[string]string{}
}
......@@ -139,8 +128,5 @@ func (s *DynPrefillScorer) Score(ctx context.Context, cycleState *schedtypes.Cyc
delete(req.Headers, PrefillDpRankHeader)
}
// Score: 1.0 for all pods. The label-filter has already restricted to prefill workers,
// and the FFI router's internal selection is authoritative.
// In the future, we could match worker IDs to pod names for precise scoring.
return uniformScores(pods, 1.0)
return uniformScores(endpoints, 1.0)
}
......@@ -21,37 +21,19 @@ import (
"encoding/json"
"errors"
"fmt"
"os"
"strings"
log "sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/observability/logging"
plugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
)
func getEnvBoolOrDefault(key string, defaultVal bool) bool {
val, ok := os.LookupEnv(key)
if !ok {
return defaultVal
}
switch strings.ToLower(val) {
case "true", "1", "yes":
return true
case "false", "0", "no":
return false
default:
return defaultVal
}
}
const (
DisaggProfileHandlerType = "disagg-profile-handler"
)
// compile-time type assertion
var _ framework.ProfileHandler = &DisaggProfileHandler{}
var _ schedtypes.ProfileHandler = &DisaggProfileHandler{}
// DisaggProfileHandlerConfig holds the configuration for the DisaggProfileHandler.
type DisaggProfileHandlerConfig struct{}
......@@ -77,41 +59,6 @@ func NewDisaggProfileHandler(enforceDisagg bool) *DisaggProfileHandler {
}
// DisaggProfileHandler is a ProfileHandler that orchestrates prefill/decode disaggregated serving.
//
// # Disaggregated mode detection
//
// In Dynamo's native architecture, disaggregated mode is determined by whether prefill workers
// actually exist at runtime (the is_disaggregated flag in the Rust KV router). However, the
// GAIE EPP framework determines profile availability at configuration time, not at runtime.
// To bridge this gap, DisaggProfileHandler uses the EPP profile mechanism as a proxy:
// it checks whether a "prefill" scheduling profile is registered in the config. If prefill
// workers are configured but none are actually running, the prefill profile's label-filter
// will find zero pods, causing the profile to fail — and the handler gracefully degrades
// to aggregated mode (see below).
//
// # Scheduling flow
//
// On each scheduling cycle it:
// 1. Checks whether a "prefill" profile is registered in the config.
// 2. Writes PrefillEnabledState into CycleState so scorer plugins can read it.
// 3. If a prefill profile exists: runs the "prefill" profile first, then the "decode" profile.
// The "decode" profile is the primary (the pod the request is ultimately sent to).
// 4. If no prefill profile exists: runs only the "decode" profile (pure aggregated mode).
//
// # Graceful degradation
//
// When a prefill profile is configured but no prefill workers are available at runtime,
// the handler degrades gracefully to aggregated mode on a per-request basis:
//
// 1. Pick (iteration 1): prefill profile exists → writes PrefillEnabled=true → runs prefill profile.
// 2. Prefill profile runs: label-filter finds 0 prefill pods → profile fails → result is nil.
// 3. Pick (iteration 2): sees prefill result is nil → overwrites PrefillEnabled=false → runs decode profile.
// 4. Decode scorer runs: reads PrefillEnabled=false → passes isDisaggregated=false to the Rust
// decode router → full KV cache overlap scoring is used (overlap_score_weight=1.0).
//
// This means the same YAML config works transparently for both aggregated and disaggregated
// deployments. If prefill workers come up later, subsequent requests automatically use
// disaggregated routing. If they go down, requests fall back to aggregated mode.
type DisaggProfileHandler struct {
typedName plugins.TypedName
enforceDisagg bool
......@@ -129,23 +76,11 @@ func (h *DisaggProfileHandler) WithName(name string) *DisaggProfileHandler {
}
// Pick selects which profiles to run in the current iteration.
//
// Iteration 1 (no results yet):
// - Writes PrefillEnabledState into CycleState.
// - If a "prefill" profile exists → returns it alone (run prefill first).
// - Otherwise → returns the "decode" profile.
//
// Iteration 2 (prefill result exists, decode not yet):
// - Returns the "decode" profile.
//
// Iteration 3+ (all results collected):
// - Returns empty map to stop the loop.
func (h *DisaggProfileHandler) Pick(ctx context.Context, cycleState *schedtypes.CycleState, _ *schedtypes.LLMRequest,
profiles map[string]*framework.SchedulerProfile, profileResults map[string]*schedtypes.ProfileRunResult) map[string]*framework.SchedulerProfile {
func (h *DisaggProfileHandler) Pick(ctx context.Context, cycleState *schedtypes.CycleState, _ *schedtypes.InferenceRequest,
profiles map[string]schedtypes.SchedulerProfile, profileResults map[string]*schedtypes.ProfileRunResult) map[string]schedtypes.SchedulerProfile {
logger := log.FromContext(ctx).V(logutil.VERBOSE)
// First call: determine if prefill is enabled and write state.
if len(profileResults) == 0 {
_, prefillExists := profiles[PrefillProfileName]
state := &PrefillEnabledState{Enabled: prefillExists}
......@@ -153,64 +88,49 @@ func (h *DisaggProfileHandler) Pick(ctx context.Context, cycleState *schedtypes.
logger.Info("DisaggProfileHandler: prefill enabled state determined", "prefillEnabled", prefillExists)
if prefillExists {
// Run prefill profile first.
return map[string]*framework.SchedulerProfile{
return map[string]schedtypes.SchedulerProfile{
PrefillProfileName: profiles[PrefillProfileName],
}
}
// No prefill profile — run decode only.
if decodeProfile, ok := profiles[DecodeProfileName]; ok {
return map[string]*framework.SchedulerProfile{
return map[string]schedtypes.SchedulerProfile{
DecodeProfileName: decodeProfile,
}
}
// Fallback: return all profiles.
return profiles
}
// Second call: prefill has run, now run decode.
if prefillResult, prefillDone := profileResults[PrefillProfileName]; prefillDone {
if _, decodeDone := profileResults[DecodeProfileName]; !decodeDone {
if prefillResult == nil {
if h.enforceDisagg {
// enforce_disagg=true: do not fall back to aggregated mode.
// Stop the scheduling loop — ProcessResults will reject the request.
logger.Info("DisaggProfileHandler: prefill profile failed and enforce_disagg=true, rejecting request")
return map[string]*framework.SchedulerProfile{}
return map[string]schedtypes.SchedulerProfile{}
}
// enforce_disagg=false: fall back to aggregated decode.
logger.Info("DisaggProfileHandler: prefill profile failed (no workers?), falling back to aggregated decode")
cycleState.Write(PrefillEnabledStateKey, &PrefillEnabledState{Enabled: false})
}
if decodeProfile, ok := profiles[DecodeProfileName]; ok {
return map[string]*framework.SchedulerProfile{
return map[string]schedtypes.SchedulerProfile{
DecodeProfileName: decodeProfile,
}
}
}
}
// All profiles have been executed.
return map[string]*framework.SchedulerProfile{}
return map[string]schedtypes.SchedulerProfile{}
}
// ProcessResults aggregates the profile run results and designates the primary profile.
// The "decode" profile is always the primary (the pod that handles the request).
func (h *DisaggProfileHandler) ProcessResults(_ context.Context, _ *schedtypes.CycleState, req *schedtypes.LLMRequest,
func (h *DisaggProfileHandler) ProcessResults(_ context.Context, _ *schedtypes.CycleState, req *schedtypes.InferenceRequest,
profileResults map[string]*schedtypes.ProfileRunResult) (*schedtypes.SchedulingResult, error) {
// When enforce_disagg=true and the prefill worker ID header was not set
// (prefill router not activated or scorer failed), reject the request
// at the EPP level instead of forwarding it to the sidecar without
// routing headers.
if h.enforceDisagg && (req.Headers == nil || req.Headers[PrefillWorkerIDHeader] == "") {
// Only enforce if a prefill profile was configured and ran.
if _, prefillRan := profileResults[PrefillProfileName]; prefillRan {
return nil, errors.New(
"disaggregated mode enforced (DYN_ENFORCE_DISAGG=true) but prefill workers " +
"are not available; request rejected. Either wait for prefill workers " +
"to register or set DYN_ENFORCE_DISAGG=false to allow aggregated fallback")
"are not available; request rejected")
}
}
......@@ -218,10 +138,8 @@ func (h *DisaggProfileHandler) ProcessResults(_ context.Context, _ *schedtypes.C
return nil, errors.New("disagg profile handler received no profile results")
}
// Determine primary profile name.
primaryProfile := DecodeProfileName
if _, ok := profileResults[DecodeProfileName]; !ok {
// If there's no decode result, pick whichever profile ran.
for name := range profileResults {
primaryProfile = name
break
......@@ -232,8 +150,7 @@ func (h *DisaggProfileHandler) ProcessResults(_ context.Context, _ *schedtypes.C
if h.enforceDisagg {
return nil, errors.New(
"disaggregated mode enforced (DYN_ENFORCE_DISAGG=true) but prefill workers " +
"are not available; request rejected. Either wait for prefill workers " +
"to register or set DYN_ENFORCE_DISAGG=false to allow aggregated fallback")
"are not available; request rejected")
}
return nil, fmt.Errorf("primary profile '%s' failed to produce a result", primaryProfile)
}
......
......@@ -29,9 +29,14 @@ package disagg
import (
"encoding/json"
"fmt"
"os"
"strings"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
"github.com/go-logr/logr"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/observability/logging"
plugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
fwkrh "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requesthandling"
schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
dynscorer "github.com/nvidia/dynamo/deploy/inference-gateway/pkg/plugins/dynamo_kv_scorer"
)
......@@ -41,19 +46,10 @@ const (
DecodeProfileName = "decode"
// PrefillEnabledStateKey tracks whether this request should use disaggregated routing.
// Initially set to true by DisaggProfileHandler.Pick() if a "prefill" scheduling
// profile exists in the EPP config. Overwritten to false per-request in two cases:
// - DisaggProfileHandler.Pick(): prefill profile result is nil (no prefill pods
// passed the label-filter).
// - DynPrefillScorer.Score(): prefill FFI routing failed (prefill router not yet
// activated, e.g., worker registered in K8s but not yet in Dynamo discovery).
// The decode scorer reads this to decide whether to use overlap_score_weight=0
// (disaggregated) or normal KV cache overlap scoring (aggregated).
PrefillEnabledStateKey = plugins.StateKey("disagg-prefill-enabled")
)
// PrefillEnabledState stores whether prefill is enabled for the current scheduling cycle.
// Written by DisaggProfileHandler, read by PrefillScorer and DecodeScorer.
type PrefillEnabledState struct {
Enabled bool
}
......@@ -64,7 +60,6 @@ func (s *PrefillEnabledState) Clone() plugins.StateData {
}
// readPrefillEnabled reads the PrefillEnabledState from CycleState.
// Returns false if the state is not found or not set.
func readPrefillEnabled(cycleState *schedtypes.CycleState) bool {
state, err := schedtypes.ReadCycleStateKey[*PrefillEnabledState](cycleState, PrefillEnabledStateKey)
if err == nil && state != nil {
......@@ -74,7 +69,7 @@ func readPrefillEnabled(cycleState *schedtypes.CycleState) bool {
}
// buildRequestJSON builds an OpenAI-compatible JSON string from a GAIE LLMRequest.
func buildRequestJSON(req *schedtypes.LLMRequest) (string, error) {
func buildRequestJSON(req *schedtypes.InferenceRequest) (string, error) {
requestBody, err := dynscorer.BuildOpenAIRequest(req)
if err != nil {
return "", fmt.Errorf("failed to build OpenAI request: %w", err)
......@@ -86,24 +81,81 @@ func buildRequestJSON(req *schedtypes.LLMRequest) (string, error) {
return string(data), nil
}
// serializePods converts pods to a JSON string for the FFI filter.
// Returns an empty string if serialization fails or pods is empty.
func serializePods(pods []schedtypes.Pod) string {
if len(pods) == 0 {
// serializeEndpoints converts endpoints to a JSON string for the FFI filter.
func serializeEndpoints(endpoints []schedtypes.Endpoint) string {
if len(endpoints) == 0 {
return ""
}
pj, err := dynscorer.SerializePodsToJSON(pods)
pj, err := dynscorer.SerializeEndpointsToJSON(endpoints)
if err != nil {
return ""
}
return pj
}
// uniformScores returns a score map with the same score for every pod.
func uniformScores(pods []schedtypes.Pod, score float64) map[schedtypes.Pod]float64 {
out := make(map[schedtypes.Pod]float64, len(pods))
for _, p := range pods {
out[p] = score
// uniformScores returns a score map with the same score for every endpoint.
func uniformScores(endpoints []schedtypes.Endpoint, score float64) map[schedtypes.Endpoint]float64 {
out := make(map[schedtypes.Endpoint]float64, len(endpoints))
for _, ep := range endpoints {
out[ep] = score
}
return out
}
// setTokenizedPrompt stores pre-computed token IDs on the LLMRequest and
// injects nvext.token_data into the PayloadMap so it is forwarded to the
// worker in the request body.
//
// The GAIE framework re-serializes the PayloadMap after scheduling/PreRequest
// plugins run (PR #2854), so this mutation is included in the forwarded body.
func setTokenizedPrompt(req *schedtypes.InferenceRequest, tokens []int64, logger logr.Logger) {
if req == nil || len(tokens) == 0 {
logger.V(logutil.DEFAULT).Info("[EPP-INJECT] No tokens to inject (empty token list)")
return
}
tokenIDs := make([]uint32, len(tokens))
for i, t := range tokens {
tokenIDs[i] = uint32(t)
}
req.TokenizedPrompt = &schedtypes.TokenizedPrompt{
TokenIDs: tokenIDs,
}
// Inject into the PayloadMap so the body includes nvext.token_data.
payloadInjected := false
if req.Body != nil {
if pm, ok := req.Body.Payload.(fwkrh.PayloadMap); ok {
nvext, _ := pm["nvext"].(map[string]any)
if nvext == nil {
nvext = map[string]any{}
}
nvext["token_data"] = tokenIDs
pm["nvext"] = nvext
payloadInjected = true
}
}
if payloadInjected {
logger.V(logutil.DEFAULT).Info("[EPP-INJECT] Injected pre-computed tokens into request body nvext.token_data",
"tokenCount", len(tokenIDs),
"requestId", req.RequestId)
} else {
logger.V(logutil.DEFAULT).Error(nil, "[EPP-INJECT] Failed to inject nvext.token_data: Payload is not a PayloadMap — sidecar will re-tokenize",
"tokenCount", len(tokenIDs),
"requestId", req.RequestId)
}
}
func getEnvBoolOrDefault(key string, def bool) bool {
if v := os.Getenv(key); v != "" {
switch strings.ToLower(v) {
case "true", "1", "yes", "on":
return true
case "false", "0", "no", "off":
return false
}
}
return def
}
......@@ -104,7 +104,7 @@ import (
"unsafe"
ctrl "sigs.k8s.io/controller-runtime"
schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
)
var logger = ctrl.Log.WithName("dynamo-kv-scorer")
......@@ -132,7 +132,6 @@ func loadDynamoConfig() {
ffiNamespace = getEnvOrDefault("DYN_NAMESPACE_PREFIX", getEnvOrDefault("DYN_NAMESPACE", "vllm-agg"))
ffiComponent = "backend" // This is not the same as DYN_COMPONENT=epp (in this case)
ffiEnforceDisagg = getEnvBoolOrDefault("DYN_ENFORCE_DISAGG", false)
// Note: model name and kv_cache_block_size are now auto-discovered from the model card
logger.Info("Dynamo KV Scorer config loaded",
"namespace", ffiNamespace,
"component", ffiComponent,
......@@ -170,7 +169,6 @@ func initFFI() error {
defer C.free(unsafe.Pointer(ns))
defer C.free(unsafe.Pointer(cm))
// Create router handles
routerHandlesMutex.Lock()
defer routerHandlesMutex.Unlock()
......@@ -204,8 +202,8 @@ func InitFFI() error {
return initFFI()
}
// podInfoJSON is the JSON-serializable representation of a backend.Pod (datalayer.PodInfo).
type podInfoJSON struct {
// endpointInfoJSON is the JSON-serializable representation of an endpoint's metadata.
type endpointInfoJSON struct {
Name string `json:"name"`
Namespace string `json:"namespace"`
PodName string `json:"podName"`
......@@ -215,7 +213,7 @@ type podInfoJSON struct {
Labels map[string]string `json:"labels"`
}
// metricsJSON is the JSON-serializable representation of backendmetrics.MetricsState (datalayer.Metrics).
// metricsJSON is the JSON-serializable representation of endpoint metrics.
type metricsJSON struct {
ActiveModels map[string]int `json:"activeModels"`
WaitingModels map[string]int `json:"waitingModels"`
......@@ -229,42 +227,42 @@ type metricsJSON struct {
UpdateTime time.Time `json:"updateTime"`
}
// podJSON is the JSON-serializable representation of a schedtypes.Pod passed across the FFI boundary.
type podJSON struct {
Pod *podInfoJSON `json:"pod"`
Metrics *metricsJSON `json:"metrics"`
// endpointJSON is the JSON-serializable representation of a schedtypes.Endpoint passed across the FFI boundary.
type endpointJSON struct {
Pod *endpointInfoJSON `json:"pod"`
Metrics *metricsJSON `json:"metrics"`
}
// SerializePodsToJSON converts a slice of schedtypes.Pod into a JSON string
// SerializeEndpointsToJSON converts a slice of schedtypes.Endpoint into a JSON string
// suitable for passing across the C FFI boundary to the Rust router.
func SerializePodsToJSON(pods []schedtypes.Pod) (string, error) {
out := make([]podJSON, 0, len(pods))
for _, p := range pods {
entry := podJSON{}
if podInfo := p.GetPod(); podInfo != nil {
entry.Pod = &podInfoJSON{
Name: podInfo.NamespacedName.Name,
Namespace: podInfo.NamespacedName.Namespace,
PodName: podInfo.PodName,
Address: podInfo.Address,
Port: podInfo.Port,
MetricsHost: podInfo.MetricsHost,
Labels: podInfo.Labels,
func SerializeEndpointsToJSON(endpoints []schedtypes.Endpoint) (string, error) {
out := make([]endpointJSON, 0, len(endpoints))
for _, ep := range endpoints {
entry := endpointJSON{}
if meta := ep.GetMetadata(); meta != nil {
entry.Pod = &endpointInfoJSON{
Name: meta.NamespacedName.Name,
Namespace: meta.NamespacedName.Namespace,
PodName: meta.PodName,
Address: meta.Address,
Port: meta.Port,
MetricsHost: meta.MetricsHost,
Labels: meta.Labels,
}
}
if m := p.GetMetrics(); m != nil {
if m := ep.GetMetrics(); m != nil {
entry.Metrics = &metricsJSON{
ActiveModels: m.ActiveModels,
WaitingModels: m.WaitingModels,
MaxActiveModels: m.MaxActiveModels,
RunningQueueSize: m.RunningQueueSize,
RunningQueueSize: m.RunningRequestsSize,
WaitingQueueSize: m.WaitingQueueSize,
KVCacheUsagePercent: m.KVCacheUsagePercent,
KvCacheMaxTokenCapacity: m.KvCacheMaxTokenCapacity,
CacheBlockSize: m.CacheBlockSize,
CacheNumGPUBlocks: m.CacheNumGPUBlocks,
CacheNumGPUBlocks: m.CacheNumBlocks,
UpdateTime: m.UpdateTime,
}
}
......@@ -274,15 +272,14 @@ func SerializePodsToJSON(pods []schedtypes.Pod) (string, error) {
data, err := json.Marshal(out)
if err != nil {
return "", fmt.Errorf("failed to serialize pods: %w", err)
return "", fmt.Errorf("failed to serialize endpoints: %w", err)
}
return string(data), nil
}
func BuildOpenAIRequest(req *schedtypes.LLMRequest) (map[string]any, error) {
func BuildOpenAIRequest(req *schedtypes.InferenceRequest) (map[string]any, error) {
requestBody := make(map[string]any)
// Preserve the original message structure for correct chat template application
if req == nil || req.Body == nil {
return nil, fmt.Errorf("missing request body")
}
......@@ -304,17 +301,14 @@ func BuildOpenAIRequest(req *schedtypes.LLMRequest) (map[string]any, error) {
return nil, fmt.Errorf("empty chat messages")
}
requestBody["messages"] = messages
} else if req.Body.Completions != nil && strings.TrimSpace(req.Body.Completions.Prompt) != "" {
// Legacy completions format - wrap as single user message
} else if req.Body.Completions != nil && !req.Body.Completions.Prompt.IsEmpty() {
requestBody["messages"] = []map[string]any{
{"role": "user", "content": req.Body.Completions.Prompt},
{"role": "user", "content": req.Body.Completions.Prompt.PlainText()},
}
} else {
return nil, fmt.Errorf("no messages or prompt provided")
}
// Model field is required by OpenAI spec but not used by the router's tokenizer
// (tokenizer is determined by the discovered model card)
if req != nil && strings.TrimSpace(req.TargetModel) != "" {
requestBody["model"] = req.TargetModel
} else {
......@@ -337,7 +331,6 @@ func CallAddRequest(requestID string, tokenData []int64, workerID uint64, dpRank
return fmt.Errorf("dynamo router handles not created")
}
// Convert token data from int64 to uint32
tokens := make([]uint32, len(tokenData))
for i, t := range tokenData {
tokens[i] = uint32(t)
......@@ -421,6 +414,20 @@ type RoutingResult struct {
TokenData []int64
}
// extractTokenData copies token IDs from a C result into Go memory.
func extractTokenData(result *C.CRoutingResult) []int64 {
count := int(result.token_count)
if count > 0 && result.token_ids != nil {
src := unsafe.Slice((*uint32)(unsafe.Pointer(result.token_ids)), count)
tokens := make([]int64, count)
for i := 0; i < count; i++ {
tokens[i] = int64(src[i])
}
return tokens
}
return nil
}
// CallRoutePrefillRequest routes a request to the best prefill worker.
// It tokenizes the request and queries only the prefill router.
func CallRoutePrefillRequest(requestJSON string, podsJSON string) (*RoutingResult, error) {
......@@ -450,22 +457,12 @@ func CallRoutePrefillRequest(requestJSON string, podsJSON string) (*RoutingResul
return nil, fmt.Errorf("route_prefill_request failed with code %d", rc)
}
// Copy token IDs into Go memory
count := int(result.token_count)
var tokens64 []int64
if count > 0 && result.token_ids != nil {
src := unsafe.Slice((*uint32)(unsafe.Pointer(result.token_ids)), count)
tokens64 = make([]int64, count)
for i := 0; i < count; i++ {
tokens64[i] = int64(src[i])
}
}
tokens := extractTokenData(&result)
workerID := uint64(result.prefill_worker_id)
dpRank := uint32(result.prefill_dp_rank)
C.free_routing_result(&result)
return &RoutingResult{WorkerID: workerID, DpRank: dpRank, TokenData: tokens64}, nil
return &RoutingResult{WorkerID: workerID, DpRank: dpRank, TokenData: tokens}, nil
}
// CallRouteDecodeRequest routes a request to the best decode worker.
......@@ -497,20 +494,10 @@ func CallRouteDecodeRequest(requestJSON string, podsJSON string, isDisaggregated
return nil, fmt.Errorf("route_decode_request failed with code %d", rc)
}
// Copy token IDs into Go memory
count := int(result.token_count)
var tokens64 []int64
if count > 0 && result.token_ids != nil {
src := unsafe.Slice((*uint32)(unsafe.Pointer(result.token_ids)), count)
tokens64 = make([]int64, count)
for i := 0; i < count; i++ {
tokens64[i] = int64(src[i])
}
}
tokens := extractTokenData(&result)
workerID := uint64(result.decode_worker_id)
dpRank := uint32(result.decode_dp_rank)
C.free_routing_result(&result)
return &RoutingResult{WorkerID: workerID, DpRank: dpRank, TokenData: tokens64}, nil
return &RoutingResult{WorkerID: workerID, DpRank: dpRank, TokenData: tokens}, nil
}
......@@ -21,9 +21,8 @@ import (
"encoding/json"
"fmt"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
plugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
)
const (
......@@ -31,16 +30,9 @@ const (
)
// compile-time type assertion
var _ framework.Filter = &LabelFilter{}
var _ schedtypes.Filter = &LabelFilter{}
// LabelFilterConfig holds the configuration for the LabelFilter plugin.
// Matches the deployment manifest schema:
//
// parameters:
// label: "nvidia.com/dynamo-sub-component-type"
// validValues:
// - "prefill"
// allowsNoLabel: false
type LabelFilterConfig struct {
Label string `json:"label"`
ValidValues []string `json:"validValues"`
......@@ -65,7 +57,6 @@ func LabelFilterFactory(name string, rawParameters json.RawMessage, _ plugins.Ha
}
func NewLabelFilter(label string, validValues []string, allowsNoLabel bool) *LabelFilter {
// Build a set for O(1) lookups
valuesSet := make(map[string]struct{}, len(validValues))
for _, v := range validValues {
valuesSet[v] = struct{}{}
......@@ -94,23 +85,22 @@ func (f *LabelFilter) WithName(name string) *LabelFilter {
return f
}
// Filter returns only the pods whose label matches one of the configured valid values.
// Pods without the label are kept only if allowsNoLabel is true.
func (f *LabelFilter) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
filtered := make([]types.Pod, 0, len(pods))
for _, pod := range pods {
if pod == nil || pod.GetPod() == nil {
// Filter returns only the endpoints whose label matches one of the configured valid values.
func (f *LabelFilter) Filter(_ context.Context, _ *schedtypes.CycleState, _ *schedtypes.InferenceRequest, endpoints []schedtypes.Endpoint) []schedtypes.Endpoint {
filtered := make([]schedtypes.Endpoint, 0, len(endpoints))
for _, ep := range endpoints {
if ep == nil || ep.GetMetadata() == nil {
continue
}
labelValue, hasLabel := pod.GetPod().Labels[f.label]
labelValue, hasLabel := ep.GetMetadata().Labels[f.label]
if !hasLabel {
if f.allowsNoLabel {
filtered = append(filtered, pod)
filtered = append(filtered, ep)
}
continue
}
if _, ok := f.validValues[labelValue]; ok {
filtered = append(filtered, pod)
filtered = append(filtered, ep)
}
}
return filtered
......
......@@ -29,7 +29,7 @@ kubectl apply -f https://github.com/kubernetes-sigs/gateway-api/releases/downloa
# Install the Inference Extension CRDs
IGW_LATEST_RELEASE=v1.2.1
IGW_LATEST_RELEASE=v1.5.0-rc.2
kubectl apply -f https://github.com/kubernetes-sigs/gateway-api-inference-extension/releases/download/${IGW_LATEST_RELEASE}/manifests.yaml
......
......@@ -58,22 +58,22 @@ spec:
{{- if .Values.epp.argsOverride }}
{{- toYaml .Values.epp.argsOverride | nindent 8 }}
{{- else }}
- -pool-name
- --pool-name
- "{{ .Values.model.shortName }}-pool"
- -pool-namespace
- --pool-namespace
- "{{ .Release.Namespace }}"
- -pool-group
- --pool-group
- "inference.networking.k8s.io"
- -v
- "4"
- --zap-encoder
- "json"
- -grpc-port
- --grpc-port
- "9002"
- -grpc-health-port
- --grpc-health-port
- "9003"
{{- if $useDynamo }}
- -config-file
- --config-file
- "{{ .Values.epp.configFile }}"
{{- end }}
{{- end }}
......
......@@ -114,14 +114,14 @@ func (e *EPPDefaults) GetBaseContainer(context ComponentContext) (corev1.Contain
container.Command = []string{}
container.Args = []string{
"-pool-name", poolName,
"-pool-namespace", poolNamespace,
"-pool-group", epp.InferencePoolGroup,
"--pool-name", poolName,
"--pool-namespace", poolNamespace,
"--pool-group", epp.InferencePoolGroup,
"-v", "4",
"--zap-encoder", "json",
"-grpc-port", fmt.Sprintf("%d", commonconsts.EPPGRPCPort),
"-grpc-health-port", "9003",
"-config-file", configFilePath,
"--grpc-port", fmt.Sprintf("%d", commonconsts.EPPGRPCPort),
"--grpc-health-port", "9003",
"--config-file", configFilePath,
}
// Mount EPP config
......
......@@ -190,7 +190,10 @@ export IMAGE_TAG=latest
# Build operator image
cd deploy/operator
docker build -t $DOCKER_SERVER/kubernetes-operator:$IMAGE_TAG .
docker build -t $DOCKER_SERVER/kubernetes-operator:$IMAGE_TAG \
--build-context snapshot=../snapshot \
--build-arg DOCKER_PROXY="" \
.
docker push $DOCKER_SERVER/kubernetes-operator:$IMAGE_TAG
cd -
......
......@@ -454,7 +454,7 @@ helm uninstall kgateway --namespace kgateway-system
kubectl delete namespace kgateway-system --ignore-not-found
# 4. Delete the Inference Extension CRDs
IGW_LATEST_RELEASE=v1.2.1
IGW_LATEST_RELEASE=v1.5.0-rc.2
kubectl delete -f https://github.com/kubernetes-sigs/gateway-api-inference-extension/releases/download/${IGW_LATEST_RELEASE}/manifests.yaml --ignore-not-found
# 5. Delete the Gateway API CRDs
......@@ -464,7 +464,7 @@ kubectl delete -f https://github.com/kubernetes-sigs/gateway-api/releases/downlo
## Gateway API Inference Extension Integration
This section documents the updated plugin implementation for Gateway API Inference Extension **v1.2.1**.
This section documents the updated plugin implementation for Gateway API Inference Extension **v1.5.0-rc.2**.
### Router bookkeeping operations
......@@ -473,8 +473,9 @@ EPP performs Dynamo router book keeping operations so the FrontEnd's Router does
### Header Routing Hints
Since v1.2.1, the EPP uses a **header-only approach** for communicating routing decisions.
The plugins set HTTP headers that are forwarded to the backend workers.
Since v1.5.0-rc.1, the EPP uses **headers and body mutations** for communicating routing decisions.
The plugins set HTTP headers for worker targeting and inject pre-computed token IDs
into the request body (`nvext.token_data`) so the frontend sidecar can skip redundant tokenization.
#### Headers Set by Dynamo Plugins
......
......@@ -68,7 +68,7 @@ spec:
- name: DYN_STORE_KV
value: "mem"
args:
- "python3 -m dynamo.vllm --model $MODEL_PATH --served-model-name $SERVED_MODEL_NAME --tensor-parallel-size 1 --data-parallel-size 1 --gpu-memory-utilization 0.90 --no-enable-prefix-caching --block-size 128"
- "python3 -m dynamo.vllm --model $MODEL_PATH --served-model-name $SERVED_MODEL_NAME --tensor-parallel-size 1 --data-parallel-size 1 --gpu-memory-utilization 0.90 --enable-prefix-caching --block-size 128 --kv-events-config '{\"enable_kv_cache_events\":true}'"
command:
- /bin/sh
- -c
......
......@@ -70,7 +70,7 @@ spec:
sharedMemory:
size: 2Gi
frontendSidecar:
image: docker.io/lambda108/dynamo:post-rebase
image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:my-tag
args:
- -m
- dynamo.frontend
......@@ -78,30 +78,34 @@ spec:
- direct
envFromSecret: hf-token-secret
extraPodSpec:
tolerations:
- key: "nvidia.com/gpu"
operator: "Exists"
effect: "NoSchedule"
affinity:
podAffinity:
preferredDuringSchedulingIgnoredDuringExecution:
- weight: 100
podAffinityTerm:
labelSelector:
matchExpressions:
- key: nvidia.com/dynamo-component-type
operator: In
values:
- worker
topologyKey: kubernetes.io/hostname
requiredDuringSchedulingIgnoredDuringExecution:
- labelSelector:
matchExpressions:
- key: nvidia.com/dynamo-component-type
operator: In
values:
- worker
topologyKey: kubernetes.io/hostname
mainContainer:
env:
- name: SERVED_MODEL_NAME
value: "Qwen/Qwen3-0.6B"
- name: MODEL_PATH
value: "Qwen/Qwen3-0.6B"
- name: UCX_TLS
value: "tcp,self"
args:
- "python3 -m dynamo.vllm --model $MODEL_PATH --served-model-name $SERVED_MODEL_NAME --tensor-parallel-size 1 --data-parallel-size 1 --disaggregation-mode prefill --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}' --gpu-memory-utilization 0.90 --enable-prefix-caching --block-size 16 --kv-events-config '{\"enable_kv_cache_events\":true}'"
command:
- /bin/sh
- -c
image: docker.io/lambda108/dynamo:post-rebase
image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:my-tag
imagePullPolicy: IfNotPresent
workingDir: /workspace/examples/backends/vllm
replicas: 1
......@@ -117,7 +121,7 @@ spec:
sharedMemory:
size: 2Gi
frontendSidecar:
image: docker.io/lambda108/dynamo:post-rebase
image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:my-tag
args:
- -m
- dynamo.frontend
......@@ -125,30 +129,34 @@ spec:
- direct
envFromSecret: hf-token-secret
extraPodSpec:
tolerations:
- key: "nvidia.com/gpu"
operator: "Exists"
effect: "NoSchedule"
affinity:
podAffinity:
preferredDuringSchedulingIgnoredDuringExecution:
- weight: 100
podAffinityTerm:
labelSelector:
matchExpressions:
- key: nvidia.com/dynamo-component-type
operator: In
values:
- worker
topologyKey: kubernetes.io/hostname
requiredDuringSchedulingIgnoredDuringExecution:
- labelSelector:
matchExpressions:
- key: nvidia.com/dynamo-component-type
operator: In
values:
- worker
topologyKey: kubernetes.io/hostname
mainContainer:
env:
- name: SERVED_MODEL_NAME
value: "Qwen/Qwen3-0.6B"
- name: MODEL_PATH
value: "Qwen/Qwen3-0.6B"
- name: UCX_TLS
value: "tcp,self"
args:
- "python3 -m dynamo.vllm --model $MODEL_PATH --served-model-name $SERVED_MODEL_NAME --tensor-parallel-size 1 --data-parallel-size 1 --disaggregation-mode decode --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}' --gpu-memory-utilization 0.90 --block-size 16"
command:
- /bin/sh
- -c
image: docker.io/lambda108/dynamo:post-rebase
image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:my-tag
imagePullPolicy: IfNotPresent
workingDir: /workspace/examples/backends/vllm
replicas: 1
......
......@@ -1118,7 +1118,14 @@ unsafe fn preprocess_request(
}
};
Ok(encoding.token_ids().to_vec())
let token_ids = encoding.token_ids().to_vec();
tracing::info!(
token_count = token_ids.len(),
first_tokens = ?&token_ids[..std::cmp::min(5, token_ids.len())],
"[EPP-TOKENIZE] Tokenized prompt in C bindings (this is the ONLY tokenization)"
);
Ok(token_ids)
}
/// Parse pods JSON into an optional set of allowed worker IDs.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment