Unverified Commit 457db719 authored by atchernych's avatar atchernych Committed by GitHub
Browse files

fix: Fix race condition in epp prefill (#7530)


Signed-off-by: default avatarAnna Tchernych <atchernych@nvidia.com>
parent 10fb23d2
......@@ -92,14 +92,16 @@ func DynDecodeScorerFactory(name string, rawParameters json.RawMessage, handle p
return nil, fmt.Errorf("Dynamo FFI init for decode scorer failed: %w", err)
}
return NewDynDecodeScorer(handle.Context()).WithName(name), nil
enforceDisagg := getEnvBoolOrDefault("DYN_ENFORCE_DISAGG", false)
return NewDynDecodeScorer(handle.Context(), enforceDisagg).WithName(name), nil
}
// NewDynDecodeScorer initializes a new DynDecodeScorer.
func NewDynDecodeScorer(ctx context.Context) *DynDecodeScorer {
func NewDynDecodeScorer(ctx context.Context, enforceDisagg bool) *DynDecodeScorer {
return &DynDecodeScorer{
typedName: plugins.TypedName{Type: DynDecodeScorerType, Name: DynDecodeScorerType},
pluginState: plugins.NewPluginState(ctx),
enforceDisagg: enforceDisagg,
}
}
......@@ -116,6 +118,7 @@ func NewDynDecodeScorer(ctx context.Context) *DynDecodeScorer {
type DynDecodeScorer struct {
typedName plugins.TypedName
pluginState *plugins.PluginState
enforceDisagg bool
firstTokenSeen sync.Map
}
......@@ -167,13 +170,15 @@ func (s *DynDecodeScorer) Score(ctx context.Context, cycleState *schedtypes.Cycl
if isDisaggregated {
req.Headers[RoutingModeHeader] = "disaggregated"
// The prefill worker ID header was already set by DynPrefillScorer
// directly on req.Headers during the prefill profile run.
if prefillID, ok := req.Headers[PrefillWorkerIDHeader]; ok {
logger.V(logutil.DEFAULT).Info("DynDecodeScorer: prefill worker header present",
"prefillWorkerID", prefillID)
} else if s.enforceDisagg {
logger.V(logutil.DEFAULT).Error(nil,
"DynDecodeScorer: prefill worker header missing and enforce_disagg=true")
} else {
logger.V(logutil.DEFAULT).Error(nil, "DynDecodeScorer: x-prefill-instance-id header missing — DynPrefillScorer did not set it")
logger.V(logutil.DEFAULT).Error(nil,
"DynDecodeScorer: x-prefill-instance-id header missing — DynPrefillScorer did not set it")
}
} else {
req.Headers[RoutingModeHeader] = "aggregated"
......
......@@ -111,6 +111,11 @@ func (s *DynPrefillScorer) Score(ctx context.Context, cycleState *schedtypes.Cyc
result, err := dynscorer.CallRoutePrefillRequest(requestJSON, podsJSON)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "DynPrefillScorer: FFI prefill routing failed")
// Overwrite PrefillEnabled to false so the decode scorer falls back
// to aggregated routing. Without this, the prefill profile "succeeds"
// (picker picks a pod) but the prefill header is not set, causing
// the sidecar to reject the request in direct routing mode.
cycleState.Write(PrefillEnabledStateKey, &PrefillEnabledState{Enabled: false})
return uniformScores(pods, 0)
}
......@@ -128,9 +133,6 @@ func (s *DynPrefillScorer) Score(ctx context.Context, cycleState *schedtypes.Cyc
}
req.Headers[PrefillWorkerIDHeader] = prefillWorkerID
// Also write to CycleState for any plugin that needs it via the standard API.
cycleState.Write(PrefillWorkerIDStateKey, &PrefillWorkerIDState{WorkerID: prefillWorkerID})
// Score: 1.0 for all pods. The label-filter has already restricted to prefill workers,
// and the FFI router's internal selection is authoritative.
// In the future, we could match worker IDs to pod names for precise scoring.
......
......@@ -21,6 +21,8 @@ import (
"encoding/json"
"errors"
"fmt"
"os"
"strings"
log "sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
......@@ -29,6 +31,21 @@ import (
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)
func getEnvBoolOrDefault(key string, defaultVal bool) bool {
val, ok := os.LookupEnv(key)
if !ok {
return defaultVal
}
switch strings.ToLower(val) {
case "true", "1", "yes":
return true
case "false", "0", "no":
return false
default:
return defaultVal
}
}
const (
DisaggProfileHandlerType = "disagg-profile-handler"
)
......@@ -47,13 +64,15 @@ func DisaggProfileHandlerFactory(name string, rawParameters json.RawMessage, _ p
return nil, fmt.Errorf("failed to parse %s plugin parameters: %w", DisaggProfileHandlerType, err)
}
}
return NewDisaggProfileHandler().WithName(name), nil
enforceDisagg := getEnvBoolOrDefault("DYN_ENFORCE_DISAGG", false)
return NewDisaggProfileHandler(enforceDisagg).WithName(name), nil
}
// NewDisaggProfileHandler initializes a new DisaggProfileHandler.
func NewDisaggProfileHandler() *DisaggProfileHandler {
func NewDisaggProfileHandler(enforceDisagg bool) *DisaggProfileHandler {
return &DisaggProfileHandler{
typedName: plugins.TypedName{Type: DisaggProfileHandlerType, Name: DisaggProfileHandlerType},
enforceDisagg: enforceDisagg,
}
}
......@@ -95,6 +114,7 @@ func NewDisaggProfileHandler() *DisaggProfileHandler {
// disaggregated routing. If they go down, requests fall back to aggregated mode.
type DisaggProfileHandler struct {
typedName plugins.TypedName
enforceDisagg bool
}
// TypedName returns the type and name tuple of this plugin instance.
......@@ -151,10 +171,14 @@ func (h *DisaggProfileHandler) Pick(ctx context.Context, cycleState *schedtypes.
// Second call: prefill has run, now run decode.
if prefillResult, prefillDone := profileResults[PrefillProfileName]; prefillDone {
if _, decodeDone := profileResults[DecodeProfileName]; !decodeDone {
// If the prefill profile failed (nil result = no prefill pods available),
// update PrefillEnabledState to false so the decode scorer uses normal
// KV cache overlap scoring instead of disaggregated mode (overlap_score_weight=0).
if prefillResult == nil {
if h.enforceDisagg {
// enforce_disagg=true: do not fall back to aggregated mode.
// Stop the scheduling loop — ProcessResults will reject the request.
logger.Info("DisaggProfileHandler: prefill profile failed and enforce_disagg=true, rejecting request")
return map[string]*framework.SchedulerProfile{}
}
// enforce_disagg=false: fall back to aggregated decode.
logger.Info("DisaggProfileHandler: prefill profile failed (no workers?), falling back to aggregated decode")
cycleState.Write(PrefillEnabledStateKey, &PrefillEnabledState{Enabled: false})
}
......@@ -173,9 +197,23 @@ func (h *DisaggProfileHandler) Pick(ctx context.Context, cycleState *schedtypes.
// ProcessResults aggregates the profile run results and designates the primary profile.
// The "decode" profile is always the primary (the pod that handles the request).
func (h *DisaggProfileHandler) ProcessResults(_ context.Context, _ *schedtypes.CycleState, _ *schedtypes.LLMRequest,
func (h *DisaggProfileHandler) ProcessResults(_ context.Context, _ *schedtypes.CycleState, req *schedtypes.LLMRequest,
profileResults map[string]*schedtypes.ProfileRunResult) (*schedtypes.SchedulingResult, error) {
// When enforce_disagg=true and the prefill worker ID header was not set
// (prefill router not activated or scorer failed), reject the request
// at the EPP level instead of forwarding it to the sidecar without
// routing headers.
if h.enforceDisagg && (req.Headers == nil || req.Headers[PrefillWorkerIDHeader] == "") {
// Only enforce if a prefill profile was configured and ran.
if _, prefillRan := profileResults[PrefillProfileName]; prefillRan {
return nil, errors.New(
"disaggregated mode enforced (DYN_ENFORCE_DISAGG=true) but prefill workers " +
"are not available; request rejected. Either wait for prefill workers " +
"to register or set DYN_ENFORCE_DISAGG=false to allow aggregated fallback")
}
}
if len(profileResults) == 0 {
return nil, errors.New("disagg profile handler received no profile results")
}
......@@ -191,6 +229,12 @@ func (h *DisaggProfileHandler) ProcessResults(_ context.Context, _ *schedtypes.C
}
if profileResults[primaryProfile] == nil {
if h.enforceDisagg {
return nil, errors.New(
"disaggregated mode enforced (DYN_ENFORCE_DISAGG=true) but prefill workers " +
"are not available; request rejected. Either wait for prefill workers " +
"to register or set DYN_ENFORCE_DISAGG=false to allow aggregated fallback")
}
return nil, fmt.Errorf("primary profile '%s' failed to produce a result", primaryProfile)
}
......
......@@ -40,13 +40,16 @@ const (
PrefillProfileName = "prefill"
DecodeProfileName = "decode"
// PrefillEnabledStateKey is used to communicate prefill-enabled status
// from the DisaggProfileHandler to the scorer plugins via CycleState.
// PrefillEnabledStateKey tracks whether this request should use disaggregated routing.
// Initially set to true by DisaggProfileHandler.Pick() if a "prefill" scheduling
// profile exists in the EPP config. Overwritten to false per-request in two cases:
// - DisaggProfileHandler.Pick(): prefill profile result is nil (no prefill pods
// passed the label-filter).
// - DynPrefillScorer.Score(): prefill FFI routing failed (prefill router not yet
// activated, e.g., worker registered in K8s but not yet in Dynamo discovery).
// The decode scorer reads this to decide whether to use overlap_score_weight=0
// (disaggregated) or normal KV cache overlap scoring (aggregated).
PrefillEnabledStateKey = plugins.StateKey("disagg-prefill-enabled")
// PrefillWorkerIDStateKey communicates the prefill worker ID selected by
// DynPrefillScorer to DynDecodeScorer so it can set the x-prefill-instance-id header.
PrefillWorkerIDStateKey = plugins.StateKey("disagg-prefill-worker-id")
)
// PrefillEnabledState stores whether prefill is enabled for the current scheduling cycle.
......@@ -60,17 +63,6 @@ func (s *PrefillEnabledState) Clone() plugins.StateData {
return &PrefillEnabledState{Enabled: s.Enabled}
}
// PrefillWorkerIDState stores the prefill worker ID selected by DynPrefillScorer.
// Written by DynPrefillScorer, read by DynDecodeScorer to set the header.
type PrefillWorkerIDState struct {
WorkerID string
}
// Clone implements plugins.StateData.
func (s *PrefillWorkerIDState) Clone() plugins.StateData {
return &PrefillWorkerIDState{WorkerID: s.WorkerID}
}
// readPrefillEnabled reads the PrefillEnabledState from CycleState.
// Returns false if the state is not found or not set.
func readPrefillEnabled(cycleState *schedtypes.CycleState) bool {
......
......@@ -251,7 +251,9 @@ To disable the EPP from listening for KV events (e.g., when prefix caching is of
3. **Optionally** set `DYN_OVERLAP_SCORE_WEIGHT=0` on the EPP to skip prefix-overlap scoring altogether, making the router select workers based on load only.
- Set `DYN_BUSY_THRESHOLD` to configure the upper bound on how "full" a worker can be (often derived from kv_active_blocks or other load metrics) before the router skips it. If the selected worker exceeds this value, routing falls back to the next best candidate. By default the value is negative meaning this is not enabled.
- Set `DYN_ENFORCE_DISAGG=true` to strictly enforce disaggregated mode. When enabled, requests fail if prefill workers have not registered yet. Without this, requests arriving before prefill workers are discovered fall through to decode-only routing. Prefill errors always fail requests regardless of this setting.
- Set `DYN_ENFORCE_DISAGG=true` (default: `false`) to control per-request behavior when prefill workers are unavailable:
- **`true` (recommended for disaggregated serving):** Requests fail with an error if prefill workers are not available. Use this when disaggregated serving is required and aggregated fallback is not acceptable.
- **`false` (default):** Requests gracefully fall back to aggregated mode (skip prefill, route directly to decode) when prefill workers are not available. When prefill workers appear later, subsequent requests automatically use disaggregated routing.
- Set `DYN_OVERLAP_SCORE_WEIGHT` to weigh how heavily the score uses token overlap (predicted KV cache hits) versus other factors (load, historical hit rate). Higher weight biases toward reusing workers with similar cached prefixes. (default: 1)
- Set `DYN_ROUTER_TEMPERATURE` to soften or sharpen the selection curve when combining scores. Low temperature makes the router pick the top candidate deterministically; higher temperature lets lower-scoring workers through more often (exploration).
- `DYN_ROUTER_TEMPERATURE` — Temperature for worker sampling via softmax (default: 0.0)
......
......@@ -762,22 +762,16 @@ pub unsafe extern "C" fn create_routers(
}
}
// Create PrefillRouter based on one-time discovery of prefill workers
// Auto-detects disaggregated mode by checking if prefill workers are present
// The prefill workers have to be created before the epp is created.
// Given that we wait first for the decode worker to show up it is reasonable to assume the prefill will be up as well.
let prefill_router = match find_prefill_endpoint(&drt, &namespace_str).await {
Some(prefill_endpoint) => {
tracing::info!("Prefill worker found, running in disaggregated mode");
// Create PrefillRouter with a pending activation channel.
// A background task watches discovery for prefill workers and activates
// the router when one appears. Before activation, requests gracefully
// fallback to decode-only routing.
let mut prefill_config = kv_router_config;
prefill_config.router_track_active_blocks = false;
// Create immediately-resolved channel to activate router
let (tx, rx) = tokio::sync::oneshot::channel();
let _ = tx.send(prefill_endpoint);
PrefillRouter::new(
rx,
let (prefill_tx, prefill_rx) = tokio::sync::oneshot::channel();
let prefill_router = PrefillRouter::new(
prefill_rx,
model_manager.clone(),
RouterMode::KV,
block_size,
......@@ -786,19 +780,13 @@ pub unsafe extern "C" fn create_routers(
model_name.clone(),
actual_namespace.clone(),
enable_eagle,
)
}
None if enforce_disagg => {
tracing::error!(
"Prefill workers required but none found (enforce_disagg is enabled)"
);
return Err(QueryRouterResult::ErrDisaggEnforced);
}
None => {
tracing::info!("No prefill workers found, running in aggregated mode");
PrefillRouter::disabled(model_manager.clone(), RouterMode::KV, enforce_disagg)
}
};
// Spawn background discovery watcher for prefill workers.
// Polls discovery until a prefill-only worker appears in the same
// rolling-update namespace, then sends its endpoint through the channel
// to activate the PrefillRouter.
spawn_prefill_discovery_watcher(drt.clone(), actual_namespace.clone(), prefill_tx);
Ok((
prefill_router,
......@@ -1354,6 +1342,74 @@ async fn init_preprocessor(
Ok(bootstrap)
}
/// Spawn a background task that watches discovery for a prefill-only worker
/// in the given namespace. When found, sends its endpoint through `tx` to
/// activate the PrefillRouter. Polls every 1 second until a match is found.
fn spawn_prefill_discovery_watcher(
drt: DistributedRuntime,
target_namespace: String,
tx: tokio::sync::oneshot::Sender<dynamo_runtime::component::Endpoint>,
) {
use dynamo_llm::model_card::ModelDeploymentCard;
use dynamo_runtime::discovery::DiscoveryInstance;
tokio::spawn(async move {
let discovery = drt.discovery();
tracing::info!(
namespace = target_namespace,
"Background task: watching for prefill workers to register..."
);
loop {
if let Ok(instances) = discovery.list(DiscoveryQuery::AllModels).await {
for instance in instances {
if let DiscoveryInstance::Model {
namespace,
component,
endpoint,
..
} = &instance
{
if namespace != &target_namespace {
continue;
}
let card = match instance.deserialize_model::<ModelDeploymentCard>() {
Ok(card) => card,
Err(_) => continue,
};
if !card.model_type.supports_prefill()
|| card.model_type.supports_chat()
|| card.model_type.supports_completions()
{
continue;
}
tracing::info!(
model_name = card.name(),
namespace = namespace.as_str(),
"Prefill worker discovered, activating PrefillRouter"
);
if let Ok(ns) = drt.namespace(namespace)
&& let Ok(comp) = ns.component(component)
{
let ep = comp.endpoint(endpoint);
if tx.send(ep).is_err() {
tracing::debug!("PrefillRouter activation channel already closed");
}
return;
}
}
}
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
});
}
/// Fetch model card via discovery and create preprocessor.
///
/// This function:
......@@ -1431,60 +1487,3 @@ async fn fetch_preprocessor_from_discovery(
actual_namespace,
})
}
/// Find a prefill endpoint from already-discovered instances (one-time filter).
/// Returns the endpoint if a prefill worker is found in the target namespace.
async fn find_prefill_endpoint(
drt: &DistributedRuntime,
target_namespace: &str,
) -> Option<dynamo_runtime::component::Endpoint> {
use dynamo_llm::model_card::ModelDeploymentCard;
use dynamo_runtime::discovery::DiscoveryInstance;
let discovery = drt.discovery();
let instances = match discovery.list(DiscoveryQuery::AllModels).await {
Ok(instances) => instances,
Err(e) => {
tracing::warn!(error = %e, "Failed to list instances for prefill discovery");
return None;
}
};
for instance in instances {
if let DiscoveryInstance::Model {
namespace,
component,
endpoint,
..
} = &instance
{
if !namespace.starts_with(target_namespace) {
continue;
}
let card = match instance.deserialize_model::<ModelDeploymentCard>() {
Ok(card) => card,
Err(_) => continue,
};
// Only handle prefill models
if !card.model_type.supports_prefill() {
continue;
}
tracing::info!(
model_name = card.name(),
"Prefill worker found in discovered instances"
);
// Build and return the endpoint
if let Ok(ns) = drt.namespace(namespace)
&& let Ok(comp) = ns.component(component)
{
return Some(comp.endpoint(endpoint));
}
}
}
None
}
......@@ -309,6 +309,11 @@ impl PrefillRouter {
pub fn is_activated(&self) -> bool {
self.prefill_router.get().is_some()
}
/// Whether disaggregated mode is strictly enforced (fail if no prefill workers).
pub fn enforce_disagg(&self) -> bool {
self.enforce_disagg
}
}
pub(super) fn link_child_context<T: Send + Sync + 'static>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment