Unverified Commit 902eabd9 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: per dp rank gap detection (#5873)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 04f32fe2
......@@ -102,14 +102,16 @@ class ZmqKvEventPublisher:
def publish_stored(
self,
event_id: int,
token_ids: list[int],
num_block_tokens: list[int],
block_hashes: list[int],
lora_id: int = 0,
parent_hash: Optional[int] = None,
):
"""Publish a BlockStored event."""
"""Publish a BlockStored event.
Note: event_id is managed internally via self.sequence counter.
"""
# Convert block hashes to signed i64 format
block_hashes_signed = [_to_signed_i64(h) for h in block_hashes]
parent_hash_signed = (
......@@ -129,8 +131,11 @@ class ZmqKvEventPublisher:
self._publish_event(event)
def publish_removed(self, event_id: int, block_hashes: list[int]):
"""Publish a BlockRemoved event."""
def publish_removed(self, block_hashes: list[int]):
"""Publish a BlockRemoved event.
Note: event_id is managed internally via self.sequence counter.
"""
# Convert block hashes to signed i64 format (vLLM compatibility)
block_hashes_signed = [_to_signed_i64(h) for h in block_hashes]
......@@ -307,6 +312,8 @@ class Publisher:
self.partial_block_hashes: set[int] = set()
self.error_queue: Queue = Queue()
self._stop_event = threading.Event()
# Track the last engine event_id to assert consecutive event IDs from the engine
self._last_engine_event_id: Optional[int] = None
# Initialize ZMQ publisher if endpoint is provided (consolidator enabled)
if zmq_endpoint:
......@@ -476,6 +483,16 @@ class Publisher:
return
event_id = event["event_id"]
# Check for consecutive event IDs from the engine
if self._last_engine_event_id is not None:
expected_id = self._last_engine_event_id + 1
if event_id != expected_id:
logging.warning(
f"Non-consecutive engine event_id: expected {expected_id}, got {event_id}"
)
self._last_engine_event_id = event_id
data = event["data"]
if data["type"] == "stored":
self.processing_initial_created_events = False
......@@ -513,13 +530,13 @@ class Publisher:
lora_id = data.get("lora_id", 0)
logging.debug(
f"publish stored event: event_id: {event_id}, token_ids: {token_ids}, num_block_tokens: {num_block_tokens}, block_hashes: {block_hashes}, lora_id: {lora_id}, parent_hash: {parent_hash}"
f"publish stored event: engine_event_id: {event_id}, token_ids: {token_ids}, num_block_tokens: {num_block_tokens}, block_hashes: {block_hashes}, lora_id: {lora_id}, parent_hash: {parent_hash}"
)
# Publish to ZMQ if consolidator is enabled, otherwise publish to NATS
# Note: event_id is managed internally by the publisher (monotonic counter per dp_rank)
if self.zmq_kv_event_publisher:
# Consolidator enabled: publish to ZMQ only
self.zmq_kv_event_publisher.publish_stored(
event_id,
token_ids,
num_block_tokens,
block_hashes,
......@@ -529,7 +546,6 @@ class Publisher:
elif self.kv_event_publisher:
# No consolidator: publish to NATS (router subscribes directly)
self.kv_event_publisher.publish_stored(
event_id,
token_ids,
num_block_tokens,
block_hashes,
......@@ -552,17 +568,16 @@ class Publisher:
removed_block_hashes.append(block_hash)
logging.debug(
f"publish removed event: event_id: {event_id}, block_hashes: {removed_block_hashes}"
f"publish removed event: engine_event_id: {event_id}, block_hashes: {removed_block_hashes}"
)
# Publish to ZMQ if consolidator is enabled, otherwise publish to NATS
# Note: event_id is managed internally by the publisher (monotonic counter per dp_rank)
if self.zmq_kv_event_publisher:
# Consolidator enabled: publish to ZMQ only
self.zmq_kv_event_publisher.publish_removed(
event_id, removed_block_hashes
)
self.zmq_kv_event_publisher.publish_removed(removed_block_hashes)
elif self.kv_event_publisher:
# No consolidator: publish to NATS (router subscribes directly)
self.kv_event_publisher.publish_removed(event_id, removed_block_hashes)
self.kv_event_publisher.publish_removed(removed_block_hashes)
elif data["type"] == "created" and self.processing_initial_created_events:
self.update_max_window_size(event)
......
......@@ -279,6 +279,7 @@ def setup_kv_event_publisher(
kv_block_size=vllm_config.cache_config.block_size,
zmq_endpoint=zmq_endpoint,
enable_local_indexer=config.enable_local_indexer,
dp_rank=dp_rank,
)
kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
kv_publishers.append(kv_publisher)
......
......@@ -115,6 +115,8 @@ pub struct ZmqKvEventPublisherConfig {
#[pyo3(get, set)]
pub enable_local_indexer: bool, // whether the underlying KvEventPublisher publishes to
// both global and worker-local KvIndexers
#[pyo3(get, set)]
pub dp_rank: DpRank, // data parallel rank for this publisher
}
#[pymethods]
......@@ -125,7 +127,8 @@ impl ZmqKvEventPublisherConfig {
kv_block_size,
zmq_endpoint = "tcp://127.0.0.1:5557".to_string(),
zmq_topic = "".to_string(),
enable_local_indexer = false
enable_local_indexer = false,
dp_rank = 0
))]
pub fn new(
worker_id: WorkerId,
......@@ -133,6 +136,7 @@ impl ZmqKvEventPublisherConfig {
zmq_endpoint: String,
zmq_topic: String,
enable_local_indexer: bool,
dp_rank: DpRank,
) -> Self {
Self {
worker_id,
......@@ -140,6 +144,7 @@ impl ZmqKvEventPublisherConfig {
zmq_endpoint,
zmq_topic,
enable_local_indexer,
dp_rank,
}
}
}
......@@ -161,6 +166,7 @@ impl ZmqKvEventPublisher {
topic: config.zmq_topic,
}),
config.enable_local_indexer,
config.dp_rank,
)
.map_err(to_pyerr)?;
Ok(Self { inner })
......@@ -192,6 +198,8 @@ impl ZmqKvEventListener {
runtime.block_on(async {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<KvCacheEvent>();
let shutdown_token = tokio_util::sync::CancellationToken::new();
// Standalone listener needs its own event ID counter
let next_event_id = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0));
tokio::spawn(start_zmq_listener(
zmq_endpoint,
......@@ -199,6 +207,7 @@ impl ZmqKvEventListener {
tx,
shutdown_token.clone(),
kv_block_size as u32,
next_event_id,
));
Ok(Self {
......@@ -273,6 +282,7 @@ impl KvEventPublisher {
kv_block_size as u32,
None,
enable_local_indexer,
dp_rank,
)
.map_err(to_pyerr)?;
......@@ -285,11 +295,10 @@ impl KvEventPublisher {
}
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (event_id, token_ids, num_block_tokens, block_hashes, lora_id, parent_hash=None, block_mm_infos=None))]
#[pyo3(signature = (token_ids, num_block_tokens, block_hashes, lora_id, parent_hash=None, block_mm_infos=None))]
fn publish_stored(
&mut self,
&self,
py: Python,
event_id: u64,
token_ids: Vec<u32>,
num_block_tokens: Vec<u64>,
block_hashes: Vec<i64>,
......@@ -302,6 +311,9 @@ impl KvEventPublisher {
let warning_count = self.warning_count.clone();
let inner = self.inner.clone();
// Use shared monotonic event_id counter from the inner publisher
let event_id = inner.next_event_id();
// Convert Python block_mm_infos to Rust Vec<Option<BlockExtraInfo>>
let mm_infos_rust: Option<Vec<Option<BlockExtraInfo>>> = block_mm_infos
.as_ref()
......@@ -338,10 +350,13 @@ impl KvEventPublisher {
})
}
fn publish_removed(&self, py: Python, event_id: u64, block_hashes: Vec<i64>) -> PyResult<()> {
fn publish_removed(&self, py: Python, block_hashes: Vec<i64>) -> PyResult<()> {
let dp_rank = self.dp_rank;
let inner = self.inner.clone();
// Use shared monotonic event_id counter from the inner publisher
let event_id = inner.next_event_id();
py.allow_threads(|| {
let block_hashes: Vec<ExternalSequenceBlockHash> = block_hashes
.into_iter()
......
......@@ -753,7 +753,6 @@ class KvEventPublisher:
def publish_stored(
self,
event_id: int,
token_ids: List[int],
num_block_tokens: List[int],
block_hashes: List[int],
......@@ -763,8 +762,9 @@ class KvEventPublisher:
"""
Publish a KV stored event.
Event IDs are managed internally by the publisher using a monotonic counter.
Args:
event_id: The event ID
token_ids: List of token IDs
num_block_tokens: Number of tokens per block
block_hashes: List of block hashes (signed 64-bit integers)
......@@ -773,12 +773,13 @@ class KvEventPublisher:
"""
...
def publish_removed(self, event_id: int, block_hashes: List[int]) -> None:
def publish_removed(self, block_hashes: List[int]) -> None:
"""
Publish a KV removed event.
Event IDs are managed internally by the publisher using a monotonic counter.
Args:
event_id: The event ID
block_hashes: List of block hashes to remove (signed 64-bit integers)
"""
...
......@@ -790,7 +791,8 @@ class ZmqKvEventPublisherConfig:
kv_block_size: int,
zmq_endpoint: str = "tcp://127.0.0.1:5557",
zmq_topic: str = "",
enable_local_indexer: bool = False
enable_local_indexer: bool = False,
dp_rank: int = 0
) -> None:
"""
Configuration for the ZmqKvEventPublisher.
......@@ -800,6 +802,7 @@ class ZmqKvEventPublisherConfig:
:param zmq_endpoint: The ZeroMQ endpoint. Defaults to "tcp://127.0.0.1:5557".
:param zmq_topic: The ZeroMQ topic to subscribe to. Defaults to an empty string.
:param enable_local_indexer: Whether to enable the worker-local KV indexer. Defaults to False.
:param dp_rank: The data parallel rank for this publisher. Defaults to 0.
"""
...
......
......@@ -280,31 +280,30 @@ async def test_approx_kv_indexer(distributed_runtime):
class EventPublisher:
def __init__(self, component: Component, worker_id: int, kv_block_size: int):
self.publisher = KvEventPublisher(component, worker_id, kv_block_size)
self.event_id_counter = 0
# Counter for generating unique block hashes (event_id is now managed internally by publisher)
self.block_hash_counter = 0
self.block_hashes: List[int] = []
def store_event(self, tokens, lora_id):
parent_hash = self.event_id_counter if self.event_id_counter > 0 else None
# Parent hash should reference the last published block, not the current one
parent_hash = self.block_hashes[-1] if self.block_hashes else None
self.publisher.publish_stored(
self.event_id_counter, # event_id
tokens, # token_ids
[
len(tokens),
], # num_block_tokens
[
self.event_id_counter,
self.block_hash_counter,
], # block_hashes
lora_id, # lora_id
parent_hash, # parent_hash
)
self.block_hashes.append(self.event_id_counter)
self.event_id_counter += 1
self.block_hashes.append(self.block_hash_counter)
self.block_hash_counter += 1
def remove_event(self):
self.publisher.publish_removed(
self.event_id_counter, # event_id
[
self.block_hashes[-1],
], # block_hashes
)
self.event_id_counter += 1
......@@ -46,7 +46,7 @@ use crate::{
approx::PruneConfig,
indexer::{KvIndexer, KvIndexerInterface, KvRouterError},
protocols::{
LocalBlockHash, OverlapScores, RouterEvent, RouterRequest, RouterResponse,
DpRank, LocalBlockHash, OverlapScores, RouterEvent, RouterRequest, RouterResponse,
TokensWithHashes, WorkerId, WorkerSelectionResult, WorkerWithDpRank,
compute_block_hash_for_seq, compute_seq_hash_for_block,
},
......@@ -80,9 +80,14 @@ pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";
// for worker-local kvindexer query
pub const WORKER_KV_INDEXER_QUERY_ENDPOINT: &str = "worker_kv_indexer_query";
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer
/// Generates a dp_rank-specific endpoint name for the worker KV indexer query service.
/// Each dp_rank has its own LocalKvIndexer and query endpoint to ensure per-dp_rank monotonicity.
pub fn worker_kv_indexer_query_endpoint(dp_rank: DpRank) -> String {
format!("worker_kv_indexer_query_dp{dp_rank}")
}
// for router discovery registration
pub const KV_ROUTER_COMPONENT: &str = "kv-router";
pub const KV_ROUTER_ENDPOINT: &str = "generate";
......@@ -627,6 +632,7 @@ impl KvRouter {
pub async fn query_worker_local_kv(
&self,
worker_id: WorkerId,
dp_rank: DpRank,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
) -> Result<WorkerKvQueryResponse> {
......@@ -636,11 +642,11 @@ impl KvRouter {
.ok_or_else(|| anyhow::anyhow!("Worker query client not available (NATS required)"))?;
query_client
.query_worker(worker_id, start_event_id, end_event_id)
.query_worker(worker_id, dp_rank, start_event_id, end_event_id)
.await
}
/// Recover missed KV events from a specific worker.
/// Recover missed KV events from a specific worker's dp_rank.
///
/// Queries the worker's local KV indexer for events starting from
/// `start_event_id` and applies them to the router's indexer.
......@@ -648,11 +654,13 @@ impl KvRouter {
/// # Arguments
///
/// * `worker_id` - The worker to recover from
/// * `dp_rank` - The data parallel rank to recover from
/// * `start_event_id` - First event ID to fetch (inclusive), or None to start from beginning
/// * `end_event_id` - Last event ID to fetch (inclusive), or None for all
pub async fn recover_from_worker(
&self,
worker_id: WorkerId,
dp_rank: DpRank,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
) -> Result<usize> {
......@@ -668,13 +676,8 @@ impl KvRouter {
}
};
subscriber::recover_from_worker(
query_client,
worker_id,
start_event_id,
end_event_id,
&event_tx,
)
query_client
.recover_from_worker(worker_id, dp_rank, start_event_id, end_event_id, &event_tx)
.await
}
}
......
......@@ -3,7 +3,7 @@
use std::fmt;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::Duration;
use anyhow::Result;
......@@ -77,6 +77,7 @@ impl KvEventSource {
source_config: KvEventSourceConfig,
cancellation_token: CancellationToken,
tx: mpsc::UnboundedSender<KvCacheEvent>,
next_event_id: Arc<AtomicU64>,
) -> Result<Self> {
match source_config {
KvEventSourceConfig::Zmq { endpoint, topic } => {
......@@ -90,6 +91,7 @@ impl KvEventSource {
tx,
cancellation_token.clone(),
kv_block_size,
next_event_id,
));
Ok(KvEventSource::Zmq { zmq_handle })
......@@ -117,6 +119,9 @@ pub struct KvEventPublisher {
cancellation_token: CancellationToken,
/// The channel to send events to.
tx: mpsc::UnboundedSender<KvCacheEvent>,
/// Internal monotonic event ID counter - ensures each event gets a unique, incrementing ID.
/// Shared with the ZMQ listener (if any) to maintain consistency.
next_event_id: Arc<AtomicU64>,
}
impl KvEventPublisher {
......@@ -125,7 +130,7 @@ impl KvEventPublisher {
kv_block_size: u32,
source_config: Option<KvEventSourceConfig>,
) -> Result<Self> {
Self::new_with_local_indexer(component, kv_block_size, source_config, false)
Self::new_with_local_indexer(component, kv_block_size, source_config, false, 0)
}
pub fn new_with_local_indexer(
......@@ -133,6 +138,7 @@ impl KvEventPublisher {
kv_block_size: u32,
source_config: Option<KvEventSourceConfig>,
enable_local_indexer: bool,
dp_rank: DpRank,
) -> Result<Self> {
let cancellation_token = CancellationToken::new();
......@@ -152,6 +158,9 @@ impl KvEventPublisher {
);
}
// Internal monotonic event ID counter - shared with ZMQ listener if any
let next_event_id = Arc::new(AtomicU64::new(0));
// Create our event source (if any)
let mut source = None;
if let Some(config) = source_config {
......@@ -161,6 +170,7 @@ impl KvEventPublisher {
config,
cancellation_token.clone(),
tx.clone(),
next_event_id.clone(),
)?);
}
......@@ -189,6 +199,7 @@ impl KvEventPublisher {
.spawn(start_worker_kv_query_endpoint(
component,
worker_id,
dp_rank,
local_indexer,
))
});
......@@ -253,6 +264,7 @@ impl KvEventPublisher {
source,
cancellation_token,
tx,
next_event_id,
})
}
......@@ -260,6 +272,12 @@ impl KvEventPublisher {
self.tx.send(event)
}
/// Get and increment the next event ID atomically.
/// Use this to assign monotonically increasing event IDs to events before publishing.
pub fn next_event_id(&self) -> u64 {
self.next_event_id.fetch_add(1, Ordering::SeqCst)
}
pub fn kv_block_size(&self) -> u32 {
self.kv_block_size
}
......@@ -406,6 +424,7 @@ pub async fn start_zmq_listener(
tx: mpsc::UnboundedSender<KvCacheEvent>,
cancellation_token: CancellationToken,
kv_block_size: u32,
next_event_id: Arc<AtomicU64>,
) {
tracing::debug!(
"KVEventPublisher connecting to ZMQ endpoint {} (topic '{}')",
......@@ -496,7 +515,9 @@ pub async fn start_zmq_listener(
continue;
}
let seq = u64::from_be_bytes(seq_bytes.try_into().unwrap());
// Note: We extract the engine's sequence number for logging but use our own
// internal monotonic counter for event_id to ensure per-dp_rank monotonicity
let engine_seq = u64::from_be_bytes(seq_bytes.try_into().unwrap());
// Decode our batch of events.
let batch_result = rmps::from_slice::<KvEventBatch>(&payload);
......@@ -507,16 +528,19 @@ pub async fn start_zmq_listener(
};
tracing::trace!(
"ZMQ listener on {} received batch with {} events (seq={}, dp_rank={})",
"ZMQ listener on {} received batch with {} events (engine_seq={}, dp_rank={})",
zmq_endpoint,
batch.events.len(),
seq,
engine_seq,
batch.data_parallel_rank.unwrap_or(0)
);
let dp_rank = batch.data_parallel_rank.unwrap_or(0) as u32;
for raw_event in batch.events.into_iter() {
let event = convert_event(raw_event, seq, kv_block_size, dp_rank, &warning_count);
// Use shared monotonic event_id counter instead of engine's sequence number
let event_id = next_event_id.fetch_add(1, Ordering::SeqCst);
let event = convert_event(raw_event, event_id, kv_block_size, dp_rank, &warning_count);
if tx.send(event).is_err() {
tracing::warn!("Failed to send message to channel - receiver dropped");
exit_reason = "channel receiver dropped";
......@@ -1558,11 +1582,13 @@ mod tests_startup_helpers {
// Cancellation token so we can stop the listener
let token = dynamo_runtime::CancellationToken::new();
// Event ID counter for the test listener
let next_event_id = Arc::new(AtomicU64::new(0));
// Spawn async listener (connects to publisher bound above)
let listener_handle = tokio::spawn({
let token = token.clone();
start_zmq_listener(endpoint.to_string(), topic, tx, token, 4)
start_zmq_listener(endpoint.to_string(), topic, tx, token, 4, next_event_id)
});
// Give time for the connection to establish
......
......@@ -19,8 +19,8 @@ use tokio_util::sync::CancellationToken;
use crate::kv_router::{
KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE,
indexer::{DumpRequest, GetWorkersRequest, WorkerKvQueryResponse},
protocols::{RouterEvent, WorkerId},
indexer::{DumpRequest, GetWorkersRequest},
protocols::{DpRank, RouterEvent, WorkerId},
router_discovery_query,
worker_query::WorkerQueryClient,
};
......@@ -47,10 +47,6 @@ const MAX_SNAPSHOT_STABILITY_ATTEMPTS: usize = 10;
const CHECK_INTERVAL_BASE: Duration = Duration::from_secs(1);
const CHECK_INTERVAL_JITTER_MS: i64 = 100;
// Worker query retry configuration
const WORKER_QUERY_MAX_RETRIES: u32 = 8;
const WORKER_QUERY_INITIAL_BACKOFF_MS: u64 = 200;
// ============================================================================
// Discovery Helpers
// ============================================================================
......@@ -79,205 +75,6 @@ async fn get_instance_discovery_stream(
Ok(Box::pin(stream))
}
// ============================================================================
// Local KvIndexer-based Recovery
// ============================================================================
/// Recover missed events from all workers with local indexers.
///
/// This function should be called on router startup to catch up on any events
/// that were missed while the router was offline.
///
/// # Arguments
///
/// * `worker_query_client` - Client for querying worker local indexers
/// * `last_received_event_ids` - Map of worker ID to last received event ID
/// * `worker_ids` - List of worker IDs to recover from
/// * `event_tx` - Channel to send recovered events to the indexer
///
/// # Returns
///
/// Total number of events recovered across all workers
pub async fn recover_from_all_workers(
worker_query_client: &WorkerQueryClient,
last_received_event_ids: &HashMap<WorkerId, u64>,
worker_ids: &Vec<WorkerId>,
event_tx: &mpsc::Sender<RouterEvent>,
) -> usize {
let mut total_recovered = 0;
let mut successful_workers = 0;
let mut failed_workers = 0;
for &worker_id in worker_ids {
// Skip workers without local indexer
if !worker_query_client.has_local_indexer(worker_id) {
tracing::debug!(
"Skipping recovery - worker {worker_id} does not have local indexer enabled"
);
continue;
}
// If we haven't seen any events from this worker, start from beginning (None)
// If we've seen events, start from last_known_id + 1
let start_event_id = last_received_event_ids
.get(&worker_id)
.map(|&last_id| last_id + 1);
match recover_from_worker(
worker_query_client,
worker_id,
start_event_id,
None, // Get all events after start_event_id
event_tx,
)
.await
{
Ok(count) => {
total_recovered += count;
if count > 0 {
successful_workers += 1;
}
}
Err(_) => {
failed_workers += 1;
}
}
}
// Log summary
if total_recovered > 0 || failed_workers > 0 {
tracing::info!(
"Startup recovery completed: {total_recovered} events recovered from {successful_workers} workers, {failed_workers} workers failed"
);
}
total_recovered
}
/// Recover missed KV events from a specific worker.
///
/// # Arguments
///
/// * `worker_query_client` - Client for querying worker local indexers
/// * `worker_id` - The worker to recover from
/// * `start_event_id` - First event ID to fetch (inclusive), or None to start from beginning
/// * `end_event_id` - Last event ID to fetch (inclusive), or None for all
/// * `event_tx` - Channel to send recovered events to the indexer
///
/// # Returns
///
/// Number of events recovered, or error if recovery failed
pub async fn recover_from_worker(
worker_query_client: &WorkerQueryClient,
worker_id: WorkerId,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
event_tx: &mpsc::Sender<RouterEvent>,
) -> Result<usize> {
if worker_query_client.has_local_indexer(worker_id) {
tracing::debug!(
"Attempting recovery from worker {worker_id}, start_event_id: {start_event_id:?}, end_event_id: {end_event_id:?}"
);
} else {
tracing::warn!("Worker {worker_id} does not have local indexer enabled, skipping recovery");
return Ok(0);
}
// Query worker for events in range, with retry logic for transient failures
// (e.g., worker's query service not yet re-subscribed after NATS restart)
let mut response = None;
let mut last_error = None;
for attempt in 0..WORKER_QUERY_MAX_RETRIES {
match worker_query_client
.query_worker(worker_id, start_event_id, end_event_id)
.await
{
Ok(resp) => {
if attempt > 0 {
tracing::info!("Worker {worker_id} query succeeded after retry {attempt}");
}
response = Some(resp);
break;
}
Err(e) => {
last_error = Some(e);
if attempt < WORKER_QUERY_MAX_RETRIES - 1 {
let backoff_ms = WORKER_QUERY_INITIAL_BACKOFF_MS * 2_u64.pow(attempt);
tracing::warn!(
"Worker {worker_id} query failed on attempt {attempt}, retrying after {backoff_ms}ms"
);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
}
}
}
}
let response = match response {
Some(r) => r,
None => return Err(last_error.unwrap_or_else(|| anyhow::anyhow!("No response"))),
};
// Handle response variants
let events = match response {
WorkerKvQueryResponse::Events(events) => {
tracing::debug!(
"Got {count} buffered events from worker {worker_id}",
count = events.len()
);
events
}
WorkerKvQueryResponse::TreeDump(events) => {
tracing::info!(
"Got tree dump from worker {worker_id} (range too old or unspecified), count: {count}",
count = events.len()
);
events
}
WorkerKvQueryResponse::TooNew {
requested_start,
requested_end,
newest_available,
} => {
tracing::warn!(
"Worker {worker_id} requested range is newer than available data: requested_start: {requested_start:?}, requested_end: {requested_end:?}, newest_available: {newest_available}"
);
return Ok(0);
}
WorkerKvQueryResponse::InvalidRange { start_id, end_id } => {
anyhow::bail!("Invalid range: end_id ({end_id}) < start_id ({start_id})");
}
WorkerKvQueryResponse::Error(message) => {
anyhow::bail!("Worker {worker_id} query failed: {message}");
}
};
let events_count = events.len();
if events_count == 0 {
tracing::debug!(
"No events to recover from worker {worker_id}, start_event_id: {start_event_id:?}"
);
return Ok(0);
}
tracing::info!(
"Recovered {events_count} events from worker {worker_id}, start_event_id: {start_event_id:?}"
);
// Apply recovered events to the indexer
for event in events {
if let Err(e) = event_tx.send(event).await {
tracing::error!(
"Failed to send recovered event to indexer for worker {worker_id}: {e}"
);
anyhow::bail!("Failed to send recovered event: {e}");
}
}
Ok(events_count)
}
// ============================================================================
// Snapshot Management
// ============================================================================
......@@ -712,27 +509,16 @@ async fn handle_worker_discovery(
"DISCOVERY: Worker {worker_id} added, dumping local indexer into router"
);
match recover_from_worker(
worker_query_client,
worker_id,
None, // Start from beginning
None, // Get all events
kv_events_tx,
)
.await
{
Ok(count) => {
let total_recovered = worker_query_client
.recover_all_dp_ranks(worker_id, kv_events_tx)
.await;
if total_recovered > 0 {
tracing::info!(
"Successfully dumped worker {worker_id}'s local indexer, recovered {count} events"
);
}
Err(e) => {
tracing::warn!(
"Failed to dump worker {worker_id}'s local indexer (may not have local indexer enabled): {e}"
"DISCOVERY: Worker {worker_id} total recovered {total_recovered} events"
);
}
}
}
DiscoveryEvent::Removed(id) => {
let worker_id = id.instance_id();
tracing::warn!("DISCOVERY: Worker {worker_id} removed, removing from router indexer");
......@@ -801,19 +587,12 @@ pub async fn start_kv_router_background_event_plane(
ready_workers.len()
);
// Recover initial state from all ready workers
// Recover initial state from all ready workers (all dp_ranks)
for worker_id in &ready_workers {
if worker_query_client.has_local_indexer(*worker_id) {
match recover_from_worker(&worker_query_client, *worker_id, None, None, &kv_events_tx)
.await
{
Ok(count) => {
tracing::info!("Successfully recovered {count} events from worker {worker_id}");
}
Err(e) => {
tracing::warn!("Failed to recover from worker {worker_id}: {e}");
}
}
worker_query_client
.recover_all_dp_ranks(*worker_id, &kv_events_tx)
.await;
}
}
......@@ -822,8 +601,9 @@ pub async fn start_kv_router_background_event_plane(
get_instance_discovery_stream(&component, &cancellation_token).await?;
tokio::spawn(async move {
// Track last received event ID per worker for gap detection
let mut last_event_ids: HashMap<WorkerId, u64> = HashMap::new();
// Track last received event ID per (worker, dp_rank) for gap detection
// Each dp_rank has its own monotonic event ID sequence
let mut last_event_ids: HashMap<(WorkerId, DpRank), u64> = HashMap::new();
loop {
tokio::select! {
......@@ -860,7 +640,9 @@ pub async fn start_kv_router_background_event_plane(
};
let worker_id = event.worker_id;
let dp_rank = event.event.dp_rank;
let event_id = event.event.event_id;
let event_key = (worker_id, dp_rank);
// Use envelope metadata for additional debugging
tracing::trace!(
......@@ -869,9 +651,9 @@ pub async fn start_kv_router_background_event_plane(
envelope.sequence
);
// Gap detection: check if event ID is monotonically increasing per worker
// Gap detection: check if event ID is monotonically increasing per (worker, dp_rank)
// Note: event_id <= last_id is duplicate/out-of-order, apply anyway (idempotent)
if let Some(&last_id) = last_event_ids.get(&worker_id)
if let Some(&last_id) = last_event_ids.get(&event_key)
&& event_id > last_id + 1
{
// Gap detected - recover missing events before processing current
......@@ -879,32 +661,29 @@ pub async fn start_kv_router_background_event_plane(
let gap_end = event_id - 1;
let gap_size = gap_end - gap_start + 1;
tracing::warn!(
"Event ID gap detected for worker {worker_id}, recovering events [{gap_start}, {gap_end}], gap_size: {gap_size}"
"Event ID gap detected for worker {worker_id} dp_rank {dp_rank}, recovering events [{gap_start}, {gap_end}], gap_size: {gap_size}"
);
// Note: While recovering, new events may queue in the subscriber's
// internal buffer. We don't explicitly buffer them here for simplicity.
// The subscriber will process them in order after recovery completes.
if let Err(e) = recover_from_worker(
&worker_query_client,
worker_id,
Some(gap_start),
Some(gap_end),
&kv_events_tx,
).await {
if let Err(e) = worker_query_client
.recover_from_worker(worker_id, dp_rank, Some(gap_start), Some(gap_end), &kv_events_tx)
.await
{
tracing::error!(
"Failed to recover gap events for worker {worker_id} (gap_start: {gap_start}, gap_end: {gap_end}); proceeding with current event anyway: {e}"
"Failed to recover gap events for worker {worker_id} dp_rank {dp_rank} (gap_start: {gap_start}, gap_end: {gap_end}); proceeding with current event anyway: {e}"
);
// Note: If recovery fails, we still apply the current event.
// The tree will have a gap, but it's better than dropping the event.
}
}
// First event from this worker is always valid - we accept whatever ID it has.
// First event from this (worker, dp_rank) is always valid - we accept whatever ID it has.
// This handles initial startup and worker restarts without requiring event 0.
// Update last seen event ID (use max to handle out-of-order)
last_event_ids
.entry(worker_id)
.entry(event_key)
.and_modify(|id| *id = (*id).max(event_id))
.or_insert(event_id);
......
......@@ -2,33 +2,43 @@
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, Result};
use dashmap::DashMap;
use dynamo_runtime::component::Component;
use dynamo_runtime::pipeline::{
AsyncEngine, AsyncEngineContextProvider, ManyOut, PushRouter, ResponseStream, RouterMode,
SingleIn, async_trait, network::Ingress,
};
use dynamo_runtime::protocols::maybe_error::MaybeError;
use tokio::sync::OnceCell;
use dynamo_runtime::stream;
use tokio::sync::mpsc;
use tokio_stream::StreamExt;
use crate::discovery::RuntimeConfigsSubscriber;
use crate::kv_router::WORKER_KV_INDEXER_QUERY_ENDPOINT;
use crate::kv_router::indexer::{LocalKvIndexer, WorkerKvQueryRequest, WorkerKvQueryResponse};
use crate::kv_router::protocols::WorkerId;
use dynamo_runtime::stream;
use crate::kv_router::protocols::{DpRank, RouterEvent, WorkerId};
use crate::kv_router::worker_kv_indexer_query_endpoint;
// Recovery retry configuration
const RECOVERY_MAX_RETRIES: u32 = 8;
const RECOVERY_INITIAL_BACKOFF_MS: u64 = 200;
/// Router-side client for querying worker local KV indexers
///
/// Performs request/reply communication with workers via request plane endpoint routing.
/// (Only queries workers that have `enable_local_indexer=true` in their MDC user_data)
/// The client is spawned by KvRouter; it uses a subscriber from RuntimeConfigs.
///
/// Each dp_rank has its own LocalKvIndexer and query endpoint, so we maintain separate
/// routers per dp_rank to ensure queries go to the correct endpoint.
pub struct WorkerQueryClient {
component: Component,
/// Subscriber for runtime configs (includes shared configs DashMap)
subscriber: RuntimeConfigsSubscriber,
router: OnceCell<Arc<PushRouter<WorkerKvQueryRequest, WorkerKvQueryResponse>>>,
/// Routers keyed by dp_rank - each dp_rank has its own endpoint
routers: DashMap<DpRank, Arc<PushRouter<WorkerKvQueryRequest, WorkerKvQueryResponse>>>,
}
impl WorkerQueryClient {
......@@ -37,7 +47,7 @@ impl WorkerQueryClient {
Self {
component,
subscriber,
router: OnceCell::new(),
routers: DashMap::new(),
}
}
......@@ -56,11 +66,46 @@ impl WorkerQueryClient {
.unwrap_or(false)
}
/// Query a specific worker's local KV indexer and return its buffered events.
/// Get the data_parallel_size for a worker (defaults to 1 if not found)
pub fn get_data_parallel_size(&self, worker_id: WorkerId) -> u32 {
self.subscriber
.configs
.get(&worker_id)
.and_then(|entry| entry.value().as_ref().map(|c| c.data_parallel_size))
.unwrap_or(1)
}
/// Get or create a router for the specified dp_rank's endpoint
async fn get_router_for_dp_rank(
&self,
dp_rank: DpRank,
) -> Result<Arc<PushRouter<WorkerKvQueryRequest, WorkerKvQueryResponse>>> {
// Fast path: check if router already exists
if let Some(router) = self.routers.get(&dp_rank) {
return Ok(router.clone());
}
// Slow path: create new router
let endpoint_name = worker_kv_indexer_query_endpoint(dp_rank);
let endpoint = self.component.endpoint(&endpoint_name);
let client = endpoint.client().await?;
let router = Arc::new(PushRouter::from_client(client, RouterMode::RoundRobin).await?);
// Insert and return (if another thread inserted first, use theirs)
Ok(self
.routers
.entry(dp_rank)
.or_insert(router)
.value()
.clone())
}
/// Query a specific worker's local KV indexer for a specific dp_rank and return its buffered events.
/// Returns an error if the worker does not have enable_local_indexer=true.
pub async fn query_worker(
&self,
worker_id: WorkerId,
dp_rank: DpRank,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
) -> Result<WorkerKvQueryResponse> {
......@@ -71,15 +116,7 @@ impl WorkerQueryClient {
);
}
let router = self
.router
.get_or_try_init(|| async {
let endpoint = self.component.endpoint(WORKER_KV_INDEXER_QUERY_ENDPOINT);
let client = endpoint.client().await?;
let router = PushRouter::from_client(client, RouterMode::RoundRobin).await?;
Ok::<_, anyhow::Error>(Arc::new(router))
})
.await?;
let router = self.get_router_for_dp_rank(dp_rank).await?;
let request = WorkerKvQueryRequest {
worker_id,
......@@ -90,7 +127,7 @@ impl WorkerQueryClient {
.direct(SingleIn::new(request), worker_id)
.await
.with_context(|| {
format!("Failed to send worker KV query request to worker {worker_id} via endpoint")
format!("Failed to send worker KV query request to worker {worker_id} dp_rank {dp_rank} via endpoint")
})?;
let response = stream
......@@ -104,12 +141,170 @@ impl WorkerQueryClient {
Ok(response)
}
/// Recover events from all dp_ranks of a single worker.
///
/// # Returns
/// Total number of events recovered across all dp_ranks
pub async fn recover_all_dp_ranks(
&self,
worker_id: WorkerId,
event_tx: &mpsc::Sender<RouterEvent>,
) -> usize {
let dp_size = self.get_data_parallel_size(worker_id);
let mut total_recovered = 0;
for dp_rank in 0..dp_size {
match self
.recover_from_worker(worker_id, dp_rank, None, None, event_tx)
.await
{
Ok(count) => {
total_recovered += count;
if count > 0 {
tracing::info!(
"Recovered {count} events from worker {worker_id} dp_rank {dp_rank}"
);
}
}
Err(e) => {
tracing::warn!(
"Failed to recover from worker {worker_id} dp_rank {dp_rank}: {e}"
);
}
}
}
total_recovered
}
/// Recover missed KV events from a specific worker's dp_rank with retry logic.
///
/// # Returns
/// Number of events recovered, or error if recovery failed after all retries
pub async fn recover_from_worker(
&self,
worker_id: WorkerId,
dp_rank: DpRank,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
event_tx: &mpsc::Sender<RouterEvent>,
) -> Result<usize> {
if !self.has_local_indexer(worker_id) {
tracing::debug!(
"Worker {worker_id} does not have local indexer enabled, skipping recovery"
);
return Ok(0);
}
tracing::debug!(
"Attempting recovery from worker {worker_id} dp_rank {dp_rank}, \
start_event_id: {start_event_id:?}, end_event_id: {end_event_id:?}"
);
// Query worker with retry logic for transient failures
let mut response = None;
let mut last_error = None;
for attempt in 0..RECOVERY_MAX_RETRIES {
match self
.query_worker(worker_id, dp_rank, start_event_id, end_event_id)
.await
{
Ok(resp) => {
if attempt > 0 {
tracing::info!(
"Worker {worker_id} dp_rank {dp_rank} query succeeded after retry {attempt}"
);
}
response = Some(resp);
break;
}
Err(e) => {
last_error = Some(e);
if attempt < RECOVERY_MAX_RETRIES - 1 {
let backoff_ms = RECOVERY_INITIAL_BACKOFF_MS * 2_u64.pow(attempt);
tracing::warn!(
"Worker {worker_id} dp_rank {dp_rank} query failed on attempt {attempt}, \
retrying after {backoff_ms}ms"
);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
}
}
}
}
let response = match response {
Some(r) => r,
None => return Err(last_error.unwrap_or_else(|| anyhow::anyhow!("No response"))),
};
// Handle response variants
let events = match response {
WorkerKvQueryResponse::Events(events) => {
tracing::debug!(
"Got {count} buffered events from worker {worker_id} dp_rank {dp_rank}",
count = events.len()
);
events
}
WorkerKvQueryResponse::TreeDump(events) => {
tracing::info!(
"Got tree dump from worker {worker_id} dp_rank {dp_rank} \
(range too old or unspecified), count: {count}",
count = events.len()
);
events
}
WorkerKvQueryResponse::TooNew {
requested_start,
requested_end,
newest_available,
} => {
tracing::warn!(
"Requested range [{requested_start:?}, {requested_end:?}] is newer than \
available (newest: {newest_available}) for worker {worker_id} dp_rank {dp_rank}"
);
return Ok(0);
}
WorkerKvQueryResponse::InvalidRange { start_id, end_id } => {
anyhow::bail!(
"Invalid range for worker {worker_id} dp_rank {dp_rank}: \
end_id ({end_id}) < start_id ({start_id})"
);
}
WorkerKvQueryResponse::Error(msg) => {
anyhow::bail!("Worker {worker_id} dp_rank {dp_rank} query error: {msg}");
}
};
// Send recovered events to the indexer
let count = events.len();
if count == 0 {
tracing::debug!("No events to recover from worker {worker_id} dp_rank {dp_rank}");
return Ok(0);
}
tracing::info!("Recovered {count} events from worker {worker_id} dp_rank {dp_rank}");
for event in events {
if let Err(e) = event_tx.send(event).await {
tracing::error!(
"Failed to send recovered event to indexer for worker {worker_id} dp_rank {dp_rank}: {e}"
);
anyhow::bail!("Failed to send recovered event: {e}");
}
}
Ok(count)
}
}
// Worker-side endpoint registration for Router -> LocalKvIndexer query service
pub(crate) async fn start_worker_kv_query_endpoint(
component: Component,
worker_id: u64,
dp_rank: DpRank,
local_indexer: Arc<LocalKvIndexer>,
) {
let engine = Arc::new(WorkerKvQueryEngine {
......@@ -121,26 +316,28 @@ pub(crate) async fn start_worker_kv_query_endpoint(
Ok(ingress) => ingress,
Err(e) => {
tracing::error!(
"Failed to build WorkerKvQuery endpoint handler for worker {worker_id}: {e}"
"Failed to build WorkerKvQuery endpoint handler for worker {worker_id} dp_rank {dp_rank}: {e}"
);
return;
}
};
let endpoint_name = worker_kv_indexer_query_endpoint(dp_rank);
tracing::info!(
"WorkerKvQuery endpoint starting for worker {worker_id} on endpoint '{}'",
WORKER_KV_INDEXER_QUERY_ENDPOINT
"WorkerKvQuery endpoint starting for worker {worker_id} dp_rank {dp_rank} on endpoint '{endpoint_name}'"
);
if let Err(e) = component
.endpoint(WORKER_KV_INDEXER_QUERY_ENDPOINT)
.endpoint(&endpoint_name)
.endpoint_builder()
.handler(ingress)
.graceful_shutdown(true)
.start()
.await
{
tracing::error!("WorkerKvQuery endpoint failed for worker {worker_id}: {e}");
tracing::error!(
"WorkerKvQuery endpoint failed for worker {worker_id} dp_rank {dp_rank}: {e}"
);
}
}
......
......@@ -91,7 +91,7 @@ impl KvManager {
"Initializing KV event publisher for DP rank {dp_rank} with block_size {block_size}, enable_local_indexer={enable_local_indexer}"
);
Arc::new(
KvEventPublisher::new_with_local_indexer(comp, block_size as u32, None, enable_local_indexer)
KvEventPublisher::new_with_local_indexer(comp, block_size as u32, None, enable_local_indexer, dp_rank)
.expect("Failed to create KV event publisher"),
)
});
......
......@@ -1940,23 +1940,7 @@ def _test_router_decisions(
# Use async to manage the test flow
async def test_sync():
# Calculate expected number of instances
# With data parallelism:
# - vLLM/SGLang: each DP rank registers as a separate instance
# - Mockers: all DP ranks share the same worker instance ID (instance_ids returns worker IDs)
if test_dp_rank:
if (
hasattr(engine_workers, "data_parallel_size")
and engine_workers.data_parallel_size is not None
):
# vLLM/SGLang: each DP rank registers as a separate instance
expected_num_instances = (
engine_workers.num_workers * engine_workers.data_parallel_size
)
else:
# Mockers with dp_size or no DP: instance_ids() returns worker IDs
expected_num_instances = engine_workers.num_workers
else:
# Workers register one instance per process (not per dp_rank)
expected_num_instances = engine_workers.num_workers
# Wait for workers to be ready and get their instance IDs
......
......@@ -187,6 +187,8 @@ class MockerProcess:
mocker_args = mocker_args or {}
# Store dp_size for DP-aware test functions
self.dp_size = mocker_args.get("dp_size")
# Alias for consistency with vLLM/SGLang workers
self.data_parallel_size = self.dp_size
command = _build_mocker_command(
endpoint=self.endpoint,
......@@ -586,15 +588,17 @@ def test_indexers_sync(
nats_process, _etcd_process = runtime_services_dynamic_ports
# Create mocker args dictionary
# Use 2 DP ranks to test per-dp_rank event ID tracking and recovery
mocker_args = {
"speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE,
"enable_local_indexer": use_nats_core,
"dp_size": 2,
}
try:
# Start mocker instances
logger.info(f"Starting {NUM_MOCKERS} mocker instances")
# Start mocker instances (2 workers x 2 DP ranks = 4 independent event streams)
logger.info(f"Starting {NUM_MOCKERS} mocker instances with dp_size=2")
mockers = MockerProcess(
request,
mocker_args=mocker_args,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment