Unverified Commit c916cd42 authored by atchernych's avatar atchernych Committed by GitHub
Browse files

feat: Support epp's "pods" interface in Dynamo fixes [DEP-424] (#6302)


Signed-off-by: default avatarAnna Tchernych <atchernych@nvidia.com>
parent 5a4c96db
......@@ -964,8 +964,7 @@ spec:
Parameters are the set of parameters to be passed to the plugin's
factory function. The factory function is responsible
to parse the parameters.
format: byte
type: string
x-kubernetes-preserve-unknown-fields: true
type:
description: Type specifies the plugin type to be instantiated.
type: string
......
......@@ -1173,8 +1173,7 @@ spec:
Parameters are the set of parameters to be passed to the plugin's
factory function. The factory function is responsible
to parse the parameters.
format: byte
type: string
x-kubernetes-preserve-unknown-fields: true
type:
description: Type specifies the plugin type to be instantiated.
type: string
......
......@@ -41,16 +41,16 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
// Dynamo plugins
dynscorer "github.com/nvidia/dynamo/deploy/inference-gateway/pkg/plugins/dynamo_kv_scorer"
"github.com/nvidia/dynamo/deploy/inference-gateway/pkg/plugins/disagg"
labelfilter "github.com/nvidia/dynamo/deploy/inference-gateway/pkg/plugins/label_filter"
)
func main() {
// Register Dynamo custom plugins:
// - kv-aware-scorer: Implements Scorer, PreRequest, and ResponseStreaming interfaces
// - Score: Calls Dynamo router to select workers based on KV cache, sets routing headers
// - PreRequest: Registers request with router bookkeeping after scheduling is finalized
// - ResponseComplete: Cleans up router bookkeeping when response completes
plugins.Register("kv-aware-scorer", dynscorer.KVAwareScorerFactory)
plugins.Register("label-filter", labelfilter.LabelFilterFactory)
plugins.Register(disagg.DisaggProfileHandlerType, disagg.DisaggProfileHandlerFactory)
plugins.Register(disagg.DynPrefillScorerType, disagg.DynPrefillScorerFactory)
plugins.Register(disagg.DynDecodeScorerType, disagg.DynDecodeScorerFactory)
// Run using standard GAIE runner (it registers built-in plugins automatically)
if err := runner.NewRunner().Run(ctrl.SetupSignalHandler()); err != nil {
......
/*
Copyright 2025 NVIDIA Corporation.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package disagg
import (
"context"
"encoding/json"
"fmt"
"sync"
log "sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
rc "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
dynscorer "github.com/nvidia/dynamo/deploy/inference-gateway/pkg/plugins/dynamo_kv_scorer"
)
const (
// DynDecodeScorerType is the plugin type registered in the plugin registry.
DynDecodeScorerType = "dyn-decode-scorer"
WorkerIDHeader = "x-worker-instance-id"
PrefillWorkerIDHeader = "x-prefill-instance-id"
RoutingModeHeader = "x-dynamo-routing-mode"
// decodeStateKey is the key used to store routing state in PluginState
decodeStateKey = "dynamo-decode-routing-state"
)
// compile-time type assertions
var _ framework.Scorer = &DynDecodeScorer{}
var _ plugins.Plugin = &DynDecodeScorer{}
var _ rc.PreRequest = &DynDecodeScorer{}
var _ rc.ResponseStreaming = &DynDecodeScorer{}
var _ rc.ResponseComplete = &DynDecodeScorer{}
// DecodeRoutingState holds routing information passed from Score() to PreRequest().
type DecodeRoutingState struct {
WorkerID string
PrefillWorkerID string
TokenData []int64
}
// Clone implements plugins.StateData.
func (s *DecodeRoutingState) Clone() plugins.StateData {
if s == nil {
return nil
}
clone := &DecodeRoutingState{
WorkerID: s.WorkerID,
PrefillWorkerID: s.PrefillWorkerID,
}
if s.TokenData != nil {
clone.TokenData = make([]int64, len(s.TokenData))
copy(clone.TokenData, s.TokenData)
}
return clone
}
// DynDecodeScorerConfig holds the configuration for the DynDecodeScorer plugin.
type DynDecodeScorerConfig struct{}
// DynDecodeScorerFactory defines the factory function for DynDecodeScorer.
func DynDecodeScorerFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
cfg := DynDecodeScorerConfig{}
if rawParameters != nil {
if err := json.Unmarshal(rawParameters, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse %s plugin parameters: %w", DynDecodeScorerType, err)
}
}
// Initialize the shared FFI (idempotent)
if err := dynscorer.InitFFI(); err != nil {
return nil, fmt.Errorf("Dynamo FFI init for decode scorer failed: %w", err)
}
return NewDynDecodeScorer(handle.Context()).WithName(name), nil
}
// NewDynDecodeScorer initializes a new DynDecodeScorer.
func NewDynDecodeScorer(ctx context.Context) *DynDecodeScorer {
return &DynDecodeScorer{
typedName: plugins.TypedName{Type: DynDecodeScorerType, Name: DynDecodeScorerType},
pluginState: plugins.NewPluginState(ctx),
}
}
// DynDecodeScorer is a scorer plugin for the decode scheduling profile.
//
// When Score() is called, it:
// 1. Reads PrefillEnabledState from CycleState (written by DisaggProfileHandler).
// 2. Calls the Dynamo FFI decode router with is_disaggregated flag.
// 3. Sets routing headers on the request.
// 4. Stores routing state for PreRequest to register with router bookkeeping.
//
// It also implements PreRequest, ResponseStreaming, and ResponseComplete lifecycle hooks
// for router bookkeeping (add_request, mark_prefill_complete, free_request).
type DynDecodeScorer struct {
typedName plugins.TypedName
pluginState *plugins.PluginState
firstTokenSeen sync.Map
}
// TypedName returns the type and name tuple of this plugin instance.
func (s *DynDecodeScorer) TypedName() plugins.TypedName {
return s.typedName
}
// WithName sets the name of the scorer.
func (s *DynDecodeScorer) WithName(name string) *DynDecodeScorer {
s.typedName.Name = name
return s
}
// Score scores pods for decode suitability.
func (s *DynDecodeScorer) Score(ctx context.Context, cycleState *schedtypes.CycleState, req *schedtypes.LLMRequest, pods []schedtypes.Pod) map[schedtypes.Pod]float64 {
logger := log.FromContext(ctx)
isDisaggregated := readPrefillEnabled(cycleState)
requestJSON, err := buildRequestJSON(req)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "DynDecodeScorer: failed to build request")
return uniformScores(pods, 1.0)
}
podsJSON := serializePods(pods)
logger.V(logutil.DEFAULT).Info("DynDecodeScorer: pods received for scoring",
"podCount", len(pods),
"podsJSON", string(podsJSON))
result, err := dynscorer.CallRouteDecodeRequest(requestJSON, podsJSON, isDisaggregated)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "DynDecodeScorer: FFI decode routing failed")
return uniformScores(pods, 1.0)
}
workerIDStr := fmt.Sprintf("%d", result.WorkerID)
logger.V(logutil.DEFAULT).Info("DynDecodeScorer: decode worker selected",
"decodeWorkerID", workerIDStr,
"isDisaggregated", isDisaggregated,
"tokenCount", len(result.TokenData))
// Set routing headers
if req.Headers == nil {
req.Headers = map[string]string{}
}
req.Headers[WorkerIDHeader] = workerIDStr
if isDisaggregated {
req.Headers[RoutingModeHeader] = "disaggregated"
// In disagg mode, the prefill worker was selected by the prefill scorer profile.
// The prefill worker ID would need to be communicated from the prefill profile result.
// For now we set the mode header; the prefill worker header will be set
// when the framework processes the prefill profile result.
} else {
req.Headers[RoutingModeHeader] = "aggregated"
}
// Store routing state for PreRequest bookkeeping
if req.RequestId != "" {
routingState := &DecodeRoutingState{
WorkerID: workerIDStr,
TokenData: result.TokenData,
}
s.pluginState.Write(req.RequestId, plugins.StateKey(decodeStateKey), routingState)
}
// Score: all decode pods get 1.0 since the router's internal selection is authoritative
// and the worker ID is communicated via headers.
return uniformScores(pods, 1.0)
}
// PreRequest is called after scheduling is finalized and before the request is sent to the worker.
// This registers the request with the Dynamo router's bookkeeping.
func (s *DynDecodeScorer) PreRequest(ctx context.Context, request *schedtypes.LLMRequest, _ *schedtypes.SchedulingResult) {
logger := log.FromContext(ctx)
if request == nil || request.RequestId == "" {
logger.V(logutil.DEBUG).Info("DynDecodeScorer PreRequest: no request ID, skipping")
return
}
state, err := plugins.ReadPluginStateKey[*DecodeRoutingState](
s.pluginState, request.RequestId, plugins.StateKey(decodeStateKey),
)
s.pluginState.Delete(request.RequestId)
if err != nil {
logger.V(logutil.DEBUG).Info("DynDecodeScorer PreRequest: no routing state found",
"requestID", request.RequestId)
return
}
var workerIDUint uint64
if _, parseErr := fmt.Sscanf(state.WorkerID, "%d", &workerIDUint); parseErr != nil {
logger.V(logutil.DEFAULT).Error(parseErr, "DynDecodeScorer PreRequest: invalid worker ID",
"requestID", request.RequestId, "workerID", state.WorkerID)
return
}
if addErr := dynscorer.CallAddRequest(request.RequestId, state.TokenData, workerIDUint, 0); addErr != nil {
logger.V(logutil.DEFAULT).Error(addErr, "DynDecodeScorer PreRequest: failed to add request",
"requestID", request.RequestId)
return
}
logger.V(logutil.VERBOSE).Info("DynDecodeScorer PreRequest: registered request",
"requestID", request.RequestId,
"workerID", state.WorkerID,
"tokenCount", len(state.TokenData))
}
// ResponseStreaming is called for each chunk of a streaming response.
// On the first token, it marks prefill as complete in the Dynamo router's bookkeeping.
func (s *DynDecodeScorer) ResponseStreaming(ctx context.Context, request *schedtypes.LLMRequest, _ *rc.Response, _ *backend.Pod) {
if request == nil || request.RequestId == "" {
return
}
if _, alreadySeen := s.firstTokenSeen.LoadOrStore(request.RequestId, true); !alreadySeen {
logger := log.FromContext(ctx)
if err := dynscorer.CallMarkPrefillComplete(request.RequestId); err != nil {
logger.V(logutil.DEFAULT).Error(err, "DynDecodeScorer ResponseStreaming: failed to mark prefill complete",
"requestID", request.RequestId)
return
}
logger.V(logutil.VERBOSE).Info("DynDecodeScorer ResponseStreaming: marked prefill complete",
"requestID", request.RequestId)
}
}
// ResponseComplete is called after the complete response is sent to the client.
// It cleans up the router bookkeeping state for the completed request.
func (s *DynDecodeScorer) ResponseComplete(ctx context.Context, request *schedtypes.LLMRequest, _ *rc.Response, _ *backend.Pod) {
logger := log.FromContext(ctx)
if request == nil || request.RequestId == "" {
return
}
s.firstTokenSeen.Delete(request.RequestId)
if err := dynscorer.CallFreeRequest(request.RequestId); err != nil {
logger.V(logutil.DEFAULT).Error(err, "DynDecodeScorer ResponseComplete: failed to free request",
"requestID", request.RequestId)
return
}
logger.V(logutil.VERBOSE).Info("DynDecodeScorer ResponseComplete: freed request",
"requestID", request.RequestId)
}
/*
Copyright 2025 NVIDIA Corporation.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package disagg
import (
"context"
"encoding/json"
"fmt"
log "sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
dynscorer "github.com/nvidia/dynamo/deploy/inference-gateway/pkg/plugins/dynamo_kv_scorer"
)
const (
// DynPrefillScorerType is the plugin type registered in the plugin registry.
DynPrefillScorerType = "dyn-prefill-scorer"
)
// compile-time type assertion
var _ framework.Scorer = &DynPrefillScorer{}
// DynPrefillScorerConfig holds the configuration for the DynPrefillScorer plugin.
type DynPrefillScorerConfig struct{}
// DynPrefillScorerFactory defines the factory function for DynPrefillScorer.
func DynPrefillScorerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
cfg := DynPrefillScorerConfig{}
if rawParameters != nil {
if err := json.Unmarshal(rawParameters, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse %s plugin parameters: %w", DynPrefillScorerType, err)
}
}
// Initialize the shared FFI (idempotent)
if err := dynscorer.InitFFI(); err != nil {
return nil, fmt.Errorf("Dynamo FFI init for prefill scorer failed: %w", err)
}
return NewDynPrefillScorer().WithName(name), nil
}
// NewDynPrefillScorer initializes a new DynPrefillScorer.
func NewDynPrefillScorer() *DynPrefillScorer {
return &DynPrefillScorer{
typedName: plugins.TypedName{Type: DynPrefillScorerType, Name: DynPrefillScorerType},
}
}
// DynPrefillScorer is a scorer plugin for the prefill scheduling profile.
//
// When Score() is called, it:
// 1. Reads PrefillEnabledState from CycleState (written by DisaggProfileHandler).
// 2. If prefill is NOT enabled, returns zero scores.
// 3. If prefill IS enabled, calls the Dynamo FFI prefill router to select the best prefill worker.
// 4. Assigns score 1.0 to all pods (the router's selection is authoritative, communicated via headers).
type DynPrefillScorer struct {
typedName plugins.TypedName
}
// TypedName returns the type and name tuple of this plugin instance.
func (s *DynPrefillScorer) TypedName() plugins.TypedName {
return s.typedName
}
// WithName sets the name of the scorer.
func (s *DynPrefillScorer) WithName(name string) *DynPrefillScorer {
s.typedName.Name = name
return s
}
// Score scores pods for prefill suitability.
func (s *DynPrefillScorer) Score(ctx context.Context, cycleState *schedtypes.CycleState, req *schedtypes.LLMRequest, pods []schedtypes.Pod) map[schedtypes.Pod]float64 {
logger := log.FromContext(ctx)
if !readPrefillEnabled(cycleState) {
logger.V(logutil.VERBOSE).Info("DynPrefillScorer: prefill not enabled, returning zero scores")
return uniformScores(pods, 0)
}
requestJSON, err := buildRequestJSON(req)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "DynPrefillScorer: failed to build request")
return uniformScores(pods, 0)
}
podsJSON := serializePods(pods)
logger.V(logutil.DEFAULT).Info("DynPrefillScorer: pods received for scoring",
"podCount", len(pods),
"podsJSON", string(podsJSON))
result, err := dynscorer.CallRoutePrefillRequest(requestJSON, podsJSON)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "DynPrefillScorer: FFI prefill routing failed")
return uniformScores(pods, 0)
}
logger.V(logutil.DEFAULT).Info("DynPrefillScorer: prefill worker selected",
"prefillWorkerID", fmt.Sprintf("%d", result.WorkerID),
"tokenCount", len(result.TokenData))
// Score: 1.0 for all pods. The label-filter has already restricted to prefill workers,
// and the FFI router's internal selection is authoritative.
// In the future, we could match worker IDs to pod names for precise scoring.
return uniformScores(pods, 1.0)
}
/*
Copyright 2025 NVIDIA Corporation.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package disagg
import (
"context"
"encoding/json"
"errors"
"fmt"
log "sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)
const (
DisaggProfileHandlerType = "disagg-profile-handler"
)
// compile-time type assertion
var _ framework.ProfileHandler = &DisaggProfileHandler{}
// DisaggProfileHandlerConfig holds the configuration for the DisaggProfileHandler.
type DisaggProfileHandlerConfig struct{}
// DisaggProfileHandlerFactory defines the factory function for DisaggProfileHandler.
func DisaggProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
cfg := DisaggProfileHandlerConfig{}
if rawParameters != nil {
if err := json.Unmarshal(rawParameters, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse %s plugin parameters: %w", DisaggProfileHandlerType, err)
}
}
return NewDisaggProfileHandler().WithName(name), nil
}
// NewDisaggProfileHandler initializes a new DisaggProfileHandler.
func NewDisaggProfileHandler() *DisaggProfileHandler {
return &DisaggProfileHandler{
typedName: plugins.TypedName{Type: DisaggProfileHandlerType, Name: DisaggProfileHandlerType},
}
}
// DisaggProfileHandler is a ProfileHandler that orchestrates prefill/decode disaggregated serving.
//
// # Disaggregated mode detection
//
// In Dynamo's native architecture, disaggregated mode is determined by whether prefill workers
// actually exist at runtime (the is_disaggregated flag in the Rust KV router). However, the
// GAIE EPP framework determines profile availability at configuration time, not at runtime.
// To bridge this gap, DisaggProfileHandler uses the EPP profile mechanism as a proxy:
// it checks whether a "prefill" scheduling profile is registered in the config. If prefill
// workers are configured but none are actually running, the prefill profile's label-filter
// will find zero pods, causing the profile to fail — and the handler gracefully degrades
// to aggregated mode (see below).
//
// # Scheduling flow
//
// On each scheduling cycle it:
// 1. Checks whether a "prefill" profile is registered in the config.
// 2. Writes PrefillEnabledState into CycleState so scorer plugins can read it.
// 3. If a prefill profile exists: runs the "prefill" profile first, then the "decode" profile.
// The "decode" profile is the primary (the pod the request is ultimately sent to).
// 4. If no prefill profile exists: runs only the "decode" profile (pure aggregated mode).
//
// # Graceful degradation
//
// When a prefill profile is configured but no prefill workers are available at runtime,
// the handler degrades gracefully to aggregated mode on a per-request basis:
//
// 1. Pick (iteration 1): prefill profile exists → writes PrefillEnabled=true → runs prefill profile.
// 2. Prefill profile runs: label-filter finds 0 prefill pods → profile fails → result is nil.
// 3. Pick (iteration 2): sees prefill result is nil → overwrites PrefillEnabled=false → runs decode profile.
// 4. Decode scorer runs: reads PrefillEnabled=false → passes isDisaggregated=false to the Rust
// decode router → full KV cache overlap scoring is used (overlap_score_weight=1.0).
//
// This means the same YAML config works transparently for both aggregated and disaggregated
// deployments. If prefill workers come up later, subsequent requests automatically use
// disaggregated routing. If they go down, requests fall back to aggregated mode.
type DisaggProfileHandler struct {
typedName plugins.TypedName
}
// TypedName returns the type and name tuple of this plugin instance.
func (h *DisaggProfileHandler) TypedName() plugins.TypedName {
return h.typedName
}
// WithName sets the name of the profile handler.
func (h *DisaggProfileHandler) WithName(name string) *DisaggProfileHandler {
h.typedName.Name = name
return h
}
// Pick selects which profiles to run in the current iteration.
//
// Iteration 1 (no results yet):
// - Writes PrefillEnabledState into CycleState.
// - If a "prefill" profile exists → returns it alone (run prefill first).
// - Otherwise → returns the "decode" profile.
//
// Iteration 2 (prefill result exists, decode not yet):
// - Returns the "decode" profile.
//
// Iteration 3+ (all results collected):
// - Returns empty map to stop the loop.
func (h *DisaggProfileHandler) Pick(ctx context.Context, cycleState *schedtypes.CycleState, _ *schedtypes.LLMRequest,
profiles map[string]*framework.SchedulerProfile, profileResults map[string]*schedtypes.ProfileRunResult) map[string]*framework.SchedulerProfile {
logger := log.FromContext(ctx).V(logutil.VERBOSE)
// First call: determine if prefill is enabled and write state.
if len(profileResults) == 0 {
_, prefillExists := profiles[PrefillProfileName]
state := &PrefillEnabledState{Enabled: prefillExists}
cycleState.Write(PrefillEnabledStateKey, state)
logger.Info("DisaggProfileHandler: prefill enabled state determined", "prefillEnabled", prefillExists)
if prefillExists {
// Run prefill profile first.
return map[string]*framework.SchedulerProfile{
PrefillProfileName: profiles[PrefillProfileName],
}
}
// No prefill profile — run decode only.
if decodeProfile, ok := profiles[DecodeProfileName]; ok {
return map[string]*framework.SchedulerProfile{
DecodeProfileName: decodeProfile,
}
}
// Fallback: return all profiles.
return profiles
}
// Second call: prefill has run, now run decode.
if prefillResult, prefillDone := profileResults[PrefillProfileName]; prefillDone {
if _, decodeDone := profileResults[DecodeProfileName]; !decodeDone {
// If the prefill profile failed (nil result = no prefill pods available),
// update PrefillEnabledState to false so the decode scorer uses normal
// KV cache overlap scoring instead of disaggregated mode (overlap_score_weight=0).
if prefillResult == nil {
logger.Info("DisaggProfileHandler: prefill profile failed (no workers?), falling back to aggregated decode")
cycleState.Write(PrefillEnabledStateKey, &PrefillEnabledState{Enabled: false})
}
if decodeProfile, ok := profiles[DecodeProfileName]; ok {
return map[string]*framework.SchedulerProfile{
DecodeProfileName: decodeProfile,
}
}
}
}
// All profiles have been executed.
return map[string]*framework.SchedulerProfile{}
}
// ProcessResults aggregates the profile run results and designates the primary profile.
// The "decode" profile is always the primary (the pod that handles the request).
func (h *DisaggProfileHandler) ProcessResults(_ context.Context, _ *schedtypes.CycleState, _ *schedtypes.LLMRequest,
profileResults map[string]*schedtypes.ProfileRunResult) (*schedtypes.SchedulingResult, error) {
if len(profileResults) == 0 {
return nil, errors.New("disagg profile handler received no profile results")
}
// Determine primary profile name.
primaryProfile := DecodeProfileName
if _, ok := profileResults[DecodeProfileName]; !ok {
// If there's no decode result, pick whichever profile ran.
for name := range profileResults {
primaryProfile = name
break
}
}
if profileResults[primaryProfile] == nil {
return nil, fmt.Errorf("primary profile '%s' failed to produce a result", primaryProfile)
}
return &schedtypes.SchedulingResult{
ProfileResults: profileResults,
PrimaryProfileName: primaryProfile,
}, nil
}
/*
Copyright 2025 NVIDIA Corporation.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package disagg implements disaggregated prefill/decode serving plugins for Dynamo EPP.
//
// The disaggregated architecture splits inference into two phases:
// - Prefill: processes the input prompt (compute-heavy, parallelizable)
// - Decode: generates tokens autoregressively (memory-bound, sequential)
//
// This package provides three plugins:
// - DisaggProfileHandler: orchestrates prefill→decode profile execution
// - DynPrefillScorer: selects prefill workers via Dynamo FFI
// - DynDecodeScorer: selects decode workers via Dynamo FFI
package disagg
import (
"encoding/json"
"fmt"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
dynscorer "github.com/nvidia/dynamo/deploy/inference-gateway/pkg/plugins/dynamo_kv_scorer"
)
const (
PrefillProfileName = "prefill"
DecodeProfileName = "decode"
// PrefillEnabledStateKey is used to communicate prefill-enabled status
// from the DisaggProfileHandler to the scorer plugins via CycleState.
PrefillEnabledStateKey = plugins.StateKey("disagg-prefill-enabled")
)
// PrefillEnabledState stores whether prefill is enabled for the current scheduling cycle.
// Written by DisaggProfileHandler, read by PrefillScorer and DecodeScorer.
type PrefillEnabledState struct {
Enabled bool
}
// Clone implements plugins.StateData.
func (s *PrefillEnabledState) Clone() plugins.StateData {
return &PrefillEnabledState{Enabled: s.Enabled}
}
// readPrefillEnabled reads the PrefillEnabledState from CycleState.
// Returns false if the state is not found or not set.
func readPrefillEnabled(cycleState *schedtypes.CycleState) bool {
state, err := schedtypes.ReadCycleStateKey[*PrefillEnabledState](cycleState, PrefillEnabledStateKey)
if err == nil && state != nil {
return state.Enabled
}
return false
}
// buildRequestJSON builds an OpenAI-compatible JSON string from a GAIE LLMRequest.
func buildRequestJSON(req *schedtypes.LLMRequest) (string, error) {
requestBody, err := dynscorer.BuildOpenAIRequest(req)
if err != nil {
return "", fmt.Errorf("failed to build OpenAI request: %w", err)
}
data, err := json.Marshal(requestBody)
if err != nil {
return "", fmt.Errorf("failed to marshal request JSON: %w", err)
}
return string(data), nil
}
// serializePods converts pods to a JSON string for the FFI filter.
// Returns an empty string if serialization fails or pods is empty.
func serializePods(pods []schedtypes.Pod) string {
if len(pods) == 0 {
return ""
}
pj, err := dynscorer.SerializePodsToJSON(pods)
if err != nil {
return ""
}
return pj
}
// uniformScores returns a score map with the same score for every pod.
func uniformScores(pods []schedtypes.Pod, score float64) map[schedtypes.Pod]float64 {
out := make(map[schedtypes.Pod]float64, len(pods))
for _, p := range pods {
out[p] = score
}
return out
}
/*
Copyright 2025 NVIDIA Corporation.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package label_filter
import (
"context"
"encoding/json"
"fmt"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)
const (
LabelFilterType = "label-filter"
)
// compile-time type assertion
var _ framework.Filter = &LabelFilter{}
// LabelFilterConfig holds the configuration for the LabelFilter plugin.
// Matches the deployment manifest schema:
//
// parameters:
// label: "nvidia.com/dynamo-sub-component-type"
// validValues:
// - "prefill"
// allowsNoLabel: false
type LabelFilterConfig struct {
Label string `json:"label"`
ValidValues []string `json:"validValues"`
AllowsNoLabel bool `json:"allowsNoLabel"`
}
func LabelFilterFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
cfg := LabelFilterConfig{}
if rawParameters == nil {
return nil, fmt.Errorf("%s plugin requires parameters with 'label' and 'validValues' fields", LabelFilterType)
}
if err := json.Unmarshal(rawParameters, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse %s plugin parameters: %w", LabelFilterType, err)
}
if cfg.Label == "" {
return nil, fmt.Errorf("%s plugin parameter 'label' must not be empty", LabelFilterType)
}
if len(cfg.ValidValues) == 0 {
return nil, fmt.Errorf("%s plugin parameter 'validValues' must contain at least one value", LabelFilterType)
}
return NewLabelFilter(cfg.Label, cfg.ValidValues, cfg.AllowsNoLabel).WithName(name), nil
}
func NewLabelFilter(label string, validValues []string, allowsNoLabel bool) *LabelFilter {
// Build a set for O(1) lookups
valuesSet := make(map[string]struct{}, len(validValues))
for _, v := range validValues {
valuesSet[v] = struct{}{}
}
return &LabelFilter{
typedName: plugins.TypedName{Type: LabelFilterType, Name: LabelFilterType},
label: label,
validValues: valuesSet,
allowsNoLabel: allowsNoLabel,
}
}
type LabelFilter struct {
typedName plugins.TypedName
label string
validValues map[string]struct{}
allowsNoLabel bool
}
func (f *LabelFilter) TypedName() plugins.TypedName {
return f.typedName
}
func (f *LabelFilter) WithName(name string) *LabelFilter {
f.typedName.Name = name
return f
}
// Filter returns only the pods whose label matches one of the configured valid values.
// Pods without the label are kept only if allowsNoLabel is true.
func (f *LabelFilter) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
filtered := make([]types.Pod, 0, len(pods))
for _, pod := range pods {
if pod == nil || pod.GetPod() == nil {
continue
}
labelValue, hasLabel := pod.GetPod().Labels[f.label]
if !hasLabel {
if f.allowsNoLabel {
filtered = append(filtered, pod)
}
continue
}
if _, ok := f.validValues[labelValue]; ok {
filtered = append(filtered, pod)
}
}
return filtered
}
......@@ -17,25 +17,35 @@
apiVersion: inference.networking.x-k8s.io/v1alpha1
kind: EndpointPickerConfig
# This config uses the same disagg-profile-handler as disaggregated deployments.
# With no "prefill" profile defined, the handler runs only the "decode" profile
# and the decode scorer receives isDisaggregated=false, enabling full KV cache
# overlap scoring for aggregated mode. Adding prefill workers and a prefill
# profile would automatically switch to disaggregated routing.
plugins:
# Required: tells EPP which profile to use (even if you only have one)
- type: single-profile-handler
- type: disagg-profile-handler
# Label filter: restricts pods to decode workers
- name: decode-filter
type: label-filter
parameters:
label: "nvidia.com/dynamo-sub-component-type"
validValues:
- "decode"
allowsNoLabel: true
# Picker: chooses the final endpoint after scoring
- name: picker
type: max-score-picker
# Dynamo KV-aware Scorer: calls Dynamo router FFI for worker selection
# Implements Scorer, PreRequest, and ResponseComplete:
# - Score: Selects workers based on KV cache, sets routing headers
# - PreRequest: Registers request with router bookkeeping
# - ResponseComplete: Frees router bookkeeping when response completes
- name: dyn-kv
type: kv-aware-scorer
# Dynamo decode scorer: calls Dynamo FFI decode router, handles request lifecycle
- name: dyn-decode
type: dyn-decode-scorer
schedulingProfiles:
- name: default
- name: decode
plugins:
- pluginRef: dyn-kv
- pluginRef: decode-filter
- pluginRef: dyn-decode
weight: 1
- pluginRef: picker
......@@ -83,6 +83,10 @@ manifests: controller-gen ensure-yq ## Generate WebhookConfiguration, ClusterRol
for file in config/crd/bases/*.yaml; do \
yq eval '(.. | select(has("extraPodSpec")) | .extraPodSpec.required) |= (. - ["containers"])' -i --indent 2 $$file || exit 1; \
done
echo "Fixing PluginSpec parameters field: json.RawMessage needs x-kubernetes-preserve-unknown-fields instead of type: string"
for file in config/crd/bases/*.yaml; do \
yq eval '(.. | select(has("parameters")) | .parameters | select(has("format") and .format == "byte")) |= (del(.format) | del(.type) | .x-kubernetes-preserve-unknown-fields = true)' -i --indent 2 $$file || exit 1; \
done
echo "Adding NVIDIA header to CRD files"
for file in config/crd/bases/*.yaml; do \
if ! head -20 "$$file" | grep -q "NVIDIA CORPORATION"; then \
......
......@@ -964,8 +964,7 @@ spec:
Parameters are the set of parameters to be passed to the plugin's
factory function. The factory function is responsible
to parse the parameters.
format: byte
type: string
x-kubernetes-preserve-unknown-fields: true
type:
description: Type specifies the plugin type to be instantiated.
type: string
......
......@@ -1173,8 +1173,7 @@ spec:
Parameters are the set of parameters to be passed to the plugin's
factory function. The factory function is responsible
to parse the parameters.
format: byte
type: string
x-kubernetes-preserve-unknown-fields: true
type:
description: Type specifies the plugin type to be instantiated.
type: string
......
......@@ -58,7 +58,7 @@ func GenerateInferencePool(
},
Selector: gaiev1.LabelSelector{
MatchLabels: map[gaiev1.LabelKey]gaiev1.LabelValue{
consts.KubeLabelDynamoComponentType: consts.ComponentTypeFrontend,
consts.KubeLabelDynamoComponentType: consts.ComponentTypeWorker,
consts.KubeLabelDynamoNamespace: gaiev1.LabelValue(dynamoNamespace),
},
},
......
......@@ -6,11 +6,14 @@ title: Inference Gateway (GAIE)
## Inference Gateway Setup with Dynamo
When integrating Dynamo with the Inference Gateway you must use the custom Dynamo EPP image.
# Inference Gateway (GAIE)
The custom Dynamo EPP image integrates the Dynamo router directly into the gateway's endpoint picker. Using the `dyn-kv` plugin, it selects the optimal worker based on KV cache state and tokenized prompt before routing the request. The integration moves intelligent routing upstream to the gateway layer.
Integrate Dynamo with the Gateway API Inference Extension for intelligent KV-aware request routing at the gateway layer.
EPP's default kv-routing approach is not token-aware because the prompt is not tokenized. But the Dynamo plugin uses a token-aware KV algorithm. It employs the dynamo router which implements kv routing by running your model's tokenizer inline. The EPP plugin configuration lives in [`helm/dynamo-gaie/epp-config-dynamo.yaml`](../../../deploy/inference-gateway/helm/dynamo-gaie/epp-config-dynamo.yaml) per EPP [convention](https://gateway-api-inference-extension.sigs.k8s.io/guides/epp-configuration/config-text/).
EPP's default kv-routing approach is not token-aware because the prompt is not tokenized. But the Dynamo plugin uses a token-aware KV algorithm. It employs the dynamo router which implements kv routing by running your model's tokenizer inline. The EPP plugin configuration lives in [`helm/dynamo-gaie/epp-config-dynamo.yaml`](helm/dynamo-gaie/epp-config-dynamo.yaml) per EPP [convention](https://gateway-api-inference-extension.sigs.k8s.io/guides/epp-configuration/config-text/).
Dynamo Integration with the Inference Gateway supports Aggregated and Disaggregated Serving. The epp config is the same for both. If no prefill workers found the service degrades gracefully to perform aggregated serving.
If you want to use LoRA deploy Dynamo without the Inference Gateway.
Currently, these setups are only supported with the kGateway based Inference Gateway.
......@@ -86,7 +89,7 @@ kubectl create secret generic hf-token-secret \
```
Create a model configuration file similar to the vllm_agg_qwen.yaml for your model.
This file demonstrates the values needed for the Vllm Agg setup in [agg.yaml](../../../examples/backends/vllm/deploy/agg.yaml)
This file demonstrates the values needed for the Vllm Agg setup in [agg.yaml](../../examples/backends/vllm/deploy/agg.yaml)
Take a note of the model's block size provided in the model card.
### 4. Build EPP image (Optional)
......@@ -124,16 +127,34 @@ you could deploy it as a standalone pod
#### 5.a. Deploy as a DGD component (recommended)
We provide an example for llama-3-70b vLLM below.
We provide an example for the Qwen vLLM below.
```bash
cd <dynamo-source-root>
kubectl apply -f examples/backends/vllm/deploy/gaie/agg.yaml -n my-model
kubectl apply -f examples/backends/vllm/deploy/gaie/http-route.yaml -n my-model
```
Examples for other models can be found in the recipes folder.
```bash
# Deploy PVC, having first Update `storageClassName` in recipes/llama-3-70b/model-cache/model-cache.yaml to match your cluster before deploying
kubectl apply -f recipes/llama-3-70b/model-cache/model-cache.yaml -n ${NAMESPACE}
kubectl apply -f recipes/llama-3-70b/model-cache/model-download.yaml -n ${NAMESPACE}
```
We provide examples for llama-3-70b vLLM under the `recipes/llama-3-70b/vllm/agg/gaie/` for aggregated and `recipes/llama-3-70b/vllm/disagg-single-node/gaie/` for disaggregated serving.
Use the proper folder in commands below.
```bash
# Deploy PVC, first Update `storageClassName` in recipes/llama-3-70b/model-cache/model-cache.yaml to match your cluster before deploying
kubectl apply -f recipes/llama-3-70b/model-cache/model-cache.yaml
kubectl apply -f recipes/llama-3-70b/model-cache/model-download.yaml
# Deploy your model
# Deploy your Dynamo Graph.
# agg
kubectl apply -f recipes/llama-3-70b/vllm/agg/gaie/deploy.yaml -n ${NAMESPACE}
# Deploy the GAIE http-route CR.
kubectl apply -f recipes/llama-3-70b/vllm/agg/gaie/http-route.yaml -n ${NAMESPACE}
# or disagg
kubectl apply -f recipes/llama-3-70b/vllm/disagg-single-node/gaie/deploy.yaml -n ${NAMESPACE}
kubectl apply -f recipes/llama-3-70b/vllm/disagg-single-node/gaie/http-route.yaml -n ${NAMESPACE}
```
- When using GAIE the FrontEnd does not choose the workers. The routing is determined in the EPP.
......@@ -168,29 +189,9 @@ If you installed it into a different namespace, you need to adjust the HttpRoute
#### 5.b. Deploy as a standalone pod
##### 5.b.1 Deploy Your Model ###
We do not recommend this method but there are hints on how to do this here.
We provide an example for Qwen vLLM below.
Before deploying you must enable the `--direct-route` flag in the FrontEnd cli in your Dynamo Graph.
```bash
command:
- python3
args:
- -m
- dynamo.frontend
- --router-mode
- direct
```
Follow the steps in [model deployment](../../../examples/backends/vllm/deploy/README.md) to deploy `Qwen/Qwen3-0.6B` model in aggregate mode using [agg.yaml](../../../examples/backends/vllm/deploy/agg.yaml) in `my-model` kubernetes namespace.
Sample commands to deploy model:
```bash
cd <dynamo-source-root>
cd examples/backends/vllm/deploy
kubectl apply -f agg.yaml -n my-model
```
##### 5.b.1 Deploy Your Model ###
##### 5.b.2 Install Dynamo GIE helm chart ###
......@@ -284,7 +285,7 @@ b. use port-forward to expose the gateway to the host
```bash
# in first terminal
kubectl port-forward svc/inference-gateway 8000:80 -n kgateway-system
kubectl port-forward svc/inference-gateway 8000:80 -n {NAMESPACE} # for NAMESPACE put wherever you installed thee gateway i.e. kgateway-system
# in second terminal where you want to send inference requests
GATEWAY_URL=http://localhost:8000
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
apiVersion: nvidia.com/v1alpha1
kind: DynamoGraphDeployment
metadata:
name: qwen-agg
spec:
backendFramework: vllm
services:
Epp:
envFromSecret: hf-token-secret
componentType: epp
replicas: 1
extraPodSpec:
mainContainer:
image: nvcr.io/nvidia/ai-dynamo/epp-image:my-tag
eppConfig:
config:
plugins:
- type: disagg-profile-handler
- name: decode-filter
type: label-filter
parameters:
label: "nvidia.com/dynamo-sub-component-type"
validValues:
- "decode"
allowsNoLabel: true
- name: picker
type: max-score-picker
- name: dyn-decode
type: dyn-decode-scorer
schedulingProfiles:
- name: decode
plugins:
- pluginRef: decode-filter
- pluginRef: dyn-decode
weight: 1
- pluginRef: picker
VllmDecodeWorker:
componentType: worker
envFromSecret: hf-token-secret
sharedMemory:
size: 2Gi
extraPodSpec:
mainContainer:
env:
- name: SERVED_MODEL_NAME
value: "Qwen/Qwen3-0.6B"
- name: MODEL_PATH
value: "Qwen/Qwen3-0.6B"
- name: DYN_STORE_KV
value: "mem"
args:
- "python3 -m dynamo.vllm --model $MODEL_PATH --served-model-name $SERVED_MODEL_NAME --tensor-parallel-size 1 --data-parallel-size 1 --gpu-memory-utilization 0.90 --no-enable-prefix-caching --block-size 128"
command:
- /bin/sh
- -c
image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:my-tag
workingDir: /workspace/examples/backends/vllm
containers:
- name: frontend
image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:my-tag
command:
- python3
args:
- -m
- dynamo.frontend
- --router-mode
- direct
ports:
- containerPort: 8000
name: http
protocol: TCP
envFrom:
- secretRef:
name: hf-token-secret
env:
- name: DYNAMO_PORT
value: "8000"
- name: DYN_HTTP_PORT
value: "8000"
- name: DYN_NAMESPACE
value: my-model-qwen-agg
- name: DYN_COMPONENT
value: frontend
- name: DYN_DISCOVERY_BACKEND
value: kubernetes
- name: DYN_PARENT_DGD_K8S_NAME
value: qwen-agg
- name: DYN_PARENT_DGD_K8S_NAMESPACE
value: my-model
- name: POD_NAME
valueFrom:
fieldRef:
fieldPath: metadata.name
- name: POD_NAMESPACE
valueFrom:
fieldRef:
fieldPath: metadata.namespace
- name: POD_UID
valueFrom:
fieldRef:
fieldPath: metadata.uid
livenessProbe:
httpGet:
path: /live
port: http
initialDelaySeconds: 15
periodSeconds: 10
readinessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 10
periodSeconds: 10
replicas: 1
resources:
limits:
gpu: "1"
requests:
gpu: "1"
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: You can remove metadata.namespace if using kubectl apply -n
# The backendRefs.namespace field should match where your InferencePool is deployed
apiVersion: gateway.networking.k8s.io/v1
kind: HTTPRoute
metadata:
name: qwen-agg-route
spec:
parentRefs:
- group: gateway.networking.k8s.io
kind: Gateway
name: inference-gateway
namespace: my-model
rules:
- backendRefs:
- group: inference.networking.k8s.io
kind: InferencePool
name: qwen-agg-pool
port: 8000
weight: 1
matches:
- path:
type: PathPrefix
value: /
timeouts:
request: 300s
......@@ -13,7 +13,7 @@ use std::time::Duration;
use dynamo_llm::kv_router::{protocols::*, publisher::KvEventPublisher};
use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_runtime::discovery::DiscoveryQuery;
use dynamo_runtime::discovery::{DiscoveryQuery, hash_pod_name};
use dynamo_runtime::{DistributedRuntime, Worker};
use dynamo_runtime::Runtime;
......@@ -24,6 +24,8 @@ use dynamo_llm::kv_router::protocols::WorkerWithDpRank;
use dynamo_llm::kv_router::{KvRouter, PrefillRouter, RouterConfigOverride};
use dynamo_runtime::pipeline::RouterMode;
use std::collections::HashSet;
static WK: OnceCell<Worker> = OnceCell::new();
static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new();
// [FIXME] shouldn't the publisher be instance passing between API calls?
......@@ -425,6 +427,8 @@ pub struct RouterHandles {
impl RouterHandles {
/// Query optimal prefill worker for a request.
///
/// When `allowed_worker_ids` is Some, only workers in that set are considered.
/// Returns worker_id on success.
async fn query_prefill_worker(
&self,
......@@ -433,6 +437,7 @@ impl RouterHandles {
update_states: bool,
lora_name: Option<String>,
priority_jump: f64,
allowed_worker_ids: Option<HashSet<WorkerId>>,
) -> Result<u64, QueryRouterResult> {
self.prefill_router
.query_prefill_worker(
......@@ -441,6 +446,7 @@ impl RouterHandles {
update_states,
lora_name,
priority_jump,
allowed_worker_ids,
)
.await
.map(|(worker_id, _dp_rank)| worker_id)
......@@ -454,6 +460,9 @@ impl RouterHandles {
/// For disaggregated mode, set `is_disaggregated` to true to use overlap_score_weight=0
/// (since KV cache is being transferred from prefill, not reused).
///
/// When `allowed_worker_ids` is Some, only workers in that set are considered.
/// This does NOT overwrite the router's internal worker state — it only filters this decision.
///
/// Note: The C bindings are query-only and must not mutate router state during worker
/// selection. State updates require a `context_id` (request id) and are managed via the
/// explicit bookkeeping APIs (`add_request`, `mark_prefill_complete`, `free_request`).
......@@ -462,6 +471,7 @@ impl RouterHandles {
&self,
tokens: &[u32],
is_disaggregated: bool,
allowed_worker_ids: Option<HashSet<WorkerId>>,
) -> Result<(WorkerWithDpRank, u32), QueryRouterResult> {
// For decode phase in disaggregated mode, use overlap_score_weight=0
// This matches prefill_router.rs
......@@ -483,6 +493,7 @@ impl RouterHandles {
false,
None,
0.0,
allowed_worker_ids,
)
.await
.map_err(|e| {
......@@ -508,14 +519,6 @@ pub enum QueryRouterResult {
}
/// Build a `KvRouterConfig` from defaults, overridden by optional `DYN_*` environment variables.
///
/// Supported env vars (all optional — unset or empty values are ignored):
/// - `DYN_OVERLAP_SCORE_WEIGHT` — Weight for overlap score in worker selection (default: 1.0)
/// - `DYN_ROUTER_TEMPERATURE` — Temperature for worker sampling via softmax (default: 0.0)
/// - `DYN_USE_KV_EVENTS` — Use KV events for cache tracking (default: true)
/// - `DYN_ROUTER_REPLICA_SYNC` — Enable replica synchronization (default: false)
/// - `DYN_ROUTER_TRACK_ACTIVE_BLOCKS` — Track active blocks (default: true)
/// - `DYN_ROUTER_TRACK_OUTPUT_BLOCKS` — Track output blocks during generation (default: false)
fn kv_router_config_from_env() -> KvRouterConfig {
let mut cfg = KvRouterConfig::default();
......@@ -976,157 +979,261 @@ pub unsafe extern "C" fn destroy(handle: RouterHandlesPtr) {
}
}
/// Route a chat completion request in a single call.
///
/// This is the main function for EPP to route a `/v1/chat/completions` request.
/// It combines tokenization and worker selection in one call:
/// 1. Applies the chat template to the request JSON
/// 2. Tokenizes the formatted prompt
/// 3. Queries the prefill router (if disaggregated mode)
/// 4. Queries the decode router
/// 5. Returns worker IDs and token_ids
///
/// After this call, EPP should:
/// - Call `add_request()` to register the request for bookkeeping
/// - Set worker ID headers and forward to backend
/// - Call `mark_prefill_complete()` on first token
/// - Call `free_request()` when the stream ends
/// - Call `free_routing_result()` to free the result
/// Free a routing result.
///
/// # Safety
/// - `handle` must be a valid RouterHandles handle
/// - `request_json` must be a valid null-terminated C string containing JSON
/// - `out_result` must be a valid pointer
/// - `result` must be a valid pointer to a CRoutingResult previously returned by route functions
#[unsafe(no_mangle)]
pub unsafe extern "C" fn route_request(
handle: RouterHandlesPtr,
request_json: *const c_char,
out_result: *mut CRoutingResult,
) -> QueryRouterResult {
if handle.is_null() || request_json.is_null() || out_result.is_null() {
return QueryRouterResult::ErrInvalidParam;
pub unsafe extern "C" fn free_routing_result(result: *mut CRoutingResult) {
if result.is_null() {
return;
}
let handles = unsafe { &*handle };
let res = unsafe { &mut *result };
// Free token IDs
if !res.token_ids.is_null() && res.token_count > 0 {
drop(unsafe {
Box::from_raw(std::slice::from_raw_parts_mut(
res.token_ids,
res.token_count,
))
});
res.token_ids = ptr::null_mut();
res.token_count = 0;
}
}
// Get preprocessor
/// Parse a JSON request string, apply the chat template, and tokenize.
/// Returns the token IDs on success, or a `QueryRouterResult` error code.
unsafe fn preprocess_request(
handles: &RouterHandles,
request_json: *const c_char,
) -> Result<Vec<u32>, QueryRouterResult> {
let preprocessor = match &handles.preprocessor {
Some(p) => p,
None => {
tracing::error!("Preprocessor not available");
return QueryRouterResult::ErrInitFailed;
return Err(QueryRouterResult::ErrInitFailed);
}
};
let json_str = match unsafe { CStr::from_ptr(request_json) }.to_str() {
Ok(s) => s,
Err(_) => return QueryRouterResult::ErrInvalidParam,
Err(_) => return Err(QueryRouterResult::ErrInvalidParam),
};
// Parse JSON
let request: dynamo_llm::types::openai::chat_completions::NvCreateChatCompletionRequest =
match serde_json::from_str(json_str) {
Ok(req) => req,
Err(e) => {
tracing::error!(error = ?e, "Failed to parse request JSON");
return QueryRouterResult::ErrInvalidParam;
return Err(QueryRouterResult::ErrInvalidParam);
}
};
// Apply chat template
let formatted_prompt = match preprocessor.apply_template(&request) {
Ok(Some(prompt)) => prompt,
Ok(None) => String::new(),
Err(e) => {
tracing::error!(error = ?e, "Failed to apply chat template");
return QueryRouterResult::ErrQueryFailed;
return Err(QueryRouterResult::ErrQueryFailed);
}
};
// Tokenize
let encoding = match preprocessor.tokenize(&formatted_prompt) {
Ok(enc) => enc,
Err(e) => {
tracing::error!(error = ?e, "Failed to tokenize");
return QueryRouterResult::ErrQueryFailed;
return Err(QueryRouterResult::ErrQueryFailed);
}
};
let tokens = encoding.token_ids();
let token_count = tokens.len();
let is_disaggregated = handles.prefill_router.is_activated();
Ok(encoding.token_ids().to_vec())
}
// Query workers
let result = handles.runtime.secondary().block_on(async {
let prefill_worker_id = if is_disaggregated {
handles
.query_prefill_worker(tokens, None, false, None, 0.0)
.await?
/// Parse pods JSON into an optional set of allowed worker IDs.
unsafe fn parse_pods_filter(pods_json: *const c_char) -> Option<HashSet<WorkerId>> {
if pods_json.is_null() {
return None;
}
match unsafe { CStr::from_ptr(pods_json) }.to_str() {
Ok(s) if !s.is_empty() => match serde_json::from_str::<Vec<serde_json::Value>>(s) {
Ok(pods) => {
let mut worker_ids = HashSet::new();
for pod in &pods {
let pod_name = pod
.get("pod")
.and_then(|p| p.get("podName"))
.or_else(|| pod.get("podName"))
.and_then(|v| v.as_str());
if let Some(name) = pod_name {
let worker_id = hash_pod_name(name);
tracing::debug!(
pod_name = name,
worker_id = format!("{:x}", worker_id),
"Mapped EPP pod to worker_id"
);
worker_ids.insert(worker_id);
}
}
tracing::info!(
pod_count = pods.len(),
unique_worker_ids = worker_ids.len(),
"Parsed EPP pods into allowed_worker_ids filter"
);
if worker_ids.is_empty() {
None
} else {
0
Some(worker_ids)
}
}
Err(e) => {
tracing::error!(error = ?e, "Failed to parse pods JSON");
None
}
},
_ => None,
}
}
/// Write token IDs into a `CRoutingResult`, transferring ownership to the caller.
fn write_tokens_to_result(tokens: &[u32], out: &mut CRoutingResult) {
let token_vec: Vec<u32> = tokens.to_vec();
let mut tokens_boxed = token_vec.into_boxed_slice();
out.token_ids = tokens_boxed.as_mut_ptr();
out.token_count = tokens.len();
std::mem::forget(tokens_boxed);
}
/// Route a request to select the best **prefill** worker only.
///
/// This is used in disaggregated mode where the EPP runs separate prefill and decode
/// scoring profiles. It tokenizes the request and queries only the prefill router.
///
/// The returned `CRoutingResult` contains:
/// - `prefill_worker_id`: the selected prefill worker
/// - `decode_worker_id`: 0 (unused — decode is handled by `route_decode_request`)
/// - `is_disaggregated`: always true (this function is only called in disagg mode)
/// - `token_ids` / `token_count`: the tokenized request (caller must free via `free_routing_result`)
///
/// # Safety
/// - `handle` must be a valid RouterHandles handle
/// - `request_json` must be a valid null-terminated C string containing JSON
/// - `pods_json` must be a valid null-terminated C string containing JSON, or null
/// - `out_result` must be a valid pointer
#[unsafe(no_mangle)]
pub unsafe extern "C" fn route_prefill_request(
handle: RouterHandlesPtr,
request_json: *const c_char,
pods_json: *const c_char,
out_result: *mut CRoutingResult,
) -> QueryRouterResult {
if handle.is_null() || request_json.is_null() || out_result.is_null() {
return QueryRouterResult::ErrInvalidParam;
}
let handles = unsafe { &*handle };
let tokens = match unsafe { preprocess_request(handles, request_json) } {
Ok(t) => t,
Err(code) => return code,
};
let (decode_worker, _overlap_blocks) = handles
.query_decode_worker(tokens, is_disaggregated)
let allowed_worker_ids = unsafe { parse_pods_filter(pods_json) };
let result = handles.runtime.secondary().block_on(async {
let prefill_worker_id = handles
.query_prefill_worker(&tokens, None, false, None, 0.0, allowed_worker_ids)
.await?;
tracing::info!(
is_disaggregated = is_disaggregated,
prefill_worker_id = prefill_worker_id,
decode_worker_id = decode_worker.worker_id,
decode_dp_rank = decode_worker.dp_rank,
token_count = token_count,
"Routed chat request"
token_count = tokens.len(),
"Routed prefill request"
);
Ok((prefill_worker_id, decode_worker))
Ok(prefill_worker_id)
});
match result {
Ok((prefill_worker_id, decode_worker)) => {
// Allocate and copy token IDs for caller (needed for add_request bookkeeping)
let token_vec: Vec<u32> = tokens.to_vec();
let mut tokens_boxed = token_vec.into_boxed_slice();
let token_ptr = tokens_boxed.as_mut_ptr();
std::mem::forget(tokens_boxed);
unsafe {
*out_result = CRoutingResult {
is_disaggregated,
prefill_worker_id,
decode_worker_id: decode_worker.worker_id,
token_ids: token_ptr,
token_count,
};
}
Ok(prefill_worker_id) => {
let out = unsafe { &mut *out_result };
*out = CRoutingResult::default();
out.is_disaggregated = true;
out.prefill_worker_id = prefill_worker_id;
write_tokens_to_result(&tokens, out);
QueryRouterResult::Ok
}
Err(code) => code,
}
}
/// Free a routing result.
/// Route a request to select the best **decode** worker only.
///
/// This is used in both aggregated and disaggregated modes.
/// - When `is_disaggregated` is true, the decode router uses `overlap_score_weight=0`
/// (KV cache is being transferred from prefill, not reused locally).
/// - When `is_disaggregated` is false, normal KV-aware scoring is used.
///
/// The returned `CRoutingResult` contains:
/// - `decode_worker_id`: the selected decode worker
/// - `prefill_worker_id`: 0 (unused — prefill is handled by `route_prefill_request`)
/// - `is_disaggregated`: mirrors the input parameter
/// - `token_ids` / `token_count`: the tokenized request (caller must free via `free_routing_result`)
///
/// # Safety
/// - `result` must be a valid pointer to a CRoutingResult previously returned by route functions
/// - `handle` must be a valid RouterHandles handle
/// - `request_json` must be a valid null-terminated C string containing JSON
/// - `pods_json` must be a valid null-terminated C string containing JSON, or null
/// - `out_result` must be a valid pointer
#[unsafe(no_mangle)]
pub unsafe extern "C" fn free_routing_result(result: *mut CRoutingResult) {
if result.is_null() {
return;
pub unsafe extern "C" fn route_decode_request(
handle: RouterHandlesPtr,
request_json: *const c_char,
pods_json: *const c_char,
is_disaggregated: bool,
out_result: *mut CRoutingResult,
) -> QueryRouterResult {
if handle.is_null() || request_json.is_null() || out_result.is_null() {
return QueryRouterResult::ErrInvalidParam;
}
let res = unsafe { &mut *result };
let handles = unsafe { &*handle };
// Free token IDs
if !res.token_ids.is_null() && res.token_count > 0 {
drop(unsafe {
Box::from_raw(std::slice::from_raw_parts_mut(
res.token_ids,
res.token_count,
))
let tokens = match unsafe { preprocess_request(handles, request_json) } {
Ok(t) => t,
Err(code) => return code,
};
let allowed_worker_ids = unsafe { parse_pods_filter(pods_json) };
let result = handles.runtime.secondary().block_on(async {
let (decode_worker, _overlap_blocks) = handles
.query_decode_worker(&tokens, is_disaggregated, allowed_worker_ids)
.await?;
tracing::info!(
is_disaggregated = is_disaggregated,
decode_worker_id = decode_worker.worker_id,
decode_dp_rank = decode_worker.dp_rank,
token_count = tokens.len(),
"Routed decode request"
);
Ok(decode_worker)
});
res.token_ids = ptr::null_mut();
res.token_count = 0;
match result {
Ok(decode_worker) => {
let out = unsafe { &mut *out_result };
*out = CRoutingResult::default();
out.is_disaggregated = is_disaggregated;
out.decode_worker_id = decode_worker.worker_id;
write_tokens_to_result(&tokens, out);
QueryRouterResult::Ok
}
Err(code) => code,
}
}
......
......@@ -919,6 +919,7 @@ impl KvRouter {
update_states,
lora_name,
0.0,
None, // allowed_worker_ids not exposed in Python API yet
)
.await
.map_err(to_pyerr)?;
......
......@@ -63,6 +63,8 @@ use crate::{
local_model::runtime_config::ModelRuntimeConfig,
};
use std::collections::HashSet;
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
......@@ -351,7 +353,9 @@ impl KvRouter {
/// Give these tokens, find the worker with the best match in it's KV cache.
/// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
/// Now also takes optional context_id for request tracking
/// Now also takes optional context_id for request tracking.
///
/// When `allowed_worker_ids` is Some, only workers in that set are considered for selection.
#[allow(clippy::too_many_arguments)]
pub async fn find_best_match(
&self,
......@@ -362,6 +366,7 @@ impl KvRouter {
update_states: bool,
lora_name: Option<String>,
priority_jump: f64,
allowed_worker_ids: Option<HashSet<WorkerId>>,
) -> anyhow::Result<(WorkerWithDpRank, u32)> {
let start = Instant::now();
......@@ -381,13 +386,20 @@ impl KvRouter {
});
let hash_elapsed = start.elapsed();
let overlap_scores = self
let mut overlap_scores = self
.indexer
.find_matches(block_hashes)
.instrument(tracing::info_span!("kv_router.find_matches"))
.await?;
let find_matches_elapsed = start.elapsed();
if let Some(ref allowed_ids) = allowed_worker_ids {
overlap_scores
.scores
.retain(|worker, _| allowed_ids.contains(&worker.worker_id));
}
// Compute seq_hashes only if scheduler needs it for active blocks tracking
let maybe_seq_hashes = tracing::info_span!("kv_router.compute_seq_hashes").in_scope(|| {
self.kv_router_config.compute_seq_hashes_for_tracking(
tokens,
......@@ -398,17 +410,18 @@ impl KvRouter {
});
let seq_hash_elapsed = start.elapsed();
let best_worker = self
let response = self
.scheduler
.schedule(
context_id.map(|s| s.to_string()),
isl_tokens,
maybe_seq_hashes,
overlap_scores.clone(),
overlap_scores,
router_config_override,
update_states,
lora_name,
priority_jump,
allowed_worker_ids,
)
.instrument(tracing::info_span!("kv_router.schedule"))
.await?;
......@@ -434,15 +447,7 @@ impl KvRouter {
"find_best_match completed"
);
// Note: Routing decision recording (for approximate mode) is now handled
// by KvPushRouter::generate after select_worker returns.
let overlap_amount = overlap_scores
.scores
.get(&best_worker)
.copied()
.unwrap_or(0);
Ok((best_worker, overlap_amount))
Ok((response.best_worker, response.overlap_blocks))
}
#[allow(clippy::too_many_arguments)]
......@@ -575,6 +580,7 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
true,
None,
0.0,
None,
)
.await?;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment