Unverified Commit 0980b27f authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: mockers with bootstrap optimization (sglang testing) + CI test (#5121)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent cd8dddee
......@@ -108,12 +108,15 @@ def create_temp_engine_args_file(args) -> Path:
"speedup_ratio": getattr(args, "speedup_ratio", None),
"dp_size": getattr(args, "dp_size", None),
"startup_time": getattr(args, "startup_time", None),
"planner_profile_data": str(getattr(args, "planner_profile_data", None))
if getattr(args, "planner_profile_data", None)
else None,
"planner_profile_data": (
str(getattr(args, "planner_profile_data", None))
if getattr(args, "planner_profile_data", None)
else None
),
"is_prefill": getattr(args, "is_prefill_worker", None),
"is_decode": getattr(args, "is_decode_worker", None),
"enable_local_indexer": getattr(args, "enable_local_indexer", None),
# Note: bootstrap_port is NOT included here - it's set per-worker in launch_workers()
}
# Remove None values to only include explicitly set arguments
......@@ -142,6 +145,13 @@ def validate_worker_type_args(args):
)
def parse_bootstrap_ports(ports_str: str | None) -> list[int]:
"""Parse comma-separated bootstrap ports string into list of integers."""
if not ports_str:
return []
return [int(p.strip()) for p in ports_str.split(",")]
def parse_args():
parser = argparse.ArgumentParser(
description="Mocker engine for testing Dynamo LLM infrastructure with vLLM-style CLI.",
......@@ -291,6 +301,15 @@ def parse_args():
default=False,
help="Enable worker-local KV indexer for tracking this worker's own KV cache state (default: False)",
)
parser.add_argument(
"--bootstrap-ports",
type=str,
default=None,
help="Comma-separated list of bootstrap ports for disaggregated serving rendezvous. "
"One port per worker (must match --num-workers). "
"Prefill workers listen on these ports; decode workers connect to them. "
"If not specified, bootstrap rendezvous is disabled.",
)
parser.add_argument(
"--store-kv",
type=str,
......@@ -313,6 +332,15 @@ def parse_args():
if args.num_workers < 1:
raise ValueError(f"--num-workers must be at least 1, got {args.num_workers}")
# Parse and validate bootstrap_ports
args.bootstrap_ports_list = parse_bootstrap_ports(args.bootstrap_ports)
if args.bootstrap_ports_list:
if len(args.bootstrap_ports_list) != args.num_workers:
raise ValueError(
f"--bootstrap-ports must have exactly --num-workers ({args.num_workers}) ports, "
f"got {len(args.bootstrap_ports_list)}: {args.bootstrap_ports_list}"
)
# Set endpoint default based on worker type if not explicitly provided
if args.endpoint is None:
if args.is_prefill_worker:
......
......@@ -5,9 +5,12 @@
# Now supports vLLM-style individual arguments for MockEngineArgs
import asyncio
import json
import logging
import os
import signal
import tempfile
from pathlib import Path
import uvloop
......@@ -85,6 +88,13 @@ async def launch_workers(args, extra_engine_args_path):
loop = asyncio.get_running_loop()
futures = []
runtimes = []
per_worker_temp_files: list[Path] = []
# Load base engine args if we need to create per-worker files with bootstrap_port
base_engine_args = None
if args.bootstrap_ports_list:
with open(extra_engine_args_path) as f:
base_engine_args = json.load(f)
for worker_id in range(args.num_workers):
logger.info(f"Creating mocker worker {worker_id + 1}/{args.num_workers}")
......@@ -93,13 +103,30 @@ async def launch_workers(args, extra_engine_args_path):
runtime = DistributedRuntime(loop, args.store_kv, args.request_plane)
runtimes.append(runtime)
# Determine which engine args file to use
if args.bootstrap_ports_list:
# Create per-worker temp file with this worker's bootstrap_port
worker_args = base_engine_args.copy()
worker_args["bootstrap_port"] = args.bootstrap_ports_list[worker_id]
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as f:
json.dump(worker_args, f)
worker_engine_args_path = Path(f.name)
per_worker_temp_files.append(worker_engine_args_path)
logger.debug(
f"Worker {worker_id}: using bootstrap_port {args.bootstrap_ports_list[worker_id]}"
)
else:
worker_engine_args_path = extra_engine_args_path
# Create EntrypointArgs for this worker
entrypoint_args = EntrypointArgs(
engine_type=EngineType.Mocker,
model_path=args.model_path,
model_name=args.model_name,
endpoint_id=args.endpoint,
extra_engine_args=extra_engine_args_path,
extra_engine_args=worker_engine_args_path,
is_prefill=args.is_prefill_worker,
)
......@@ -130,6 +157,13 @@ async def launch_workers(args, extra_engine_args_path):
for runtime in runtimes:
runtime.shutdown()
# Clean up per-worker temp files
for temp_file in per_worker_temp_files:
try:
temp_file.unlink()
except Exception:
pass
def main():
uvloop.run(worker())
......
......@@ -96,21 +96,6 @@ class PrefillWorkerHandler(BaseWorkerHandler):
bootstrap_room = self._generate_bootstrap_room()
logging.debug(f"Generated bootstrap_room locally: {bootstrap_room}")
bootstrap_info = {
"bootstrap_host": self.bootstrap_host,
"bootstrap_port": self.bootstrap_port,
"bootstrap_room": bootstrap_room,
}
# Yield in LLMEngineOutput format for PrefillRouter compatibility
# The disaggregated_params field contains the bootstrap info
yield {
"token_ids": [],
"text": None,
"finish_reason": None,
"disaggregated_params": bootstrap_info,
}
input_param = self._get_input_param(inner_request)
# Propagate trace context to SGLang
......
......@@ -1080,6 +1080,9 @@ pub fn add_query_instance_id(
///
/// For disaggregated mode: sets `prefill_worker_id` and `decode_worker_id`
/// For aggregated mode: sets `backend_instance_id` (when both IDs are the same)
///
/// Also sets `enable_local_updates: false` since the external caller (EPP/GAIE)
/// will handle bookkeeping via C FFI functions.
pub fn set_worker_ids_for_stage2(
request: &mut NvCreateChatCompletionRequest,
decode_worker_id: Option<i64>,
......@@ -1091,6 +1094,9 @@ pub fn set_worker_ids_for_stage2(
.expect("NvExt builder should not fail")
});
// Disable local updates - external caller handles bookkeeping via C FFI
nvext.enable_local_updates = Some(false);
// Check if this is aggregated mode (same worker for both)
let is_aggregated = prefill_worker_id == decode_worker_id;
......
......@@ -354,71 +354,76 @@ impl KvRouter {
tracing::info!("Worker query client initialized");
// Start KV event subscriber background process (only when use_kv_events is enabled)
// This is spawned as a background task to avoid blocking router startup.
// The task waits for runtime_configs to determine whether to use NATS Core or JetStream.
// We block here until at least one worker runtime config is registered,
// then spawn the subscriber. This ensures the router is ready before accepting requests.
if kv_router_config.use_kv_events
&& let Indexer::KvIndexer(ref kv_indexer) = indexer
{
// Clone everything needed for the background task
let component_clone = component.clone();
let kv_indexer_clone = kv_indexer.clone();
let cancellation_token_clone = cancellation_token.clone();
let mut runtime_configs_rx_clone = runtime_configs_rx.clone();
let worker_query_client_clone =
worker_query::WorkerQueryClient::new(component.clone(), runtime_configs_rx.clone());
tokio::spawn(async move {
// Wait for runtime_configs to have at least one entry
let (all_local_indexer, count) = loop {
{
let configs = runtime_configs_rx_clone.borrow();
if !configs.is_empty() {
let all_local_indexer =
configs.values().all(|c| c.enable_local_indexer);
break (all_local_indexer, configs.len());
}
// Wait for at least one worker runtime config to be registered
tracing::info!("Waiting for at least one worker runtime config to be registered...");
let (all_local_indexer, count) = loop {
{
let configs = runtime_configs_rx_clone.borrow();
if !configs.is_empty() {
let all_local_indexer = configs.values().all(|c| c.enable_local_indexer);
break (all_local_indexer, configs.len());
}
}
// Wait for changes to runtime_configs
tokio::select! {
_ = cancellation_token_clone.cancelled() => {
tracing::debug!("Subscriber selection task cancelled");
return;
}
result = runtime_configs_rx_clone.changed() => {
if result.is_err() {
tracing::debug!("Runtime configs channel closed");
return;
}
// Wait for changes to runtime_configs
tokio::select! {
_ = cancellation_token.cancelled() => {
tracing::debug!("KvRouter startup cancelled while waiting for workers");
anyhow::bail!("KvRouter startup cancelled");
}
result = runtime_configs_rx_clone.changed() => {
if result.is_err() {
tracing::debug!("Runtime configs channel closed");
anyhow::bail!("Runtime configs channel closed before any workers registered");
}
}
};
}
};
tracing::info!("Found {count} worker runtime config(s), starting KV event subscriber");
if all_local_indexer {
// All workers have local_indexer enabled - use NATS Core
tracing::info!(
"All {count} workers have local_indexer enabled, using NATS Core subscription"
);
// Clone everything needed for the background subscriber task
let component_clone = component.clone();
let kv_indexer_clone = kv_indexer.clone();
let cancellation_token_clone = cancellation_token.clone();
let worker_query_client_clone =
worker_query::WorkerQueryClient::new(component.clone(), runtime_configs_rx.clone());
// Spawn subscriber as background task (long-running)
if all_local_indexer {
// All workers have local_indexer enabled - use NATS Core
tracing::info!(
"All {count} workers have local_indexer enabled, using NATS Core subscription"
);
tokio::spawn(async move {
if let Err(e) = start_kv_router_background_nats_core(
component_clone.clone(),
component_clone,
kv_indexer_clone.event_sender(),
kv_indexer_clone.remove_worker_sender(),
cancellation_token_clone.clone(),
cancellation_token_clone,
worker_query_client_clone,
)
.await
{
tracing::error!("Failed to start NATS Core subscriber: {e}");
}
} else {
// Not all workers have local_indexer - use JetStream
tracing::info!(
"Not all workers have local_indexer enabled, using JetStream subscription"
);
});
} else {
// Not all workers have local_indexer - use JetStream
tracing::info!(
"Not all workers have local_indexer enabled, using JetStream subscription"
);
tokio::spawn(async move {
if let Err(e) = start_kv_router_background(
component_clone.clone(),
component_clone,
consumer_id,
kv_indexer_clone.event_sender(),
kv_indexer_clone.remove_worker_sender(),
......@@ -428,7 +433,7 @@ impl KvRouter {
kv_router_config
.router_snapshot_threshold
.map(|_| kv_indexer_clone.snapshot_event_sender()),
cancellation_token_clone.clone(),
cancellation_token_clone,
kv_router_config.router_snapshot_threshold,
kv_router_config.router_reset_states,
)
......@@ -436,8 +441,8 @@ impl KvRouter {
{
tracing::error!("Failed to start JetStream subscriber: {e}");
}
}
});
});
}
}
tracing::info!("KV Routing initialized");
......@@ -815,17 +820,12 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let is_query_only = request.get_annotation_value("query_instance_id").is_some();
// Determine if this router should handle local state updates (add_request, free, etc.)
// Only skip local updates for GAIE Stage 2: when BOTH prefill and decode worker IDs
// are externally specified (indicates external orchestrator handles tracking).
// For internal routing (e.g., bootstrap optimization with only prefill_worker_id set),
// we still handle updates locally.
let routing = request.routing.as_ref();
let handle_local_updates = routing
.map(|r| {
// GAIE Stage 2 sets both worker IDs - external caller handles tracking
// All other cases (including backend_instance_id for routing) - we handle locally
r.prefill_worker_id.is_none() || r.decode_worker_id.is_none()
})
// Default is true (router handles bookkeeping). Set to false for GAIE Stage 2 where
// an external orchestrator (e.g., EPP sidecar) handles bookkeeping via C FFI.
let handle_local_updates = request
.routing
.as_ref()
.and_then(|r| r.enable_local_updates)
.unwrap_or(true);
// Get phase from tracker (defaults to Aggregated if no tracker or phase not set)
......@@ -917,9 +917,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let stream_context = response_stream.context();
let context_for_monitoring = stream_context.clone();
// TODO: When handle_local_updates=false, consider moving mark_prefill_completed
// to an external caller (e.g., sidecar) if they support a first-token hook.
// Currently mark_prefill_completed is called here for all flows.
// Wrap stream with lifecycle management (mark_prefill_completed, free)
// Only perform these operations if handle_local_updates is true.
// When false, an external caller (e.g., GAIE sidecar) handles bookkeeping via C FFI.
let wrapped_stream = Box::pin(async_stream::stream! {
let mut prefill_marked = false;
......@@ -937,7 +937,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
break;
};
if !prefill_marked {
if handle_local_updates && !prefill_marked {
// Only mark prefill completed when we receive actual tokens,
// not empty bootstrap info (token_ids: []) from disaggregated prefill
let has_tokens = item.data.as_ref()
......@@ -956,8 +956,11 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
}
}
// Always call free() - it's idempotent and safe even if already freed or never added
if let Err(e) = chooser.free(&context_id).await {
// Only call free() if we handle local updates.
// When handle_local_updates=false, external caller handles cleanup via C FFI.
if handle_local_updates
&& let Err(e) = chooser.free(&context_id).await
{
tracing::warn!("Failed to free request {context_id}: {e}");
}
});
......
......@@ -6,7 +6,7 @@ use std::sync::{Arc, OnceLock};
use anyhow::Result;
use futures::StreamExt;
use rand::Rng;
use tokio::sync::oneshot;
use tokio::sync::{OwnedSemaphorePermit, oneshot};
use tokio_util::sync::CancellationToken;
use dynamo_runtime::{
......@@ -24,7 +24,6 @@ use crate::{
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
protocols::common::preprocessor::{BootstrapInfo, PrefillResult},
protocols::common::timing::{RequestPhase, RequestTracker},
protocols::openai::nvext::WorkerIdInfo,
};
/// Errors that can occur during prefill routing
......@@ -85,10 +84,10 @@ impl InnerPrefillRouter {
/// It optionally calls a prefill worker before routing to decode, extracting disaggregated_params
/// from the prefill response and injecting them into the decode request.
///
/// Supports regular Dynamo and GAIE integrated mode via query_instance_id state machine:
/// - GAIE Stage 1: query_instance_id transitions "" -> "prefill" -> "decode", returns only worker IDs
/// - GAIE Stage 2: routing.prefill_worker_id/routing.decode_worker_id are set, full execution with specified workers
/// - Non-GAIE: like GAIE Stage 2 but the worker ids have to be determined.
/// Modes:
/// - Query-only: `query_instance_id` annotation present → returns worker IDs without execution
/// - Pre-routed: `prefill_worker_id`/`decode_worker_id` set → routes to specified workers
/// - Normal: Worker IDs determined by router based on KV cache state
pub struct PrefillRouter {
prefill_router: OnceLock<InnerPrefillRouter>,
model_manager: Arc<ModelManager>,
......@@ -232,11 +231,6 @@ impl PrefillRouter {
Ok(())
}
/// Generate a unique bootstrap room ID for disaggregated serving
fn generate_bootstrap_room() -> u64 {
rand::rng().random()
}
/// Build bootstrap_info for disaggregated serving
/// If preselected_worker is provided (GAIE Stage 2), use it directly.
/// Otherwise, query for the best worker (KV mode) or select next worker (non-KV modes).
......@@ -250,7 +244,6 @@ impl PrefillRouter {
// Worker selection
let (worker_id, dp_rank) = if let Some(id) = preselected_worker {
// GAIE Stage 2: use pre-selected worker
let dp_rank = req.routing.as_ref().and_then(|r| r.dp_rank).unwrap_or(0);
tracing::debug!(
worker_id = id,
......@@ -285,7 +278,7 @@ impl PrefillRouter {
let host = endpoint.bootstrap_host?;
let port = endpoint.bootstrap_port?;
let bootstrap_room = Self::generate_bootstrap_room();
let bootstrap_room: u64 = rand::rng().random();
tracing::info!(
worker_id = worker_id,
......@@ -308,12 +301,18 @@ impl PrefillRouter {
))
}
/// Execute prefill with the given router and extract structured result
/// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization)
/// Execute prefill with the given router and extract structured result.
///
/// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization).
///
/// If `phase_permit` is provided, it is dropped after the first output is received,
/// allowing subsequent `set_phase` calls to proceed. This is used in the bootstrap
/// optimization path to ensure `record_worker` completes before the phase changes.
async fn execute_prefill(
router: Option<InnerPrefillRouter>,
request: SingleIn<PreprocessedRequest>,
target_worker: Option<u64>,
phase_permit: Option<OwnedSemaphorePermit>,
) -> Result<(PrefillResult, Option<u64>), PrefillError> {
let router = router.ok_or(PrefillError::NotActivated)?;
let mut prefill_response = router
......@@ -321,6 +320,10 @@ impl PrefillRouter {
.await
.map_err(|e| PrefillError::PrefillError(e.to_string()))?;
// Drop phase permit now - routing is complete, record_worker was called in select_worker.
// This unblocks set_phase(Decode) in the main task without waiting for prefill output.
drop(phase_permit);
let Some(first_output) = prefill_response.next().await else {
return Err(PrefillError::PrefillError(
"Prefill router returned no output (stream ended)".to_string(),
......@@ -379,17 +382,24 @@ impl PrefillRouter {
))
}
/// Spawn prefill as a background task
/// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization)
/// Spawn prefill as a background task.
///
/// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization).
///
/// The `phase_permit` is passed to the spawned task and dropped after the first output,
/// allowing the main task's `set_phase(Decode)` to proceed.
fn spawn_prefill_task(
&self,
prefill_request: SingleIn<PreprocessedRequest>,
target_worker: Option<u64>,
phase_permit: OwnedSemaphorePermit,
) {
let router = self.prefill_router.get().cloned();
tokio::spawn(async move {
match Self::execute_prefill(router, prefill_request, target_worker).await {
match Self::execute_prefill(router, prefill_request, target_worker, Some(phase_permit))
.await
{
Ok(_) => {
tracing::debug!("Prefill background task completed");
}
......@@ -400,67 +410,17 @@ impl PrefillRouter {
});
}
/// Call the prefill router and extract structured prefill result and worker ID
/// Call the prefill router and extract structured prefill result and worker ID.
///
/// This is the synchronous prefill path - we wait for prefill to complete before proceeding.
/// No phase permit is needed since `record_worker` completes before we return.
async fn call_prefill(
&self,
request: SingleIn<PreprocessedRequest>,
) -> Result<(PrefillResult, Option<u64>), PrefillError> {
// For call_prefill path, routing is handled by the router itself (no direct routing needed)
Self::execute_prefill(self.prefill_router.get().cloned(), request, None).await
}
}
/// GAIE helper functions for preparing prefill requests
impl PrefillRouter {
/// Prepare prefill request for GAIE flows
/// - Stage 1: Sets query_instance_id:prefill annotation
/// - Stage 2: Sets backend_instance_id to target prefill worker
fn prepare_prefill_for_gaie(prefill_req: &mut PreprocessedRequest, is_gaie_stage1: bool) {
if is_gaie_stage1 {
// GAIE Stage 1: Set query_instance_id to "prefill" for prefill worker selection
prefill_req
.annotations
.retain(|a| !a.starts_with("query_instance_id"));
prefill_req
.annotations
.push(format!("query_instance_id:{}", RequestPhase::Prefill));
} else if let Some(prefill_worker_id) = prefill_req
.routing
.as_ref()
.and_then(|r| r.prefill_worker_id)
{
// GAIE Stage 2: Route to pre-selected prefill worker from the stage 1
tracing::debug!(
prefill_worker_id = prefill_worker_id,
"GAIE Stage 2: Routing prefill to pre-selected worker"
);
prefill_req.routing_mut().backend_instance_id = Some(prefill_worker_id);
}
}
/// Prepare decode request for GAIE Stage 1
/// Extracts prefill_worker_id from prefill result and sets decode annotations
fn prepare_decode_for_gaie_stage1(
decode_req: &mut PreprocessedRequest,
prefill_result: &PrefillResult,
) {
let prefill_worker_id = prefill_result
.disaggregated_params
.get("worker_id")
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok())
.and_then(|info| info.prefill_worker_id);
if let Some(worker_id) = prefill_worker_id {
decode_req
.annotations
.retain(|a| !a.starts_with("query_instance_id"));
decode_req
.annotations
.push(format!("query_instance_id:{}", RequestPhase::Decode));
decode_req
.annotations
.push(format!("prefill_worker_id:{worker_id}"));
}
// No phase permit needed - we wait for completion before changing phase
Self::execute_prefill(self.prefill_router.get().cloned(), request, None, None).await
}
}
......@@ -490,22 +450,14 @@ impl
let request_id = context.id().to_string();
let engine_ctx = context.context();
// GAIE Stage 1: the presence of the empty query_instance_id signals query-only mode
// State machine: "" -> "prefill" -> "decode" (disagg) OR "" -> aggregated worker (agg fallback)
let is_gaie_stage1 = req
.get_annotation_value("query_instance_id")
.is_some_and(|s| s.is_empty());
// Save original max_tokens for decode
let original_max_tokens = req.stop_conditions.max_tokens;
// GAIE Stage 1: Check if prefill router is activated - if not, skip to decode
if is_gaie_stage1 && self.prefill_router.get().is_none() {
tracing::debug!("GAIE Stage 1: Prefill router not activated, skipping to decode");
// If prefill router is not activated, skip directly to decode
if self.prefill_router.get().is_none() {
if self.enforce_disagg {
return Err(anyhow::anyhow!(PrefillError::NotActivated));
}
// Fall back to decode-only
return next.generate(context.map(|_| req)).await;
}
......@@ -515,47 +467,45 @@ impl
req.tracker = Some(Arc::new(RequestTracker::new()));
}
let tracker = req.tracker.as_ref().unwrap();
tracker.set_phase(RequestPhase::Prefill);
let prefill_phase_permit = tracker.set_phase(RequestPhase::Prefill).await;
tracker.record_prefill_start();
// Prepare prefill request with max_tokens = 1 (clone after tracker is set)
let mut prefill_req = req.clone();
prefill_req.stop_conditions.max_tokens = Some(1);
// Prepare prefill request for GAIE flows (Stage 1 or Stage 2)
Self::prepare_prefill_for_gaie(&mut prefill_req, is_gaie_stage1);
// Try build_bootstrap_info optimization (skip for GAIE Stage 1 which needs query-only flow)
// For GAIE Stage 2, use prefill_worker_id if provided
// Try build_bootstrap_info optimization: if we can get bootstrap info upfront,
// spawn prefill in background and proceed to decode immediately.
let preselected_worker = prefill_req
.routing
.as_ref()
.and_then(|r| r.prefill_worker_id);
let prefill_result = if !is_gaie_stage1
&& let Some((worker_id, dp_rank, bootstrap_info)) = self
.build_bootstrap_info(&prefill_req, preselected_worker)
.await
let prefill_result = if let Some((worker_id, dp_rank, bootstrap_info)) = self
.build_bootstrap_info(&prefill_req, preselected_worker)
.await
{
// Bootstrap optimization path: spawn prefill in background
let routing = prefill_req.routing_mut();
routing.prefill_worker_id = Some(worker_id);
routing.backend_instance_id = Some(worker_id); // Route prefill to the SAME worker we got bootstrap_info from
routing.dp_rank = Some(dp_rank);
prefill_req.bootstrap_info = Some(bootstrap_info.clone());
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
self.spawn_prefill_task(prefill_context, Some(worker_id));
// Pass phase permit to spawned task - it drops after first output (record_worker complete)
// This allows set_phase(Decode) below to proceed only after prefill routing is done
self.spawn_prefill_task(prefill_context, Some(worker_id), prefill_phase_permit);
Ok((None, Some(worker_id), Some(bootstrap_info)))
} else {
// Original prefill path: wait for prefill to complete
tracing::debug!(
is_gaie_stage1 = is_gaie_stage1,
"Using original prefill path"
);
tracing::debug!("Using original prefill path");
// Drop the phase permit before calling call_prefill - we wait for completion
// so there's no race with set_phase(Decode) below
drop(prefill_phase_permit);
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
......@@ -579,20 +529,18 @@ impl
Ok((maybe_prefill_result, _prefill_worker_id, bootstrap_info)) => {
tracing::debug!("Prefill completed, proceeding to decode");
// Set phase to Decode for the decode request
// Set phase to Decode for the decode request.
// In bootstrap path, this blocks until the spawned prefill task drops its permit
// (after first output / record_worker completes), ensuring correct phase for routing.
if let Some(ref tracker) = req.tracker {
tracker.set_phase(RequestPhase::Decode);
let _decode_permit = tracker.set_phase(RequestPhase::Decode).await;
// Permit is dropped immediately - decode proceeds, no need to hold it
}
let mut decode_req = req;
// Update request with prefill result
if is_gaie_stage1 {
if let Some(ref prefill_result) = maybe_prefill_result {
Self::prepare_decode_for_gaie_stage1(&mut decode_req, prefill_result);
}
} else if let Some(prefill_result) = maybe_prefill_result {
// Normal or GAIE Stage 2: Set prefill_result for decode
if let Some(prefill_result) = maybe_prefill_result {
decode_req.prefill_result = Some(prefill_result);
}
......@@ -611,17 +559,6 @@ impl
..existing_override.unwrap_or_default()
});
// GAIE Stage 2: Route to pre-selected decode worker if specified
if let Some(decode_worker_id) =
decode_req.routing.as_ref().and_then(|r| r.decode_worker_id)
{
decode_req.routing_mut().backend_instance_id = Some(decode_worker_id);
tracing::debug!(
decode_worker_id = decode_worker_id,
"GAIE Stage 2: Routing decode to pre-selected worker"
);
}
// Map the modified request through with preserved context
let decode_request = context.map(|_| decode_req);
next.generate(decode_request).await
......
......@@ -651,6 +651,51 @@ pub async fn start_kv_router_background(
Ok(())
}
/// Handle a worker discovery event (added or removed).
async fn handle_worker_discovery(
event: DiscoveryEvent,
worker_query_client: &WorkerQueryClient,
kv_events_tx: &mpsc::Sender<RouterEvent>,
remove_worker_tx: &mpsc::Sender<WorkerId>,
) {
match event {
DiscoveryEvent::Added(instance) => {
let worker_id = instance.instance_id();
tracing::info!(
"DISCOVERY: Worker {worker_id} added, dumping local indexer into router"
);
match recover_from_worker(
worker_query_client,
worker_id,
None, // Start from beginning
None, // Get all events
kv_events_tx,
)
.await
{
Ok(count) => {
tracing::info!(
"Successfully dumped worker {worker_id}'s local indexer, recovered {count} events"
);
}
Err(e) => {
tracing::warn!(
"Failed to dump worker {worker_id}'s local indexer (may not have local indexer enabled): {e}"
);
}
}
}
DiscoveryEvent::Removed(worker_id) => {
tracing::warn!("DISCOVERY: Worker {worker_id} removed, removing from router indexer");
if let Err(e) = remove_worker_tx.send(worker_id).await {
tracing::warn!("Failed to send worker removal for worker {worker_id}: {e}");
}
}
}
}
/// Start a simplified background task for event consumption using NATS Core.
///
/// This is used when local indexer mode is enabled. Unlike `start_kv_router_background`,
......@@ -660,6 +705,9 @@ pub async fn start_kv_router_background(
/// - On worker Added: dumps worker's local indexer into router
/// - On worker Removed: removes worker from router indexer
///
/// This function first recovers state from all currently registered workers before
/// spawning the background task, ensuring the router is ready before returning.
///
/// This is appropriate when workers have local indexers enabled.
pub async fn start_kv_router_background_nats_core(
component: Component,
......@@ -688,6 +736,40 @@ pub async fn start_kv_router_background_nats_core(
.list_and_watch(generate_discovery_key, Some(cancellation_token.clone()))
.await?;
// Drain and process all existing workers before spawning the background loop.
// list_and_watch returns existing instances first, so we poll with a short timeout
// to process all initial workers synchronously before the router becomes "ready".
loop {
// Use a short timeout to detect when initial discovery events are exhausted
let poll_result =
tokio::time::timeout(Duration::from_millis(100), instance_event_stream.next()).await;
match poll_result {
Ok(Some(Ok(event))) => {
handle_worker_discovery(
event,
&worker_query_client,
&kv_events_tx,
&remove_worker_tx,
)
.await;
}
Ok(Some(Err(e))) => {
tracing::warn!("Error receiving discovery event during initial sync: {e}");
}
Ok(None) => {
// Stream ended
tracing::warn!("Discovery stream ended during initial sync");
break;
}
Err(_) => {
// Timeout - no more initial events
tracing::debug!("Initial worker discovery sync complete");
break;
}
}
}
tokio::spawn(async move {
// Track last received event ID per worker for gap detection
let mut last_event_ids: HashMap<WorkerId, u64> = HashMap::new();
......@@ -703,51 +785,17 @@ pub async fn start_kv_router_background_nats_core(
// Handle generate endpoint instance add/remove events
Some(discovery_event_result) = instance_event_stream.next() => {
let Ok(discovery_event) = discovery_event_result else {
let Ok(event) = discovery_event_result else {
continue;
};
match discovery_event {
DiscoveryEvent::Added(_instance) => {
// Extract worker_id from the instance
let worker_id = _instance.instance_id();
tracing::info!(
"DISCOVERY: Worker {worker_id} added, dumping local indexer into router"
);
// Query worker's local indexer and dump all events
match recover_from_worker(
&worker_query_client,
worker_id,
None, // Start from beginning
None, // Get all events
&kv_events_tx,
)
.await
{
Ok(count) => {
tracing::info!(
"Successfully dumped worker {worker_id}'s local indexer, recovered {count} events"
);
}
Err(e) => {
tracing::warn!(
"Failed to dump worker {worker_id}'s local indexer (may not have local indexer enabled): {e}"
);
}
}
}
DiscoveryEvent::Removed(worker_id) => {
tracing::warn!(
"DISCOVERY: Worker {worker_id} removed, removing from router indexer"
);
if let Err(e) = remove_worker_tx.send(worker_id).await {
tracing::warn!("Failed to send worker removal for worker {worker_id}: {e}");
}
}
}
handle_worker_discovery(
event,
&worker_query_client,
&kv_events_tx,
&remove_worker_tx,
)
.await;
}
// Handle event consumption from NATS Core subscription
......
......@@ -10,9 +10,10 @@ use dynamo_runtime::discovery::DiscoverySpec;
use dynamo_runtime::protocols::EndpointId;
use dynamo_runtime::slug::Slug;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::utils::get_http_rpc_host_from_env;
use crate::entrypoint::RouterConfig;
use crate::mocker::protocols::MockEngineArgs;
use crate::mocker::protocols::{MockEngineArgs, WorkerType};
use crate::model_card::ModelDeploymentCard;
use crate::model_type::{ModelInput, ModelType};
use crate::preprocessor::media::{ImageDecoder, MediaDecoder, MediaFetcher};
......@@ -249,6 +250,22 @@ impl LocalModelBuilder {
video: None,
});
self.media_fetcher = Some(MediaFetcher::default());
// Set bootstrap endpoint for prefill workers with bootstrap_port configured
if mocker_engine_args.worker_type == WorkerType::Prefill
&& let Some(port) = mocker_engine_args.bootstrap_port
{
let host = get_http_rpc_host_from_env();
self.runtime_config.disaggregated_endpoint =
Some(runtime_config::DisaggregatedEndpoint {
bootstrap_host: Some(host),
bootstrap_port: Some(port),
});
tracing::info!(
bootstrap_port = port,
"Mocker prefill worker: publishing bootstrap endpoint to discovery"
);
}
}
// frontend and echo engine don't need a path.
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod bootstrap;
pub mod engine;
pub mod evictor;
pub mod kv_manager;
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Bootstrap rendezvous for disaggregated mocker testing.
//!
//! Simulates the SGLang disaggregated serving handshake for KV transfer coordination.
//! Either prefill or decode can arrive first; the rendezvous completes when both are ready.
//!
//! - Prefill: calls `complete_room(room_id)` after first token (KV cache ready)
//! - Decode: connects to prefill's bootstrap server, blocks until prefill completes
//!
//! Wire protocol:
//! - Decode -> Prefill: room_id (8 bytes, little-endian u64)
//! - Prefill -> Decode: ACK (1 byte, 0x01) after prefill completes
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Result, bail};
use dashmap::DashMap;
use dashmap::mapref::entry::Entry;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
/// Timeout for bootstrap rendezvous operations.
const RENDEZVOUS_TIMEOUT: Duration = Duration::from_secs(30);
/// ACK byte sent from server to decode after prefill completes.
const ACK_BYTE: u8 = 0x01;
/// State for a room in the rendezvous.
struct RoomState {
/// True if prefill has completed (KV cache ready)
prefill_completed: bool,
/// Channel to notify decode when prefill completes (if decode is waiting)
decode_waiting: Option<oneshot::Sender<()>>,
}
/// Bootstrap server for prefill mockers.
/// Handles rendezvous between prefill and decode for KV transfer coordination.
pub struct BootstrapServer {
port: u16,
rooms: Arc<DashMap<u64, RoomState>>,
}
impl BootstrapServer {
/// Start the bootstrap server on the specified port.
pub async fn start(port: u16, cancel_token: CancellationToken) -> Result<Arc<Self>> {
let listener = TcpListener::bind(format!("0.0.0.0:{port}")).await?;
let actual_port = listener.local_addr()?.port();
tracing::info!("Bootstrap server started on port {actual_port}");
let rooms: Arc<DashMap<u64, RoomState>> = Arc::new(DashMap::new());
let server = Arc::new(Self {
port: actual_port,
rooms: rooms.clone(),
});
// Spawn accept loop
tokio::spawn(async move {
loop {
tokio::select! {
result = listener.accept() => {
match result {
Ok((stream, addr)) => {
tracing::debug!("Bootstrap: accepted connection from {addr}");
let rooms_clone = rooms.clone();
tokio::spawn(async move {
if let Err(e) = Self::handle_connection(stream, rooms_clone).await {
tracing::warn!("Bootstrap: connection error: {e}");
}
});
}
Err(e) => {
tracing::warn!("Bootstrap: accept failed: {e}");
}
}
}
_ = cancel_token.cancelled() => {
tracing::debug!("Bootstrap server shutting down");
break;
}
}
}
});
Ok(server)
}
/// Handle a connection from decode. Blocks until prefill completes for this room.
async fn handle_connection(
mut stream: TcpStream,
rooms: Arc<DashMap<u64, RoomState>>,
) -> Result<()> {
// Read room_id (8 bytes, little-endian)
let mut buf = [0u8; 8];
stream.read_exact(&mut buf).await?;
let room_id = u64::from_le_bytes(buf);
tracing::debug!("Bootstrap: decode connected for room {room_id}");
// Check room state and wait if needed
let rx = match rooms.entry(room_id) {
Entry::Occupied(mut entry) => {
if entry.get().prefill_completed {
// Prefill already done, immediate ACK
entry.remove();
tracing::debug!("Bootstrap: room {room_id} already completed, immediate ACK");
None
} else {
// Prefill registered but not completed, wait
let (tx, rx) = oneshot::channel();
entry.get_mut().decode_waiting = Some(tx);
tracing::debug!("Bootstrap: room {room_id} waiting for prefill to complete");
Some(rx)
}
}
Entry::Vacant(entry) => {
// Decode arrived first, create entry and wait
let (tx, rx) = oneshot::channel();
entry.insert(RoomState {
prefill_completed: false,
decode_waiting: Some(tx),
});
tracing::debug!("Bootstrap: room {room_id} decode arrived first, waiting");
Some(rx)
}
};
// Wait for prefill to complete if needed
if let Some(rx) = rx {
match tokio::time::timeout(RENDEZVOUS_TIMEOUT, rx).await {
Ok(Ok(())) => {
tracing::debug!("Bootstrap: room {room_id} prefill completed, sending ACK");
}
Ok(Err(_)) => {
bail!("Bootstrap: room {room_id} sender dropped");
}
Err(_) => {
rooms.remove(&room_id);
bail!("Bootstrap: room {room_id} timeout waiting for prefill");
}
}
}
// Send ACK
stream.write_all(&[ACK_BYTE]).await?;
Ok(())
}
/// Mark a room as completed (prefill finished, KV cache ready).
/// If decode is already waiting, unblocks it.
pub fn complete_room(&self, room_id: u64) {
match self.rooms.entry(room_id) {
Entry::Occupied(mut entry) => {
if let Some(sender) = entry.get_mut().decode_waiting.take() {
// Decode is waiting, unblock it
let _ = sender.send(());
entry.remove();
tracing::debug!("Bootstrap: room {room_id} completed, decode unblocked");
} else {
// Decode not connected yet, mark completed
entry.get_mut().prefill_completed = true;
tracing::debug!("Bootstrap: room {room_id} completed, awaiting decode");
}
}
Entry::Vacant(entry) => {
// Decode hasn't connected yet
entry.insert(RoomState {
prefill_completed: true,
decode_waiting: None,
});
tracing::debug!("Bootstrap: room {room_id} completed (no decode yet)");
}
}
}
/// Get the port the server is listening on.
pub fn port(&self) -> u16 {
self.port
}
}
/// Connect to a prefill worker's bootstrap server and wait for KV to be ready.
pub async fn connect_to_prefill(host: &str, port: u16, room_id: u64) -> Result<()> {
let host = host.trim_matches(|c| c == '[' || c == ']');
let addr = format!("{host}:{port}");
tracing::debug!("Bootstrap: decode connecting to {addr} for room {room_id}");
// Connect with timeout
let mut stream = tokio::time::timeout(RENDEZVOUS_TIMEOUT, TcpStream::connect(&addr))
.await
.map_err(|_| anyhow::anyhow!("Bootstrap: connect timeout to {addr}"))?
.map_err(|e| anyhow::anyhow!("Bootstrap: connect failed to {addr}: {e}"))?;
// Send room_id
stream.write_all(&room_id.to_le_bytes()).await?;
// Wait for ACK (blocks until prefill completes)
let mut ack = [0u8; 1];
tokio::time::timeout(RENDEZVOUS_TIMEOUT, stream.read_exact(&mut ack))
.await
.map_err(|_| anyhow::anyhow!("Bootstrap: ACK timeout for room {room_id}"))?
.map_err(|e| anyhow::anyhow!("Bootstrap: read ACK failed: {e}"))?;
if ack[0] != ACK_BYTE {
bail!(
"Bootstrap: invalid ACK byte {:02x} for room {room_id}",
ack[0]
);
}
tracing::debug!("Bootstrap: decode received ACK for room {room_id}");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_prefill_completes_first() {
let cancel_token = CancellationToken::new();
let server = BootstrapServer::start(0, cancel_token.clone())
.await
.unwrap();
let port = server.port();
let room_id = 1001u64;
// Prefill completes first
server.complete_room(room_id);
// Decode connects - should get immediate ACK
let result = connect_to_prefill("127.0.0.1", port, room_id).await;
assert!(result.is_ok(), "Decode should succeed: {result:?}");
cancel_token.cancel();
}
#[tokio::test]
async fn test_decode_connects_first() {
let cancel_token = CancellationToken::new();
let server = BootstrapServer::start(0, cancel_token.clone())
.await
.unwrap();
let port = server.port();
let room_id = 1002u64;
// Spawn decode (will block waiting for prefill)
let decode_handle =
tokio::spawn(async move { connect_to_prefill("127.0.0.1", port, room_id).await });
// Give decode time to connect and register
tokio::time::sleep(Duration::from_millis(50)).await;
// Prefill completes - should unblock decode
server.complete_room(room_id);
let result = decode_handle.await.unwrap();
assert!(result.is_ok(), "Decode should succeed: {result:?}");
cancel_token.cancel();
}
#[tokio::test]
async fn test_interleaved_ordering() {
let cancel_token = CancellationToken::new();
let server = BootstrapServer::start(0, cancel_token.clone())
.await
.unwrap();
let port = server.port();
let room_id = 1003u64;
// Spawn decode
let server_clone = server.clone();
let decode_handle = tokio::spawn(async move {
// Small delay so prefill can "register" conceptually first
tokio::time::sleep(Duration::from_millis(10)).await;
connect_to_prefill("127.0.0.1", port, room_id).await
});
// Prefill completes after decode starts connecting
tokio::time::sleep(Duration::from_millis(50)).await;
server_clone.complete_room(room_id);
let result = decode_handle.await.unwrap();
assert!(result.is_ok(), "Decode should succeed: {result:?}");
cancel_token.cancel();
}
#[tokio::test]
async fn test_multiple_rooms_concurrent() {
let cancel_token = CancellationToken::new();
let server = BootstrapServer::start(0, cancel_token.clone())
.await
.unwrap();
let port = server.port();
let mut handles = vec![];
// Room 1: prefill first
let server1 = server.clone();
handles.push(tokio::spawn(async move {
server1.complete_room(2001);
tokio::time::sleep(Duration::from_millis(10)).await;
connect_to_prefill("127.0.0.1", port, 2001).await
}));
// Room 2: decode first
let server2 = server.clone();
handles.push(tokio::spawn(async move {
let decode = tokio::spawn(connect_to_prefill("127.0.0.1", port, 2002));
tokio::time::sleep(Duration::from_millis(50)).await;
server2.complete_room(2002);
decode.await.unwrap()
}));
// Room 3: simultaneous
let server3 = server.clone();
handles.push(tokio::spawn(async move {
let decode = tokio::spawn(connect_to_prefill("127.0.0.1", port, 2003));
server3.complete_room(2003);
decode.await.unwrap()
}));
for (i, handle) in handles.into_iter().enumerate() {
let result = handle.await.unwrap();
assert!(
result.is_ok(),
"Room {} should succeed: {result:?}",
2001 + i
);
}
cancel_token.cancel();
}
#[tokio::test]
async fn test_decode_timeout_no_prefill() {
let cancel_token = CancellationToken::new();
let server = BootstrapServer::start(0, cancel_token.clone())
.await
.unwrap();
let port = server.port();
let room_id = 9999u64;
// Decode connects but prefill never completes - use short timeout
let result = tokio::time::timeout(
Duration::from_millis(100),
connect_to_prefill("127.0.0.1", port, room_id),
)
.await;
// Should timeout (outer timeout, not inner RENDEZVOUS_TIMEOUT)
assert!(result.is_err(), "Should timeout waiting for prefill");
cancel_token.cancel();
}
}
......@@ -28,6 +28,7 @@ use dynamo_runtime::{
};
use crate::kv_router::publisher::WorkerMetricsPublisher;
use crate::mocker::bootstrap::{BootstrapServer, connect_to_prefill};
use crate::mocker::protocols::DirectRequest;
use crate::mocker::protocols::{MockEngineArgs, OutputSignal, WorkerType};
use crate::mocker::scheduler::Scheduler;
......@@ -47,6 +48,8 @@ pub struct MockVllmEngine {
active_requests: Arc<Mutex<HashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>>,
request_senders: Arc<OnceCell<Vec<mpsc::UnboundedSender<DirectRequest>>>>,
engine_args: MockEngineArgs,
/// Bootstrap server for prefill workers in disaggregated mode
bootstrap_server: Arc<OnceCell<Arc<BootstrapServer>>>,
}
impl MockVllmEngine {
......@@ -56,6 +59,7 @@ impl MockVllmEngine {
active_requests: Arc::new(Mutex::new(HashMap::new())),
request_senders: Arc::new(OnceCell::new()),
engine_args: args,
bootstrap_server: Arc::new(OnceCell::new()),
}
}
......@@ -73,6 +77,15 @@ impl MockVllmEngine {
tracing::info!("Engine startup simulation completed");
}
// Start bootstrap server for prefill workers in disaggregated mode
if self.engine_args.worker_type == WorkerType::Prefill
&& let Some(port) = self.engine_args.bootstrap_port
{
let server = BootstrapServer::start(port, cancel_token.clone()).await?;
let _ = self.bootstrap_server.set(server);
tracing::info!(port = port, "Bootstrap server started for prefill worker");
}
// Pass component to schedulers only if prefix caching is enabled and not a decode worker
let scheduler_component = if self.engine_args.enable_prefix_caching
&& self.engine_args.worker_type != WorkerType::Decode
......@@ -253,6 +266,22 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
)));
}
// Bootstrap rendezvous for disaggregated serving
// - Decode: connect to prefill's server, block until prefill completes
// - Prefill: complete_room() is called after first token (see below)
let bootstrap_room = request.bootstrap_info.as_ref().map(|b| b.bootstrap_room);
if let Some(bootstrap_info) = &request.bootstrap_info
&& self.engine_args.worker_type == WorkerType::Decode
{
connect_to_prefill(
&bootstrap_info.bootstrap_host,
bootstrap_info.bootstrap_port,
bootstrap_info.bootstrap_room,
)
.await
.map_err(|e| Error::msg(format!("Bootstrap connection failed: {e}")))?;
}
let request_uuid = ctx.id().parse().unwrap_or(Uuid::new_v4());
// For prefill workers, override max_tokens to 1
......@@ -288,6 +317,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
let active_requests = self.active_requests.clone();
let async_context = ctx.context();
let bootstrap_server = self.bootstrap_server.clone();
// Spawn a task to handle the complex async logic
tokio::spawn(async move {
......@@ -325,6 +355,14 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
completion_usage: None,
};
// Prefill: after first token, mark room complete (unblocks decode)
if is_prefill
&& token_count == 1
&& let (Some(server), Some(room_id)) = (bootstrap_server.get(), bootstrap_room)
{
server.complete_room(room_id);
}
if signal.completed && token_count < max_output_tokens {
let _ = stream_tx.send(LLMEngineOutput::error("Completion signal received before max tokens reached".to_string()));
break;
......
......@@ -124,6 +124,12 @@ pub struct MockEngineArgs {
/// Enable worker-local KV indexer for tracking this worker's own KV cache state
#[builder(default = "false")]
pub enable_local_indexer: bool,
/// Bootstrap port for disaggregated serving rendezvous.
/// Prefill workers listen on this port; decode workers connect to it.
/// If None, bootstrap rendezvous is disabled.
#[builder(default = "None")]
pub bootstrap_port: Option<u16>,
}
impl Default for MockEngineArgs {
......@@ -163,6 +169,7 @@ impl MockEngineArgs {
"is_decode",
"planner_profile_data",
"enable_local_indexer",
"bootstrap_port",
]
.iter()
.cloned()
......@@ -250,6 +257,12 @@ impl MockEngineArgs {
builder = builder.enable_local_indexer(enabled);
}
if let Some(value) = extra_args.get("bootstrap_port")
&& let Some(port) = value.as_u64()
{
builder = builder.bootstrap_port(Some(port as u16));
}
// Parse worker type from is_prefill and is_decode flags
let is_prefill = extra_args
.get("is_prefill")
......
......@@ -245,6 +245,7 @@ impl OpenAIPreprocessor {
prefill_worker_id: nvext.prefill_worker_id,
decode_worker_id: nvext.decode_worker_id,
dp_rank: None, // dp_rank is set later in the pipeline
enable_local_updates: nvext.enable_local_updates,
};
builder.routing(Some(routing));
}
......
......@@ -34,6 +34,14 @@ pub struct RoutingHints {
/// Data parallel rank for the request
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dp_rank: Option<u32>,
/// Controls whether the router should manage local bookkeeping (add_request,
/// mark_prefill_completed, free) for this request.
///
/// - `None` or `Some(true)`: Router handles bookkeeping locally (default behavior)
/// - `Some(false)`: External caller (e.g., GAIE sidecar) handles bookkeeping via C FFI
#[serde(default, skip_serializing_if = "Option::is_none")]
pub enable_local_updates: Option<bool>,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
......
......@@ -6,9 +6,12 @@
//! This module provides [`RequestTracker`] for tracking timing and routing information
//! that can be returned to clients via the `nvext` response field.
use serde::{Deserialize, Serialize};
use std::sync::{Mutex, OnceLock};
use std::sync::{Arc, OnceLock};
use std::time::{Instant, SystemTime, UNIX_EPOCH};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use utoipa::ToSchema;
use crate::protocols::openai::nvext::WorkerIdInfo;
......@@ -80,6 +83,12 @@ pub struct RequestTracker {
/// Request phase (Prefill/Decode/Aggregated)
phase: Mutex<RequestPhase>,
/// Semaphore for coordinating phase transitions.
/// Acquiring a permit blocks subsequent set_phase calls until the permit is dropped.
/// This prevents race conditions in the bootstrap optimization path where prefill
/// runs in background and needs to complete record_worker before phase changes.
phase_semaphore: Arc<Semaphore>,
}
impl RequestTracker {
......@@ -102,6 +111,7 @@ impl RequestTracker {
prefill_worker_id: OnceLock::new(),
decode_worker_id: OnceLock::new(),
phase: Mutex::new(RequestPhase::Aggregated),
phase_semaphore: Arc::new(Semaphore::new(1)),
}
}
......@@ -175,14 +185,29 @@ impl RequestTracker {
self.decode_worker_id.set(id).is_ok()
}
/// Set the request phase. Can be called multiple times to update the phase.
pub fn set_phase(&self, phase: RequestPhase) {
*self.phase.lock().unwrap() = phase;
/// Set the request phase and return a permit that blocks subsequent phase changes.
///
/// The returned permit must be dropped to allow the next `set_phase` call to proceed.
/// Under normal operation, callers can simply ignore the returned permit (letting it
/// drop immediately). In the bootstrap optimization path, the permit is held and
/// passed to the spawned prefill task, which drops it after `record_worker` completes.
///
/// This prevents the race condition where the phase changes to Decode before the
/// background prefill task has recorded its worker ID.
pub async fn set_phase(&self, phase: RequestPhase) -> OwnedSemaphorePermit {
let permit = self
.phase_semaphore
.clone()
.acquire_owned()
.await
.expect("phase semaphore should never be closed");
*self.phase.lock() = phase;
permit
}
/// Get the current request phase.
pub fn phase(&self) -> RequestPhase {
*self.phase.lock().unwrap()
*self.phase.lock()
}
/// Record worker ID based on the current phase.
......
......@@ -105,6 +105,17 @@ pub struct NvExt {
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub decode_worker_id: Option<u64>,
/// Controls whether the router should manage local bookkeeping (add_request,
/// mark_prefill_completed, free) for this request.
///
/// - `None` or `true`: Router handles bookkeeping locally (default behavior)
/// - `false`: External caller (e.g., GAIE sidecar) handles bookkeeping via C FFI
///
/// Set to `false` for GAIE Stage 2 when the EPP/sidecar manages request lifecycle.
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub enable_local_updates: Option<bool>,
}
impl Default for NvExt {
......@@ -153,6 +164,7 @@ mod tests {
assert_eq!(nv_ext.extra_fields, None);
assert_eq!(nv_ext.prefill_worker_id, None);
assert_eq!(nv_ext.decode_worker_id, None);
assert_eq!(nv_ext.enable_local_updates, None);
}
// Test valid builder configurations
......
......@@ -230,7 +230,7 @@ impl EndpointConfigBuilder {
/// This function handles both health check and discovery transport building.
/// All transport modes use consistent addressing:
/// - HTTP: Uses full URL path including endpoint name (e.g., http://host:port/v1/rpc/endpoint_name)
/// - TCP: Includes endpoint name for routing (e.g., host:port/endpoint_name)
/// - TCP: Includes instance_id and endpoint name for routing (e.g., host:port/instance_id_hex/endpoint_name)
/// - NATS: Uses subject-based addressing (unique per endpoint)
///
/// # Errors
......@@ -266,9 +266,14 @@ fn build_transport_type_inner(
.and_then(|p| p.parse::<u16>().ok())
.unwrap_or(crate::pipeline::network::manager::get_actual_tcp_rpc_port()?);
// Include endpoint name for proper TCP routing
// TCP client parses this format and adds x-endpoint-path header for server-side routing
let tcp_endpoint = format!("{}:{}/{}", tcp_host, tcp_port, endpoint_id.name);
// Include instance_id and endpoint name for proper TCP routing.
// Format: host:port/instance_id_hex/endpoint_name
// This ensures each worker has a unique routing key when multiple workers
// share the same TCP server (e.g., --num-workers > 1).
let tcp_endpoint = format!(
"{}:{}/{:x}/{}",
tcp_host, tcp_port, connection_id, endpoint_id.name
);
Ok(TransportType::Tcp(tcp_endpoint))
}
......
......@@ -413,9 +413,11 @@ impl super::unified_server::RequestPlaneServer for SharedTcpServer {
component_name: String,
system_health: Arc<Mutex<SystemHealth>>,
) -> Result<()> {
// For TCP, we use endpoint_name as both the endpoint_path (routing key) and endpoint_name
// Include instance_id in the routing key to avoid collisions when multiple workers
// share the same TCP server (e.g., --num-workers > 1 in tests)
let endpoint_path = format!("{instance_id:x}/{endpoint_name}");
self.register_endpoint(
endpoint_name.clone(),
endpoint_path,
service_handler,
instance_id,
namespace,
......@@ -427,7 +429,19 @@ impl super::unified_server::RequestPlaneServer for SharedTcpServer {
}
async fn unregister_endpoint(&self, endpoint_name: &str) -> Result<()> {
self.unregister_endpoint(endpoint_name, endpoint_name).await;
// With multiple workers per process, each registers with a unique key
// "{instance_id}/{endpoint_name}". Find and remove all matching entries.
let suffix = format!("/{endpoint_name}");
let keys_to_remove: Vec<String> = self
.handlers
.iter()
.filter(|entry| entry.key().ends_with(&suffix))
.map(|entry| entry.key().clone())
.collect();
for key in keys_to_remove {
self.unregister_endpoint(&key, endpoint_name).await;
}
Ok(())
}
......
......@@ -14,6 +14,7 @@
//! directly accesses transport implementations or configuration.
use super::egress::unified_client::RequestPlaneClient;
use super::ingress::shared_tcp_endpoint::SharedTcpServer;
use super::ingress::unified_server::RequestPlaneServer;
use crate::distributed::RequestPlaneMode;
use anyhow::Result;
......@@ -26,6 +27,17 @@ use tokio_util::sync::CancellationToken;
/// Uses OnceLock since the port is set once when the server binds and never changes.
static ACTUAL_TCP_RPC_PORT: OnceLock<u16> = OnceLock::new();
/// Global storage for the shared TCP server instance.
///
/// When multiple workers run in the same process, they must share a single TCP server
/// to ensure all endpoints are registered on the same server. Without this, each worker
/// would create its own server on a different port, but all would publish the same port
/// (from ACTUAL_TCP_RPC_PORT) to discovery, causing "No handler found" errors.
///
/// Uses `tokio::sync::OnceCell` to support async initialization (binding the TCP socket).
static GLOBAL_TCP_SERVER: tokio::sync::OnceCell<Arc<SharedTcpServer>> =
tokio::sync::OnceCell::const_new();
/// Get the actual TCP RPC port that the server is listening on.
pub fn get_actual_tcp_rpc_port() -> anyhow::Result<u16> {
ACTUAL_TCP_RPC_PORT.get().copied().ok_or_else(|| {
......@@ -300,35 +312,41 @@ impl NetworkManager {
}
async fn create_tcp_server(&self) -> Result<Arc<dyn RequestPlaneServer>> {
use super::ingress::shared_tcp_endpoint::SharedTcpServer;
// Use the global TCP server to ensure all workers in the same process share
// a single server. This is critical for correct endpoint routing.
let server = GLOBAL_TCP_SERVER
.get_or_try_init(|| async {
// Use configured port if specified, otherwise use port 0 (OS assigns free port)
let port = self.config.tcp_port.unwrap_or(0);
let bind_addr = format!("{}:{}", self.config.tcp_host, port)
.parse()
.map_err(|e| anyhow::anyhow!("Invalid TCP bind address: {}", e))?;
// Use configured port if specified, otherwise use port 0 (OS assigns free port)
let port = self.config.tcp_port.unwrap_or(0);
let bind_addr = format!("{}:{}", self.config.tcp_host, port)
.parse()
.map_err(|e| anyhow::anyhow!("Invalid TCP bind address: {}", e))?;
tracing::info!(
bind_addr = %bind_addr,
port_source = if self.config.tcp_port.is_some() { "DYN_TCP_RPC_PORT" } else { "OS-assigned" },
"Creating TCP request plane server"
);
tracing::info!(
bind_addr = %bind_addr,
port_source = if self.config.tcp_port.is_some() { "DYN_TCP_RPC_PORT" } else { "OS-assigned" },
"Creating TCP request plane server"
);
let server = SharedTcpServer::new(bind_addr, self.cancellation_token.clone());
let server = SharedTcpServer::new(bind_addr, self.cancellation_token.clone());
// Bind and start server, getting the actual bound address
let actual_addr = server.clone().bind_and_start().await?;
// Bind and start server, getting the actual bound address
let actual_addr = server.clone().bind_and_start().await?;
// Store the actual bound port globally so build_transport_type() can access it
set_actual_tcp_rpc_port(actual_addr.port());
// Store the actual bound port globally so build_transport_type() can access it
set_actual_tcp_rpc_port(actual_addr.port());
tracing::info!(
actual_addr = %actual_addr,
actual_port = actual_addr.port(),
"TCP request plane server started"
);
tracing::info!(
actual_addr = %actual_addr,
actual_port = actual_addr.port(),
"TCP request plane server started"
);
Ok::<_, anyhow::Error>(server)
})
.await?;
Ok(server as Arc<dyn RequestPlaneServer>)
Ok(server.clone() as Arc<dyn RequestPlaneServer>)
}
async fn create_nats_server(&self) -> Result<Arc<dyn RequestPlaneServer>> {
......
......@@ -6,6 +6,12 @@
# Combined pre_merge wall time (this file):
# - Serialized: 304.01s.
# - Parallel (-n auto): 34.55s (269.46s saved, 8.80x).
#
# NOTE: TCP request plane is NOT tested here. These tests use --num-workers > 1 which spawns
# multiple workers in a single process sharing one TCP server. The shared TCP server uses
# endpoint_path (e.g., "generate") as the routing key, causing handler collisions when multiple
# workers register the same endpoint. This is a test-only limitation; production deployments
# with separate processes per worker work correctly with TCP.
import logging
import os
from typing import Any, Dict, Optional
......@@ -155,6 +161,8 @@ def _build_mocker_command(
command.extend(["--data-parallel-size", str(mocker_args["dp_size"])])
if mocker_args.get("enable_local_indexer"):
command.append("--enable-local-indexer")
if "bootstrap_ports" in mocker_args:
command.extend(["--bootstrap-ports", mocker_args["bootstrap_ports"]])
return command
......@@ -233,6 +241,7 @@ class DisaggMockerProcess:
num_mockers: int = 1,
store_backend: str = "etcd",
request_plane: str = "nats",
enable_bootstrap: bool = False,
):
if worker_type not in ("prefill", "decode"):
raise ValueError(
......@@ -242,6 +251,7 @@ class DisaggMockerProcess:
self.namespace = namespace
self.worker_type = worker_type
self.num_workers = num_mockers
self._bootstrap_ports: list[int] = []
# Set component name and endpoint based on worker type
if worker_type == "prefill":
......@@ -251,7 +261,17 @@ class DisaggMockerProcess:
self.component_name = "backend"
self.endpoint = f"dyn://{self.namespace}.backend.generate"
mocker_args = mocker_args or {}
mocker_args = (mocker_args or {}).copy()
# Allocate bootstrap ports for prefill workers if enabled (one per worker)
if enable_bootstrap and worker_type == "prefill":
self._bootstrap_ports = allocate_ports(num_mockers, BASE_PORT)
mocker_args["bootstrap_ports"] = ",".join(
str(p) for p in self._bootstrap_ports
)
logger.info(
f"Allocated bootstrap ports {self._bootstrap_ports} for {num_mockers} prefill workers"
)
command = _build_mocker_command(
endpoint=self.endpoint,
......@@ -279,6 +299,11 @@ class DisaggMockerProcess:
f"endpoint: {self.endpoint}"
)
@property
def bootstrap_ports(self) -> list[int]:
"""Return the allocated bootstrap ports, if any."""
return self._bootstrap_ports
def __enter__(self):
logger.info(
f"Starting {self.worker_type} mocker process with {self.num_workers} worker(s)"
......@@ -289,6 +314,11 @@ class DisaggMockerProcess:
def __exit__(self, exc_type, exc_val, exc_tb):
logger.info(f"Stopping {self.worker_type} mocker process")
self._process.__exit__(exc_type, exc_val, exc_tb)
# Deallocate bootstrap ports if we allocated any
if self._bootstrap_ports:
deallocate_ports(self._bootstrap_ports)
logger.info(f"Deallocated bootstrap ports {self._bootstrap_ports}")
self._bootstrap_ports = []
@pytest.mark.timeout(42) # ~3x average (~13.80s), rounded up
......@@ -487,9 +517,9 @@ def test_kv_push_router_bindings(
],
ids=[
"jetstream",
"nats",
"nats_core",
"file",
], # "nats_core" commented out to match commented test case
],
)
@pytest.mark.timeout(90) # TODO: figure out a timeout
def test_indexers_sync(
......@@ -677,37 +707,43 @@ def test_router_decisions(
mockers.__exit__(None, None, None)
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.parametrize("registration_order", ["prefill_first", "decode_first"])
@pytest.mark.parametrize(
"enable_disagg_bootstrap", [False, True], ids=["no_bootstrap", "with_bootstrap"]
)
@pytest.mark.timeout(59) # ~3x average (~19.51s), rounded up
def test_router_decisions_disagg(
request,
runtime_services_dynamic_ports,
predownload_tokenizers,
registration_order,
request_plane,
enable_disagg_bootstrap,
):
"""Validate KV cache prefix reuse in disaggregated prefill-decode setup.
Tests that progressive requests with overlapping prefixes are routed to the
same prefill worker due to KV cache reuse.
Parameterized to test both registration orders:
- prefill_first: prefill workers register before decode workers
- decode_first: decode workers register before prefill workers
Parameterized to test:
- registration_order: prefill_first vs decode_first
- enable_disagg_bootstrap: without vs with bootstrap rendezvous
"""
# runtime_services_dynamic_ports handles NATS and etcd startup
logger.info(
f"Starting disaggregated router prefix reuse test "
f"(registration_order={registration_order})"
f"(registration_order={registration_order}, bootstrap={enable_disagg_bootstrap})"
)
# Generate shared namespace for prefill and decode workers
namespace_suffix = generate_random_suffix()
shared_namespace = f"test-namespace-{namespace_suffix}"
# Create mocker args
mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE}
# Create mocker args - use JetStream for KV events (more reliable than NATS Core)
mocker_args = {
"speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE,
"enable_local_indexer": False,
}
prefill_workers = None
decode_workers = None
......@@ -722,7 +758,8 @@ def test_router_decisions_disagg(
worker_type="prefill",
mocker_args=mocker_args,
num_mockers=4,
request_plane=request_plane,
request_plane="nats",
enable_bootstrap=enable_disagg_bootstrap,
)
prefill_workers.__enter__()
logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
......@@ -735,7 +772,7 @@ def test_router_decisions_disagg(
worker_type="decode",
mocker_args=mocker_args,
num_mockers=4,
request_plane=request_plane,
request_plane="nats",
)
decode_workers.__enter__()
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
......@@ -748,7 +785,7 @@ def test_router_decisions_disagg(
worker_type="decode",
mocker_args=mocker_args,
num_mockers=4,
request_plane=request_plane,
request_plane="nats",
)
decode_workers.__enter__()
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
......@@ -761,7 +798,8 @@ def test_router_decisions_disagg(
worker_type="prefill",
mocker_args=mocker_args,
num_mockers=4,
request_plane=request_plane,
request_plane="nats",
enable_bootstrap=enable_disagg_bootstrap,
)
prefill_workers.__enter__()
logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
......@@ -779,7 +817,7 @@ def test_router_decisions_disagg(
request=request,
frontend_port=frontend_port,
test_payload=TEST_PAYLOAD,
request_plane=request_plane,
request_plane="nats",
)
finally:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment