Unverified Commit 0980b27f authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: mockers with bootstrap optimization (sglang testing) + CI test (#5121)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent cd8dddee
...@@ -108,12 +108,15 @@ def create_temp_engine_args_file(args) -> Path: ...@@ -108,12 +108,15 @@ def create_temp_engine_args_file(args) -> Path:
"speedup_ratio": getattr(args, "speedup_ratio", None), "speedup_ratio": getattr(args, "speedup_ratio", None),
"dp_size": getattr(args, "dp_size", None), "dp_size": getattr(args, "dp_size", None),
"startup_time": getattr(args, "startup_time", None), "startup_time": getattr(args, "startup_time", None),
"planner_profile_data": str(getattr(args, "planner_profile_data", None)) "planner_profile_data": (
if getattr(args, "planner_profile_data", None) str(getattr(args, "planner_profile_data", None))
else None, if getattr(args, "planner_profile_data", None)
else None
),
"is_prefill": getattr(args, "is_prefill_worker", None), "is_prefill": getattr(args, "is_prefill_worker", None),
"is_decode": getattr(args, "is_decode_worker", None), "is_decode": getattr(args, "is_decode_worker", None),
"enable_local_indexer": getattr(args, "enable_local_indexer", None), "enable_local_indexer": getattr(args, "enable_local_indexer", None),
# Note: bootstrap_port is NOT included here - it's set per-worker in launch_workers()
} }
# Remove None values to only include explicitly set arguments # Remove None values to only include explicitly set arguments
...@@ -142,6 +145,13 @@ def validate_worker_type_args(args): ...@@ -142,6 +145,13 @@ def validate_worker_type_args(args):
) )
def parse_bootstrap_ports(ports_str: str | None) -> list[int]:
"""Parse comma-separated bootstrap ports string into list of integers."""
if not ports_str:
return []
return [int(p.strip()) for p in ports_str.split(",")]
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Mocker engine for testing Dynamo LLM infrastructure with vLLM-style CLI.", description="Mocker engine for testing Dynamo LLM infrastructure with vLLM-style CLI.",
...@@ -291,6 +301,15 @@ def parse_args(): ...@@ -291,6 +301,15 @@ def parse_args():
default=False, default=False,
help="Enable worker-local KV indexer for tracking this worker's own KV cache state (default: False)", help="Enable worker-local KV indexer for tracking this worker's own KV cache state (default: False)",
) )
parser.add_argument(
"--bootstrap-ports",
type=str,
default=None,
help="Comma-separated list of bootstrap ports for disaggregated serving rendezvous. "
"One port per worker (must match --num-workers). "
"Prefill workers listen on these ports; decode workers connect to them. "
"If not specified, bootstrap rendezvous is disabled.",
)
parser.add_argument( parser.add_argument(
"--store-kv", "--store-kv",
type=str, type=str,
...@@ -313,6 +332,15 @@ def parse_args(): ...@@ -313,6 +332,15 @@ def parse_args():
if args.num_workers < 1: if args.num_workers < 1:
raise ValueError(f"--num-workers must be at least 1, got {args.num_workers}") raise ValueError(f"--num-workers must be at least 1, got {args.num_workers}")
# Parse and validate bootstrap_ports
args.bootstrap_ports_list = parse_bootstrap_ports(args.bootstrap_ports)
if args.bootstrap_ports_list:
if len(args.bootstrap_ports_list) != args.num_workers:
raise ValueError(
f"--bootstrap-ports must have exactly --num-workers ({args.num_workers}) ports, "
f"got {len(args.bootstrap_ports_list)}: {args.bootstrap_ports_list}"
)
# Set endpoint default based on worker type if not explicitly provided # Set endpoint default based on worker type if not explicitly provided
if args.endpoint is None: if args.endpoint is None:
if args.is_prefill_worker: if args.is_prefill_worker:
......
...@@ -5,9 +5,12 @@ ...@@ -5,9 +5,12 @@
# Now supports vLLM-style individual arguments for MockEngineArgs # Now supports vLLM-style individual arguments for MockEngineArgs
import asyncio import asyncio
import json
import logging import logging
import os import os
import signal import signal
import tempfile
from pathlib import Path
import uvloop import uvloop
...@@ -85,6 +88,13 @@ async def launch_workers(args, extra_engine_args_path): ...@@ -85,6 +88,13 @@ async def launch_workers(args, extra_engine_args_path):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
futures = [] futures = []
runtimes = [] runtimes = []
per_worker_temp_files: list[Path] = []
# Load base engine args if we need to create per-worker files with bootstrap_port
base_engine_args = None
if args.bootstrap_ports_list:
with open(extra_engine_args_path) as f:
base_engine_args = json.load(f)
for worker_id in range(args.num_workers): for worker_id in range(args.num_workers):
logger.info(f"Creating mocker worker {worker_id + 1}/{args.num_workers}") logger.info(f"Creating mocker worker {worker_id + 1}/{args.num_workers}")
...@@ -93,13 +103,30 @@ async def launch_workers(args, extra_engine_args_path): ...@@ -93,13 +103,30 @@ async def launch_workers(args, extra_engine_args_path):
runtime = DistributedRuntime(loop, args.store_kv, args.request_plane) runtime = DistributedRuntime(loop, args.store_kv, args.request_plane)
runtimes.append(runtime) runtimes.append(runtime)
# Determine which engine args file to use
if args.bootstrap_ports_list:
# Create per-worker temp file with this worker's bootstrap_port
worker_args = base_engine_args.copy()
worker_args["bootstrap_port"] = args.bootstrap_ports_list[worker_id]
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as f:
json.dump(worker_args, f)
worker_engine_args_path = Path(f.name)
per_worker_temp_files.append(worker_engine_args_path)
logger.debug(
f"Worker {worker_id}: using bootstrap_port {args.bootstrap_ports_list[worker_id]}"
)
else:
worker_engine_args_path = extra_engine_args_path
# Create EntrypointArgs for this worker # Create EntrypointArgs for this worker
entrypoint_args = EntrypointArgs( entrypoint_args = EntrypointArgs(
engine_type=EngineType.Mocker, engine_type=EngineType.Mocker,
model_path=args.model_path, model_path=args.model_path,
model_name=args.model_name, model_name=args.model_name,
endpoint_id=args.endpoint, endpoint_id=args.endpoint,
extra_engine_args=extra_engine_args_path, extra_engine_args=worker_engine_args_path,
is_prefill=args.is_prefill_worker, is_prefill=args.is_prefill_worker,
) )
...@@ -130,6 +157,13 @@ async def launch_workers(args, extra_engine_args_path): ...@@ -130,6 +157,13 @@ async def launch_workers(args, extra_engine_args_path):
for runtime in runtimes: for runtime in runtimes:
runtime.shutdown() runtime.shutdown()
# Clean up per-worker temp files
for temp_file in per_worker_temp_files:
try:
temp_file.unlink()
except Exception:
pass
def main(): def main():
uvloop.run(worker()) uvloop.run(worker())
......
...@@ -96,21 +96,6 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -96,21 +96,6 @@ class PrefillWorkerHandler(BaseWorkerHandler):
bootstrap_room = self._generate_bootstrap_room() bootstrap_room = self._generate_bootstrap_room()
logging.debug(f"Generated bootstrap_room locally: {bootstrap_room}") logging.debug(f"Generated bootstrap_room locally: {bootstrap_room}")
bootstrap_info = {
"bootstrap_host": self.bootstrap_host,
"bootstrap_port": self.bootstrap_port,
"bootstrap_room": bootstrap_room,
}
# Yield in LLMEngineOutput format for PrefillRouter compatibility
# The disaggregated_params field contains the bootstrap info
yield {
"token_ids": [],
"text": None,
"finish_reason": None,
"disaggregated_params": bootstrap_info,
}
input_param = self._get_input_param(inner_request) input_param = self._get_input_param(inner_request)
# Propagate trace context to SGLang # Propagate trace context to SGLang
......
...@@ -1080,6 +1080,9 @@ pub fn add_query_instance_id( ...@@ -1080,6 +1080,9 @@ pub fn add_query_instance_id(
/// ///
/// For disaggregated mode: sets `prefill_worker_id` and `decode_worker_id` /// For disaggregated mode: sets `prefill_worker_id` and `decode_worker_id`
/// For aggregated mode: sets `backend_instance_id` (when both IDs are the same) /// For aggregated mode: sets `backend_instance_id` (when both IDs are the same)
///
/// Also sets `enable_local_updates: false` since the external caller (EPP/GAIE)
/// will handle bookkeeping via C FFI functions.
pub fn set_worker_ids_for_stage2( pub fn set_worker_ids_for_stage2(
request: &mut NvCreateChatCompletionRequest, request: &mut NvCreateChatCompletionRequest,
decode_worker_id: Option<i64>, decode_worker_id: Option<i64>,
...@@ -1091,6 +1094,9 @@ pub fn set_worker_ids_for_stage2( ...@@ -1091,6 +1094,9 @@ pub fn set_worker_ids_for_stage2(
.expect("NvExt builder should not fail") .expect("NvExt builder should not fail")
}); });
// Disable local updates - external caller handles bookkeeping via C FFI
nvext.enable_local_updates = Some(false);
// Check if this is aggregated mode (same worker for both) // Check if this is aggregated mode (same worker for both)
let is_aggregated = prefill_worker_id == decode_worker_id; let is_aggregated = prefill_worker_id == decode_worker_id;
......
...@@ -354,71 +354,76 @@ impl KvRouter { ...@@ -354,71 +354,76 @@ impl KvRouter {
tracing::info!("Worker query client initialized"); tracing::info!("Worker query client initialized");
// Start KV event subscriber background process (only when use_kv_events is enabled) // Start KV event subscriber background process (only when use_kv_events is enabled)
// This is spawned as a background task to avoid blocking router startup. // We block here until at least one worker runtime config is registered,
// The task waits for runtime_configs to determine whether to use NATS Core or JetStream. // then spawn the subscriber. This ensures the router is ready before accepting requests.
if kv_router_config.use_kv_events if kv_router_config.use_kv_events
&& let Indexer::KvIndexer(ref kv_indexer) = indexer && let Indexer::KvIndexer(ref kv_indexer) = indexer
{ {
// Clone everything needed for the background task
let component_clone = component.clone();
let kv_indexer_clone = kv_indexer.clone();
let cancellation_token_clone = cancellation_token.clone();
let mut runtime_configs_rx_clone = runtime_configs_rx.clone(); let mut runtime_configs_rx_clone = runtime_configs_rx.clone();
let worker_query_client_clone =
worker_query::WorkerQueryClient::new(component.clone(), runtime_configs_rx.clone());
tokio::spawn(async move { // Wait for at least one worker runtime config to be registered
// Wait for runtime_configs to have at least one entry tracing::info!("Waiting for at least one worker runtime config to be registered...");
let (all_local_indexer, count) = loop { let (all_local_indexer, count) = loop {
{ {
let configs = runtime_configs_rx_clone.borrow(); let configs = runtime_configs_rx_clone.borrow();
if !configs.is_empty() { if !configs.is_empty() {
let all_local_indexer = let all_local_indexer = configs.values().all(|c| c.enable_local_indexer);
configs.values().all(|c| c.enable_local_indexer); break (all_local_indexer, configs.len());
break (all_local_indexer, configs.len());
}
} }
}
// Wait for changes to runtime_configs // Wait for changes to runtime_configs
tokio::select! { tokio::select! {
_ = cancellation_token_clone.cancelled() => { _ = cancellation_token.cancelled() => {
tracing::debug!("Subscriber selection task cancelled"); tracing::debug!("KvRouter startup cancelled while waiting for workers");
return; anyhow::bail!("KvRouter startup cancelled");
} }
result = runtime_configs_rx_clone.changed() => { result = runtime_configs_rx_clone.changed() => {
if result.is_err() { if result.is_err() {
tracing::debug!("Runtime configs channel closed"); tracing::debug!("Runtime configs channel closed");
return; anyhow::bail!("Runtime configs channel closed before any workers registered");
}
} }
} }
}; }
};
tracing::info!("Found {count} worker runtime config(s), starting KV event subscriber");
if all_local_indexer { // Clone everything needed for the background subscriber task
// All workers have local_indexer enabled - use NATS Core let component_clone = component.clone();
tracing::info!( let kv_indexer_clone = kv_indexer.clone();
"All {count} workers have local_indexer enabled, using NATS Core subscription" let cancellation_token_clone = cancellation_token.clone();
); let worker_query_client_clone =
worker_query::WorkerQueryClient::new(component.clone(), runtime_configs_rx.clone());
// Spawn subscriber as background task (long-running)
if all_local_indexer {
// All workers have local_indexer enabled - use NATS Core
tracing::info!(
"All {count} workers have local_indexer enabled, using NATS Core subscription"
);
tokio::spawn(async move {
if let Err(e) = start_kv_router_background_nats_core( if let Err(e) = start_kv_router_background_nats_core(
component_clone.clone(), component_clone,
kv_indexer_clone.event_sender(), kv_indexer_clone.event_sender(),
kv_indexer_clone.remove_worker_sender(), kv_indexer_clone.remove_worker_sender(),
cancellation_token_clone.clone(), cancellation_token_clone,
worker_query_client_clone, worker_query_client_clone,
) )
.await .await
{ {
tracing::error!("Failed to start NATS Core subscriber: {e}"); tracing::error!("Failed to start NATS Core subscriber: {e}");
} }
} else { });
// Not all workers have local_indexer - use JetStream } else {
tracing::info!( // Not all workers have local_indexer - use JetStream
"Not all workers have local_indexer enabled, using JetStream subscription" tracing::info!(
); "Not all workers have local_indexer enabled, using JetStream subscription"
);
tokio::spawn(async move {
if let Err(e) = start_kv_router_background( if let Err(e) = start_kv_router_background(
component_clone.clone(), component_clone,
consumer_id, consumer_id,
kv_indexer_clone.event_sender(), kv_indexer_clone.event_sender(),
kv_indexer_clone.remove_worker_sender(), kv_indexer_clone.remove_worker_sender(),
...@@ -428,7 +433,7 @@ impl KvRouter { ...@@ -428,7 +433,7 @@ impl KvRouter {
kv_router_config kv_router_config
.router_snapshot_threshold .router_snapshot_threshold
.map(|_| kv_indexer_clone.snapshot_event_sender()), .map(|_| kv_indexer_clone.snapshot_event_sender()),
cancellation_token_clone.clone(), cancellation_token_clone,
kv_router_config.router_snapshot_threshold, kv_router_config.router_snapshot_threshold,
kv_router_config.router_reset_states, kv_router_config.router_reset_states,
) )
...@@ -436,8 +441,8 @@ impl KvRouter { ...@@ -436,8 +441,8 @@ impl KvRouter {
{ {
tracing::error!("Failed to start JetStream subscriber: {e}"); tracing::error!("Failed to start JetStream subscriber: {e}");
} }
} });
}); }
} }
tracing::info!("KV Routing initialized"); tracing::info!("KV Routing initialized");
...@@ -815,17 +820,12 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -815,17 +820,12 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let is_query_only = request.get_annotation_value("query_instance_id").is_some(); let is_query_only = request.get_annotation_value("query_instance_id").is_some();
// Determine if this router should handle local state updates (add_request, free, etc.) // Determine if this router should handle local state updates (add_request, free, etc.)
// Only skip local updates for GAIE Stage 2: when BOTH prefill and decode worker IDs // Default is true (router handles bookkeeping). Set to false for GAIE Stage 2 where
// are externally specified (indicates external orchestrator handles tracking). // an external orchestrator (e.g., EPP sidecar) handles bookkeeping via C FFI.
// For internal routing (e.g., bootstrap optimization with only prefill_worker_id set), let handle_local_updates = request
// we still handle updates locally. .routing
let routing = request.routing.as_ref(); .as_ref()
let handle_local_updates = routing .and_then(|r| r.enable_local_updates)
.map(|r| {
// GAIE Stage 2 sets both worker IDs - external caller handles tracking
// All other cases (including backend_instance_id for routing) - we handle locally
r.prefill_worker_id.is_none() || r.decode_worker_id.is_none()
})
.unwrap_or(true); .unwrap_or(true);
// Get phase from tracker (defaults to Aggregated if no tracker or phase not set) // Get phase from tracker (defaults to Aggregated if no tracker or phase not set)
...@@ -917,9 +917,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -917,9 +917,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let stream_context = response_stream.context(); let stream_context = response_stream.context();
let context_for_monitoring = stream_context.clone(); let context_for_monitoring = stream_context.clone();
// TODO: When handle_local_updates=false, consider moving mark_prefill_completed // Wrap stream with lifecycle management (mark_prefill_completed, free)
// to an external caller (e.g., sidecar) if they support a first-token hook. // Only perform these operations if handle_local_updates is true.
// Currently mark_prefill_completed is called here for all flows. // When false, an external caller (e.g., GAIE sidecar) handles bookkeeping via C FFI.
let wrapped_stream = Box::pin(async_stream::stream! { let wrapped_stream = Box::pin(async_stream::stream! {
let mut prefill_marked = false; let mut prefill_marked = false;
...@@ -937,7 +937,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -937,7 +937,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
break; break;
}; };
if !prefill_marked { if handle_local_updates && !prefill_marked {
// Only mark prefill completed when we receive actual tokens, // Only mark prefill completed when we receive actual tokens,
// not empty bootstrap info (token_ids: []) from disaggregated prefill // not empty bootstrap info (token_ids: []) from disaggregated prefill
let has_tokens = item.data.as_ref() let has_tokens = item.data.as_ref()
...@@ -956,8 +956,11 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -956,8 +956,11 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
} }
} }
// Always call free() - it's idempotent and safe even if already freed or never added // Only call free() if we handle local updates.
if let Err(e) = chooser.free(&context_id).await { // When handle_local_updates=false, external caller handles cleanup via C FFI.
if handle_local_updates
&& let Err(e) = chooser.free(&context_id).await
{
tracing::warn!("Failed to free request {context_id}: {e}"); tracing::warn!("Failed to free request {context_id}: {e}");
} }
}); });
......
...@@ -6,7 +6,7 @@ use std::sync::{Arc, OnceLock}; ...@@ -6,7 +6,7 @@ use std::sync::{Arc, OnceLock};
use anyhow::Result; use anyhow::Result;
use futures::StreamExt; use futures::StreamExt;
use rand::Rng; use rand::Rng;
use tokio::sync::oneshot; use tokio::sync::{OwnedSemaphorePermit, oneshot};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use dynamo_runtime::{ use dynamo_runtime::{
...@@ -24,7 +24,6 @@ use crate::{ ...@@ -24,7 +24,6 @@ use crate::{
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest}, protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
protocols::common::preprocessor::{BootstrapInfo, PrefillResult}, protocols::common::preprocessor::{BootstrapInfo, PrefillResult},
protocols::common::timing::{RequestPhase, RequestTracker}, protocols::common::timing::{RequestPhase, RequestTracker},
protocols::openai::nvext::WorkerIdInfo,
}; };
/// Errors that can occur during prefill routing /// Errors that can occur during prefill routing
...@@ -85,10 +84,10 @@ impl InnerPrefillRouter { ...@@ -85,10 +84,10 @@ impl InnerPrefillRouter {
/// It optionally calls a prefill worker before routing to decode, extracting disaggregated_params /// It optionally calls a prefill worker before routing to decode, extracting disaggregated_params
/// from the prefill response and injecting them into the decode request. /// from the prefill response and injecting them into the decode request.
/// ///
/// Supports regular Dynamo and GAIE integrated mode via query_instance_id state machine: /// Modes:
/// - GAIE Stage 1: query_instance_id transitions "" -> "prefill" -> "decode", returns only worker IDs /// - Query-only: `query_instance_id` annotation present → returns worker IDs without execution
/// - GAIE Stage 2: routing.prefill_worker_id/routing.decode_worker_id are set, full execution with specified workers /// - Pre-routed: `prefill_worker_id`/`decode_worker_id` set → routes to specified workers
/// - Non-GAIE: like GAIE Stage 2 but the worker ids have to be determined. /// - Normal: Worker IDs determined by router based on KV cache state
pub struct PrefillRouter { pub struct PrefillRouter {
prefill_router: OnceLock<InnerPrefillRouter>, prefill_router: OnceLock<InnerPrefillRouter>,
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
...@@ -232,11 +231,6 @@ impl PrefillRouter { ...@@ -232,11 +231,6 @@ impl PrefillRouter {
Ok(()) Ok(())
} }
/// Generate a unique bootstrap room ID for disaggregated serving
fn generate_bootstrap_room() -> u64 {
rand::rng().random()
}
/// Build bootstrap_info for disaggregated serving /// Build bootstrap_info for disaggregated serving
/// If preselected_worker is provided (GAIE Stage 2), use it directly. /// If preselected_worker is provided (GAIE Stage 2), use it directly.
/// Otherwise, query for the best worker (KV mode) or select next worker (non-KV modes). /// Otherwise, query for the best worker (KV mode) or select next worker (non-KV modes).
...@@ -250,7 +244,6 @@ impl PrefillRouter { ...@@ -250,7 +244,6 @@ impl PrefillRouter {
// Worker selection // Worker selection
let (worker_id, dp_rank) = if let Some(id) = preselected_worker { let (worker_id, dp_rank) = if let Some(id) = preselected_worker {
// GAIE Stage 2: use pre-selected worker
let dp_rank = req.routing.as_ref().and_then(|r| r.dp_rank).unwrap_or(0); let dp_rank = req.routing.as_ref().and_then(|r| r.dp_rank).unwrap_or(0);
tracing::debug!( tracing::debug!(
worker_id = id, worker_id = id,
...@@ -285,7 +278,7 @@ impl PrefillRouter { ...@@ -285,7 +278,7 @@ impl PrefillRouter {
let host = endpoint.bootstrap_host?; let host = endpoint.bootstrap_host?;
let port = endpoint.bootstrap_port?; let port = endpoint.bootstrap_port?;
let bootstrap_room = Self::generate_bootstrap_room(); let bootstrap_room: u64 = rand::rng().random();
tracing::info!( tracing::info!(
worker_id = worker_id, worker_id = worker_id,
...@@ -308,12 +301,18 @@ impl PrefillRouter { ...@@ -308,12 +301,18 @@ impl PrefillRouter {
)) ))
} }
/// Execute prefill with the given router and extract structured result /// Execute prefill with the given router and extract structured result.
/// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization) ///
/// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization).
///
/// If `phase_permit` is provided, it is dropped after the first output is received,
/// allowing subsequent `set_phase` calls to proceed. This is used in the bootstrap
/// optimization path to ensure `record_worker` completes before the phase changes.
async fn execute_prefill( async fn execute_prefill(
router: Option<InnerPrefillRouter>, router: Option<InnerPrefillRouter>,
request: SingleIn<PreprocessedRequest>, request: SingleIn<PreprocessedRequest>,
target_worker: Option<u64>, target_worker: Option<u64>,
phase_permit: Option<OwnedSemaphorePermit>,
) -> Result<(PrefillResult, Option<u64>), PrefillError> { ) -> Result<(PrefillResult, Option<u64>), PrefillError> {
let router = router.ok_or(PrefillError::NotActivated)?; let router = router.ok_or(PrefillError::NotActivated)?;
let mut prefill_response = router let mut prefill_response = router
...@@ -321,6 +320,10 @@ impl PrefillRouter { ...@@ -321,6 +320,10 @@ impl PrefillRouter {
.await .await
.map_err(|e| PrefillError::PrefillError(e.to_string()))?; .map_err(|e| PrefillError::PrefillError(e.to_string()))?;
// Drop phase permit now - routing is complete, record_worker was called in select_worker.
// This unblocks set_phase(Decode) in the main task without waiting for prefill output.
drop(phase_permit);
let Some(first_output) = prefill_response.next().await else { let Some(first_output) = prefill_response.next().await else {
return Err(PrefillError::PrefillError( return Err(PrefillError::PrefillError(
"Prefill router returned no output (stream ended)".to_string(), "Prefill router returned no output (stream ended)".to_string(),
...@@ -379,17 +382,24 @@ impl PrefillRouter { ...@@ -379,17 +382,24 @@ impl PrefillRouter {
)) ))
} }
/// Spawn prefill as a background task /// Spawn prefill as a background task.
/// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization) ///
/// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization).
///
/// The `phase_permit` is passed to the spawned task and dropped after the first output,
/// allowing the main task's `set_phase(Decode)` to proceed.
fn spawn_prefill_task( fn spawn_prefill_task(
&self, &self,
prefill_request: SingleIn<PreprocessedRequest>, prefill_request: SingleIn<PreprocessedRequest>,
target_worker: Option<u64>, target_worker: Option<u64>,
phase_permit: OwnedSemaphorePermit,
) { ) {
let router = self.prefill_router.get().cloned(); let router = self.prefill_router.get().cloned();
tokio::spawn(async move { tokio::spawn(async move {
match Self::execute_prefill(router, prefill_request, target_worker).await { match Self::execute_prefill(router, prefill_request, target_worker, Some(phase_permit))
.await
{
Ok(_) => { Ok(_) => {
tracing::debug!("Prefill background task completed"); tracing::debug!("Prefill background task completed");
} }
...@@ -400,67 +410,17 @@ impl PrefillRouter { ...@@ -400,67 +410,17 @@ impl PrefillRouter {
}); });
} }
/// Call the prefill router and extract structured prefill result and worker ID /// Call the prefill router and extract structured prefill result and worker ID.
///
/// This is the synchronous prefill path - we wait for prefill to complete before proceeding.
/// No phase permit is needed since `record_worker` completes before we return.
async fn call_prefill( async fn call_prefill(
&self, &self,
request: SingleIn<PreprocessedRequest>, request: SingleIn<PreprocessedRequest>,
) -> Result<(PrefillResult, Option<u64>), PrefillError> { ) -> Result<(PrefillResult, Option<u64>), PrefillError> {
// For call_prefill path, routing is handled by the router itself (no direct routing needed) // For call_prefill path, routing is handled by the router itself (no direct routing needed)
Self::execute_prefill(self.prefill_router.get().cloned(), request, None).await // No phase permit needed - we wait for completion before changing phase
} Self::execute_prefill(self.prefill_router.get().cloned(), request, None, None).await
}
/// GAIE helper functions for preparing prefill requests
impl PrefillRouter {
/// Prepare prefill request for GAIE flows
/// - Stage 1: Sets query_instance_id:prefill annotation
/// - Stage 2: Sets backend_instance_id to target prefill worker
fn prepare_prefill_for_gaie(prefill_req: &mut PreprocessedRequest, is_gaie_stage1: bool) {
if is_gaie_stage1 {
// GAIE Stage 1: Set query_instance_id to "prefill" for prefill worker selection
prefill_req
.annotations
.retain(|a| !a.starts_with("query_instance_id"));
prefill_req
.annotations
.push(format!("query_instance_id:{}", RequestPhase::Prefill));
} else if let Some(prefill_worker_id) = prefill_req
.routing
.as_ref()
.and_then(|r| r.prefill_worker_id)
{
// GAIE Stage 2: Route to pre-selected prefill worker from the stage 1
tracing::debug!(
prefill_worker_id = prefill_worker_id,
"GAIE Stage 2: Routing prefill to pre-selected worker"
);
prefill_req.routing_mut().backend_instance_id = Some(prefill_worker_id);
}
}
/// Prepare decode request for GAIE Stage 1
/// Extracts prefill_worker_id from prefill result and sets decode annotations
fn prepare_decode_for_gaie_stage1(
decode_req: &mut PreprocessedRequest,
prefill_result: &PrefillResult,
) {
let prefill_worker_id = prefill_result
.disaggregated_params
.get("worker_id")
.and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok())
.and_then(|info| info.prefill_worker_id);
if let Some(worker_id) = prefill_worker_id {
decode_req
.annotations
.retain(|a| !a.starts_with("query_instance_id"));
decode_req
.annotations
.push(format!("query_instance_id:{}", RequestPhase::Decode));
decode_req
.annotations
.push(format!("prefill_worker_id:{worker_id}"));
}
} }
} }
...@@ -490,22 +450,14 @@ impl ...@@ -490,22 +450,14 @@ impl
let request_id = context.id().to_string(); let request_id = context.id().to_string();
let engine_ctx = context.context(); let engine_ctx = context.context();
// GAIE Stage 1: the presence of the empty query_instance_id signals query-only mode
// State machine: "" -> "prefill" -> "decode" (disagg) OR "" -> aggregated worker (agg fallback)
let is_gaie_stage1 = req
.get_annotation_value("query_instance_id")
.is_some_and(|s| s.is_empty());
// Save original max_tokens for decode // Save original max_tokens for decode
let original_max_tokens = req.stop_conditions.max_tokens; let original_max_tokens = req.stop_conditions.max_tokens;
// GAIE Stage 1: Check if prefill router is activated - if not, skip to decode // If prefill router is not activated, skip directly to decode
if is_gaie_stage1 && self.prefill_router.get().is_none() { if self.prefill_router.get().is_none() {
tracing::debug!("GAIE Stage 1: Prefill router not activated, skipping to decode");
if self.enforce_disagg { if self.enforce_disagg {
return Err(anyhow::anyhow!(PrefillError::NotActivated)); return Err(anyhow::anyhow!(PrefillError::NotActivated));
} }
// Fall back to decode-only
return next.generate(context.map(|_| req)).await; return next.generate(context.map(|_| req)).await;
} }
...@@ -515,47 +467,45 @@ impl ...@@ -515,47 +467,45 @@ impl
req.tracker = Some(Arc::new(RequestTracker::new())); req.tracker = Some(Arc::new(RequestTracker::new()));
} }
let tracker = req.tracker.as_ref().unwrap(); let tracker = req.tracker.as_ref().unwrap();
tracker.set_phase(RequestPhase::Prefill); let prefill_phase_permit = tracker.set_phase(RequestPhase::Prefill).await;
tracker.record_prefill_start(); tracker.record_prefill_start();
// Prepare prefill request with max_tokens = 1 (clone after tracker is set) // Prepare prefill request with max_tokens = 1 (clone after tracker is set)
let mut prefill_req = req.clone(); let mut prefill_req = req.clone();
prefill_req.stop_conditions.max_tokens = Some(1); prefill_req.stop_conditions.max_tokens = Some(1);
// Prepare prefill request for GAIE flows (Stage 1 or Stage 2) // Try build_bootstrap_info optimization: if we can get bootstrap info upfront,
Self::prepare_prefill_for_gaie(&mut prefill_req, is_gaie_stage1); // spawn prefill in background and proceed to decode immediately.
// Try build_bootstrap_info optimization (skip for GAIE Stage 1 which needs query-only flow)
// For GAIE Stage 2, use prefill_worker_id if provided
let preselected_worker = prefill_req let preselected_worker = prefill_req
.routing .routing
.as_ref() .as_ref()
.and_then(|r| r.prefill_worker_id); .and_then(|r| r.prefill_worker_id);
let prefill_result = if !is_gaie_stage1 let prefill_result = if let Some((worker_id, dp_rank, bootstrap_info)) = self
&& let Some((worker_id, dp_rank, bootstrap_info)) = self .build_bootstrap_info(&prefill_req, preselected_worker)
.build_bootstrap_info(&prefill_req, preselected_worker) .await
.await
{ {
// Bootstrap optimization path: spawn prefill in background // Bootstrap optimization path: spawn prefill in background
let routing = prefill_req.routing_mut(); let routing = prefill_req.routing_mut();
routing.prefill_worker_id = Some(worker_id); routing.prefill_worker_id = Some(worker_id);
routing.backend_instance_id = Some(worker_id); // Route prefill to the SAME worker we got bootstrap_info from
routing.dp_rank = Some(dp_rank); routing.dp_rank = Some(dp_rank);
prefill_req.bootstrap_info = Some(bootstrap_info.clone()); prefill_req.bootstrap_info = Some(bootstrap_info.clone());
let prefill_context = Context::with_id(prefill_req, request_id.clone()); let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context()); engine_ctx.link_child(prefill_context.context());
self.spawn_prefill_task(prefill_context, Some(worker_id)); // Pass phase permit to spawned task - it drops after first output (record_worker complete)
// This allows set_phase(Decode) below to proceed only after prefill routing is done
self.spawn_prefill_task(prefill_context, Some(worker_id), prefill_phase_permit);
Ok((None, Some(worker_id), Some(bootstrap_info))) Ok((None, Some(worker_id), Some(bootstrap_info)))
} else { } else {
// Original prefill path: wait for prefill to complete // Original prefill path: wait for prefill to complete
tracing::debug!( tracing::debug!("Using original prefill path");
is_gaie_stage1 = is_gaie_stage1,
"Using original prefill path" // Drop the phase permit before calling call_prefill - we wait for completion
); // so there's no race with set_phase(Decode) below
drop(prefill_phase_permit);
let prefill_context = Context::with_id(prefill_req, request_id.clone()); let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context()); engine_ctx.link_child(prefill_context.context());
...@@ -579,20 +529,18 @@ impl ...@@ -579,20 +529,18 @@ impl
Ok((maybe_prefill_result, _prefill_worker_id, bootstrap_info)) => { Ok((maybe_prefill_result, _prefill_worker_id, bootstrap_info)) => {
tracing::debug!("Prefill completed, proceeding to decode"); tracing::debug!("Prefill completed, proceeding to decode");
// Set phase to Decode for the decode request // Set phase to Decode for the decode request.
// In bootstrap path, this blocks until the spawned prefill task drops its permit
// (after first output / record_worker completes), ensuring correct phase for routing.
if let Some(ref tracker) = req.tracker { if let Some(ref tracker) = req.tracker {
tracker.set_phase(RequestPhase::Decode); let _decode_permit = tracker.set_phase(RequestPhase::Decode).await;
// Permit is dropped immediately - decode proceeds, no need to hold it
} }
let mut decode_req = req; let mut decode_req = req;
// Update request with prefill result // Update request with prefill result
if is_gaie_stage1 { if let Some(prefill_result) = maybe_prefill_result {
if let Some(ref prefill_result) = maybe_prefill_result {
Self::prepare_decode_for_gaie_stage1(&mut decode_req, prefill_result);
}
} else if let Some(prefill_result) = maybe_prefill_result {
// Normal or GAIE Stage 2: Set prefill_result for decode
decode_req.prefill_result = Some(prefill_result); decode_req.prefill_result = Some(prefill_result);
} }
...@@ -611,17 +559,6 @@ impl ...@@ -611,17 +559,6 @@ impl
..existing_override.unwrap_or_default() ..existing_override.unwrap_or_default()
}); });
// GAIE Stage 2: Route to pre-selected decode worker if specified
if let Some(decode_worker_id) =
decode_req.routing.as_ref().and_then(|r| r.decode_worker_id)
{
decode_req.routing_mut().backend_instance_id = Some(decode_worker_id);
tracing::debug!(
decode_worker_id = decode_worker_id,
"GAIE Stage 2: Routing decode to pre-selected worker"
);
}
// Map the modified request through with preserved context // Map the modified request through with preserved context
let decode_request = context.map(|_| decode_req); let decode_request = context.map(|_| decode_req);
next.generate(decode_request).await next.generate(decode_request).await
......
...@@ -651,6 +651,51 @@ pub async fn start_kv_router_background( ...@@ -651,6 +651,51 @@ pub async fn start_kv_router_background(
Ok(()) Ok(())
} }
/// Handle a worker discovery event (added or removed).
async fn handle_worker_discovery(
event: DiscoveryEvent,
worker_query_client: &WorkerQueryClient,
kv_events_tx: &mpsc::Sender<RouterEvent>,
remove_worker_tx: &mpsc::Sender<WorkerId>,
) {
match event {
DiscoveryEvent::Added(instance) => {
let worker_id = instance.instance_id();
tracing::info!(
"DISCOVERY: Worker {worker_id} added, dumping local indexer into router"
);
match recover_from_worker(
worker_query_client,
worker_id,
None, // Start from beginning
None, // Get all events
kv_events_tx,
)
.await
{
Ok(count) => {
tracing::info!(
"Successfully dumped worker {worker_id}'s local indexer, recovered {count} events"
);
}
Err(e) => {
tracing::warn!(
"Failed to dump worker {worker_id}'s local indexer (may not have local indexer enabled): {e}"
);
}
}
}
DiscoveryEvent::Removed(worker_id) => {
tracing::warn!("DISCOVERY: Worker {worker_id} removed, removing from router indexer");
if let Err(e) = remove_worker_tx.send(worker_id).await {
tracing::warn!("Failed to send worker removal for worker {worker_id}: {e}");
}
}
}
}
/// Start a simplified background task for event consumption using NATS Core. /// Start a simplified background task for event consumption using NATS Core.
/// ///
/// This is used when local indexer mode is enabled. Unlike `start_kv_router_background`, /// This is used when local indexer mode is enabled. Unlike `start_kv_router_background`,
...@@ -660,6 +705,9 @@ pub async fn start_kv_router_background( ...@@ -660,6 +705,9 @@ pub async fn start_kv_router_background(
/// - On worker Added: dumps worker's local indexer into router /// - On worker Added: dumps worker's local indexer into router
/// - On worker Removed: removes worker from router indexer /// - On worker Removed: removes worker from router indexer
/// ///
/// This function first recovers state from all currently registered workers before
/// spawning the background task, ensuring the router is ready before returning.
///
/// This is appropriate when workers have local indexers enabled. /// This is appropriate when workers have local indexers enabled.
pub async fn start_kv_router_background_nats_core( pub async fn start_kv_router_background_nats_core(
component: Component, component: Component,
...@@ -688,6 +736,40 @@ pub async fn start_kv_router_background_nats_core( ...@@ -688,6 +736,40 @@ pub async fn start_kv_router_background_nats_core(
.list_and_watch(generate_discovery_key, Some(cancellation_token.clone())) .list_and_watch(generate_discovery_key, Some(cancellation_token.clone()))
.await?; .await?;
// Drain and process all existing workers before spawning the background loop.
// list_and_watch returns existing instances first, so we poll with a short timeout
// to process all initial workers synchronously before the router becomes "ready".
loop {
// Use a short timeout to detect when initial discovery events are exhausted
let poll_result =
tokio::time::timeout(Duration::from_millis(100), instance_event_stream.next()).await;
match poll_result {
Ok(Some(Ok(event))) => {
handle_worker_discovery(
event,
&worker_query_client,
&kv_events_tx,
&remove_worker_tx,
)
.await;
}
Ok(Some(Err(e))) => {
tracing::warn!("Error receiving discovery event during initial sync: {e}");
}
Ok(None) => {
// Stream ended
tracing::warn!("Discovery stream ended during initial sync");
break;
}
Err(_) => {
// Timeout - no more initial events
tracing::debug!("Initial worker discovery sync complete");
break;
}
}
}
tokio::spawn(async move { tokio::spawn(async move {
// Track last received event ID per worker for gap detection // Track last received event ID per worker for gap detection
let mut last_event_ids: HashMap<WorkerId, u64> = HashMap::new(); let mut last_event_ids: HashMap<WorkerId, u64> = HashMap::new();
...@@ -703,51 +785,17 @@ pub async fn start_kv_router_background_nats_core( ...@@ -703,51 +785,17 @@ pub async fn start_kv_router_background_nats_core(
// Handle generate endpoint instance add/remove events // Handle generate endpoint instance add/remove events
Some(discovery_event_result) = instance_event_stream.next() => { Some(discovery_event_result) = instance_event_stream.next() => {
let Ok(discovery_event) = discovery_event_result else { let Ok(event) = discovery_event_result else {
continue; continue;
}; };
match discovery_event { handle_worker_discovery(
DiscoveryEvent::Added(_instance) => { event,
// Extract worker_id from the instance &worker_query_client,
let worker_id = _instance.instance_id(); &kv_events_tx,
&remove_worker_tx,
tracing::info!( )
"DISCOVERY: Worker {worker_id} added, dumping local indexer into router" .await;
);
// Query worker's local indexer and dump all events
match recover_from_worker(
&worker_query_client,
worker_id,
None, // Start from beginning
None, // Get all events
&kv_events_tx,
)
.await
{
Ok(count) => {
tracing::info!(
"Successfully dumped worker {worker_id}'s local indexer, recovered {count} events"
);
}
Err(e) => {
tracing::warn!(
"Failed to dump worker {worker_id}'s local indexer (may not have local indexer enabled): {e}"
);
}
}
}
DiscoveryEvent::Removed(worker_id) => {
tracing::warn!(
"DISCOVERY: Worker {worker_id} removed, removing from router indexer"
);
if let Err(e) = remove_worker_tx.send(worker_id).await {
tracing::warn!("Failed to send worker removal for worker {worker_id}: {e}");
}
}
}
} }
// Handle event consumption from NATS Core subscription // Handle event consumption from NATS Core subscription
......
...@@ -10,9 +10,10 @@ use dynamo_runtime::discovery::DiscoverySpec; ...@@ -10,9 +10,10 @@ use dynamo_runtime::discovery::DiscoverySpec;
use dynamo_runtime::protocols::EndpointId; use dynamo_runtime::protocols::EndpointId;
use dynamo_runtime::slug::Slug; use dynamo_runtime::slug::Slug;
use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::utils::get_http_rpc_host_from_env;
use crate::entrypoint::RouterConfig; use crate::entrypoint::RouterConfig;
use crate::mocker::protocols::MockEngineArgs; use crate::mocker::protocols::{MockEngineArgs, WorkerType};
use crate::model_card::ModelDeploymentCard; use crate::model_card::ModelDeploymentCard;
use crate::model_type::{ModelInput, ModelType}; use crate::model_type::{ModelInput, ModelType};
use crate::preprocessor::media::{ImageDecoder, MediaDecoder, MediaFetcher}; use crate::preprocessor::media::{ImageDecoder, MediaDecoder, MediaFetcher};
...@@ -249,6 +250,22 @@ impl LocalModelBuilder { ...@@ -249,6 +250,22 @@ impl LocalModelBuilder {
video: None, video: None,
}); });
self.media_fetcher = Some(MediaFetcher::default()); self.media_fetcher = Some(MediaFetcher::default());
// Set bootstrap endpoint for prefill workers with bootstrap_port configured
if mocker_engine_args.worker_type == WorkerType::Prefill
&& let Some(port) = mocker_engine_args.bootstrap_port
{
let host = get_http_rpc_host_from_env();
self.runtime_config.disaggregated_endpoint =
Some(runtime_config::DisaggregatedEndpoint {
bootstrap_host: Some(host),
bootstrap_port: Some(port),
});
tracing::info!(
bootstrap_port = port,
"Mocker prefill worker: publishing bootstrap endpoint to discovery"
);
}
} }
// frontend and echo engine don't need a path. // frontend and echo engine don't need a path.
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
pub mod bootstrap;
pub mod engine; pub mod engine;
pub mod evictor; pub mod evictor;
pub mod kv_manager; pub mod kv_manager;
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Bootstrap rendezvous for disaggregated mocker testing.
//!
//! Simulates the SGLang disaggregated serving handshake for KV transfer coordination.
//! Either prefill or decode can arrive first; the rendezvous completes when both are ready.
//!
//! - Prefill: calls `complete_room(room_id)` after first token (KV cache ready)
//! - Decode: connects to prefill's bootstrap server, blocks until prefill completes
//!
//! Wire protocol:
//! - Decode -> Prefill: room_id (8 bytes, little-endian u64)
//! - Prefill -> Decode: ACK (1 byte, 0x01) after prefill completes
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Result, bail};
use dashmap::DashMap;
use dashmap::mapref::entry::Entry;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
/// Timeout for bootstrap rendezvous operations.
const RENDEZVOUS_TIMEOUT: Duration = Duration::from_secs(30);
/// ACK byte sent from server to decode after prefill completes.
const ACK_BYTE: u8 = 0x01;
/// State for a room in the rendezvous.
struct RoomState {
/// True if prefill has completed (KV cache ready)
prefill_completed: bool,
/// Channel to notify decode when prefill completes (if decode is waiting)
decode_waiting: Option<oneshot::Sender<()>>,
}
/// Bootstrap server for prefill mockers.
/// Handles rendezvous between prefill and decode for KV transfer coordination.
pub struct BootstrapServer {
port: u16,
rooms: Arc<DashMap<u64, RoomState>>,
}
impl BootstrapServer {
/// Start the bootstrap server on the specified port.
pub async fn start(port: u16, cancel_token: CancellationToken) -> Result<Arc<Self>> {
let listener = TcpListener::bind(format!("0.0.0.0:{port}")).await?;
let actual_port = listener.local_addr()?.port();
tracing::info!("Bootstrap server started on port {actual_port}");
let rooms: Arc<DashMap<u64, RoomState>> = Arc::new(DashMap::new());
let server = Arc::new(Self {
port: actual_port,
rooms: rooms.clone(),
});
// Spawn accept loop
tokio::spawn(async move {
loop {
tokio::select! {
result = listener.accept() => {
match result {
Ok((stream, addr)) => {
tracing::debug!("Bootstrap: accepted connection from {addr}");
let rooms_clone = rooms.clone();
tokio::spawn(async move {
if let Err(e) = Self::handle_connection(stream, rooms_clone).await {
tracing::warn!("Bootstrap: connection error: {e}");
}
});
}
Err(e) => {
tracing::warn!("Bootstrap: accept failed: {e}");
}
}
}
_ = cancel_token.cancelled() => {
tracing::debug!("Bootstrap server shutting down");
break;
}
}
}
});
Ok(server)
}
/// Handle a connection from decode. Blocks until prefill completes for this room.
async fn handle_connection(
mut stream: TcpStream,
rooms: Arc<DashMap<u64, RoomState>>,
) -> Result<()> {
// Read room_id (8 bytes, little-endian)
let mut buf = [0u8; 8];
stream.read_exact(&mut buf).await?;
let room_id = u64::from_le_bytes(buf);
tracing::debug!("Bootstrap: decode connected for room {room_id}");
// Check room state and wait if needed
let rx = match rooms.entry(room_id) {
Entry::Occupied(mut entry) => {
if entry.get().prefill_completed {
// Prefill already done, immediate ACK
entry.remove();
tracing::debug!("Bootstrap: room {room_id} already completed, immediate ACK");
None
} else {
// Prefill registered but not completed, wait
let (tx, rx) = oneshot::channel();
entry.get_mut().decode_waiting = Some(tx);
tracing::debug!("Bootstrap: room {room_id} waiting for prefill to complete");
Some(rx)
}
}
Entry::Vacant(entry) => {
// Decode arrived first, create entry and wait
let (tx, rx) = oneshot::channel();
entry.insert(RoomState {
prefill_completed: false,
decode_waiting: Some(tx),
});
tracing::debug!("Bootstrap: room {room_id} decode arrived first, waiting");
Some(rx)
}
};
// Wait for prefill to complete if needed
if let Some(rx) = rx {
match tokio::time::timeout(RENDEZVOUS_TIMEOUT, rx).await {
Ok(Ok(())) => {
tracing::debug!("Bootstrap: room {room_id} prefill completed, sending ACK");
}
Ok(Err(_)) => {
bail!("Bootstrap: room {room_id} sender dropped");
}
Err(_) => {
rooms.remove(&room_id);
bail!("Bootstrap: room {room_id} timeout waiting for prefill");
}
}
}
// Send ACK
stream.write_all(&[ACK_BYTE]).await?;
Ok(())
}
/// Mark a room as completed (prefill finished, KV cache ready).
/// If decode is already waiting, unblocks it.
pub fn complete_room(&self, room_id: u64) {
match self.rooms.entry(room_id) {
Entry::Occupied(mut entry) => {
if let Some(sender) = entry.get_mut().decode_waiting.take() {
// Decode is waiting, unblock it
let _ = sender.send(());
entry.remove();
tracing::debug!("Bootstrap: room {room_id} completed, decode unblocked");
} else {
// Decode not connected yet, mark completed
entry.get_mut().prefill_completed = true;
tracing::debug!("Bootstrap: room {room_id} completed, awaiting decode");
}
}
Entry::Vacant(entry) => {
// Decode hasn't connected yet
entry.insert(RoomState {
prefill_completed: true,
decode_waiting: None,
});
tracing::debug!("Bootstrap: room {room_id} completed (no decode yet)");
}
}
}
/// Get the port the server is listening on.
pub fn port(&self) -> u16 {
self.port
}
}
/// Connect to a prefill worker's bootstrap server and wait for KV to be ready.
pub async fn connect_to_prefill(host: &str, port: u16, room_id: u64) -> Result<()> {
let host = host.trim_matches(|c| c == '[' || c == ']');
let addr = format!("{host}:{port}");
tracing::debug!("Bootstrap: decode connecting to {addr} for room {room_id}");
// Connect with timeout
let mut stream = tokio::time::timeout(RENDEZVOUS_TIMEOUT, TcpStream::connect(&addr))
.await
.map_err(|_| anyhow::anyhow!("Bootstrap: connect timeout to {addr}"))?
.map_err(|e| anyhow::anyhow!("Bootstrap: connect failed to {addr}: {e}"))?;
// Send room_id
stream.write_all(&room_id.to_le_bytes()).await?;
// Wait for ACK (blocks until prefill completes)
let mut ack = [0u8; 1];
tokio::time::timeout(RENDEZVOUS_TIMEOUT, stream.read_exact(&mut ack))
.await
.map_err(|_| anyhow::anyhow!("Bootstrap: ACK timeout for room {room_id}"))?
.map_err(|e| anyhow::anyhow!("Bootstrap: read ACK failed: {e}"))?;
if ack[0] != ACK_BYTE {
bail!(
"Bootstrap: invalid ACK byte {:02x} for room {room_id}",
ack[0]
);
}
tracing::debug!("Bootstrap: decode received ACK for room {room_id}");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_prefill_completes_first() {
let cancel_token = CancellationToken::new();
let server = BootstrapServer::start(0, cancel_token.clone())
.await
.unwrap();
let port = server.port();
let room_id = 1001u64;
// Prefill completes first
server.complete_room(room_id);
// Decode connects - should get immediate ACK
let result = connect_to_prefill("127.0.0.1", port, room_id).await;
assert!(result.is_ok(), "Decode should succeed: {result:?}");
cancel_token.cancel();
}
#[tokio::test]
async fn test_decode_connects_first() {
let cancel_token = CancellationToken::new();
let server = BootstrapServer::start(0, cancel_token.clone())
.await
.unwrap();
let port = server.port();
let room_id = 1002u64;
// Spawn decode (will block waiting for prefill)
let decode_handle =
tokio::spawn(async move { connect_to_prefill("127.0.0.1", port, room_id).await });
// Give decode time to connect and register
tokio::time::sleep(Duration::from_millis(50)).await;
// Prefill completes - should unblock decode
server.complete_room(room_id);
let result = decode_handle.await.unwrap();
assert!(result.is_ok(), "Decode should succeed: {result:?}");
cancel_token.cancel();
}
#[tokio::test]
async fn test_interleaved_ordering() {
let cancel_token = CancellationToken::new();
let server = BootstrapServer::start(0, cancel_token.clone())
.await
.unwrap();
let port = server.port();
let room_id = 1003u64;
// Spawn decode
let server_clone = server.clone();
let decode_handle = tokio::spawn(async move {
// Small delay so prefill can "register" conceptually first
tokio::time::sleep(Duration::from_millis(10)).await;
connect_to_prefill("127.0.0.1", port, room_id).await
});
// Prefill completes after decode starts connecting
tokio::time::sleep(Duration::from_millis(50)).await;
server_clone.complete_room(room_id);
let result = decode_handle.await.unwrap();
assert!(result.is_ok(), "Decode should succeed: {result:?}");
cancel_token.cancel();
}
#[tokio::test]
async fn test_multiple_rooms_concurrent() {
let cancel_token = CancellationToken::new();
let server = BootstrapServer::start(0, cancel_token.clone())
.await
.unwrap();
let port = server.port();
let mut handles = vec![];
// Room 1: prefill first
let server1 = server.clone();
handles.push(tokio::spawn(async move {
server1.complete_room(2001);
tokio::time::sleep(Duration::from_millis(10)).await;
connect_to_prefill("127.0.0.1", port, 2001).await
}));
// Room 2: decode first
let server2 = server.clone();
handles.push(tokio::spawn(async move {
let decode = tokio::spawn(connect_to_prefill("127.0.0.1", port, 2002));
tokio::time::sleep(Duration::from_millis(50)).await;
server2.complete_room(2002);
decode.await.unwrap()
}));
// Room 3: simultaneous
let server3 = server.clone();
handles.push(tokio::spawn(async move {
let decode = tokio::spawn(connect_to_prefill("127.0.0.1", port, 2003));
server3.complete_room(2003);
decode.await.unwrap()
}));
for (i, handle) in handles.into_iter().enumerate() {
let result = handle.await.unwrap();
assert!(
result.is_ok(),
"Room {} should succeed: {result:?}",
2001 + i
);
}
cancel_token.cancel();
}
#[tokio::test]
async fn test_decode_timeout_no_prefill() {
let cancel_token = CancellationToken::new();
let server = BootstrapServer::start(0, cancel_token.clone())
.await
.unwrap();
let port = server.port();
let room_id = 9999u64;
// Decode connects but prefill never completes - use short timeout
let result = tokio::time::timeout(
Duration::from_millis(100),
connect_to_prefill("127.0.0.1", port, room_id),
)
.await;
// Should timeout (outer timeout, not inner RENDEZVOUS_TIMEOUT)
assert!(result.is_err(), "Should timeout waiting for prefill");
cancel_token.cancel();
}
}
...@@ -28,6 +28,7 @@ use dynamo_runtime::{ ...@@ -28,6 +28,7 @@ use dynamo_runtime::{
}; };
use crate::kv_router::publisher::WorkerMetricsPublisher; use crate::kv_router::publisher::WorkerMetricsPublisher;
use crate::mocker::bootstrap::{BootstrapServer, connect_to_prefill};
use crate::mocker::protocols::DirectRequest; use crate::mocker::protocols::DirectRequest;
use crate::mocker::protocols::{MockEngineArgs, OutputSignal, WorkerType}; use crate::mocker::protocols::{MockEngineArgs, OutputSignal, WorkerType};
use crate::mocker::scheduler::Scheduler; use crate::mocker::scheduler::Scheduler;
...@@ -47,6 +48,8 @@ pub struct MockVllmEngine { ...@@ -47,6 +48,8 @@ pub struct MockVllmEngine {
active_requests: Arc<Mutex<HashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>>, active_requests: Arc<Mutex<HashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>>,
request_senders: Arc<OnceCell<Vec<mpsc::UnboundedSender<DirectRequest>>>>, request_senders: Arc<OnceCell<Vec<mpsc::UnboundedSender<DirectRequest>>>>,
engine_args: MockEngineArgs, engine_args: MockEngineArgs,
/// Bootstrap server for prefill workers in disaggregated mode
bootstrap_server: Arc<OnceCell<Arc<BootstrapServer>>>,
} }
impl MockVllmEngine { impl MockVllmEngine {
...@@ -56,6 +59,7 @@ impl MockVllmEngine { ...@@ -56,6 +59,7 @@ impl MockVllmEngine {
active_requests: Arc::new(Mutex::new(HashMap::new())), active_requests: Arc::new(Mutex::new(HashMap::new())),
request_senders: Arc::new(OnceCell::new()), request_senders: Arc::new(OnceCell::new()),
engine_args: args, engine_args: args,
bootstrap_server: Arc::new(OnceCell::new()),
} }
} }
...@@ -73,6 +77,15 @@ impl MockVllmEngine { ...@@ -73,6 +77,15 @@ impl MockVllmEngine {
tracing::info!("Engine startup simulation completed"); tracing::info!("Engine startup simulation completed");
} }
// Start bootstrap server for prefill workers in disaggregated mode
if self.engine_args.worker_type == WorkerType::Prefill
&& let Some(port) = self.engine_args.bootstrap_port
{
let server = BootstrapServer::start(port, cancel_token.clone()).await?;
let _ = self.bootstrap_server.set(server);
tracing::info!(port = port, "Bootstrap server started for prefill worker");
}
// Pass component to schedulers only if prefix caching is enabled and not a decode worker // Pass component to schedulers only if prefix caching is enabled and not a decode worker
let scheduler_component = if self.engine_args.enable_prefix_caching let scheduler_component = if self.engine_args.enable_prefix_caching
&& self.engine_args.worker_type != WorkerType::Decode && self.engine_args.worker_type != WorkerType::Decode
...@@ -253,6 +266,22 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -253,6 +266,22 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
))); )));
} }
// Bootstrap rendezvous for disaggregated serving
// - Decode: connect to prefill's server, block until prefill completes
// - Prefill: complete_room() is called after first token (see below)
let bootstrap_room = request.bootstrap_info.as_ref().map(|b| b.bootstrap_room);
if let Some(bootstrap_info) = &request.bootstrap_info
&& self.engine_args.worker_type == WorkerType::Decode
{
connect_to_prefill(
&bootstrap_info.bootstrap_host,
bootstrap_info.bootstrap_port,
bootstrap_info.bootstrap_room,
)
.await
.map_err(|e| Error::msg(format!("Bootstrap connection failed: {e}")))?;
}
let request_uuid = ctx.id().parse().unwrap_or(Uuid::new_v4()); let request_uuid = ctx.id().parse().unwrap_or(Uuid::new_v4());
// For prefill workers, override max_tokens to 1 // For prefill workers, override max_tokens to 1
...@@ -288,6 +317,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -288,6 +317,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
let active_requests = self.active_requests.clone(); let active_requests = self.active_requests.clone();
let async_context = ctx.context(); let async_context = ctx.context();
let bootstrap_server = self.bootstrap_server.clone();
// Spawn a task to handle the complex async logic // Spawn a task to handle the complex async logic
tokio::spawn(async move { tokio::spawn(async move {
...@@ -325,6 +355,14 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -325,6 +355,14 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
completion_usage: None, completion_usage: None,
}; };
// Prefill: after first token, mark room complete (unblocks decode)
if is_prefill
&& token_count == 1
&& let (Some(server), Some(room_id)) = (bootstrap_server.get(), bootstrap_room)
{
server.complete_room(room_id);
}
if signal.completed && token_count < max_output_tokens { if signal.completed && token_count < max_output_tokens {
let _ = stream_tx.send(LLMEngineOutput::error("Completion signal received before max tokens reached".to_string())); let _ = stream_tx.send(LLMEngineOutput::error("Completion signal received before max tokens reached".to_string()));
break; break;
......
...@@ -124,6 +124,12 @@ pub struct MockEngineArgs { ...@@ -124,6 +124,12 @@ pub struct MockEngineArgs {
/// Enable worker-local KV indexer for tracking this worker's own KV cache state /// Enable worker-local KV indexer for tracking this worker's own KV cache state
#[builder(default = "false")] #[builder(default = "false")]
pub enable_local_indexer: bool, pub enable_local_indexer: bool,
/// Bootstrap port for disaggregated serving rendezvous.
/// Prefill workers listen on this port; decode workers connect to it.
/// If None, bootstrap rendezvous is disabled.
#[builder(default = "None")]
pub bootstrap_port: Option<u16>,
} }
impl Default for MockEngineArgs { impl Default for MockEngineArgs {
...@@ -163,6 +169,7 @@ impl MockEngineArgs { ...@@ -163,6 +169,7 @@ impl MockEngineArgs {
"is_decode", "is_decode",
"planner_profile_data", "planner_profile_data",
"enable_local_indexer", "enable_local_indexer",
"bootstrap_port",
] ]
.iter() .iter()
.cloned() .cloned()
...@@ -250,6 +257,12 @@ impl MockEngineArgs { ...@@ -250,6 +257,12 @@ impl MockEngineArgs {
builder = builder.enable_local_indexer(enabled); builder = builder.enable_local_indexer(enabled);
} }
if let Some(value) = extra_args.get("bootstrap_port")
&& let Some(port) = value.as_u64()
{
builder = builder.bootstrap_port(Some(port as u16));
}
// Parse worker type from is_prefill and is_decode flags // Parse worker type from is_prefill and is_decode flags
let is_prefill = extra_args let is_prefill = extra_args
.get("is_prefill") .get("is_prefill")
......
...@@ -245,6 +245,7 @@ impl OpenAIPreprocessor { ...@@ -245,6 +245,7 @@ impl OpenAIPreprocessor {
prefill_worker_id: nvext.prefill_worker_id, prefill_worker_id: nvext.prefill_worker_id,
decode_worker_id: nvext.decode_worker_id, decode_worker_id: nvext.decode_worker_id,
dp_rank: None, // dp_rank is set later in the pipeline dp_rank: None, // dp_rank is set later in the pipeline
enable_local_updates: nvext.enable_local_updates,
}; };
builder.routing(Some(routing)); builder.routing(Some(routing));
} }
......
...@@ -34,6 +34,14 @@ pub struct RoutingHints { ...@@ -34,6 +34,14 @@ pub struct RoutingHints {
/// Data parallel rank for the request /// Data parallel rank for the request
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub dp_rank: Option<u32>, pub dp_rank: Option<u32>,
/// Controls whether the router should manage local bookkeeping (add_request,
/// mark_prefill_completed, free) for this request.
///
/// - `None` or `Some(true)`: Router handles bookkeeping locally (default behavior)
/// - `Some(false)`: External caller (e.g., GAIE sidecar) handles bookkeeping via C FFI
#[serde(default, skip_serializing_if = "Option::is_none")]
pub enable_local_updates: Option<bool>,
} }
#[derive(Serialize, Deserialize, Debug, Clone, Default)] #[derive(Serialize, Deserialize, Debug, Clone, Default)]
......
...@@ -6,9 +6,12 @@ ...@@ -6,9 +6,12 @@
//! This module provides [`RequestTracker`] for tracking timing and routing information //! This module provides [`RequestTracker`] for tracking timing and routing information
//! that can be returned to clients via the `nvext` response field. //! that can be returned to clients via the `nvext` response field.
use serde::{Deserialize, Serialize}; use std::sync::{Arc, OnceLock};
use std::sync::{Mutex, OnceLock};
use std::time::{Instant, SystemTime, UNIX_EPOCH}; use std::time::{Instant, SystemTime, UNIX_EPOCH};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use utoipa::ToSchema; use utoipa::ToSchema;
use crate::protocols::openai::nvext::WorkerIdInfo; use crate::protocols::openai::nvext::WorkerIdInfo;
...@@ -80,6 +83,12 @@ pub struct RequestTracker { ...@@ -80,6 +83,12 @@ pub struct RequestTracker {
/// Request phase (Prefill/Decode/Aggregated) /// Request phase (Prefill/Decode/Aggregated)
phase: Mutex<RequestPhase>, phase: Mutex<RequestPhase>,
/// Semaphore for coordinating phase transitions.
/// Acquiring a permit blocks subsequent set_phase calls until the permit is dropped.
/// This prevents race conditions in the bootstrap optimization path where prefill
/// runs in background and needs to complete record_worker before phase changes.
phase_semaphore: Arc<Semaphore>,
} }
impl RequestTracker { impl RequestTracker {
...@@ -102,6 +111,7 @@ impl RequestTracker { ...@@ -102,6 +111,7 @@ impl RequestTracker {
prefill_worker_id: OnceLock::new(), prefill_worker_id: OnceLock::new(),
decode_worker_id: OnceLock::new(), decode_worker_id: OnceLock::new(),
phase: Mutex::new(RequestPhase::Aggregated), phase: Mutex::new(RequestPhase::Aggregated),
phase_semaphore: Arc::new(Semaphore::new(1)),
} }
} }
...@@ -175,14 +185,29 @@ impl RequestTracker { ...@@ -175,14 +185,29 @@ impl RequestTracker {
self.decode_worker_id.set(id).is_ok() self.decode_worker_id.set(id).is_ok()
} }
/// Set the request phase. Can be called multiple times to update the phase. /// Set the request phase and return a permit that blocks subsequent phase changes.
pub fn set_phase(&self, phase: RequestPhase) { ///
*self.phase.lock().unwrap() = phase; /// The returned permit must be dropped to allow the next `set_phase` call to proceed.
/// Under normal operation, callers can simply ignore the returned permit (letting it
/// drop immediately). In the bootstrap optimization path, the permit is held and
/// passed to the spawned prefill task, which drops it after `record_worker` completes.
///
/// This prevents the race condition where the phase changes to Decode before the
/// background prefill task has recorded its worker ID.
pub async fn set_phase(&self, phase: RequestPhase) -> OwnedSemaphorePermit {
let permit = self
.phase_semaphore
.clone()
.acquire_owned()
.await
.expect("phase semaphore should never be closed");
*self.phase.lock() = phase;
permit
} }
/// Get the current request phase. /// Get the current request phase.
pub fn phase(&self) -> RequestPhase { pub fn phase(&self) -> RequestPhase {
*self.phase.lock().unwrap() *self.phase.lock()
} }
/// Record worker ID based on the current phase. /// Record worker ID based on the current phase.
......
...@@ -105,6 +105,17 @@ pub struct NvExt { ...@@ -105,6 +105,17 @@ pub struct NvExt {
#[builder(default, setter(strip_option))] #[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub decode_worker_id: Option<u64>, pub decode_worker_id: Option<u64>,
/// Controls whether the router should manage local bookkeeping (add_request,
/// mark_prefill_completed, free) for this request.
///
/// - `None` or `true`: Router handles bookkeeping locally (default behavior)
/// - `false`: External caller (e.g., GAIE sidecar) handles bookkeeping via C FFI
///
/// Set to `false` for GAIE Stage 2 when the EPP/sidecar manages request lifecycle.
#[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub enable_local_updates: Option<bool>,
} }
impl Default for NvExt { impl Default for NvExt {
...@@ -153,6 +164,7 @@ mod tests { ...@@ -153,6 +164,7 @@ mod tests {
assert_eq!(nv_ext.extra_fields, None); assert_eq!(nv_ext.extra_fields, None);
assert_eq!(nv_ext.prefill_worker_id, None); assert_eq!(nv_ext.prefill_worker_id, None);
assert_eq!(nv_ext.decode_worker_id, None); assert_eq!(nv_ext.decode_worker_id, None);
assert_eq!(nv_ext.enable_local_updates, None);
} }
// Test valid builder configurations // Test valid builder configurations
......
...@@ -230,7 +230,7 @@ impl EndpointConfigBuilder { ...@@ -230,7 +230,7 @@ impl EndpointConfigBuilder {
/// This function handles both health check and discovery transport building. /// This function handles both health check and discovery transport building.
/// All transport modes use consistent addressing: /// All transport modes use consistent addressing:
/// - HTTP: Uses full URL path including endpoint name (e.g., http://host:port/v1/rpc/endpoint_name) /// - HTTP: Uses full URL path including endpoint name (e.g., http://host:port/v1/rpc/endpoint_name)
/// - TCP: Includes endpoint name for routing (e.g., host:port/endpoint_name) /// - TCP: Includes instance_id and endpoint name for routing (e.g., host:port/instance_id_hex/endpoint_name)
/// - NATS: Uses subject-based addressing (unique per endpoint) /// - NATS: Uses subject-based addressing (unique per endpoint)
/// ///
/// # Errors /// # Errors
...@@ -266,9 +266,14 @@ fn build_transport_type_inner( ...@@ -266,9 +266,14 @@ fn build_transport_type_inner(
.and_then(|p| p.parse::<u16>().ok()) .and_then(|p| p.parse::<u16>().ok())
.unwrap_or(crate::pipeline::network::manager::get_actual_tcp_rpc_port()?); .unwrap_or(crate::pipeline::network::manager::get_actual_tcp_rpc_port()?);
// Include endpoint name for proper TCP routing // Include instance_id and endpoint name for proper TCP routing.
// TCP client parses this format and adds x-endpoint-path header for server-side routing // Format: host:port/instance_id_hex/endpoint_name
let tcp_endpoint = format!("{}:{}/{}", tcp_host, tcp_port, endpoint_id.name); // This ensures each worker has a unique routing key when multiple workers
// share the same TCP server (e.g., --num-workers > 1).
let tcp_endpoint = format!(
"{}:{}/{:x}/{}",
tcp_host, tcp_port, connection_id, endpoint_id.name
);
Ok(TransportType::Tcp(tcp_endpoint)) Ok(TransportType::Tcp(tcp_endpoint))
} }
......
...@@ -413,9 +413,11 @@ impl super::unified_server::RequestPlaneServer for SharedTcpServer { ...@@ -413,9 +413,11 @@ impl super::unified_server::RequestPlaneServer for SharedTcpServer {
component_name: String, component_name: String,
system_health: Arc<Mutex<SystemHealth>>, system_health: Arc<Mutex<SystemHealth>>,
) -> Result<()> { ) -> Result<()> {
// For TCP, we use endpoint_name as both the endpoint_path (routing key) and endpoint_name // Include instance_id in the routing key to avoid collisions when multiple workers
// share the same TCP server (e.g., --num-workers > 1 in tests)
let endpoint_path = format!("{instance_id:x}/{endpoint_name}");
self.register_endpoint( self.register_endpoint(
endpoint_name.clone(), endpoint_path,
service_handler, service_handler,
instance_id, instance_id,
namespace, namespace,
...@@ -427,7 +429,19 @@ impl super::unified_server::RequestPlaneServer for SharedTcpServer { ...@@ -427,7 +429,19 @@ impl super::unified_server::RequestPlaneServer for SharedTcpServer {
} }
async fn unregister_endpoint(&self, endpoint_name: &str) -> Result<()> { async fn unregister_endpoint(&self, endpoint_name: &str) -> Result<()> {
self.unregister_endpoint(endpoint_name, endpoint_name).await; // With multiple workers per process, each registers with a unique key
// "{instance_id}/{endpoint_name}". Find and remove all matching entries.
let suffix = format!("/{endpoint_name}");
let keys_to_remove: Vec<String> = self
.handlers
.iter()
.filter(|entry| entry.key().ends_with(&suffix))
.map(|entry| entry.key().clone())
.collect();
for key in keys_to_remove {
self.unregister_endpoint(&key, endpoint_name).await;
}
Ok(()) Ok(())
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
//! directly accesses transport implementations or configuration. //! directly accesses transport implementations or configuration.
use super::egress::unified_client::RequestPlaneClient; use super::egress::unified_client::RequestPlaneClient;
use super::ingress::shared_tcp_endpoint::SharedTcpServer;
use super::ingress::unified_server::RequestPlaneServer; use super::ingress::unified_server::RequestPlaneServer;
use crate::distributed::RequestPlaneMode; use crate::distributed::RequestPlaneMode;
use anyhow::Result; use anyhow::Result;
...@@ -26,6 +27,17 @@ use tokio_util::sync::CancellationToken; ...@@ -26,6 +27,17 @@ use tokio_util::sync::CancellationToken;
/// Uses OnceLock since the port is set once when the server binds and never changes. /// Uses OnceLock since the port is set once when the server binds and never changes.
static ACTUAL_TCP_RPC_PORT: OnceLock<u16> = OnceLock::new(); static ACTUAL_TCP_RPC_PORT: OnceLock<u16> = OnceLock::new();
/// Global storage for the shared TCP server instance.
///
/// When multiple workers run in the same process, they must share a single TCP server
/// to ensure all endpoints are registered on the same server. Without this, each worker
/// would create its own server on a different port, but all would publish the same port
/// (from ACTUAL_TCP_RPC_PORT) to discovery, causing "No handler found" errors.
///
/// Uses `tokio::sync::OnceCell` to support async initialization (binding the TCP socket).
static GLOBAL_TCP_SERVER: tokio::sync::OnceCell<Arc<SharedTcpServer>> =
tokio::sync::OnceCell::const_new();
/// Get the actual TCP RPC port that the server is listening on. /// Get the actual TCP RPC port that the server is listening on.
pub fn get_actual_tcp_rpc_port() -> anyhow::Result<u16> { pub fn get_actual_tcp_rpc_port() -> anyhow::Result<u16> {
ACTUAL_TCP_RPC_PORT.get().copied().ok_or_else(|| { ACTUAL_TCP_RPC_PORT.get().copied().ok_or_else(|| {
...@@ -300,35 +312,41 @@ impl NetworkManager { ...@@ -300,35 +312,41 @@ impl NetworkManager {
} }
async fn create_tcp_server(&self) -> Result<Arc<dyn RequestPlaneServer>> { async fn create_tcp_server(&self) -> Result<Arc<dyn RequestPlaneServer>> {
use super::ingress::shared_tcp_endpoint::SharedTcpServer; // Use the global TCP server to ensure all workers in the same process share
// a single server. This is critical for correct endpoint routing.
let server = GLOBAL_TCP_SERVER
.get_or_try_init(|| async {
// Use configured port if specified, otherwise use port 0 (OS assigns free port)
let port = self.config.tcp_port.unwrap_or(0);
let bind_addr = format!("{}:{}", self.config.tcp_host, port)
.parse()
.map_err(|e| anyhow::anyhow!("Invalid TCP bind address: {}", e))?;
// Use configured port if specified, otherwise use port 0 (OS assigns free port) tracing::info!(
let port = self.config.tcp_port.unwrap_or(0); bind_addr = %bind_addr,
let bind_addr = format!("{}:{}", self.config.tcp_host, port) port_source = if self.config.tcp_port.is_some() { "DYN_TCP_RPC_PORT" } else { "OS-assigned" },
.parse() "Creating TCP request plane server"
.map_err(|e| anyhow::anyhow!("Invalid TCP bind address: {}", e))?; );
tracing::info!( let server = SharedTcpServer::new(bind_addr, self.cancellation_token.clone());
bind_addr = %bind_addr,
port_source = if self.config.tcp_port.is_some() { "DYN_TCP_RPC_PORT" } else { "OS-assigned" },
"Creating TCP request plane server"
);
let server = SharedTcpServer::new(bind_addr, self.cancellation_token.clone()); // Bind and start server, getting the actual bound address
let actual_addr = server.clone().bind_and_start().await?;
// Bind and start server, getting the actual bound address // Store the actual bound port globally so build_transport_type() can access it
let actual_addr = server.clone().bind_and_start().await?; set_actual_tcp_rpc_port(actual_addr.port());
// Store the actual bound port globally so build_transport_type() can access it tracing::info!(
set_actual_tcp_rpc_port(actual_addr.port()); actual_addr = %actual_addr,
actual_port = actual_addr.port(),
"TCP request plane server started"
);
tracing::info!( Ok::<_, anyhow::Error>(server)
actual_addr = %actual_addr, })
actual_port = actual_addr.port(), .await?;
"TCP request plane server started"
);
Ok(server as Arc<dyn RequestPlaneServer>) Ok(server.clone() as Arc<dyn RequestPlaneServer>)
} }
async fn create_nats_server(&self) -> Result<Arc<dyn RequestPlaneServer>> { async fn create_nats_server(&self) -> Result<Arc<dyn RequestPlaneServer>> {
......
...@@ -6,6 +6,12 @@ ...@@ -6,6 +6,12 @@
# Combined pre_merge wall time (this file): # Combined pre_merge wall time (this file):
# - Serialized: 304.01s. # - Serialized: 304.01s.
# - Parallel (-n auto): 34.55s (269.46s saved, 8.80x). # - Parallel (-n auto): 34.55s (269.46s saved, 8.80x).
#
# NOTE: TCP request plane is NOT tested here. These tests use --num-workers > 1 which spawns
# multiple workers in a single process sharing one TCP server. The shared TCP server uses
# endpoint_path (e.g., "generate") as the routing key, causing handler collisions when multiple
# workers register the same endpoint. This is a test-only limitation; production deployments
# with separate processes per worker work correctly with TCP.
import logging import logging
import os import os
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
...@@ -155,6 +161,8 @@ def _build_mocker_command( ...@@ -155,6 +161,8 @@ def _build_mocker_command(
command.extend(["--data-parallel-size", str(mocker_args["dp_size"])]) command.extend(["--data-parallel-size", str(mocker_args["dp_size"])])
if mocker_args.get("enable_local_indexer"): if mocker_args.get("enable_local_indexer"):
command.append("--enable-local-indexer") command.append("--enable-local-indexer")
if "bootstrap_ports" in mocker_args:
command.extend(["--bootstrap-ports", mocker_args["bootstrap_ports"]])
return command return command
...@@ -233,6 +241,7 @@ class DisaggMockerProcess: ...@@ -233,6 +241,7 @@ class DisaggMockerProcess:
num_mockers: int = 1, num_mockers: int = 1,
store_backend: str = "etcd", store_backend: str = "etcd",
request_plane: str = "nats", request_plane: str = "nats",
enable_bootstrap: bool = False,
): ):
if worker_type not in ("prefill", "decode"): if worker_type not in ("prefill", "decode"):
raise ValueError( raise ValueError(
...@@ -242,6 +251,7 @@ class DisaggMockerProcess: ...@@ -242,6 +251,7 @@ class DisaggMockerProcess:
self.namespace = namespace self.namespace = namespace
self.worker_type = worker_type self.worker_type = worker_type
self.num_workers = num_mockers self.num_workers = num_mockers
self._bootstrap_ports: list[int] = []
# Set component name and endpoint based on worker type # Set component name and endpoint based on worker type
if worker_type == "prefill": if worker_type == "prefill":
...@@ -251,7 +261,17 @@ class DisaggMockerProcess: ...@@ -251,7 +261,17 @@ class DisaggMockerProcess:
self.component_name = "backend" self.component_name = "backend"
self.endpoint = f"dyn://{self.namespace}.backend.generate" self.endpoint = f"dyn://{self.namespace}.backend.generate"
mocker_args = mocker_args or {} mocker_args = (mocker_args or {}).copy()
# Allocate bootstrap ports for prefill workers if enabled (one per worker)
if enable_bootstrap and worker_type == "prefill":
self._bootstrap_ports = allocate_ports(num_mockers, BASE_PORT)
mocker_args["bootstrap_ports"] = ",".join(
str(p) for p in self._bootstrap_ports
)
logger.info(
f"Allocated bootstrap ports {self._bootstrap_ports} for {num_mockers} prefill workers"
)
command = _build_mocker_command( command = _build_mocker_command(
endpoint=self.endpoint, endpoint=self.endpoint,
...@@ -279,6 +299,11 @@ class DisaggMockerProcess: ...@@ -279,6 +299,11 @@ class DisaggMockerProcess:
f"endpoint: {self.endpoint}" f"endpoint: {self.endpoint}"
) )
@property
def bootstrap_ports(self) -> list[int]:
"""Return the allocated bootstrap ports, if any."""
return self._bootstrap_ports
def __enter__(self): def __enter__(self):
logger.info( logger.info(
f"Starting {self.worker_type} mocker process with {self.num_workers} worker(s)" f"Starting {self.worker_type} mocker process with {self.num_workers} worker(s)"
...@@ -289,6 +314,11 @@ class DisaggMockerProcess: ...@@ -289,6 +314,11 @@ class DisaggMockerProcess:
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
logger.info(f"Stopping {self.worker_type} mocker process") logger.info(f"Stopping {self.worker_type} mocker process")
self._process.__exit__(exc_type, exc_val, exc_tb) self._process.__exit__(exc_type, exc_val, exc_tb)
# Deallocate bootstrap ports if we allocated any
if self._bootstrap_ports:
deallocate_ports(self._bootstrap_ports)
logger.info(f"Deallocated bootstrap ports {self._bootstrap_ports}")
self._bootstrap_ports = []
@pytest.mark.timeout(42) # ~3x average (~13.80s), rounded up @pytest.mark.timeout(42) # ~3x average (~13.80s), rounded up
...@@ -487,9 +517,9 @@ def test_kv_push_router_bindings( ...@@ -487,9 +517,9 @@ def test_kv_push_router_bindings(
], ],
ids=[ ids=[
"jetstream", "jetstream",
"nats", "nats_core",
"file", "file",
], # "nats_core" commented out to match commented test case ],
) )
@pytest.mark.timeout(90) # TODO: figure out a timeout @pytest.mark.timeout(90) # TODO: figure out a timeout
def test_indexers_sync( def test_indexers_sync(
...@@ -677,37 +707,43 @@ def test_router_decisions( ...@@ -677,37 +707,43 @@ def test_router_decisions(
mockers.__exit__(None, None, None) mockers.__exit__(None, None, None)
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.parametrize("registration_order", ["prefill_first", "decode_first"]) @pytest.mark.parametrize("registration_order", ["prefill_first", "decode_first"])
@pytest.mark.parametrize(
"enable_disagg_bootstrap", [False, True], ids=["no_bootstrap", "with_bootstrap"]
)
@pytest.mark.timeout(59) # ~3x average (~19.51s), rounded up @pytest.mark.timeout(59) # ~3x average (~19.51s), rounded up
def test_router_decisions_disagg( def test_router_decisions_disagg(
request, request,
runtime_services_dynamic_ports, runtime_services_dynamic_ports,
predownload_tokenizers, predownload_tokenizers,
registration_order, registration_order,
request_plane, enable_disagg_bootstrap,
): ):
"""Validate KV cache prefix reuse in disaggregated prefill-decode setup. """Validate KV cache prefix reuse in disaggregated prefill-decode setup.
Tests that progressive requests with overlapping prefixes are routed to the Tests that progressive requests with overlapping prefixes are routed to the
same prefill worker due to KV cache reuse. same prefill worker due to KV cache reuse.
Parameterized to test both registration orders: Parameterized to test:
- prefill_first: prefill workers register before decode workers - registration_order: prefill_first vs decode_first
- decode_first: decode workers register before prefill workers - enable_disagg_bootstrap: without vs with bootstrap rendezvous
""" """
# runtime_services_dynamic_ports handles NATS and etcd startup # runtime_services_dynamic_ports handles NATS and etcd startup
logger.info( logger.info(
f"Starting disaggregated router prefix reuse test " f"Starting disaggregated router prefix reuse test "
f"(registration_order={registration_order})" f"(registration_order={registration_order}, bootstrap={enable_disagg_bootstrap})"
) )
# Generate shared namespace for prefill and decode workers # Generate shared namespace for prefill and decode workers
namespace_suffix = generate_random_suffix() namespace_suffix = generate_random_suffix()
shared_namespace = f"test-namespace-{namespace_suffix}" shared_namespace = f"test-namespace-{namespace_suffix}"
# Create mocker args # Create mocker args - use JetStream for KV events (more reliable than NATS Core)
mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE} mocker_args = {
"speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE,
"enable_local_indexer": False,
}
prefill_workers = None prefill_workers = None
decode_workers = None decode_workers = None
...@@ -722,7 +758,8 @@ def test_router_decisions_disagg( ...@@ -722,7 +758,8 @@ def test_router_decisions_disagg(
worker_type="prefill", worker_type="prefill",
mocker_args=mocker_args, mocker_args=mocker_args,
num_mockers=4, num_mockers=4,
request_plane=request_plane, request_plane="nats",
enable_bootstrap=enable_disagg_bootstrap,
) )
prefill_workers.__enter__() prefill_workers.__enter__()
logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}") logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
...@@ -735,7 +772,7 @@ def test_router_decisions_disagg( ...@@ -735,7 +772,7 @@ def test_router_decisions_disagg(
worker_type="decode", worker_type="decode",
mocker_args=mocker_args, mocker_args=mocker_args,
num_mockers=4, num_mockers=4,
request_plane=request_plane, request_plane="nats",
) )
decode_workers.__enter__() decode_workers.__enter__()
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}") logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
...@@ -748,7 +785,7 @@ def test_router_decisions_disagg( ...@@ -748,7 +785,7 @@ def test_router_decisions_disagg(
worker_type="decode", worker_type="decode",
mocker_args=mocker_args, mocker_args=mocker_args,
num_mockers=4, num_mockers=4,
request_plane=request_plane, request_plane="nats",
) )
decode_workers.__enter__() decode_workers.__enter__()
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}") logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
...@@ -761,7 +798,8 @@ def test_router_decisions_disagg( ...@@ -761,7 +798,8 @@ def test_router_decisions_disagg(
worker_type="prefill", worker_type="prefill",
mocker_args=mocker_args, mocker_args=mocker_args,
num_mockers=4, num_mockers=4,
request_plane=request_plane, request_plane="nats",
enable_bootstrap=enable_disagg_bootstrap,
) )
prefill_workers.__enter__() prefill_workers.__enter__()
logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}") logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
...@@ -779,7 +817,7 @@ def test_router_decisions_disagg( ...@@ -779,7 +817,7 @@ def test_router_decisions_disagg(
request=request, request=request,
frontend_port=frontend_port, frontend_port=frontend_port,
test_payload=TEST_PAYLOAD, test_payload=TEST_PAYLOAD,
request_plane=request_plane, request_plane="nats",
) )
finally: finally:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment