Unverified Commit c916cd42 authored by atchernych's avatar atchernych Committed by GitHub
Browse files

feat: Support epp's "pods" interface in Dynamo fixes [DEP-424] (#6302)


Signed-off-by: default avatarAnna Tchernych <atchernych@nvidia.com>
parent 5a4c96db
......@@ -964,8 +964,7 @@ spec:
Parameters are the set of parameters to be passed to the plugin's
factory function. The factory function is responsible
to parse the parameters.
format: byte
type: string
x-kubernetes-preserve-unknown-fields: true
type:
description: Type specifies the plugin type to be instantiated.
type: string
......
......@@ -1173,8 +1173,7 @@ spec:
Parameters are the set of parameters to be passed to the plugin's
factory function. The factory function is responsible
to parse the parameters.
format: byte
type: string
x-kubernetes-preserve-unknown-fields: true
type:
description: Type specifies the plugin type to be instantiated.
type: string
......
......@@ -41,16 +41,16 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
// Dynamo plugins
dynscorer "github.com/nvidia/dynamo/deploy/inference-gateway/pkg/plugins/dynamo_kv_scorer"
"github.com/nvidia/dynamo/deploy/inference-gateway/pkg/plugins/disagg"
labelfilter "github.com/nvidia/dynamo/deploy/inference-gateway/pkg/plugins/label_filter"
)
func main() {
// Register Dynamo custom plugins:
// - kv-aware-scorer: Implements Scorer, PreRequest, and ResponseStreaming interfaces
// - Score: Calls Dynamo router to select workers based on KV cache, sets routing headers
// - PreRequest: Registers request with router bookkeeping after scheduling is finalized
// - ResponseComplete: Cleans up router bookkeeping when response completes
plugins.Register("kv-aware-scorer", dynscorer.KVAwareScorerFactory)
plugins.Register("label-filter", labelfilter.LabelFilterFactory)
plugins.Register(disagg.DisaggProfileHandlerType, disagg.DisaggProfileHandlerFactory)
plugins.Register(disagg.DynPrefillScorerType, disagg.DynPrefillScorerFactory)
plugins.Register(disagg.DynDecodeScorerType, disagg.DynDecodeScorerFactory)
// Run using standard GAIE runner (it registers built-in plugins automatically)
if err := runner.NewRunner().Run(ctrl.SetupSignalHandler()); err != nil {
......
/*
Copyright 2025 NVIDIA Corporation.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package disagg
import (
"context"
"encoding/json"
"fmt"
"sync"
log "sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
rc "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
dynscorer "github.com/nvidia/dynamo/deploy/inference-gateway/pkg/plugins/dynamo_kv_scorer"
)
const (
// DynDecodeScorerType is the plugin type registered in the plugin registry.
DynDecodeScorerType = "dyn-decode-scorer"
WorkerIDHeader = "x-worker-instance-id"
PrefillWorkerIDHeader = "x-prefill-instance-id"
RoutingModeHeader = "x-dynamo-routing-mode"
// decodeStateKey is the key used to store routing state in PluginState
decodeStateKey = "dynamo-decode-routing-state"
)
// compile-time type assertions
var _ framework.Scorer = &DynDecodeScorer{}
var _ plugins.Plugin = &DynDecodeScorer{}
var _ rc.PreRequest = &DynDecodeScorer{}
var _ rc.ResponseStreaming = &DynDecodeScorer{}
var _ rc.ResponseComplete = &DynDecodeScorer{}
// DecodeRoutingState holds routing information passed from Score() to PreRequest().
type DecodeRoutingState struct {
WorkerID string
PrefillWorkerID string
TokenData []int64
}
// Clone implements plugins.StateData.
func (s *DecodeRoutingState) Clone() plugins.StateData {
if s == nil {
return nil
}
clone := &DecodeRoutingState{
WorkerID: s.WorkerID,
PrefillWorkerID: s.PrefillWorkerID,
}
if s.TokenData != nil {
clone.TokenData = make([]int64, len(s.TokenData))
copy(clone.TokenData, s.TokenData)
}
return clone
}
// DynDecodeScorerConfig holds the configuration for the DynDecodeScorer plugin.
type DynDecodeScorerConfig struct{}
// DynDecodeScorerFactory defines the factory function for DynDecodeScorer.
func DynDecodeScorerFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
cfg := DynDecodeScorerConfig{}
if rawParameters != nil {
if err := json.Unmarshal(rawParameters, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse %s plugin parameters: %w", DynDecodeScorerType, err)
}
}
// Initialize the shared FFI (idempotent)
if err := dynscorer.InitFFI(); err != nil {
return nil, fmt.Errorf("Dynamo FFI init for decode scorer failed: %w", err)
}
return NewDynDecodeScorer(handle.Context()).WithName(name), nil
}
// NewDynDecodeScorer initializes a new DynDecodeScorer.
func NewDynDecodeScorer(ctx context.Context) *DynDecodeScorer {
return &DynDecodeScorer{
typedName: plugins.TypedName{Type: DynDecodeScorerType, Name: DynDecodeScorerType},
pluginState: plugins.NewPluginState(ctx),
}
}
// DynDecodeScorer is a scorer plugin for the decode scheduling profile.
//
// When Score() is called, it:
// 1. Reads PrefillEnabledState from CycleState (written by DisaggProfileHandler).
// 2. Calls the Dynamo FFI decode router with is_disaggregated flag.
// 3. Sets routing headers on the request.
// 4. Stores routing state for PreRequest to register with router bookkeeping.
//
// It also implements PreRequest, ResponseStreaming, and ResponseComplete lifecycle hooks
// for router bookkeeping (add_request, mark_prefill_complete, free_request).
type DynDecodeScorer struct {
typedName plugins.TypedName
pluginState *plugins.PluginState
firstTokenSeen sync.Map
}
// TypedName returns the type and name tuple of this plugin instance.
func (s *DynDecodeScorer) TypedName() plugins.TypedName {
return s.typedName
}
// WithName sets the name of the scorer.
func (s *DynDecodeScorer) WithName(name string) *DynDecodeScorer {
s.typedName.Name = name
return s
}
// Score scores pods for decode suitability.
func (s *DynDecodeScorer) Score(ctx context.Context, cycleState *schedtypes.CycleState, req *schedtypes.LLMRequest, pods []schedtypes.Pod) map[schedtypes.Pod]float64 {
logger := log.FromContext(ctx)
isDisaggregated := readPrefillEnabled(cycleState)
requestJSON, err := buildRequestJSON(req)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "DynDecodeScorer: failed to build request")
return uniformScores(pods, 1.0)
}
podsJSON := serializePods(pods)
logger.V(logutil.DEFAULT).Info("DynDecodeScorer: pods received for scoring",
"podCount", len(pods),
"podsJSON", string(podsJSON))
result, err := dynscorer.CallRouteDecodeRequest(requestJSON, podsJSON, isDisaggregated)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "DynDecodeScorer: FFI decode routing failed")
return uniformScores(pods, 1.0)
}
workerIDStr := fmt.Sprintf("%d", result.WorkerID)
logger.V(logutil.DEFAULT).Info("DynDecodeScorer: decode worker selected",
"decodeWorkerID", workerIDStr,
"isDisaggregated", isDisaggregated,
"tokenCount", len(result.TokenData))
// Set routing headers
if req.Headers == nil {
req.Headers = map[string]string{}
}
req.Headers[WorkerIDHeader] = workerIDStr
if isDisaggregated {
req.Headers[RoutingModeHeader] = "disaggregated"
// In disagg mode, the prefill worker was selected by the prefill scorer profile.
// The prefill worker ID would need to be communicated from the prefill profile result.
// For now we set the mode header; the prefill worker header will be set
// when the framework processes the prefill profile result.
} else {
req.Headers[RoutingModeHeader] = "aggregated"
}
// Store routing state for PreRequest bookkeeping
if req.RequestId != "" {
routingState := &DecodeRoutingState{
WorkerID: workerIDStr,
TokenData: result.TokenData,
}
s.pluginState.Write(req.RequestId, plugins.StateKey(decodeStateKey), routingState)
}
// Score: all decode pods get 1.0 since the router's internal selection is authoritative
// and the worker ID is communicated via headers.
return uniformScores(pods, 1.0)
}
// PreRequest is called after scheduling is finalized and before the request is sent to the worker.
// This registers the request with the Dynamo router's bookkeeping.
func (s *DynDecodeScorer) PreRequest(ctx context.Context, request *schedtypes.LLMRequest, _ *schedtypes.SchedulingResult) {
logger := log.FromContext(ctx)
if request == nil || request.RequestId == "" {
logger.V(logutil.DEBUG).Info("DynDecodeScorer PreRequest: no request ID, skipping")
return
}
state, err := plugins.ReadPluginStateKey[*DecodeRoutingState](
s.pluginState, request.RequestId, plugins.StateKey(decodeStateKey),
)
s.pluginState.Delete(request.RequestId)
if err != nil {
logger.V(logutil.DEBUG).Info("DynDecodeScorer PreRequest: no routing state found",
"requestID", request.RequestId)
return
}
var workerIDUint uint64
if _, parseErr := fmt.Sscanf(state.WorkerID, "%d", &workerIDUint); parseErr != nil {
logger.V(logutil.DEFAULT).Error(parseErr, "DynDecodeScorer PreRequest: invalid worker ID",
"requestID", request.RequestId, "workerID", state.WorkerID)
return
}
if addErr := dynscorer.CallAddRequest(request.RequestId, state.TokenData, workerIDUint, 0); addErr != nil {
logger.V(logutil.DEFAULT).Error(addErr, "DynDecodeScorer PreRequest: failed to add request",
"requestID", request.RequestId)
return
}
logger.V(logutil.VERBOSE).Info("DynDecodeScorer PreRequest: registered request",
"requestID", request.RequestId,
"workerID", state.WorkerID,
"tokenCount", len(state.TokenData))
}
// ResponseStreaming is called for each chunk of a streaming response.
// On the first token, it marks prefill as complete in the Dynamo router's bookkeeping.
func (s *DynDecodeScorer) ResponseStreaming(ctx context.Context, request *schedtypes.LLMRequest, _ *rc.Response, _ *backend.Pod) {
if request == nil || request.RequestId == "" {
return
}
if _, alreadySeen := s.firstTokenSeen.LoadOrStore(request.RequestId, true); !alreadySeen {
logger := log.FromContext(ctx)
if err := dynscorer.CallMarkPrefillComplete(request.RequestId); err != nil {
logger.V(logutil.DEFAULT).Error(err, "DynDecodeScorer ResponseStreaming: failed to mark prefill complete",
"requestID", request.RequestId)
return
}
logger.V(logutil.VERBOSE).Info("DynDecodeScorer ResponseStreaming: marked prefill complete",
"requestID", request.RequestId)
}
}
// ResponseComplete is called after the complete response is sent to the client.
// It cleans up the router bookkeeping state for the completed request.
func (s *DynDecodeScorer) ResponseComplete(ctx context.Context, request *schedtypes.LLMRequest, _ *rc.Response, _ *backend.Pod) {
logger := log.FromContext(ctx)
if request == nil || request.RequestId == "" {
return
}
s.firstTokenSeen.Delete(request.RequestId)
if err := dynscorer.CallFreeRequest(request.RequestId); err != nil {
logger.V(logutil.DEFAULT).Error(err, "DynDecodeScorer ResponseComplete: failed to free request",
"requestID", request.RequestId)
return
}
logger.V(logutil.VERBOSE).Info("DynDecodeScorer ResponseComplete: freed request",
"requestID", request.RequestId)
}
/*
Copyright 2025 NVIDIA Corporation.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package disagg
import (
"context"
"encoding/json"
"fmt"
log "sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
dynscorer "github.com/nvidia/dynamo/deploy/inference-gateway/pkg/plugins/dynamo_kv_scorer"
)
const (
// DynPrefillScorerType is the plugin type registered in the plugin registry.
DynPrefillScorerType = "dyn-prefill-scorer"
)
// compile-time type assertion
var _ framework.Scorer = &DynPrefillScorer{}
// DynPrefillScorerConfig holds the configuration for the DynPrefillScorer plugin.
type DynPrefillScorerConfig struct{}
// DynPrefillScorerFactory defines the factory function for DynPrefillScorer.
func DynPrefillScorerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
cfg := DynPrefillScorerConfig{}
if rawParameters != nil {
if err := json.Unmarshal(rawParameters, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse %s plugin parameters: %w", DynPrefillScorerType, err)
}
}
// Initialize the shared FFI (idempotent)
if err := dynscorer.InitFFI(); err != nil {
return nil, fmt.Errorf("Dynamo FFI init for prefill scorer failed: %w", err)
}
return NewDynPrefillScorer().WithName(name), nil
}
// NewDynPrefillScorer initializes a new DynPrefillScorer.
func NewDynPrefillScorer() *DynPrefillScorer {
return &DynPrefillScorer{
typedName: plugins.TypedName{Type: DynPrefillScorerType, Name: DynPrefillScorerType},
}
}
// DynPrefillScorer is a scorer plugin for the prefill scheduling profile.
//
// When Score() is called, it:
// 1. Reads PrefillEnabledState from CycleState (written by DisaggProfileHandler).
// 2. If prefill is NOT enabled, returns zero scores.
// 3. If prefill IS enabled, calls the Dynamo FFI prefill router to select the best prefill worker.
// 4. Assigns score 1.0 to all pods (the router's selection is authoritative, communicated via headers).
type DynPrefillScorer struct {
typedName plugins.TypedName
}
// TypedName returns the type and name tuple of this plugin instance.
func (s *DynPrefillScorer) TypedName() plugins.TypedName {
return s.typedName
}
// WithName sets the name of the scorer.
func (s *DynPrefillScorer) WithName(name string) *DynPrefillScorer {
s.typedName.Name = name
return s
}
// Score scores pods for prefill suitability.
func (s *DynPrefillScorer) Score(ctx context.Context, cycleState *schedtypes.CycleState, req *schedtypes.LLMRequest, pods []schedtypes.Pod) map[schedtypes.Pod]float64 {
logger := log.FromContext(ctx)
if !readPrefillEnabled(cycleState) {
logger.V(logutil.VERBOSE).Info("DynPrefillScorer: prefill not enabled, returning zero scores")
return uniformScores(pods, 0)
}
requestJSON, err := buildRequestJSON(req)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "DynPrefillScorer: failed to build request")
return uniformScores(pods, 0)
}
podsJSON := serializePods(pods)
logger.V(logutil.DEFAULT).Info("DynPrefillScorer: pods received for scoring",
"podCount", len(pods),
"podsJSON", string(podsJSON))
result, err := dynscorer.CallRoutePrefillRequest(requestJSON, podsJSON)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "DynPrefillScorer: FFI prefill routing failed")
return uniformScores(pods, 0)
}
logger.V(logutil.DEFAULT).Info("DynPrefillScorer: prefill worker selected",
"prefillWorkerID", fmt.Sprintf("%d", result.WorkerID),
"tokenCount", len(result.TokenData))
// Score: 1.0 for all pods. The label-filter has already restricted to prefill workers,
// and the FFI router's internal selection is authoritative.
// In the future, we could match worker IDs to pod names for precise scoring.
return uniformScores(pods, 1.0)
}
/*
Copyright 2025 NVIDIA Corporation.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package disagg
import (
"context"
"encoding/json"
"errors"
"fmt"
log "sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)
const (
DisaggProfileHandlerType = "disagg-profile-handler"
)
// compile-time type assertion
var _ framework.ProfileHandler = &DisaggProfileHandler{}
// DisaggProfileHandlerConfig holds the configuration for the DisaggProfileHandler.
type DisaggProfileHandlerConfig struct{}
// DisaggProfileHandlerFactory defines the factory function for DisaggProfileHandler.
func DisaggProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
cfg := DisaggProfileHandlerConfig{}
if rawParameters != nil {
if err := json.Unmarshal(rawParameters, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse %s plugin parameters: %w", DisaggProfileHandlerType, err)
}
}
return NewDisaggProfileHandler().WithName(name), nil
}
// NewDisaggProfileHandler initializes a new DisaggProfileHandler.
func NewDisaggProfileHandler() *DisaggProfileHandler {
return &DisaggProfileHandler{
typedName: plugins.TypedName{Type: DisaggProfileHandlerType, Name: DisaggProfileHandlerType},
}
}
// DisaggProfileHandler is a ProfileHandler that orchestrates prefill/decode disaggregated serving.
//
// # Disaggregated mode detection
//
// In Dynamo's native architecture, disaggregated mode is determined by whether prefill workers
// actually exist at runtime (the is_disaggregated flag in the Rust KV router). However, the
// GAIE EPP framework determines profile availability at configuration time, not at runtime.
// To bridge this gap, DisaggProfileHandler uses the EPP profile mechanism as a proxy:
// it checks whether a "prefill" scheduling profile is registered in the config. If prefill
// workers are configured but none are actually running, the prefill profile's label-filter
// will find zero pods, causing the profile to fail — and the handler gracefully degrades
// to aggregated mode (see below).
//
// # Scheduling flow
//
// On each scheduling cycle it:
// 1. Checks whether a "prefill" profile is registered in the config.
// 2. Writes PrefillEnabledState into CycleState so scorer plugins can read it.
// 3. If a prefill profile exists: runs the "prefill" profile first, then the "decode" profile.
// The "decode" profile is the primary (the pod the request is ultimately sent to).
// 4. If no prefill profile exists: runs only the "decode" profile (pure aggregated mode).
//
// # Graceful degradation
//
// When a prefill profile is configured but no prefill workers are available at runtime,
// the handler degrades gracefully to aggregated mode on a per-request basis:
//
// 1. Pick (iteration 1): prefill profile exists → writes PrefillEnabled=true → runs prefill profile.
// 2. Prefill profile runs: label-filter finds 0 prefill pods → profile fails → result is nil.
// 3. Pick (iteration 2): sees prefill result is nil → overwrites PrefillEnabled=false → runs decode profile.
// 4. Decode scorer runs: reads PrefillEnabled=false → passes isDisaggregated=false to the Rust
// decode router → full KV cache overlap scoring is used (overlap_score_weight=1.0).
//
// This means the same YAML config works transparently for both aggregated and disaggregated
// deployments. If prefill workers come up later, subsequent requests automatically use
// disaggregated routing. If they go down, requests fall back to aggregated mode.
type DisaggProfileHandler struct {
typedName plugins.TypedName
}
// TypedName returns the type and name tuple of this plugin instance.
func (h *DisaggProfileHandler) TypedName() plugins.TypedName {
return h.typedName
}
// WithName sets the name of the profile handler.
func (h *DisaggProfileHandler) WithName(name string) *DisaggProfileHandler {
h.typedName.Name = name
return h
}
// Pick selects which profiles to run in the current iteration.
//
// Iteration 1 (no results yet):
// - Writes PrefillEnabledState into CycleState.
// - If a "prefill" profile exists → returns it alone (run prefill first).
// - Otherwise → returns the "decode" profile.
//
// Iteration 2 (prefill result exists, decode not yet):
// - Returns the "decode" profile.
//
// Iteration 3+ (all results collected):
// - Returns empty map to stop the loop.
func (h *DisaggProfileHandler) Pick(ctx context.Context, cycleState *schedtypes.CycleState, _ *schedtypes.LLMRequest,
profiles map[string]*framework.SchedulerProfile, profileResults map[string]*schedtypes.ProfileRunResult) map[string]*framework.SchedulerProfile {
logger := log.FromContext(ctx).V(logutil.VERBOSE)
// First call: determine if prefill is enabled and write state.
if len(profileResults) == 0 {
_, prefillExists := profiles[PrefillProfileName]
state := &PrefillEnabledState{Enabled: prefillExists}
cycleState.Write(PrefillEnabledStateKey, state)
logger.Info("DisaggProfileHandler: prefill enabled state determined", "prefillEnabled", prefillExists)
if prefillExists {
// Run prefill profile first.
return map[string]*framework.SchedulerProfile{
PrefillProfileName: profiles[PrefillProfileName],
}
}
// No prefill profile — run decode only.
if decodeProfile, ok := profiles[DecodeProfileName]; ok {
return map[string]*framework.SchedulerProfile{
DecodeProfileName: decodeProfile,
}
}
// Fallback: return all profiles.
return profiles
}
// Second call: prefill has run, now run decode.
if prefillResult, prefillDone := profileResults[PrefillProfileName]; prefillDone {
if _, decodeDone := profileResults[DecodeProfileName]; !decodeDone {
// If the prefill profile failed (nil result = no prefill pods available),
// update PrefillEnabledState to false so the decode scorer uses normal
// KV cache overlap scoring instead of disaggregated mode (overlap_score_weight=0).
if prefillResult == nil {
logger.Info("DisaggProfileHandler: prefill profile failed (no workers?), falling back to aggregated decode")
cycleState.Write(PrefillEnabledStateKey, &PrefillEnabledState{Enabled: false})
}
if decodeProfile, ok := profiles[DecodeProfileName]; ok {
return map[string]*framework.SchedulerProfile{
DecodeProfileName: decodeProfile,
}
}
}
}
// All profiles have been executed.
return map[string]*framework.SchedulerProfile{}
}
// ProcessResults aggregates the profile run results and designates the primary profile.
// The "decode" profile is always the primary (the pod that handles the request).
func (h *DisaggProfileHandler) ProcessResults(_ context.Context, _ *schedtypes.CycleState, _ *schedtypes.LLMRequest,
profileResults map[string]*schedtypes.ProfileRunResult) (*schedtypes.SchedulingResult, error) {
if len(profileResults) == 0 {
return nil, errors.New("disagg profile handler received no profile results")
}
// Determine primary profile name.
primaryProfile := DecodeProfileName
if _, ok := profileResults[DecodeProfileName]; !ok {
// If there's no decode result, pick whichever profile ran.
for name := range profileResults {
primaryProfile = name
break
}
}
if profileResults[primaryProfile] == nil {
return nil, fmt.Errorf("primary profile '%s' failed to produce a result", primaryProfile)
}
return &schedtypes.SchedulingResult{
ProfileResults: profileResults,
PrimaryProfileName: primaryProfile,
}, nil
}
/*
Copyright 2025 NVIDIA Corporation.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package disagg implements disaggregated prefill/decode serving plugins for Dynamo EPP.
//
// The disaggregated architecture splits inference into two phases:
// - Prefill: processes the input prompt (compute-heavy, parallelizable)
// - Decode: generates tokens autoregressively (memory-bound, sequential)
//
// This package provides three plugins:
// - DisaggProfileHandler: orchestrates prefill→decode profile execution
// - DynPrefillScorer: selects prefill workers via Dynamo FFI
// - DynDecodeScorer: selects decode workers via Dynamo FFI
package disagg
import (
"encoding/json"
"fmt"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
dynscorer "github.com/nvidia/dynamo/deploy/inference-gateway/pkg/plugins/dynamo_kv_scorer"
)
const (
PrefillProfileName = "prefill"
DecodeProfileName = "decode"
// PrefillEnabledStateKey is used to communicate prefill-enabled status
// from the DisaggProfileHandler to the scorer plugins via CycleState.
PrefillEnabledStateKey = plugins.StateKey("disagg-prefill-enabled")
)
// PrefillEnabledState stores whether prefill is enabled for the current scheduling cycle.
// Written by DisaggProfileHandler, read by PrefillScorer and DecodeScorer.
type PrefillEnabledState struct {
Enabled bool
}
// Clone implements plugins.StateData.
func (s *PrefillEnabledState) Clone() plugins.StateData {
return &PrefillEnabledState{Enabled: s.Enabled}
}
// readPrefillEnabled reads the PrefillEnabledState from CycleState.
// Returns false if the state is not found or not set.
func readPrefillEnabled(cycleState *schedtypes.CycleState) bool {
state, err := schedtypes.ReadCycleStateKey[*PrefillEnabledState](cycleState, PrefillEnabledStateKey)
if err == nil && state != nil {
return state.Enabled
}
return false
}
// buildRequestJSON builds an OpenAI-compatible JSON string from a GAIE LLMRequest.
func buildRequestJSON(req *schedtypes.LLMRequest) (string, error) {
requestBody, err := dynscorer.BuildOpenAIRequest(req)
if err != nil {
return "", fmt.Errorf("failed to build OpenAI request: %w", err)
}
data, err := json.Marshal(requestBody)
if err != nil {
return "", fmt.Errorf("failed to marshal request JSON: %w", err)
}
return string(data), nil
}
// serializePods converts pods to a JSON string for the FFI filter.
// Returns an empty string if serialization fails or pods is empty.
func serializePods(pods []schedtypes.Pod) string {
if len(pods) == 0 {
return ""
}
pj, err := dynscorer.SerializePodsToJSON(pods)
if err != nil {
return ""
}
return pj
}
// uniformScores returns a score map with the same score for every pod.
func uniformScores(pods []schedtypes.Pod, score float64) map[schedtypes.Pod]float64 {
out := make(map[schedtypes.Pod]float64, len(pods))
for _, p := range pods {
out[p] = score
}
return out
}
......@@ -14,6 +14,11 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
// Package dynamo_kv_scorer provides the CGO/FFI bindings to the Dynamo Rust router.
//
// This package owns all CGO interactions with the libdynamo_llm_capi static library.
// The disagg plugin package imports the exported Go wrapper functions from here
// to call into the Rust router for prefill/decode worker selection and bookkeeping.
package dynamo_kv_scorer
/*
......@@ -42,7 +47,7 @@ enum {
struct RouterHandles;
typedef struct RouterHandles RouterHandles;
// Routing result from route_chat_request
// Routing result from route functions
typedef struct {
bool is_disaggregated;
uint64_t prefill_worker_id;
......@@ -51,15 +56,22 @@ typedef struct {
size_t token_count;
} CRoutingResult;
// Router bindings API (replaces Pipeline API)
// Router bindings API
query_router_result_t create_routers(const char *namespace_c_str,
const char *component_c_str,
bool decode_fallback,
RouterHandles **out_handle);
query_router_result_t route_request(RouterHandles *handle,
const char *request_json,
CRoutingResult *out_result);
query_router_result_t route_prefill_request(RouterHandles *handle,
const char *request_json,
const char *pods_json,
CRoutingResult *out_result);
query_router_result_t route_decode_request(RouterHandles *handle,
const char *request_json,
const char *pods_json,
bool is_disaggregated,
CRoutingResult *out_result);
query_router_result_t add_request(RouterHandles *handle,
const char *request_id,
......@@ -81,114 +93,17 @@ void destroy(RouterHandles *handle);
import "C"
import (
"context"
"encoding/json"
"fmt"
"os"
"strings"
"sync"
"time"
"unsafe"
log "sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
rc "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
schedtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)
const (
PluginName = "dynamo-kv-scorer"
KVAwareScorerType = "kv-aware-scorer"
WorkerIDHeader = "x-worker-instance-id"
PrefillWorkerIDHeader = "x-prefill-instance-id"
RoutingModeHeader = "x-dynamo-routing-mode"
// stateKey is the key used to store routing state in PluginState
stateKey = "dynamo-routing-state"
)
// --------------------------- config / env ---------------------------
var warmupOnce sync.Once
var warmupErr error
type params struct{}
// DynamoRoutingState holds routing information passed from Score() to PreRequest().
// This is stored in PluginState keyed by request ID.
type DynamoRoutingState struct {
WorkerID string
PrefillWorkerID string
// TokenData holds the token IDs from the router, needed for add_request bookkeeping.
// These tokens are used to compute overlap blocks and track active blocks accurately.
TokenData []int64
}
// Clone implements plugins.StateData interface.
func (s *DynamoRoutingState) Clone() plugins.StateData {
if s == nil {
return nil
}
clone := &DynamoRoutingState{
WorkerID: s.WorkerID,
PrefillWorkerID: s.PrefillWorkerID,
}
if s.TokenData != nil {
clone.TokenData = make([]int64, len(s.TokenData))
copy(clone.TokenData, s.TokenData)
}
return clone
}
type KVAwareScorer struct {
typedName plugins.TypedName
pluginState *plugins.PluginState
firstTokenSeen sync.Map // map[requestID]bool - tracks which requests have received first token
}
var _ plugins.Plugin = (*KVAwareScorer)(nil)
var _ framework.Scorer = (*KVAwareScorer)(nil)
var _ rc.PreRequest = (*KVAwareScorer)(nil)
var _ rc.ResponseStreaming = (*KVAwareScorer)(nil)
var _ rc.ResponseComplete = (*KVAwareScorer)(nil)
func NewKVAwareScorer(ctx context.Context) *KVAwareScorer {
return &KVAwareScorer{
typedName: plugins.TypedName{Type: KVAwareScorerType, Name: PluginName},
pluginState: plugins.NewPluginState(ctx),
}
}
func (k *KVAwareScorer) WithName(name string) *KVAwareScorer { k.typedName.Name = name; return k }
func KVAwareScorerFactory(name string, raw json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
p := params{}
_ = json.Unmarshal(raw, &p)
s := NewKVAwareScorer(handle.Context()).WithName(name)
// one-time FFI init (runtime + persistent pipeline)
warmupOnce.Do(func() {
defer func() {
if r := recover(); r != nil {
warmupErr = fmt.Errorf("Dynamo configuration error: %v", r)
}
}()
warmupErr = initFFI()
})
if warmupErr != nil {
return nil, fmt.Errorf("Dynamo FFI init for the Router failed: %w", warmupErr)
}
return s, nil
}
func (k *KVAwareScorer) TypedName() plugins.TypedName { return k.typedName }
// --------------------------- FFI integration ---------------------------
var (
ffiOnce sync.Once
ffiErr error
......@@ -232,7 +147,7 @@ func getEnvBoolOrDefault(key string, def bool) bool {
return def
}
// initFFI: initialize router handles using the new Router bindings.
// initFFI initializes router handles using the Router bindings.
func initFFI() error {
ffiOnce.Do(func() {
loadDynamoConfig()
......@@ -261,257 +176,88 @@ func initFFI() error {
return ffiErr
}
// --------------------------- scoring ---------------------------
func (k *KVAwareScorer) Score(
ctx context.Context,
cycleState *schedtypes.CycleState,
req *schedtypes.LLMRequest,
pods []schedtypes.Pod,
) map[schedtypes.Pod]float64 {
logger := log.FromContext(ctx)
workerID, prefillWorkerID, tokenData, err := k.callDynamoRouter(ctx, req)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Dynamo call failed; proceeding without worker id")
} else if workerID != "" {
logger.V(logutil.DEFAULT).Info(
"Dynamo router selected worker",
"workerID", workerID,
"prefillWorkerID", prefillWorkerID,
"tokenDataCount", len(tokenData),
)
// Store in request headers
if req.Headers == nil {
req.Headers = map[string]string{}
}
req.Headers[WorkerIDHeader] = workerID
// Set routing mode and prefill worker ID based on disaggregated vs aggregated
if prefillWorkerID != "" && prefillWorkerID != workerID {
// Disaggregated mode: separate prefill and decode workers
req.Headers[RoutingModeHeader] = "disaggregated"
req.Headers[PrefillWorkerIDHeader] = prefillWorkerID
} else {
// Aggregated mode: single worker handles both prefill and decode
req.Headers[RoutingModeHeader] = "aggregated"
}
// Store routing state for PreRequest to register with router bookkeeping.
// PreRequest is called AFTER scheduling is finalized, ensuring we only
// register committed requests (avoiding phantom bookkeeping entries).
if req.RequestId != "" {
routingState := &DynamoRoutingState{
WorkerID: workerID,
PrefillWorkerID: prefillWorkerID,
TokenData: tokenData,
}
k.pluginState.Write(req.RequestId, plugins.StateKey(stateKey), routingState)
}
}
out := make(map[schedtypes.Pod]float64, len(pods))
for _, p := range pods {
out[p] = 1.0
}
return out
// InitFFI exposes the FFI initialization for use by the disagg plugin package.
// It is idempotent — safe to call multiple times.
func InitFFI() error {
return initFFI()
}
// PreRequest is called after scheduling is finalized and before the request is sent to the worker.
// This registers the request with the Dynamo router's bookkeeping (add_request), passing the
// token data obtained during Score(). This ensures only committed requests are tracked.
func (k *KVAwareScorer) PreRequest(
ctx context.Context,
request *schedtypes.LLMRequest,
schedulingResult *schedtypes.SchedulingResult,
) {
logger := log.FromContext(ctx)
if request == nil || request.RequestId == "" {
logger.V(logutil.DEBUG).Info("PreRequest: no request ID, skipping router bookkeeping")
return
}
// Read and delete the routing state stored by Score()
state, err := plugins.ReadPluginStateKey[*DynamoRoutingState](
k.pluginState, request.RequestId, plugins.StateKey(stateKey),
)
k.pluginState.Delete(request.RequestId) // Clean up state after reading
if err != nil {
// No state found means Score() didn't store routing info (e.g., router call failed)
logger.V(logutil.DEBUG).Info("PreRequest: no routing state found, skipping router bookkeeping",
"requestID", request.RequestId)
return
}
// Parse worker ID
var workerIDUint uint64
if _, parseErr := fmt.Sscanf(state.WorkerID, "%d", &workerIDUint); parseErr != nil {
logger.V(logutil.DEFAULT).Error(parseErr, "PreRequest: invalid worker ID",
"requestID", request.RequestId, "workerID", state.WorkerID)
return
}
// Register request with router bookkeeping now that scheduling is committed
if addErr := CallAddRequest(request.RequestId, state.TokenData, workerIDUint, 0); addErr != nil {
logger.V(logutil.DEFAULT).Error(addErr, "PreRequest: failed to add request to router bookkeeping",
"requestID", request.RequestId)
return
}
logger.V(logutil.VERBOSE).Info("PreRequest: registered request with router bookkeeping",
"requestID", request.RequestId,
"workerID", state.WorkerID,
"prefillWorkerID", state.PrefillWorkerID,
"tokenCount", len(state.TokenData),
)
// podInfoJSON is the JSON-serializable representation of a backend.Pod (datalayer.PodInfo).
type podInfoJSON struct {
Name string `json:"name"`
Namespace string `json:"namespace"`
PodName string `json:"podName"`
Address string `json:"address"`
Port string `json:"port"`
MetricsHost string `json:"metricsHost"`
Labels map[string]string `json:"labels"`
}
// ResponseStreaming is called for each chunk of a streaming response.
// On the first token, it marks prefill as complete in the Dynamo router's bookkeeping.
func (k *KVAwareScorer) ResponseStreaming(
ctx context.Context,
request *schedtypes.LLMRequest,
response *rc.Response,
targetPod *backend.Pod,
) {
if request == nil || request.RequestId == "" {
return
}
// Check if we've already seen the first token for this request
// LoadOrStore returns (value, loaded) - if loaded is false, this is the first time
if _, alreadySeen := k.firstTokenSeen.LoadOrStore(request.RequestId, true); !alreadySeen {
// This is the first token - mark prefill as complete
logger := log.FromContext(ctx)
if err := CallMarkPrefillComplete(request.RequestId); err != nil {
logger.V(logutil.DEFAULT).Error(err, "ResponseStreaming: failed to mark prefill complete",
"requestID", request.RequestId)
return
}
logger.V(logutil.VERBOSE).Info("ResponseStreaming: marked prefill complete (first token received)",
"requestID", request.RequestId)
}
// metricsJSON is the JSON-serializable representation of backendmetrics.MetricsState (datalayer.Metrics).
type metricsJSON struct {
ActiveModels map[string]int `json:"activeModels"`
WaitingModels map[string]int `json:"waitingModels"`
MaxActiveModels int `json:"maxActiveModels"`
RunningQueueSize int `json:"runningQueueSize"`
WaitingQueueSize int `json:"waitingQueueSize"`
KVCacheUsagePercent float64 `json:"kvCacheUsagePercent"`
KvCacheMaxTokenCapacity int `json:"kvCacheMaxTokenCapacity"`
CacheBlockSize int `json:"cacheBlockSize"`
CacheNumGPUBlocks int `json:"cacheNumGPUBlocks"`
UpdateTime time.Time `json:"updateTime"`
}
// ResponseComplete is called after the complete response is sent to the client.
// It cleans up the router bookkeeping state for the completed request by calling
// free_request to release resources associated with the request.
func (k *KVAwareScorer) ResponseComplete(
ctx context.Context,
request *schedtypes.LLMRequest,
response *rc.Response,
targetPod *backend.Pod,
) {
logger := log.FromContext(ctx)
if request == nil {
logger.V(logutil.DEBUG).Info("ResponseComplete: request is nil, skipping cleanup")
return
}
requestID := request.RequestId
if requestID == "" {
logger.V(logutil.DEBUG).Info("ResponseComplete: no request ID, skipping cleanup")
return
}
// Clean up the first token tracking map
k.firstTokenSeen.Delete(requestID)
// Call the dynamo router to free the request bookkeeping
if err := callFreeRequestInternal(requestID); err != nil {
logger.V(logutil.DEFAULT).Error(err, "ResponseComplete: failed to free request",
"requestID", requestID)
return
}
logger.V(logutil.VERBOSE).Info("ResponseComplete: freed request from router",
"requestID", requestID)
// podJSON is the JSON-serializable representation of a schedtypes.Pod passed across the FFI boundary.
type podJSON struct {
Pod *podInfoJSON `json:"pod"`
Metrics *metricsJSON `json:"metrics"`
}
// --------------------------- router call ---------------------------
func (k *KVAwareScorer) callDynamoRouter(
ctx context.Context,
req *schedtypes.LLMRequest,
) (workerID string, prefillWorkerID string, tokenData []int64, err error) {
logger := log.FromContext(ctx)
if err := initFFI(); err != nil {
logger.V(logutil.DEFAULT).Error(err, "FFI init failed")
return "", "", nil, err
}
if !routerInitialized {
return "", "", nil, fmt.Errorf("dynamo router not initialized")
}
// SerializePodsToJSON converts a slice of schedtypes.Pod into a JSON string
// suitable for passing across the C FFI boundary to the Rust router.
func SerializePodsToJSON(pods []schedtypes.Pod) (string, error) {
out := make([]podJSON, 0, len(pods))
for _, p := range pods {
entry := podJSON{}
if podInfo := p.GetPod(); podInfo != nil {
entry.Pod = &podInfoJSON{
Name: podInfo.NamespacedName.Name,
Namespace: podInfo.NamespacedName.Namespace,
PodName: podInfo.PodName,
Address: podInfo.Address,
Port: podInfo.Port,
MetricsHost: podInfo.MetricsHost,
Labels: podInfo.Labels,
}
}
routerHandlesMutex.RLock()
router := routerHandles
routerHandlesMutex.RUnlock()
if m := p.GetMetrics(); m != nil {
entry.Metrics = &metricsJSON{
ActiveModels: m.ActiveModels,
WaitingModels: m.WaitingModels,
MaxActiveModels: m.MaxActiveModels,
RunningQueueSize: m.RunningQueueSize,
WaitingQueueSize: m.WaitingQueueSize,
KVCacheUsagePercent: m.KVCacheUsagePercent,
KvCacheMaxTokenCapacity: m.KvCacheMaxTokenCapacity,
CacheBlockSize: m.CacheBlockSize,
CacheNumGPUBlocks: m.CacheNumGPUBlocks,
UpdateTime: m.UpdateTime,
}
}
if router == nil {
return "", "", nil, fmt.Errorf("dynamo router handles not created")
out = append(out, entry)
}
// Build OpenAI-compatible JSON request from the GAIE LLMRequest structure
requestBody, err := buildOpenAIRequest(req)
data, err := json.Marshal(out)
if err != nil {
logger.V(logutil.DEFAULT).Info("Invalid/empty request body for router; refusing to route",
"err", err.Error())
return "", "", nil, err
return "", fmt.Errorf("failed to serialize pods: %w", err)
}
requestJSON, jsonErr := json.Marshal(requestBody)
if jsonErr != nil {
logger.V(logutil.DEFAULT).Error(jsonErr, "Failed to marshal OpenAI request")
return "", "", nil, fmt.Errorf("marshal OpenAI request: %w", jsonErr)
}
cRequestJSON := C.CString(string(requestJSON))
defer C.free(unsafe.Pointer(cRequestJSON))
var result C.CRoutingResult
rc := C.route_request(router, cRequestJSON, &result)
if rc != C.QUERY_ROUTER_OK {
return "", "", nil, fmt.Errorf("route_request failed with code %d", rc)
}
// Copy token IDs into Go memory before freeing the Rust-allocated result.
// These tokens are needed for add_request bookkeeping (overlap + active block tracking).
count := int(result.token_count)
var tokens64 []int64
if count > 0 && result.token_ids != nil {
src := unsafe.Slice((*uint32)(unsafe.Pointer(result.token_ids)), count)
tokens64 = make([]int64, count)
for i := 0; i < count; i++ {
tokens64[i] = int64(src[i])
}
}
// Copy scalar result fields before freeing the struct
isDisaggregated := result.is_disaggregated
decodeWorkerID := uint64(result.decode_worker_id)
prefillWorkerIDVal := uint64(result.prefill_worker_id)
// Free the Rust-allocated routing result (including token_ids)
C.free_routing_result(&result)
workerIDStr := fmt.Sprintf("%d", decodeWorkerID)
prefillWorkerIDStr := ""
if isDisaggregated {
prefillWorkerIDStr = fmt.Sprintf("%d", prefillWorkerIDVal)
}
logger.V(logutil.DEFAULT).Info("Worker selection completed",
"workerID", workerIDStr, "prefillWorkerID", prefillWorkerIDStr,
"isDisaggregated", isDisaggregated, "tokenCount", count)
return workerIDStr, prefillWorkerIDStr, tokens64, nil
return string(data), nil
}
// buildOpenAIRequest constructs an OpenAI-compatible request from the GAIE LLMRequest structure.
// Preserves message roles for correct chat template application and tokenization.
func buildOpenAIRequest(req *schedtypes.LLMRequest) (map[string]any, error) {
func BuildOpenAIRequest(req *schedtypes.LLMRequest) (map[string]any, error) {
requestBody := make(map[string]any)
// Preserve the original message structure for correct chat template application
......@@ -555,8 +301,6 @@ func buildOpenAIRequest(req *schedtypes.LLMRequest) (map[string]any, error) {
return requestBody, nil
}
// --------------------------- router bookkeeping ---------------------------
// CallAddRequest registers a request with the router's bookkeeping.
func CallAddRequest(requestID string, tokenData []int64, workerID uint64, dpRank uint32) error {
if !routerInitialized {
......@@ -600,8 +344,7 @@ func CallAddRequest(requestID string, tokenData []int64, workerID uint64, dpRank
return nil
}
// CallMarkPrefillComplete marks prefill as completed for a request.
// Exported for use by response handlers.
// CallMarkPrefillComplete marks prefill as completed for a request (bookkeeping).
func CallMarkPrefillComplete(requestID string) error {
if !routerInitialized {
return fmt.Errorf("dynamo router not initialized")
......@@ -625,8 +368,8 @@ func CallMarkPrefillComplete(requestID string) error {
return nil
}
// callFreeRequestInternal cleans up router state for a completed/cancelled request.
func callFreeRequestInternal(requestID string) error {
// CallFreeRequest cleans up router state for a completed/cancelled request (bookkeeping).
func CallFreeRequest(requestID string) error {
if !routerInitialized {
return fmt.Errorf("dynamo router not initialized")
}
......@@ -649,17 +392,100 @@ func callFreeRequestInternal(requestID string) error {
return nil
}
// --------------------------- shutdown ---------------------------
// RoutingResult holds the result of a prefill or decode routing call.
type RoutingResult struct {
WorkerID uint64
TokenData []int64
}
func cleanupDynamo() error {
routerHandlesMutex.Lock()
defer routerHandlesMutex.Unlock()
// CallRoutePrefillRequest routes a request to the best prefill worker.
// It tokenizes the request and queries only the prefill router.
func CallRoutePrefillRequest(requestJSON string, podsJSON string) (*RoutingResult, error) {
if !routerInitialized {
return nil, fmt.Errorf("dynamo router not initialized")
}
if routerHandles != nil {
C.destroy(routerHandles)
routerHandles = nil
routerHandlesMutex.RLock()
router := routerHandles
routerHandlesMutex.RUnlock()
if router == nil {
return nil, fmt.Errorf("dynamo router handles not created")
}
routerInitialized = false
return nil
cRequestJSON := C.CString(requestJSON)
defer C.free(unsafe.Pointer(cRequestJSON))
var cPodsJSON *C.char
if podsJSON != "" {
cPodsJSON = C.CString(podsJSON)
defer C.free(unsafe.Pointer(cPodsJSON))
}
var result C.CRoutingResult
rc := C.route_prefill_request(router, cRequestJSON, cPodsJSON, &result)
if rc != C.QUERY_ROUTER_OK {
return nil, fmt.Errorf("route_prefill_request failed with code %d", rc)
}
// Copy token IDs into Go memory
count := int(result.token_count)
var tokens64 []int64
if count > 0 && result.token_ids != nil {
src := unsafe.Slice((*uint32)(unsafe.Pointer(result.token_ids)), count)
tokens64 = make([]int64, count)
for i := 0; i < count; i++ {
tokens64[i] = int64(src[i])
}
}
workerID := uint64(result.prefill_worker_id)
C.free_routing_result(&result)
return &RoutingResult{WorkerID: workerID, TokenData: tokens64}, nil
}
// CallRouteDecodeRequest routes a request to the best decode worker.
// When isDisaggregated is true, overlap_score_weight=0 is used (KV cache transferred from prefill).
func CallRouteDecodeRequest(requestJSON string, podsJSON string, isDisaggregated bool) (*RoutingResult, error) {
if !routerInitialized {
return nil, fmt.Errorf("dynamo router not initialized")
}
routerHandlesMutex.RLock()
router := routerHandles
routerHandlesMutex.RUnlock()
if router == nil {
return nil, fmt.Errorf("dynamo router handles not created")
}
cRequestJSON := C.CString(requestJSON)
defer C.free(unsafe.Pointer(cRequestJSON))
var cPodsJSON *C.char
if podsJSON != "" {
cPodsJSON = C.CString(podsJSON)
defer C.free(unsafe.Pointer(cPodsJSON))
}
var result C.CRoutingResult
rc := C.route_decode_request(router, cRequestJSON, cPodsJSON, C.bool(isDisaggregated), &result)
if rc != C.QUERY_ROUTER_OK {
return nil, fmt.Errorf("route_decode_request failed with code %d", rc)
}
// Copy token IDs into Go memory
count := int(result.token_count)
var tokens64 []int64
if count > 0 && result.token_ids != nil {
src := unsafe.Slice((*uint32)(unsafe.Pointer(result.token_ids)), count)
tokens64 = make([]int64, count)
for i := 0; i < count; i++ {
tokens64[i] = int64(src[i])
}
}
workerID := uint64(result.decode_worker_id)
C.free_routing_result(&result)
return &RoutingResult{WorkerID: workerID, TokenData: tokens64}, nil
}
/*
Copyright 2025 NVIDIA Corporation.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package label_filter
import (
"context"
"encoding/json"
"fmt"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)
const (
LabelFilterType = "label-filter"
)
// compile-time type assertion
var _ framework.Filter = &LabelFilter{}
// LabelFilterConfig holds the configuration for the LabelFilter plugin.
// Matches the deployment manifest schema:
//
// parameters:
// label: "nvidia.com/dynamo-sub-component-type"
// validValues:
// - "prefill"
// allowsNoLabel: false
type LabelFilterConfig struct {
Label string `json:"label"`
ValidValues []string `json:"validValues"`
AllowsNoLabel bool `json:"allowsNoLabel"`
}
func LabelFilterFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
cfg := LabelFilterConfig{}
if rawParameters == nil {
return nil, fmt.Errorf("%s plugin requires parameters with 'label' and 'validValues' fields", LabelFilterType)
}
if err := json.Unmarshal(rawParameters, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse %s plugin parameters: %w", LabelFilterType, err)
}
if cfg.Label == "" {
return nil, fmt.Errorf("%s plugin parameter 'label' must not be empty", LabelFilterType)
}
if len(cfg.ValidValues) == 0 {
return nil, fmt.Errorf("%s plugin parameter 'validValues' must contain at least one value", LabelFilterType)
}
return NewLabelFilter(cfg.Label, cfg.ValidValues, cfg.AllowsNoLabel).WithName(name), nil
}
func NewLabelFilter(label string, validValues []string, allowsNoLabel bool) *LabelFilter {
// Build a set for O(1) lookups
valuesSet := make(map[string]struct{}, len(validValues))
for _, v := range validValues {
valuesSet[v] = struct{}{}
}
return &LabelFilter{
typedName: plugins.TypedName{Type: LabelFilterType, Name: LabelFilterType},
label: label,
validValues: valuesSet,
allowsNoLabel: allowsNoLabel,
}
}
type LabelFilter struct {
typedName plugins.TypedName
label string
validValues map[string]struct{}
allowsNoLabel bool
}
func (f *LabelFilter) TypedName() plugins.TypedName {
return f.typedName
}
func (f *LabelFilter) WithName(name string) *LabelFilter {
f.typedName.Name = name
return f
}
// Filter returns only the pods whose label matches one of the configured valid values.
// Pods without the label are kept only if allowsNoLabel is true.
func (f *LabelFilter) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
filtered := make([]types.Pod, 0, len(pods))
for _, pod := range pods {
if pod == nil || pod.GetPod() == nil {
continue
}
labelValue, hasLabel := pod.GetPod().Labels[f.label]
if !hasLabel {
if f.allowsNoLabel {
filtered = append(filtered, pod)
}
continue
}
if _, ok := f.validValues[labelValue]; ok {
filtered = append(filtered, pod)
}
}
return filtered
}
......@@ -17,25 +17,35 @@
apiVersion: inference.networking.x-k8s.io/v1alpha1
kind: EndpointPickerConfig
# This config uses the same disagg-profile-handler as disaggregated deployments.
# With no "prefill" profile defined, the handler runs only the "decode" profile
# and the decode scorer receives isDisaggregated=false, enabling full KV cache
# overlap scoring for aggregated mode. Adding prefill workers and a prefill
# profile would automatically switch to disaggregated routing.
plugins:
# Required: tells EPP which profile to use (even if you only have one)
- type: single-profile-handler
- type: disagg-profile-handler
# Label filter: restricts pods to decode workers
- name: decode-filter
type: label-filter
parameters:
label: "nvidia.com/dynamo-sub-component-type"
validValues:
- "decode"
allowsNoLabel: true
# Picker: chooses the final endpoint after scoring
- name: picker
type: max-score-picker
# Dynamo KV-aware Scorer: calls Dynamo router FFI for worker selection
# Implements Scorer, PreRequest, and ResponseComplete:
# - Score: Selects workers based on KV cache, sets routing headers
# - PreRequest: Registers request with router bookkeeping
# - ResponseComplete: Frees router bookkeeping when response completes
- name: dyn-kv
type: kv-aware-scorer
# Dynamo decode scorer: calls Dynamo FFI decode router, handles request lifecycle
- name: dyn-decode
type: dyn-decode-scorer
schedulingProfiles:
- name: default
- name: decode
plugins:
- pluginRef: dyn-kv
- pluginRef: decode-filter
- pluginRef: dyn-decode
weight: 1
- pluginRef: picker
......@@ -83,6 +83,10 @@ manifests: controller-gen ensure-yq ## Generate WebhookConfiguration, ClusterRol
for file in config/crd/bases/*.yaml; do \
yq eval '(.. | select(has("extraPodSpec")) | .extraPodSpec.required) |= (. - ["containers"])' -i --indent 2 $$file || exit 1; \
done
echo "Fixing PluginSpec parameters field: json.RawMessage needs x-kubernetes-preserve-unknown-fields instead of type: string"
for file in config/crd/bases/*.yaml; do \
yq eval '(.. | select(has("parameters")) | .parameters | select(has("format") and .format == "byte")) |= (del(.format) | del(.type) | .x-kubernetes-preserve-unknown-fields = true)' -i --indent 2 $$file || exit 1; \
done
echo "Adding NVIDIA header to CRD files"
for file in config/crd/bases/*.yaml; do \
if ! head -20 "$$file" | grep -q "NVIDIA CORPORATION"; then \
......
......@@ -964,8 +964,7 @@ spec:
Parameters are the set of parameters to be passed to the plugin's
factory function. The factory function is responsible
to parse the parameters.
format: byte
type: string
x-kubernetes-preserve-unknown-fields: true
type:
description: Type specifies the plugin type to be instantiated.
type: string
......
......@@ -1173,8 +1173,7 @@ spec:
Parameters are the set of parameters to be passed to the plugin's
factory function. The factory function is responsible
to parse the parameters.
format: byte
type: string
x-kubernetes-preserve-unknown-fields: true
type:
description: Type specifies the plugin type to be instantiated.
type: string
......
......@@ -58,7 +58,7 @@ func GenerateInferencePool(
},
Selector: gaiev1.LabelSelector{
MatchLabels: map[gaiev1.LabelKey]gaiev1.LabelValue{
consts.KubeLabelDynamoComponentType: consts.ComponentTypeFrontend,
consts.KubeLabelDynamoComponentType: consts.ComponentTypeWorker,
consts.KubeLabelDynamoNamespace: gaiev1.LabelValue(dynamoNamespace),
},
},
......
......@@ -6,11 +6,14 @@ title: Inference Gateway (GAIE)
## Inference Gateway Setup with Dynamo
When integrating Dynamo with the Inference Gateway you must use the custom Dynamo EPP image.
# Inference Gateway (GAIE)
The custom Dynamo EPP image integrates the Dynamo router directly into the gateway's endpoint picker. Using the `dyn-kv` plugin, it selects the optimal worker based on KV cache state and tokenized prompt before routing the request. The integration moves intelligent routing upstream to the gateway layer.
Integrate Dynamo with the Gateway API Inference Extension for intelligent KV-aware request routing at the gateway layer.
EPP's default kv-routing approach is not token-aware because the prompt is not tokenized. But the Dynamo plugin uses a token-aware KV algorithm. It employs the dynamo router which implements kv routing by running your model's tokenizer inline. The EPP plugin configuration lives in [`helm/dynamo-gaie/epp-config-dynamo.yaml`](../../../deploy/inference-gateway/helm/dynamo-gaie/epp-config-dynamo.yaml) per EPP [convention](https://gateway-api-inference-extension.sigs.k8s.io/guides/epp-configuration/config-text/).
EPP's default kv-routing approach is not token-aware because the prompt is not tokenized. But the Dynamo plugin uses a token-aware KV algorithm. It employs the dynamo router which implements kv routing by running your model's tokenizer inline. The EPP plugin configuration lives in [`helm/dynamo-gaie/epp-config-dynamo.yaml`](helm/dynamo-gaie/epp-config-dynamo.yaml) per EPP [convention](https://gateway-api-inference-extension.sigs.k8s.io/guides/epp-configuration/config-text/).
Dynamo Integration with the Inference Gateway supports Aggregated and Disaggregated Serving. The epp config is the same for both. If no prefill workers found the service degrades gracefully to perform aggregated serving.
If you want to use LoRA deploy Dynamo without the Inference Gateway.
Currently, these setups are only supported with the kGateway based Inference Gateway.
......@@ -86,7 +89,7 @@ kubectl create secret generic hf-token-secret \
```
Create a model configuration file similar to the vllm_agg_qwen.yaml for your model.
This file demonstrates the values needed for the Vllm Agg setup in [agg.yaml](../../../examples/backends/vllm/deploy/agg.yaml)
This file demonstrates the values needed for the Vllm Agg setup in [agg.yaml](../../examples/backends/vllm/deploy/agg.yaml)
Take a note of the model's block size provided in the model card.
### 4. Build EPP image (Optional)
......@@ -124,16 +127,34 @@ you could deploy it as a standalone pod
#### 5.a. Deploy as a DGD component (recommended)
We provide an example for llama-3-70b vLLM below.
We provide an example for the Qwen vLLM below.
```bash
cd <dynamo-source-root>
kubectl apply -f examples/backends/vllm/deploy/gaie/agg.yaml -n my-model
kubectl apply -f examples/backends/vllm/deploy/gaie/http-route.yaml -n my-model
```
Examples for other models can be found in the recipes folder.
```bash
# Deploy PVC, having first Update `storageClassName` in recipes/llama-3-70b/model-cache/model-cache.yaml to match your cluster before deploying
kubectl apply -f recipes/llama-3-70b/model-cache/model-cache.yaml -n ${NAMESPACE}
kubectl apply -f recipes/llama-3-70b/model-cache/model-download.yaml -n ${NAMESPACE}
```
We provide examples for llama-3-70b vLLM under the `recipes/llama-3-70b/vllm/agg/gaie/` for aggregated and `recipes/llama-3-70b/vllm/disagg-single-node/gaie/` for disaggregated serving.
Use the proper folder in commands below.
```bash
# Deploy PVC, first Update `storageClassName` in recipes/llama-3-70b/model-cache/model-cache.yaml to match your cluster before deploying
kubectl apply -f recipes/llama-3-70b/model-cache/model-cache.yaml
kubectl apply -f recipes/llama-3-70b/model-cache/model-download.yaml
# Deploy your model
# Deploy your Dynamo Graph.
# agg
kubectl apply -f recipes/llama-3-70b/vllm/agg/gaie/deploy.yaml -n ${NAMESPACE}
# Deploy the GAIE http-route CR.
kubectl apply -f recipes/llama-3-70b/vllm/agg/gaie/http-route.yaml -n ${NAMESPACE}
# or disagg
kubectl apply -f recipes/llama-3-70b/vllm/disagg-single-node/gaie/deploy.yaml -n ${NAMESPACE}
kubectl apply -f recipes/llama-3-70b/vllm/disagg-single-node/gaie/http-route.yaml -n ${NAMESPACE}
```
- When using GAIE the FrontEnd does not choose the workers. The routing is determined in the EPP.
......@@ -168,29 +189,9 @@ If you installed it into a different namespace, you need to adjust the HttpRoute
#### 5.b. Deploy as a standalone pod
##### 5.b.1 Deploy Your Model ###
We do not recommend this method but there are hints on how to do this here.
We provide an example for Qwen vLLM below.
Before deploying you must enable the `--direct-route` flag in the FrontEnd cli in your Dynamo Graph.
```bash
command:
- python3
args:
- -m
- dynamo.frontend
- --router-mode
- direct
```
Follow the steps in [model deployment](../../../examples/backends/vllm/deploy/README.md) to deploy `Qwen/Qwen3-0.6B` model in aggregate mode using [agg.yaml](../../../examples/backends/vllm/deploy/agg.yaml) in `my-model` kubernetes namespace.
Sample commands to deploy model:
```bash
cd <dynamo-source-root>
cd examples/backends/vllm/deploy
kubectl apply -f agg.yaml -n my-model
```
##### 5.b.1 Deploy Your Model ###
##### 5.b.2 Install Dynamo GIE helm chart ###
......@@ -284,7 +285,7 @@ b. use port-forward to expose the gateway to the host
```bash
# in first terminal
kubectl port-forward svc/inference-gateway 8000:80 -n kgateway-system
kubectl port-forward svc/inference-gateway 8000:80 -n {NAMESPACE} # for NAMESPACE put wherever you installed thee gateway i.e. kgateway-system
# in second terminal where you want to send inference requests
GATEWAY_URL=http://localhost:8000
......@@ -427,4 +428,4 @@ The plugins set HTTP headers that are forwarded to the backend workers.
| Header | Description | Set By |
|--------|-------------|--------|
| `x-worker-instance-id` | Primary worker ID (decode worker in disagg mode) | kv-aware-scorer |
| `x-prefill-instance-id` | Prefill worker ID (disaggregated mode only) | kv-aware-scorer |
\ No newline at end of file
| `x-prefill-instance-id` | Prefill worker ID (disaggregated mode only) | kv-aware-scorer |
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
apiVersion: nvidia.com/v1alpha1
kind: DynamoGraphDeployment
metadata:
name: qwen-agg
spec:
backendFramework: vllm
services:
Epp:
envFromSecret: hf-token-secret
componentType: epp
replicas: 1
extraPodSpec:
mainContainer:
image: nvcr.io/nvidia/ai-dynamo/epp-image:my-tag
eppConfig:
config:
plugins:
- type: disagg-profile-handler
- name: decode-filter
type: label-filter
parameters:
label: "nvidia.com/dynamo-sub-component-type"
validValues:
- "decode"
allowsNoLabel: true
- name: picker
type: max-score-picker
- name: dyn-decode
type: dyn-decode-scorer
schedulingProfiles:
- name: decode
plugins:
- pluginRef: decode-filter
- pluginRef: dyn-decode
weight: 1
- pluginRef: picker
VllmDecodeWorker:
componentType: worker
envFromSecret: hf-token-secret
sharedMemory:
size: 2Gi
extraPodSpec:
mainContainer:
env:
- name: SERVED_MODEL_NAME
value: "Qwen/Qwen3-0.6B"
- name: MODEL_PATH
value: "Qwen/Qwen3-0.6B"
- name: DYN_STORE_KV
value: "mem"
args:
- "python3 -m dynamo.vllm --model $MODEL_PATH --served-model-name $SERVED_MODEL_NAME --tensor-parallel-size 1 --data-parallel-size 1 --gpu-memory-utilization 0.90 --no-enable-prefix-caching --block-size 128"
command:
- /bin/sh
- -c
image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:my-tag
workingDir: /workspace/examples/backends/vllm
containers:
- name: frontend
image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:my-tag
command:
- python3
args:
- -m
- dynamo.frontend
- --router-mode
- direct
ports:
- containerPort: 8000
name: http
protocol: TCP
envFrom:
- secretRef:
name: hf-token-secret
env:
- name: DYNAMO_PORT
value: "8000"
- name: DYN_HTTP_PORT
value: "8000"
- name: DYN_NAMESPACE
value: my-model-qwen-agg
- name: DYN_COMPONENT
value: frontend
- name: DYN_DISCOVERY_BACKEND
value: kubernetes
- name: DYN_PARENT_DGD_K8S_NAME
value: qwen-agg
- name: DYN_PARENT_DGD_K8S_NAMESPACE
value: my-model
- name: POD_NAME
valueFrom:
fieldRef:
fieldPath: metadata.name
- name: POD_NAMESPACE
valueFrom:
fieldRef:
fieldPath: metadata.namespace
- name: POD_UID
valueFrom:
fieldRef:
fieldPath: metadata.uid
livenessProbe:
httpGet:
path: /live
port: http
initialDelaySeconds: 15
periodSeconds: 10
readinessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 10
periodSeconds: 10
replicas: 1
resources:
limits:
gpu: "1"
requests:
gpu: "1"
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: You can remove metadata.namespace if using kubectl apply -n
# The backendRefs.namespace field should match where your InferencePool is deployed
apiVersion: gateway.networking.k8s.io/v1
kind: HTTPRoute
metadata:
name: qwen-agg-route
spec:
parentRefs:
- group: gateway.networking.k8s.io
kind: Gateway
name: inference-gateway
namespace: my-model
rules:
- backendRefs:
- group: inference.networking.k8s.io
kind: InferencePool
name: qwen-agg-pool
port: 8000
weight: 1
matches:
- path:
type: PathPrefix
value: /
timeouts:
request: 300s
......@@ -13,7 +13,7 @@ use std::time::Duration;
use dynamo_llm::kv_router::{protocols::*, publisher::KvEventPublisher};
use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_runtime::discovery::DiscoveryQuery;
use dynamo_runtime::discovery::{DiscoveryQuery, hash_pod_name};
use dynamo_runtime::{DistributedRuntime, Worker};
use dynamo_runtime::Runtime;
......@@ -24,6 +24,8 @@ use dynamo_llm::kv_router::protocols::WorkerWithDpRank;
use dynamo_llm::kv_router::{KvRouter, PrefillRouter, RouterConfigOverride};
use dynamo_runtime::pipeline::RouterMode;
use std::collections::HashSet;
static WK: OnceCell<Worker> = OnceCell::new();
static DRT: AsyncOnceCell<DistributedRuntime> = AsyncOnceCell::new();
// [FIXME] shouldn't the publisher be instance passing between API calls?
......@@ -425,6 +427,8 @@ pub struct RouterHandles {
impl RouterHandles {
/// Query optimal prefill worker for a request.
///
/// When `allowed_worker_ids` is Some, only workers in that set are considered.
/// Returns worker_id on success.
async fn query_prefill_worker(
&self,
......@@ -433,6 +437,7 @@ impl RouterHandles {
update_states: bool,
lora_name: Option<String>,
priority_jump: f64,
allowed_worker_ids: Option<HashSet<WorkerId>>,
) -> Result<u64, QueryRouterResult> {
self.prefill_router
.query_prefill_worker(
......@@ -441,6 +446,7 @@ impl RouterHandles {
update_states,
lora_name,
priority_jump,
allowed_worker_ids,
)
.await
.map(|(worker_id, _dp_rank)| worker_id)
......@@ -454,6 +460,9 @@ impl RouterHandles {
/// For disaggregated mode, set `is_disaggregated` to true to use overlap_score_weight=0
/// (since KV cache is being transferred from prefill, not reused).
///
/// When `allowed_worker_ids` is Some, only workers in that set are considered.
/// This does NOT overwrite the router's internal worker state — it only filters this decision.
///
/// Note: The C bindings are query-only and must not mutate router state during worker
/// selection. State updates require a `context_id` (request id) and are managed via the
/// explicit bookkeeping APIs (`add_request`, `mark_prefill_complete`, `free_request`).
......@@ -462,6 +471,7 @@ impl RouterHandles {
&self,
tokens: &[u32],
is_disaggregated: bool,
allowed_worker_ids: Option<HashSet<WorkerId>>,
) -> Result<(WorkerWithDpRank, u32), QueryRouterResult> {
// For decode phase in disaggregated mode, use overlap_score_weight=0
// This matches prefill_router.rs
......@@ -483,6 +493,7 @@ impl RouterHandles {
false,
None,
0.0,
allowed_worker_ids,
)
.await
.map_err(|e| {
......@@ -508,14 +519,6 @@ pub enum QueryRouterResult {
}
/// Build a `KvRouterConfig` from defaults, overridden by optional `DYN_*` environment variables.
///
/// Supported env vars (all optional — unset or empty values are ignored):
/// - `DYN_OVERLAP_SCORE_WEIGHT` — Weight for overlap score in worker selection (default: 1.0)
/// - `DYN_ROUTER_TEMPERATURE` — Temperature for worker sampling via softmax (default: 0.0)
/// - `DYN_USE_KV_EVENTS` — Use KV events for cache tracking (default: true)
/// - `DYN_ROUTER_REPLICA_SYNC` — Enable replica synchronization (default: false)
/// - `DYN_ROUTER_TRACK_ACTIVE_BLOCKS` — Track active blocks (default: true)
/// - `DYN_ROUTER_TRACK_OUTPUT_BLOCKS` — Track output blocks during generation (default: false)
fn kv_router_config_from_env() -> KvRouterConfig {
let mut cfg = KvRouterConfig::default();
......@@ -976,157 +979,261 @@ pub unsafe extern "C" fn destroy(handle: RouterHandlesPtr) {
}
}
/// Route a chat completion request in a single call.
///
/// This is the main function for EPP to route a `/v1/chat/completions` request.
/// It combines tokenization and worker selection in one call:
/// 1. Applies the chat template to the request JSON
/// 2. Tokenizes the formatted prompt
/// 3. Queries the prefill router (if disaggregated mode)
/// 4. Queries the decode router
/// 5. Returns worker IDs and token_ids
///
/// After this call, EPP should:
/// - Call `add_request()` to register the request for bookkeeping
/// - Set worker ID headers and forward to backend
/// - Call `mark_prefill_complete()` on first token
/// - Call `free_request()` when the stream ends
/// - Call `free_routing_result()` to free the result
/// Free a routing result.
///
/// # Safety
/// - `handle` must be a valid RouterHandles handle
/// - `request_json` must be a valid null-terminated C string containing JSON
/// - `out_result` must be a valid pointer
/// - `result` must be a valid pointer to a CRoutingResult previously returned by route functions
#[unsafe(no_mangle)]
pub unsafe extern "C" fn route_request(
handle: RouterHandlesPtr,
request_json: *const c_char,
out_result: *mut CRoutingResult,
) -> QueryRouterResult {
if handle.is_null() || request_json.is_null() || out_result.is_null() {
return QueryRouterResult::ErrInvalidParam;
pub unsafe extern "C" fn free_routing_result(result: *mut CRoutingResult) {
if result.is_null() {
return;
}
let handles = unsafe { &*handle };
let res = unsafe { &mut *result };
// Get preprocessor
// Free token IDs
if !res.token_ids.is_null() && res.token_count > 0 {
drop(unsafe {
Box::from_raw(std::slice::from_raw_parts_mut(
res.token_ids,
res.token_count,
))
});
res.token_ids = ptr::null_mut();
res.token_count = 0;
}
}
/// Parse a JSON request string, apply the chat template, and tokenize.
/// Returns the token IDs on success, or a `QueryRouterResult` error code.
unsafe fn preprocess_request(
handles: &RouterHandles,
request_json: *const c_char,
) -> Result<Vec<u32>, QueryRouterResult> {
let preprocessor = match &handles.preprocessor {
Some(p) => p,
None => {
tracing::error!("Preprocessor not available");
return QueryRouterResult::ErrInitFailed;
return Err(QueryRouterResult::ErrInitFailed);
}
};
let json_str = match unsafe { CStr::from_ptr(request_json) }.to_str() {
Ok(s) => s,
Err(_) => return QueryRouterResult::ErrInvalidParam,
Err(_) => return Err(QueryRouterResult::ErrInvalidParam),
};
// Parse JSON
let request: dynamo_llm::types::openai::chat_completions::NvCreateChatCompletionRequest =
match serde_json::from_str(json_str) {
Ok(req) => req,
Err(e) => {
tracing::error!(error = ?e, "Failed to parse request JSON");
return QueryRouterResult::ErrInvalidParam;
return Err(QueryRouterResult::ErrInvalidParam);
}
};
// Apply chat template
let formatted_prompt = match preprocessor.apply_template(&request) {
Ok(Some(prompt)) => prompt,
Ok(None) => String::new(),
Err(e) => {
tracing::error!(error = ?e, "Failed to apply chat template");
return QueryRouterResult::ErrQueryFailed;
return Err(QueryRouterResult::ErrQueryFailed);
}
};
// Tokenize
let encoding = match preprocessor.tokenize(&formatted_prompt) {
Ok(enc) => enc,
Err(e) => {
tracing::error!(error = ?e, "Failed to tokenize");
return QueryRouterResult::ErrQueryFailed;
return Err(QueryRouterResult::ErrQueryFailed);
}
};
let tokens = encoding.token_ids();
let token_count = tokens.len();
let is_disaggregated = handles.prefill_router.is_activated();
Ok(encoding.token_ids().to_vec())
}
// Query workers
let result = handles.runtime.secondary().block_on(async {
let prefill_worker_id = if is_disaggregated {
handles
.query_prefill_worker(tokens, None, false, None, 0.0)
.await?
} else {
0
};
/// Parse pods JSON into an optional set of allowed worker IDs.
unsafe fn parse_pods_filter(pods_json: *const c_char) -> Option<HashSet<WorkerId>> {
if pods_json.is_null() {
return None;
}
match unsafe { CStr::from_ptr(pods_json) }.to_str() {
Ok(s) if !s.is_empty() => match serde_json::from_str::<Vec<serde_json::Value>>(s) {
Ok(pods) => {
let mut worker_ids = HashSet::new();
for pod in &pods {
let pod_name = pod
.get("pod")
.and_then(|p| p.get("podName"))
.or_else(|| pod.get("podName"))
.and_then(|v| v.as_str());
if let Some(name) = pod_name {
let worker_id = hash_pod_name(name);
tracing::debug!(
pod_name = name,
worker_id = format!("{:x}", worker_id),
"Mapped EPP pod to worker_id"
);
worker_ids.insert(worker_id);
}
}
tracing::info!(
pod_count = pods.len(),
unique_worker_ids = worker_ids.len(),
"Parsed EPP pods into allowed_worker_ids filter"
);
if worker_ids.is_empty() {
None
} else {
Some(worker_ids)
}
}
Err(e) => {
tracing::error!(error = ?e, "Failed to parse pods JSON");
None
}
},
_ => None,
}
}
let (decode_worker, _overlap_blocks) = handles
.query_decode_worker(tokens, is_disaggregated)
/// Write token IDs into a `CRoutingResult`, transferring ownership to the caller.
fn write_tokens_to_result(tokens: &[u32], out: &mut CRoutingResult) {
let token_vec: Vec<u32> = tokens.to_vec();
let mut tokens_boxed = token_vec.into_boxed_slice();
out.token_ids = tokens_boxed.as_mut_ptr();
out.token_count = tokens.len();
std::mem::forget(tokens_boxed);
}
/// Route a request to select the best **prefill** worker only.
///
/// This is used in disaggregated mode where the EPP runs separate prefill and decode
/// scoring profiles. It tokenizes the request and queries only the prefill router.
///
/// The returned `CRoutingResult` contains:
/// - `prefill_worker_id`: the selected prefill worker
/// - `decode_worker_id`: 0 (unused — decode is handled by `route_decode_request`)
/// - `is_disaggregated`: always true (this function is only called in disagg mode)
/// - `token_ids` / `token_count`: the tokenized request (caller must free via `free_routing_result`)
///
/// # Safety
/// - `handle` must be a valid RouterHandles handle
/// - `request_json` must be a valid null-terminated C string containing JSON
/// - `pods_json` must be a valid null-terminated C string containing JSON, or null
/// - `out_result` must be a valid pointer
#[unsafe(no_mangle)]
pub unsafe extern "C" fn route_prefill_request(
handle: RouterHandlesPtr,
request_json: *const c_char,
pods_json: *const c_char,
out_result: *mut CRoutingResult,
) -> QueryRouterResult {
if handle.is_null() || request_json.is_null() || out_result.is_null() {
return QueryRouterResult::ErrInvalidParam;
}
let handles = unsafe { &*handle };
let tokens = match unsafe { preprocess_request(handles, request_json) } {
Ok(t) => t,
Err(code) => return code,
};
let allowed_worker_ids = unsafe { parse_pods_filter(pods_json) };
let result = handles.runtime.secondary().block_on(async {
let prefill_worker_id = handles
.query_prefill_worker(&tokens, None, false, None, 0.0, allowed_worker_ids)
.await?;
tracing::info!(
is_disaggregated = is_disaggregated,
prefill_worker_id = prefill_worker_id,
decode_worker_id = decode_worker.worker_id,
decode_dp_rank = decode_worker.dp_rank,
token_count = token_count,
"Routed chat request"
token_count = tokens.len(),
"Routed prefill request"
);
Ok((prefill_worker_id, decode_worker))
Ok(prefill_worker_id)
});
match result {
Ok((prefill_worker_id, decode_worker)) => {
// Allocate and copy token IDs for caller (needed for add_request bookkeeping)
let token_vec: Vec<u32> = tokens.to_vec();
let mut tokens_boxed = token_vec.into_boxed_slice();
let token_ptr = tokens_boxed.as_mut_ptr();
std::mem::forget(tokens_boxed);
unsafe {
*out_result = CRoutingResult {
is_disaggregated,
prefill_worker_id,
decode_worker_id: decode_worker.worker_id,
token_ids: token_ptr,
token_count,
};
}
Ok(prefill_worker_id) => {
let out = unsafe { &mut *out_result };
*out = CRoutingResult::default();
out.is_disaggregated = true;
out.prefill_worker_id = prefill_worker_id;
write_tokens_to_result(&tokens, out);
QueryRouterResult::Ok
}
Err(code) => code,
}
}
/// Free a routing result.
/// Route a request to select the best **decode** worker only.
///
/// This is used in both aggregated and disaggregated modes.
/// - When `is_disaggregated` is true, the decode router uses `overlap_score_weight=0`
/// (KV cache is being transferred from prefill, not reused locally).
/// - When `is_disaggregated` is false, normal KV-aware scoring is used.
///
/// The returned `CRoutingResult` contains:
/// - `decode_worker_id`: the selected decode worker
/// - `prefill_worker_id`: 0 (unused — prefill is handled by `route_prefill_request`)
/// - `is_disaggregated`: mirrors the input parameter
/// - `token_ids` / `token_count`: the tokenized request (caller must free via `free_routing_result`)
///
/// # Safety
/// - `result` must be a valid pointer to a CRoutingResult previously returned by route functions
/// - `handle` must be a valid RouterHandles handle
/// - `request_json` must be a valid null-terminated C string containing JSON
/// - `pods_json` must be a valid null-terminated C string containing JSON, or null
/// - `out_result` must be a valid pointer
#[unsafe(no_mangle)]
pub unsafe extern "C" fn free_routing_result(result: *mut CRoutingResult) {
if result.is_null() {
return;
pub unsafe extern "C" fn route_decode_request(
handle: RouterHandlesPtr,
request_json: *const c_char,
pods_json: *const c_char,
is_disaggregated: bool,
out_result: *mut CRoutingResult,
) -> QueryRouterResult {
if handle.is_null() || request_json.is_null() || out_result.is_null() {
return QueryRouterResult::ErrInvalidParam;
}
let res = unsafe { &mut *result };
let handles = unsafe { &*handle };
// Free token IDs
if !res.token_ids.is_null() && res.token_count > 0 {
drop(unsafe {
Box::from_raw(std::slice::from_raw_parts_mut(
res.token_ids,
res.token_count,
))
});
res.token_ids = ptr::null_mut();
res.token_count = 0;
let tokens = match unsafe { preprocess_request(handles, request_json) } {
Ok(t) => t,
Err(code) => return code,
};
let allowed_worker_ids = unsafe { parse_pods_filter(pods_json) };
let result = handles.runtime.secondary().block_on(async {
let (decode_worker, _overlap_blocks) = handles
.query_decode_worker(&tokens, is_disaggregated, allowed_worker_ids)
.await?;
tracing::info!(
is_disaggregated = is_disaggregated,
decode_worker_id = decode_worker.worker_id,
decode_dp_rank = decode_worker.dp_rank,
token_count = tokens.len(),
"Routed decode request"
);
Ok(decode_worker)
});
match result {
Ok(decode_worker) => {
let out = unsafe { &mut *out_result };
*out = CRoutingResult::default();
out.is_disaggregated = is_disaggregated;
out.decode_worker_id = decode_worker.worker_id;
write_tokens_to_result(&tokens, out);
QueryRouterResult::Ok
}
Err(code) => code,
}
}
......
......@@ -919,6 +919,7 @@ impl KvRouter {
update_states,
lora_name,
0.0,
None, // allowed_worker_ids not exposed in Python API yet
)
.await
.map_err(to_pyerr)?;
......
......@@ -63,6 +63,8 @@ use crate::{
local_model::runtime_config::ModelRuntimeConfig,
};
use std::collections::HashSet;
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
......@@ -351,7 +353,9 @@ impl KvRouter {
/// Give these tokens, find the worker with the best match in it's KV cache.
/// Returns the best worker (with dp_rank) and overlap amount in number of blocks.
/// Now also takes optional context_id for request tracking
/// Now also takes optional context_id for request tracking.
///
/// When `allowed_worker_ids` is Some, only workers in that set are considered for selection.
#[allow(clippy::too_many_arguments)]
pub async fn find_best_match(
&self,
......@@ -362,6 +366,7 @@ impl KvRouter {
update_states: bool,
lora_name: Option<String>,
priority_jump: f64,
allowed_worker_ids: Option<HashSet<WorkerId>>,
) -> anyhow::Result<(WorkerWithDpRank, u32)> {
let start = Instant::now();
......@@ -381,13 +386,20 @@ impl KvRouter {
});
let hash_elapsed = start.elapsed();
let overlap_scores = self
let mut overlap_scores = self
.indexer
.find_matches(block_hashes)
.instrument(tracing::info_span!("kv_router.find_matches"))
.await?;
let find_matches_elapsed = start.elapsed();
if let Some(ref allowed_ids) = allowed_worker_ids {
overlap_scores
.scores
.retain(|worker, _| allowed_ids.contains(&worker.worker_id));
}
// Compute seq_hashes only if scheduler needs it for active blocks tracking
let maybe_seq_hashes = tracing::info_span!("kv_router.compute_seq_hashes").in_scope(|| {
self.kv_router_config.compute_seq_hashes_for_tracking(
tokens,
......@@ -398,17 +410,18 @@ impl KvRouter {
});
let seq_hash_elapsed = start.elapsed();
let best_worker = self
let response = self
.scheduler
.schedule(
context_id.map(|s| s.to_string()),
isl_tokens,
maybe_seq_hashes,
overlap_scores.clone(),
overlap_scores,
router_config_override,
update_states,
lora_name,
priority_jump,
allowed_worker_ids,
)
.instrument(tracing::info_span!("kv_router.schedule"))
.await?;
......@@ -434,15 +447,7 @@ impl KvRouter {
"find_best_match completed"
);
// Note: Routing decision recording (for approximate mode) is now handled
// by KvPushRouter::generate after select_worker returns.
let overlap_amount = overlap_scores
.scores
.get(&best_worker)
.copied()
.unwrap_or(0);
Ok((best_worker, overlap_amount))
Ok((response.best_worker, response.overlap_blocks))
}
#[allow(clippy::too_many_arguments)]
......@@ -575,6 +580,7 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
true,
None,
0.0,
None,
)
.await?;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment