Unverified Commit 902eabd9 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: per dp rank gap detection (#5873)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 04f32fe2
...@@ -102,14 +102,16 @@ class ZmqKvEventPublisher: ...@@ -102,14 +102,16 @@ class ZmqKvEventPublisher:
def publish_stored( def publish_stored(
self, self,
event_id: int,
token_ids: list[int], token_ids: list[int],
num_block_tokens: list[int], num_block_tokens: list[int],
block_hashes: list[int], block_hashes: list[int],
lora_id: int = 0, lora_id: int = 0,
parent_hash: Optional[int] = None, parent_hash: Optional[int] = None,
): ):
"""Publish a BlockStored event.""" """Publish a BlockStored event.
Note: event_id is managed internally via self.sequence counter.
"""
# Convert block hashes to signed i64 format # Convert block hashes to signed i64 format
block_hashes_signed = [_to_signed_i64(h) for h in block_hashes] block_hashes_signed = [_to_signed_i64(h) for h in block_hashes]
parent_hash_signed = ( parent_hash_signed = (
...@@ -129,8 +131,11 @@ class ZmqKvEventPublisher: ...@@ -129,8 +131,11 @@ class ZmqKvEventPublisher:
self._publish_event(event) self._publish_event(event)
def publish_removed(self, event_id: int, block_hashes: list[int]): def publish_removed(self, block_hashes: list[int]):
"""Publish a BlockRemoved event.""" """Publish a BlockRemoved event.
Note: event_id is managed internally via self.sequence counter.
"""
# Convert block hashes to signed i64 format (vLLM compatibility) # Convert block hashes to signed i64 format (vLLM compatibility)
block_hashes_signed = [_to_signed_i64(h) for h in block_hashes] block_hashes_signed = [_to_signed_i64(h) for h in block_hashes]
...@@ -307,6 +312,8 @@ class Publisher: ...@@ -307,6 +312,8 @@ class Publisher:
self.partial_block_hashes: set[int] = set() self.partial_block_hashes: set[int] = set()
self.error_queue: Queue = Queue() self.error_queue: Queue = Queue()
self._stop_event = threading.Event() self._stop_event = threading.Event()
# Track the last engine event_id to assert consecutive event IDs from the engine
self._last_engine_event_id: Optional[int] = None
# Initialize ZMQ publisher if endpoint is provided (consolidator enabled) # Initialize ZMQ publisher if endpoint is provided (consolidator enabled)
if zmq_endpoint: if zmq_endpoint:
...@@ -476,6 +483,16 @@ class Publisher: ...@@ -476,6 +483,16 @@ class Publisher:
return return
event_id = event["event_id"] event_id = event["event_id"]
# Check for consecutive event IDs from the engine
if self._last_engine_event_id is not None:
expected_id = self._last_engine_event_id + 1
if event_id != expected_id:
logging.warning(
f"Non-consecutive engine event_id: expected {expected_id}, got {event_id}"
)
self._last_engine_event_id = event_id
data = event["data"] data = event["data"]
if data["type"] == "stored": if data["type"] == "stored":
self.processing_initial_created_events = False self.processing_initial_created_events = False
...@@ -513,13 +530,13 @@ class Publisher: ...@@ -513,13 +530,13 @@ class Publisher:
lora_id = data.get("lora_id", 0) lora_id = data.get("lora_id", 0)
logging.debug( logging.debug(
f"publish stored event: event_id: {event_id}, token_ids: {token_ids}, num_block_tokens: {num_block_tokens}, block_hashes: {block_hashes}, lora_id: {lora_id}, parent_hash: {parent_hash}" f"publish stored event: engine_event_id: {event_id}, token_ids: {token_ids}, num_block_tokens: {num_block_tokens}, block_hashes: {block_hashes}, lora_id: {lora_id}, parent_hash: {parent_hash}"
) )
# Publish to ZMQ if consolidator is enabled, otherwise publish to NATS # Publish to ZMQ if consolidator is enabled, otherwise publish to NATS
# Note: event_id is managed internally by the publisher (monotonic counter per dp_rank)
if self.zmq_kv_event_publisher: if self.zmq_kv_event_publisher:
# Consolidator enabled: publish to ZMQ only # Consolidator enabled: publish to ZMQ only
self.zmq_kv_event_publisher.publish_stored( self.zmq_kv_event_publisher.publish_stored(
event_id,
token_ids, token_ids,
num_block_tokens, num_block_tokens,
block_hashes, block_hashes,
...@@ -529,7 +546,6 @@ class Publisher: ...@@ -529,7 +546,6 @@ class Publisher:
elif self.kv_event_publisher: elif self.kv_event_publisher:
# No consolidator: publish to NATS (router subscribes directly) # No consolidator: publish to NATS (router subscribes directly)
self.kv_event_publisher.publish_stored( self.kv_event_publisher.publish_stored(
event_id,
token_ids, token_ids,
num_block_tokens, num_block_tokens,
block_hashes, block_hashes,
...@@ -552,17 +568,16 @@ class Publisher: ...@@ -552,17 +568,16 @@ class Publisher:
removed_block_hashes.append(block_hash) removed_block_hashes.append(block_hash)
logging.debug( logging.debug(
f"publish removed event: event_id: {event_id}, block_hashes: {removed_block_hashes}" f"publish removed event: engine_event_id: {event_id}, block_hashes: {removed_block_hashes}"
) )
# Publish to ZMQ if consolidator is enabled, otherwise publish to NATS # Publish to ZMQ if consolidator is enabled, otherwise publish to NATS
# Note: event_id is managed internally by the publisher (monotonic counter per dp_rank)
if self.zmq_kv_event_publisher: if self.zmq_kv_event_publisher:
# Consolidator enabled: publish to ZMQ only # Consolidator enabled: publish to ZMQ only
self.zmq_kv_event_publisher.publish_removed( self.zmq_kv_event_publisher.publish_removed(removed_block_hashes)
event_id, removed_block_hashes
)
elif self.kv_event_publisher: elif self.kv_event_publisher:
# No consolidator: publish to NATS (router subscribes directly) # No consolidator: publish to NATS (router subscribes directly)
self.kv_event_publisher.publish_removed(event_id, removed_block_hashes) self.kv_event_publisher.publish_removed(removed_block_hashes)
elif data["type"] == "created" and self.processing_initial_created_events: elif data["type"] == "created" and self.processing_initial_created_events:
self.update_max_window_size(event) self.update_max_window_size(event)
......
...@@ -279,6 +279,7 @@ def setup_kv_event_publisher( ...@@ -279,6 +279,7 @@ def setup_kv_event_publisher(
kv_block_size=vllm_config.cache_config.block_size, kv_block_size=vllm_config.cache_config.block_size,
zmq_endpoint=zmq_endpoint, zmq_endpoint=zmq_endpoint,
enable_local_indexer=config.enable_local_indexer, enable_local_indexer=config.enable_local_indexer,
dp_rank=dp_rank,
) )
kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config) kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
kv_publishers.append(kv_publisher) kv_publishers.append(kv_publisher)
......
...@@ -114,7 +114,9 @@ pub struct ZmqKvEventPublisherConfig { ...@@ -114,7 +114,9 @@ pub struct ZmqKvEventPublisherConfig {
pub zmq_topic: String, pub zmq_topic: String,
#[pyo3(get, set)] #[pyo3(get, set)]
pub enable_local_indexer: bool, // whether the underlying KvEventPublisher publishes to pub enable_local_indexer: bool, // whether the underlying KvEventPublisher publishes to
// both global and worker-local KvIndexers // both global and worker-local KvIndexers
#[pyo3(get, set)]
pub dp_rank: DpRank, // data parallel rank for this publisher
} }
#[pymethods] #[pymethods]
...@@ -125,7 +127,8 @@ impl ZmqKvEventPublisherConfig { ...@@ -125,7 +127,8 @@ impl ZmqKvEventPublisherConfig {
kv_block_size, kv_block_size,
zmq_endpoint = "tcp://127.0.0.1:5557".to_string(), zmq_endpoint = "tcp://127.0.0.1:5557".to_string(),
zmq_topic = "".to_string(), zmq_topic = "".to_string(),
enable_local_indexer = false enable_local_indexer = false,
dp_rank = 0
))] ))]
pub fn new( pub fn new(
worker_id: WorkerId, worker_id: WorkerId,
...@@ -133,6 +136,7 @@ impl ZmqKvEventPublisherConfig { ...@@ -133,6 +136,7 @@ impl ZmqKvEventPublisherConfig {
zmq_endpoint: String, zmq_endpoint: String,
zmq_topic: String, zmq_topic: String,
enable_local_indexer: bool, enable_local_indexer: bool,
dp_rank: DpRank,
) -> Self { ) -> Self {
Self { Self {
worker_id, worker_id,
...@@ -140,6 +144,7 @@ impl ZmqKvEventPublisherConfig { ...@@ -140,6 +144,7 @@ impl ZmqKvEventPublisherConfig {
zmq_endpoint, zmq_endpoint,
zmq_topic, zmq_topic,
enable_local_indexer, enable_local_indexer,
dp_rank,
} }
} }
} }
...@@ -161,6 +166,7 @@ impl ZmqKvEventPublisher { ...@@ -161,6 +166,7 @@ impl ZmqKvEventPublisher {
topic: config.zmq_topic, topic: config.zmq_topic,
}), }),
config.enable_local_indexer, config.enable_local_indexer,
config.dp_rank,
) )
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
Ok(Self { inner }) Ok(Self { inner })
...@@ -192,6 +198,8 @@ impl ZmqKvEventListener { ...@@ -192,6 +198,8 @@ impl ZmqKvEventListener {
runtime.block_on(async { runtime.block_on(async {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<KvCacheEvent>(); let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<KvCacheEvent>();
let shutdown_token = tokio_util::sync::CancellationToken::new(); let shutdown_token = tokio_util::sync::CancellationToken::new();
// Standalone listener needs its own event ID counter
let next_event_id = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0));
tokio::spawn(start_zmq_listener( tokio::spawn(start_zmq_listener(
zmq_endpoint, zmq_endpoint,
...@@ -199,6 +207,7 @@ impl ZmqKvEventListener { ...@@ -199,6 +207,7 @@ impl ZmqKvEventListener {
tx, tx,
shutdown_token.clone(), shutdown_token.clone(),
kv_block_size as u32, kv_block_size as u32,
next_event_id,
)); ));
Ok(Self { Ok(Self {
...@@ -273,6 +282,7 @@ impl KvEventPublisher { ...@@ -273,6 +282,7 @@ impl KvEventPublisher {
kv_block_size as u32, kv_block_size as u32,
None, None,
enable_local_indexer, enable_local_indexer,
dp_rank,
) )
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
...@@ -285,11 +295,10 @@ impl KvEventPublisher { ...@@ -285,11 +295,10 @@ impl KvEventPublisher {
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
#[pyo3(signature = (event_id, token_ids, num_block_tokens, block_hashes, lora_id, parent_hash=None, block_mm_infos=None))] #[pyo3(signature = (token_ids, num_block_tokens, block_hashes, lora_id, parent_hash=None, block_mm_infos=None))]
fn publish_stored( fn publish_stored(
&mut self, &self,
py: Python, py: Python,
event_id: u64,
token_ids: Vec<u32>, token_ids: Vec<u32>,
num_block_tokens: Vec<u64>, num_block_tokens: Vec<u64>,
block_hashes: Vec<i64>, block_hashes: Vec<i64>,
...@@ -302,6 +311,9 @@ impl KvEventPublisher { ...@@ -302,6 +311,9 @@ impl KvEventPublisher {
let warning_count = self.warning_count.clone(); let warning_count = self.warning_count.clone();
let inner = self.inner.clone(); let inner = self.inner.clone();
// Use shared monotonic event_id counter from the inner publisher
let event_id = inner.next_event_id();
// Convert Python block_mm_infos to Rust Vec<Option<BlockExtraInfo>> // Convert Python block_mm_infos to Rust Vec<Option<BlockExtraInfo>>
let mm_infos_rust: Option<Vec<Option<BlockExtraInfo>>> = block_mm_infos let mm_infos_rust: Option<Vec<Option<BlockExtraInfo>>> = block_mm_infos
.as_ref() .as_ref()
...@@ -338,10 +350,13 @@ impl KvEventPublisher { ...@@ -338,10 +350,13 @@ impl KvEventPublisher {
}) })
} }
fn publish_removed(&self, py: Python, event_id: u64, block_hashes: Vec<i64>) -> PyResult<()> { fn publish_removed(&self, py: Python, block_hashes: Vec<i64>) -> PyResult<()> {
let dp_rank = self.dp_rank; let dp_rank = self.dp_rank;
let inner = self.inner.clone(); let inner = self.inner.clone();
// Use shared monotonic event_id counter from the inner publisher
let event_id = inner.next_event_id();
py.allow_threads(|| { py.allow_threads(|| {
let block_hashes: Vec<ExternalSequenceBlockHash> = block_hashes let block_hashes: Vec<ExternalSequenceBlockHash> = block_hashes
.into_iter() .into_iter()
......
...@@ -753,7 +753,6 @@ class KvEventPublisher: ...@@ -753,7 +753,6 @@ class KvEventPublisher:
def publish_stored( def publish_stored(
self, self,
event_id: int,
token_ids: List[int], token_ids: List[int],
num_block_tokens: List[int], num_block_tokens: List[int],
block_hashes: List[int], block_hashes: List[int],
...@@ -763,8 +762,9 @@ class KvEventPublisher: ...@@ -763,8 +762,9 @@ class KvEventPublisher:
""" """
Publish a KV stored event. Publish a KV stored event.
Event IDs are managed internally by the publisher using a monotonic counter.
Args: Args:
event_id: The event ID
token_ids: List of token IDs token_ids: List of token IDs
num_block_tokens: Number of tokens per block num_block_tokens: Number of tokens per block
block_hashes: List of block hashes (signed 64-bit integers) block_hashes: List of block hashes (signed 64-bit integers)
...@@ -773,12 +773,13 @@ class KvEventPublisher: ...@@ -773,12 +773,13 @@ class KvEventPublisher:
""" """
... ...
def publish_removed(self, event_id: int, block_hashes: List[int]) -> None: def publish_removed(self, block_hashes: List[int]) -> None:
""" """
Publish a KV removed event. Publish a KV removed event.
Event IDs are managed internally by the publisher using a monotonic counter.
Args: Args:
event_id: The event ID
block_hashes: List of block hashes to remove (signed 64-bit integers) block_hashes: List of block hashes to remove (signed 64-bit integers)
""" """
... ...
...@@ -790,7 +791,8 @@ class ZmqKvEventPublisherConfig: ...@@ -790,7 +791,8 @@ class ZmqKvEventPublisherConfig:
kv_block_size: int, kv_block_size: int,
zmq_endpoint: str = "tcp://127.0.0.1:5557", zmq_endpoint: str = "tcp://127.0.0.1:5557",
zmq_topic: str = "", zmq_topic: str = "",
enable_local_indexer: bool = False enable_local_indexer: bool = False,
dp_rank: int = 0
) -> None: ) -> None:
""" """
Configuration for the ZmqKvEventPublisher. Configuration for the ZmqKvEventPublisher.
...@@ -800,6 +802,7 @@ class ZmqKvEventPublisherConfig: ...@@ -800,6 +802,7 @@ class ZmqKvEventPublisherConfig:
:param zmq_endpoint: The ZeroMQ endpoint. Defaults to "tcp://127.0.0.1:5557". :param zmq_endpoint: The ZeroMQ endpoint. Defaults to "tcp://127.0.0.1:5557".
:param zmq_topic: The ZeroMQ topic to subscribe to. Defaults to an empty string. :param zmq_topic: The ZeroMQ topic to subscribe to. Defaults to an empty string.
:param enable_local_indexer: Whether to enable the worker-local KV indexer. Defaults to False. :param enable_local_indexer: Whether to enable the worker-local KV indexer. Defaults to False.
:param dp_rank: The data parallel rank for this publisher. Defaults to 0.
""" """
... ...
......
...@@ -280,31 +280,30 @@ async def test_approx_kv_indexer(distributed_runtime): ...@@ -280,31 +280,30 @@ async def test_approx_kv_indexer(distributed_runtime):
class EventPublisher: class EventPublisher:
def __init__(self, component: Component, worker_id: int, kv_block_size: int): def __init__(self, component: Component, worker_id: int, kv_block_size: int):
self.publisher = KvEventPublisher(component, worker_id, kv_block_size) self.publisher = KvEventPublisher(component, worker_id, kv_block_size)
self.event_id_counter = 0 # Counter for generating unique block hashes (event_id is now managed internally by publisher)
self.block_hash_counter = 0
self.block_hashes: List[int] = [] self.block_hashes: List[int] = []
def store_event(self, tokens, lora_id): def store_event(self, tokens, lora_id):
parent_hash = self.event_id_counter if self.event_id_counter > 0 else None # Parent hash should reference the last published block, not the current one
parent_hash = self.block_hashes[-1] if self.block_hashes else None
self.publisher.publish_stored( self.publisher.publish_stored(
self.event_id_counter, # event_id
tokens, # token_ids tokens, # token_ids
[ [
len(tokens), len(tokens),
], # num_block_tokens ], # num_block_tokens
[ [
self.event_id_counter, self.block_hash_counter,
], # block_hashes ], # block_hashes
lora_id, # lora_id lora_id, # lora_id
parent_hash, # parent_hash parent_hash, # parent_hash
) )
self.block_hashes.append(self.event_id_counter) self.block_hashes.append(self.block_hash_counter)
self.event_id_counter += 1 self.block_hash_counter += 1
def remove_event(self): def remove_event(self):
self.publisher.publish_removed( self.publisher.publish_removed(
self.event_id_counter, # event_id
[ [
self.block_hashes[-1], self.block_hashes[-1],
], # block_hashes ], # block_hashes
) )
self.event_id_counter += 1
...@@ -46,7 +46,7 @@ use crate::{ ...@@ -46,7 +46,7 @@ use crate::{
approx::PruneConfig, approx::PruneConfig,
indexer::{KvIndexer, KvIndexerInterface, KvRouterError}, indexer::{KvIndexer, KvIndexerInterface, KvRouterError},
protocols::{ protocols::{
LocalBlockHash, OverlapScores, RouterEvent, RouterRequest, RouterResponse, DpRank, LocalBlockHash, OverlapScores, RouterEvent, RouterRequest, RouterResponse,
TokensWithHashes, WorkerId, WorkerSelectionResult, WorkerWithDpRank, TokensWithHashes, WorkerId, WorkerSelectionResult, WorkerWithDpRank,
compute_block_hash_for_seq, compute_seq_hash_for_block, compute_block_hash_for_seq, compute_seq_hash_for_block,
}, },
...@@ -80,9 +80,14 @@ pub const RADIX_STATE_BUCKET: &str = "radix-bucket"; ...@@ -80,9 +80,14 @@ pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state"; pub const RADIX_STATE_FILE: &str = "radix-state";
// for worker-local kvindexer query // for worker-local kvindexer query
pub const WORKER_KV_INDEXER_QUERY_ENDPOINT: &str = "worker_kv_indexer_query";
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer
/// Generates a dp_rank-specific endpoint name for the worker KV indexer query service.
/// Each dp_rank has its own LocalKvIndexer and query endpoint to ensure per-dp_rank monotonicity.
pub fn worker_kv_indexer_query_endpoint(dp_rank: DpRank) -> String {
format!("worker_kv_indexer_query_dp{dp_rank}")
}
// for router discovery registration // for router discovery registration
pub const KV_ROUTER_COMPONENT: &str = "kv-router"; pub const KV_ROUTER_COMPONENT: &str = "kv-router";
pub const KV_ROUTER_ENDPOINT: &str = "generate"; pub const KV_ROUTER_ENDPOINT: &str = "generate";
...@@ -627,6 +632,7 @@ impl KvRouter { ...@@ -627,6 +632,7 @@ impl KvRouter {
pub async fn query_worker_local_kv( pub async fn query_worker_local_kv(
&self, &self,
worker_id: WorkerId, worker_id: WorkerId,
dp_rank: DpRank,
start_event_id: Option<u64>, start_event_id: Option<u64>,
end_event_id: Option<u64>, end_event_id: Option<u64>,
) -> Result<WorkerKvQueryResponse> { ) -> Result<WorkerKvQueryResponse> {
...@@ -636,11 +642,11 @@ impl KvRouter { ...@@ -636,11 +642,11 @@ impl KvRouter {
.ok_or_else(|| anyhow::anyhow!("Worker query client not available (NATS required)"))?; .ok_or_else(|| anyhow::anyhow!("Worker query client not available (NATS required)"))?;
query_client query_client
.query_worker(worker_id, start_event_id, end_event_id) .query_worker(worker_id, dp_rank, start_event_id, end_event_id)
.await .await
} }
/// Recover missed KV events from a specific worker. /// Recover missed KV events from a specific worker's dp_rank.
/// ///
/// Queries the worker's local KV indexer for events starting from /// Queries the worker's local KV indexer for events starting from
/// `start_event_id` and applies them to the router's indexer. /// `start_event_id` and applies them to the router's indexer.
...@@ -648,11 +654,13 @@ impl KvRouter { ...@@ -648,11 +654,13 @@ impl KvRouter {
/// # Arguments /// # Arguments
/// ///
/// * `worker_id` - The worker to recover from /// * `worker_id` - The worker to recover from
/// * `dp_rank` - The data parallel rank to recover from
/// * `start_event_id` - First event ID to fetch (inclusive), or None to start from beginning /// * `start_event_id` - First event ID to fetch (inclusive), or None to start from beginning
/// * `end_event_id` - Last event ID to fetch (inclusive), or None for all /// * `end_event_id` - Last event ID to fetch (inclusive), or None for all
pub async fn recover_from_worker( pub async fn recover_from_worker(
&self, &self,
worker_id: WorkerId, worker_id: WorkerId,
dp_rank: DpRank,
start_event_id: Option<u64>, start_event_id: Option<u64>,
end_event_id: Option<u64>, end_event_id: Option<u64>,
) -> Result<usize> { ) -> Result<usize> {
...@@ -668,14 +676,9 @@ impl KvRouter { ...@@ -668,14 +676,9 @@ impl KvRouter {
} }
}; };
subscriber::recover_from_worker( query_client
query_client, .recover_from_worker(worker_id, dp_rank, start_event_id, end_event_id, &event_tx)
worker_id, .await
start_event_id,
end_event_id,
&event_tx,
)
.await
} }
} }
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
use std::fmt; use std::fmt;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::Duration; use std::time::Duration;
use anyhow::Result; use anyhow::Result;
...@@ -77,6 +77,7 @@ impl KvEventSource { ...@@ -77,6 +77,7 @@ impl KvEventSource {
source_config: KvEventSourceConfig, source_config: KvEventSourceConfig,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
tx: mpsc::UnboundedSender<KvCacheEvent>, tx: mpsc::UnboundedSender<KvCacheEvent>,
next_event_id: Arc<AtomicU64>,
) -> Result<Self> { ) -> Result<Self> {
match source_config { match source_config {
KvEventSourceConfig::Zmq { endpoint, topic } => { KvEventSourceConfig::Zmq { endpoint, topic } => {
...@@ -90,6 +91,7 @@ impl KvEventSource { ...@@ -90,6 +91,7 @@ impl KvEventSource {
tx, tx,
cancellation_token.clone(), cancellation_token.clone(),
kv_block_size, kv_block_size,
next_event_id,
)); ));
Ok(KvEventSource::Zmq { zmq_handle }) Ok(KvEventSource::Zmq { zmq_handle })
...@@ -117,6 +119,9 @@ pub struct KvEventPublisher { ...@@ -117,6 +119,9 @@ pub struct KvEventPublisher {
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
/// The channel to send events to. /// The channel to send events to.
tx: mpsc::UnboundedSender<KvCacheEvent>, tx: mpsc::UnboundedSender<KvCacheEvent>,
/// Internal monotonic event ID counter - ensures each event gets a unique, incrementing ID.
/// Shared with the ZMQ listener (if any) to maintain consistency.
next_event_id: Arc<AtomicU64>,
} }
impl KvEventPublisher { impl KvEventPublisher {
...@@ -125,7 +130,7 @@ impl KvEventPublisher { ...@@ -125,7 +130,7 @@ impl KvEventPublisher {
kv_block_size: u32, kv_block_size: u32,
source_config: Option<KvEventSourceConfig>, source_config: Option<KvEventSourceConfig>,
) -> Result<Self> { ) -> Result<Self> {
Self::new_with_local_indexer(component, kv_block_size, source_config, false) Self::new_with_local_indexer(component, kv_block_size, source_config, false, 0)
} }
pub fn new_with_local_indexer( pub fn new_with_local_indexer(
...@@ -133,6 +138,7 @@ impl KvEventPublisher { ...@@ -133,6 +138,7 @@ impl KvEventPublisher {
kv_block_size: u32, kv_block_size: u32,
source_config: Option<KvEventSourceConfig>, source_config: Option<KvEventSourceConfig>,
enable_local_indexer: bool, enable_local_indexer: bool,
dp_rank: DpRank,
) -> Result<Self> { ) -> Result<Self> {
let cancellation_token = CancellationToken::new(); let cancellation_token = CancellationToken::new();
...@@ -152,6 +158,9 @@ impl KvEventPublisher { ...@@ -152,6 +158,9 @@ impl KvEventPublisher {
); );
} }
// Internal monotonic event ID counter - shared with ZMQ listener if any
let next_event_id = Arc::new(AtomicU64::new(0));
// Create our event source (if any) // Create our event source (if any)
let mut source = None; let mut source = None;
if let Some(config) = source_config { if let Some(config) = source_config {
...@@ -161,6 +170,7 @@ impl KvEventPublisher { ...@@ -161,6 +170,7 @@ impl KvEventPublisher {
config, config,
cancellation_token.clone(), cancellation_token.clone(),
tx.clone(), tx.clone(),
next_event_id.clone(),
)?); )?);
} }
...@@ -189,6 +199,7 @@ impl KvEventPublisher { ...@@ -189,6 +199,7 @@ impl KvEventPublisher {
.spawn(start_worker_kv_query_endpoint( .spawn(start_worker_kv_query_endpoint(
component, component,
worker_id, worker_id,
dp_rank,
local_indexer, local_indexer,
)) ))
}); });
...@@ -253,6 +264,7 @@ impl KvEventPublisher { ...@@ -253,6 +264,7 @@ impl KvEventPublisher {
source, source,
cancellation_token, cancellation_token,
tx, tx,
next_event_id,
}) })
} }
...@@ -260,6 +272,12 @@ impl KvEventPublisher { ...@@ -260,6 +272,12 @@ impl KvEventPublisher {
self.tx.send(event) self.tx.send(event)
} }
/// Get and increment the next event ID atomically.
/// Use this to assign monotonically increasing event IDs to events before publishing.
pub fn next_event_id(&self) -> u64 {
self.next_event_id.fetch_add(1, Ordering::SeqCst)
}
pub fn kv_block_size(&self) -> u32 { pub fn kv_block_size(&self) -> u32 {
self.kv_block_size self.kv_block_size
} }
...@@ -406,6 +424,7 @@ pub async fn start_zmq_listener( ...@@ -406,6 +424,7 @@ pub async fn start_zmq_listener(
tx: mpsc::UnboundedSender<KvCacheEvent>, tx: mpsc::UnboundedSender<KvCacheEvent>,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
kv_block_size: u32, kv_block_size: u32,
next_event_id: Arc<AtomicU64>,
) { ) {
tracing::debug!( tracing::debug!(
"KVEventPublisher connecting to ZMQ endpoint {} (topic '{}')", "KVEventPublisher connecting to ZMQ endpoint {} (topic '{}')",
...@@ -496,7 +515,9 @@ pub async fn start_zmq_listener( ...@@ -496,7 +515,9 @@ pub async fn start_zmq_listener(
continue; continue;
} }
let seq = u64::from_be_bytes(seq_bytes.try_into().unwrap()); // Note: We extract the engine's sequence number for logging but use our own
// internal monotonic counter for event_id to ensure per-dp_rank monotonicity
let engine_seq = u64::from_be_bytes(seq_bytes.try_into().unwrap());
// Decode our batch of events. // Decode our batch of events.
let batch_result = rmps::from_slice::<KvEventBatch>(&payload); let batch_result = rmps::from_slice::<KvEventBatch>(&payload);
...@@ -507,16 +528,19 @@ pub async fn start_zmq_listener( ...@@ -507,16 +528,19 @@ pub async fn start_zmq_listener(
}; };
tracing::trace!( tracing::trace!(
"ZMQ listener on {} received batch with {} events (seq={}, dp_rank={})", "ZMQ listener on {} received batch with {} events (engine_seq={}, dp_rank={})",
zmq_endpoint, zmq_endpoint,
batch.events.len(), batch.events.len(),
seq, engine_seq,
batch.data_parallel_rank.unwrap_or(0) batch.data_parallel_rank.unwrap_or(0)
); );
let dp_rank = batch.data_parallel_rank.unwrap_or(0) as u32; let dp_rank = batch.data_parallel_rank.unwrap_or(0) as u32;
for raw_event in batch.events.into_iter() { for raw_event in batch.events.into_iter() {
let event = convert_event(raw_event, seq, kv_block_size, dp_rank, &warning_count); // Use shared monotonic event_id counter instead of engine's sequence number
let event_id = next_event_id.fetch_add(1, Ordering::SeqCst);
let event = convert_event(raw_event, event_id, kv_block_size, dp_rank, &warning_count);
if tx.send(event).is_err() { if tx.send(event).is_err() {
tracing::warn!("Failed to send message to channel - receiver dropped"); tracing::warn!("Failed to send message to channel - receiver dropped");
exit_reason = "channel receiver dropped"; exit_reason = "channel receiver dropped";
...@@ -1558,11 +1582,13 @@ mod tests_startup_helpers { ...@@ -1558,11 +1582,13 @@ mod tests_startup_helpers {
// Cancellation token so we can stop the listener // Cancellation token so we can stop the listener
let token = dynamo_runtime::CancellationToken::new(); let token = dynamo_runtime::CancellationToken::new();
// Event ID counter for the test listener
let next_event_id = Arc::new(AtomicU64::new(0));
// Spawn async listener (connects to publisher bound above) // Spawn async listener (connects to publisher bound above)
let listener_handle = tokio::spawn({ let listener_handle = tokio::spawn({
let token = token.clone(); let token = token.clone();
start_zmq_listener(endpoint.to_string(), topic, tx, token, 4) start_zmq_listener(endpoint.to_string(), topic, tx, token, 4, next_event_id)
}); });
// Give time for the connection to establish // Give time for the connection to establish
......
...@@ -19,8 +19,8 @@ use tokio_util::sync::CancellationToken; ...@@ -19,8 +19,8 @@ use tokio_util::sync::CancellationToken;
use crate::kv_router::{ use crate::kv_router::{
KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE, KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE,
indexer::{DumpRequest, GetWorkersRequest, WorkerKvQueryResponse}, indexer::{DumpRequest, GetWorkersRequest},
protocols::{RouterEvent, WorkerId}, protocols::{DpRank, RouterEvent, WorkerId},
router_discovery_query, router_discovery_query,
worker_query::WorkerQueryClient, worker_query::WorkerQueryClient,
}; };
...@@ -47,10 +47,6 @@ const MAX_SNAPSHOT_STABILITY_ATTEMPTS: usize = 10; ...@@ -47,10 +47,6 @@ const MAX_SNAPSHOT_STABILITY_ATTEMPTS: usize = 10;
const CHECK_INTERVAL_BASE: Duration = Duration::from_secs(1); const CHECK_INTERVAL_BASE: Duration = Duration::from_secs(1);
const CHECK_INTERVAL_JITTER_MS: i64 = 100; const CHECK_INTERVAL_JITTER_MS: i64 = 100;
// Worker query retry configuration
const WORKER_QUERY_MAX_RETRIES: u32 = 8;
const WORKER_QUERY_INITIAL_BACKOFF_MS: u64 = 200;
// ============================================================================ // ============================================================================
// Discovery Helpers // Discovery Helpers
// ============================================================================ // ============================================================================
...@@ -79,205 +75,6 @@ async fn get_instance_discovery_stream( ...@@ -79,205 +75,6 @@ async fn get_instance_discovery_stream(
Ok(Box::pin(stream)) Ok(Box::pin(stream))
} }
// ============================================================================
// Local KvIndexer-based Recovery
// ============================================================================
/// Recover missed events from all workers with local indexers.
///
/// This function should be called on router startup to catch up on any events
/// that were missed while the router was offline.
///
/// # Arguments
///
/// * `worker_query_client` - Client for querying worker local indexers
/// * `last_received_event_ids` - Map of worker ID to last received event ID
/// * `worker_ids` - List of worker IDs to recover from
/// * `event_tx` - Channel to send recovered events to the indexer
///
/// # Returns
///
/// Total number of events recovered across all workers
pub async fn recover_from_all_workers(
worker_query_client: &WorkerQueryClient,
last_received_event_ids: &HashMap<WorkerId, u64>,
worker_ids: &Vec<WorkerId>,
event_tx: &mpsc::Sender<RouterEvent>,
) -> usize {
let mut total_recovered = 0;
let mut successful_workers = 0;
let mut failed_workers = 0;
for &worker_id in worker_ids {
// Skip workers without local indexer
if !worker_query_client.has_local_indexer(worker_id) {
tracing::debug!(
"Skipping recovery - worker {worker_id} does not have local indexer enabled"
);
continue;
}
// If we haven't seen any events from this worker, start from beginning (None)
// If we've seen events, start from last_known_id + 1
let start_event_id = last_received_event_ids
.get(&worker_id)
.map(|&last_id| last_id + 1);
match recover_from_worker(
worker_query_client,
worker_id,
start_event_id,
None, // Get all events after start_event_id
event_tx,
)
.await
{
Ok(count) => {
total_recovered += count;
if count > 0 {
successful_workers += 1;
}
}
Err(_) => {
failed_workers += 1;
}
}
}
// Log summary
if total_recovered > 0 || failed_workers > 0 {
tracing::info!(
"Startup recovery completed: {total_recovered} events recovered from {successful_workers} workers, {failed_workers} workers failed"
);
}
total_recovered
}
/// Recover missed KV events from a specific worker.
///
/// # Arguments
///
/// * `worker_query_client` - Client for querying worker local indexers
/// * `worker_id` - The worker to recover from
/// * `start_event_id` - First event ID to fetch (inclusive), or None to start from beginning
/// * `end_event_id` - Last event ID to fetch (inclusive), or None for all
/// * `event_tx` - Channel to send recovered events to the indexer
///
/// # Returns
///
/// Number of events recovered, or error if recovery failed
pub async fn recover_from_worker(
worker_query_client: &WorkerQueryClient,
worker_id: WorkerId,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
event_tx: &mpsc::Sender<RouterEvent>,
) -> Result<usize> {
if worker_query_client.has_local_indexer(worker_id) {
tracing::debug!(
"Attempting recovery from worker {worker_id}, start_event_id: {start_event_id:?}, end_event_id: {end_event_id:?}"
);
} else {
tracing::warn!("Worker {worker_id} does not have local indexer enabled, skipping recovery");
return Ok(0);
}
// Query worker for events in range, with retry logic for transient failures
// (e.g., worker's query service not yet re-subscribed after NATS restart)
let mut response = None;
let mut last_error = None;
for attempt in 0..WORKER_QUERY_MAX_RETRIES {
match worker_query_client
.query_worker(worker_id, start_event_id, end_event_id)
.await
{
Ok(resp) => {
if attempt > 0 {
tracing::info!("Worker {worker_id} query succeeded after retry {attempt}");
}
response = Some(resp);
break;
}
Err(e) => {
last_error = Some(e);
if attempt < WORKER_QUERY_MAX_RETRIES - 1 {
let backoff_ms = WORKER_QUERY_INITIAL_BACKOFF_MS * 2_u64.pow(attempt);
tracing::warn!(
"Worker {worker_id} query failed on attempt {attempt}, retrying after {backoff_ms}ms"
);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
}
}
}
}
let response = match response {
Some(r) => r,
None => return Err(last_error.unwrap_or_else(|| anyhow::anyhow!("No response"))),
};
// Handle response variants
let events = match response {
WorkerKvQueryResponse::Events(events) => {
tracing::debug!(
"Got {count} buffered events from worker {worker_id}",
count = events.len()
);
events
}
WorkerKvQueryResponse::TreeDump(events) => {
tracing::info!(
"Got tree dump from worker {worker_id} (range too old or unspecified), count: {count}",
count = events.len()
);
events
}
WorkerKvQueryResponse::TooNew {
requested_start,
requested_end,
newest_available,
} => {
tracing::warn!(
"Worker {worker_id} requested range is newer than available data: requested_start: {requested_start:?}, requested_end: {requested_end:?}, newest_available: {newest_available}"
);
return Ok(0);
}
WorkerKvQueryResponse::InvalidRange { start_id, end_id } => {
anyhow::bail!("Invalid range: end_id ({end_id}) < start_id ({start_id})");
}
WorkerKvQueryResponse::Error(message) => {
anyhow::bail!("Worker {worker_id} query failed: {message}");
}
};
let events_count = events.len();
if events_count == 0 {
tracing::debug!(
"No events to recover from worker {worker_id}, start_event_id: {start_event_id:?}"
);
return Ok(0);
}
tracing::info!(
"Recovered {events_count} events from worker {worker_id}, start_event_id: {start_event_id:?}"
);
// Apply recovered events to the indexer
for event in events {
if let Err(e) = event_tx.send(event).await {
tracing::error!(
"Failed to send recovered event to indexer for worker {worker_id}: {e}"
);
anyhow::bail!("Failed to send recovered event: {e}");
}
}
Ok(events_count)
}
// ============================================================================ // ============================================================================
// Snapshot Management // Snapshot Management
// ============================================================================ // ============================================================================
...@@ -712,25 +509,14 @@ async fn handle_worker_discovery( ...@@ -712,25 +509,14 @@ async fn handle_worker_discovery(
"DISCOVERY: Worker {worker_id} added, dumping local indexer into router" "DISCOVERY: Worker {worker_id} added, dumping local indexer into router"
); );
match recover_from_worker( let total_recovered = worker_query_client
worker_query_client, .recover_all_dp_ranks(worker_id, kv_events_tx)
worker_id, .await;
None, // Start from beginning
None, // Get all events if total_recovered > 0 {
kv_events_tx, tracing::info!(
) "DISCOVERY: Worker {worker_id} total recovered {total_recovered} events"
.await );
{
Ok(count) => {
tracing::info!(
"Successfully dumped worker {worker_id}'s local indexer, recovered {count} events"
);
}
Err(e) => {
tracing::warn!(
"Failed to dump worker {worker_id}'s local indexer (may not have local indexer enabled): {e}"
);
}
} }
} }
DiscoveryEvent::Removed(id) => { DiscoveryEvent::Removed(id) => {
...@@ -801,19 +587,12 @@ pub async fn start_kv_router_background_event_plane( ...@@ -801,19 +587,12 @@ pub async fn start_kv_router_background_event_plane(
ready_workers.len() ready_workers.len()
); );
// Recover initial state from all ready workers // Recover initial state from all ready workers (all dp_ranks)
for worker_id in &ready_workers { for worker_id in &ready_workers {
if worker_query_client.has_local_indexer(*worker_id) { if worker_query_client.has_local_indexer(*worker_id) {
match recover_from_worker(&worker_query_client, *worker_id, None, None, &kv_events_tx) worker_query_client
.await .recover_all_dp_ranks(*worker_id, &kv_events_tx)
{ .await;
Ok(count) => {
tracing::info!("Successfully recovered {count} events from worker {worker_id}");
}
Err(e) => {
tracing::warn!("Failed to recover from worker {worker_id}: {e}");
}
}
} }
} }
...@@ -822,8 +601,9 @@ pub async fn start_kv_router_background_event_plane( ...@@ -822,8 +601,9 @@ pub async fn start_kv_router_background_event_plane(
get_instance_discovery_stream(&component, &cancellation_token).await?; get_instance_discovery_stream(&component, &cancellation_token).await?;
tokio::spawn(async move { tokio::spawn(async move {
// Track last received event ID per worker for gap detection // Track last received event ID per (worker, dp_rank) for gap detection
let mut last_event_ids: HashMap<WorkerId, u64> = HashMap::new(); // Each dp_rank has its own monotonic event ID sequence
let mut last_event_ids: HashMap<(WorkerId, DpRank), u64> = HashMap::new();
loop { loop {
tokio::select! { tokio::select! {
...@@ -860,7 +640,9 @@ pub async fn start_kv_router_background_event_plane( ...@@ -860,7 +640,9 @@ pub async fn start_kv_router_background_event_plane(
}; };
let worker_id = event.worker_id; let worker_id = event.worker_id;
let dp_rank = event.event.dp_rank;
let event_id = event.event.event_id; let event_id = event.event.event_id;
let event_key = (worker_id, dp_rank);
// Use envelope metadata for additional debugging // Use envelope metadata for additional debugging
tracing::trace!( tracing::trace!(
...@@ -869,9 +651,9 @@ pub async fn start_kv_router_background_event_plane( ...@@ -869,9 +651,9 @@ pub async fn start_kv_router_background_event_plane(
envelope.sequence envelope.sequence
); );
// Gap detection: check if event ID is monotonically increasing per worker // Gap detection: check if event ID is monotonically increasing per (worker, dp_rank)
// Note: event_id <= last_id is duplicate/out-of-order, apply anyway (idempotent) // Note: event_id <= last_id is duplicate/out-of-order, apply anyway (idempotent)
if let Some(&last_id) = last_event_ids.get(&worker_id) if let Some(&last_id) = last_event_ids.get(&event_key)
&& event_id > last_id + 1 && event_id > last_id + 1
{ {
// Gap detected - recover missing events before processing current // Gap detected - recover missing events before processing current
...@@ -879,32 +661,29 @@ pub async fn start_kv_router_background_event_plane( ...@@ -879,32 +661,29 @@ pub async fn start_kv_router_background_event_plane(
let gap_end = event_id - 1; let gap_end = event_id - 1;
let gap_size = gap_end - gap_start + 1; let gap_size = gap_end - gap_start + 1;
tracing::warn!( tracing::warn!(
"Event ID gap detected for worker {worker_id}, recovering events [{gap_start}, {gap_end}], gap_size: {gap_size}" "Event ID gap detected for worker {worker_id} dp_rank {dp_rank}, recovering events [{gap_start}, {gap_end}], gap_size: {gap_size}"
); );
// Note: While recovering, new events may queue in the subscriber's // Note: While recovering, new events may queue in the subscriber's
// internal buffer. We don't explicitly buffer them here for simplicity. // internal buffer. We don't explicitly buffer them here for simplicity.
// The subscriber will process them in order after recovery completes. // The subscriber will process them in order after recovery completes.
if let Err(e) = recover_from_worker( if let Err(e) = worker_query_client
&worker_query_client, .recover_from_worker(worker_id, dp_rank, Some(gap_start), Some(gap_end), &kv_events_tx)
worker_id, .await
Some(gap_start), {
Some(gap_end),
&kv_events_tx,
).await {
tracing::error!( tracing::error!(
"Failed to recover gap events for worker {worker_id} (gap_start: {gap_start}, gap_end: {gap_end}); proceeding with current event anyway: {e}" "Failed to recover gap events for worker {worker_id} dp_rank {dp_rank} (gap_start: {gap_start}, gap_end: {gap_end}); proceeding with current event anyway: {e}"
); );
// Note: If recovery fails, we still apply the current event. // Note: If recovery fails, we still apply the current event.
// The tree will have a gap, but it's better than dropping the event. // The tree will have a gap, but it's better than dropping the event.
} }
} }
// First event from this worker is always valid - we accept whatever ID it has. // First event from this (worker, dp_rank) is always valid - we accept whatever ID it has.
// This handles initial startup and worker restarts without requiring event 0. // This handles initial startup and worker restarts without requiring event 0.
// Update last seen event ID (use max to handle out-of-order) // Update last seen event ID (use max to handle out-of-order)
last_event_ids last_event_ids
.entry(worker_id) .entry(event_key)
.and_modify(|id| *id = (*id).max(event_id)) .and_modify(|id| *id = (*id).max(event_id))
.or_insert(event_id); .or_insert(event_id);
......
...@@ -2,33 +2,43 @@ ...@@ -2,33 +2,43 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use dashmap::DashMap;
use dynamo_runtime::component::Component; use dynamo_runtime::component::Component;
use dynamo_runtime::pipeline::{ use dynamo_runtime::pipeline::{
AsyncEngine, AsyncEngineContextProvider, ManyOut, PushRouter, ResponseStream, RouterMode, AsyncEngine, AsyncEngineContextProvider, ManyOut, PushRouter, ResponseStream, RouterMode,
SingleIn, async_trait, network::Ingress, SingleIn, async_trait, network::Ingress,
}; };
use dynamo_runtime::protocols::maybe_error::MaybeError; use dynamo_runtime::protocols::maybe_error::MaybeError;
use tokio::sync::OnceCell; use dynamo_runtime::stream;
use tokio::sync::mpsc;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use crate::discovery::RuntimeConfigsSubscriber; use crate::discovery::RuntimeConfigsSubscriber;
use crate::kv_router::WORKER_KV_INDEXER_QUERY_ENDPOINT;
use crate::kv_router::indexer::{LocalKvIndexer, WorkerKvQueryRequest, WorkerKvQueryResponse}; use crate::kv_router::indexer::{LocalKvIndexer, WorkerKvQueryRequest, WorkerKvQueryResponse};
use crate::kv_router::protocols::WorkerId; use crate::kv_router::protocols::{DpRank, RouterEvent, WorkerId};
use dynamo_runtime::stream; use crate::kv_router::worker_kv_indexer_query_endpoint;
// Recovery retry configuration
const RECOVERY_MAX_RETRIES: u32 = 8;
const RECOVERY_INITIAL_BACKOFF_MS: u64 = 200;
/// Router-side client for querying worker local KV indexers /// Router-side client for querying worker local KV indexers
/// ///
/// Performs request/reply communication with workers via request plane endpoint routing. /// Performs request/reply communication with workers via request plane endpoint routing.
/// (Only queries workers that have `enable_local_indexer=true` in their MDC user_data) /// (Only queries workers that have `enable_local_indexer=true` in their MDC user_data)
/// The client is spawned by KvRouter; it uses a subscriber from RuntimeConfigs. /// The client is spawned by KvRouter; it uses a subscriber from RuntimeConfigs.
///
/// Each dp_rank has its own LocalKvIndexer and query endpoint, so we maintain separate
/// routers per dp_rank to ensure queries go to the correct endpoint.
pub struct WorkerQueryClient { pub struct WorkerQueryClient {
component: Component, component: Component,
/// Subscriber for runtime configs (includes shared configs DashMap) /// Subscriber for runtime configs (includes shared configs DashMap)
subscriber: RuntimeConfigsSubscriber, subscriber: RuntimeConfigsSubscriber,
router: OnceCell<Arc<PushRouter<WorkerKvQueryRequest, WorkerKvQueryResponse>>>, /// Routers keyed by dp_rank - each dp_rank has its own endpoint
routers: DashMap<DpRank, Arc<PushRouter<WorkerKvQueryRequest, WorkerKvQueryResponse>>>,
} }
impl WorkerQueryClient { impl WorkerQueryClient {
...@@ -37,7 +47,7 @@ impl WorkerQueryClient { ...@@ -37,7 +47,7 @@ impl WorkerQueryClient {
Self { Self {
component, component,
subscriber, subscriber,
router: OnceCell::new(), routers: DashMap::new(),
} }
} }
...@@ -56,11 +66,46 @@ impl WorkerQueryClient { ...@@ -56,11 +66,46 @@ impl WorkerQueryClient {
.unwrap_or(false) .unwrap_or(false)
} }
/// Query a specific worker's local KV indexer and return its buffered events. /// Get the data_parallel_size for a worker (defaults to 1 if not found)
pub fn get_data_parallel_size(&self, worker_id: WorkerId) -> u32 {
self.subscriber
.configs
.get(&worker_id)
.and_then(|entry| entry.value().as_ref().map(|c| c.data_parallel_size))
.unwrap_or(1)
}
/// Get or create a router for the specified dp_rank's endpoint
async fn get_router_for_dp_rank(
&self,
dp_rank: DpRank,
) -> Result<Arc<PushRouter<WorkerKvQueryRequest, WorkerKvQueryResponse>>> {
// Fast path: check if router already exists
if let Some(router) = self.routers.get(&dp_rank) {
return Ok(router.clone());
}
// Slow path: create new router
let endpoint_name = worker_kv_indexer_query_endpoint(dp_rank);
let endpoint = self.component.endpoint(&endpoint_name);
let client = endpoint.client().await?;
let router = Arc::new(PushRouter::from_client(client, RouterMode::RoundRobin).await?);
// Insert and return (if another thread inserted first, use theirs)
Ok(self
.routers
.entry(dp_rank)
.or_insert(router)
.value()
.clone())
}
/// Query a specific worker's local KV indexer for a specific dp_rank and return its buffered events.
/// Returns an error if the worker does not have enable_local_indexer=true. /// Returns an error if the worker does not have enable_local_indexer=true.
pub async fn query_worker( pub async fn query_worker(
&self, &self,
worker_id: WorkerId, worker_id: WorkerId,
dp_rank: DpRank,
start_event_id: Option<u64>, start_event_id: Option<u64>,
end_event_id: Option<u64>, end_event_id: Option<u64>,
) -> Result<WorkerKvQueryResponse> { ) -> Result<WorkerKvQueryResponse> {
...@@ -71,15 +116,7 @@ impl WorkerQueryClient { ...@@ -71,15 +116,7 @@ impl WorkerQueryClient {
); );
} }
let router = self let router = self.get_router_for_dp_rank(dp_rank).await?;
.router
.get_or_try_init(|| async {
let endpoint = self.component.endpoint(WORKER_KV_INDEXER_QUERY_ENDPOINT);
let client = endpoint.client().await?;
let router = PushRouter::from_client(client, RouterMode::RoundRobin).await?;
Ok::<_, anyhow::Error>(Arc::new(router))
})
.await?;
let request = WorkerKvQueryRequest { let request = WorkerKvQueryRequest {
worker_id, worker_id,
...@@ -90,7 +127,7 @@ impl WorkerQueryClient { ...@@ -90,7 +127,7 @@ impl WorkerQueryClient {
.direct(SingleIn::new(request), worker_id) .direct(SingleIn::new(request), worker_id)
.await .await
.with_context(|| { .with_context(|| {
format!("Failed to send worker KV query request to worker {worker_id} via endpoint") format!("Failed to send worker KV query request to worker {worker_id} dp_rank {dp_rank} via endpoint")
})?; })?;
let response = stream let response = stream
...@@ -104,12 +141,170 @@ impl WorkerQueryClient { ...@@ -104,12 +141,170 @@ impl WorkerQueryClient {
Ok(response) Ok(response)
} }
/// Recover events from all dp_ranks of a single worker.
///
/// # Returns
/// Total number of events recovered across all dp_ranks
pub async fn recover_all_dp_ranks(
&self,
worker_id: WorkerId,
event_tx: &mpsc::Sender<RouterEvent>,
) -> usize {
let dp_size = self.get_data_parallel_size(worker_id);
let mut total_recovered = 0;
for dp_rank in 0..dp_size {
match self
.recover_from_worker(worker_id, dp_rank, None, None, event_tx)
.await
{
Ok(count) => {
total_recovered += count;
if count > 0 {
tracing::info!(
"Recovered {count} events from worker {worker_id} dp_rank {dp_rank}"
);
}
}
Err(e) => {
tracing::warn!(
"Failed to recover from worker {worker_id} dp_rank {dp_rank}: {e}"
);
}
}
}
total_recovered
}
/// Recover missed KV events from a specific worker's dp_rank with retry logic.
///
/// # Returns
/// Number of events recovered, or error if recovery failed after all retries
pub async fn recover_from_worker(
&self,
worker_id: WorkerId,
dp_rank: DpRank,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
event_tx: &mpsc::Sender<RouterEvent>,
) -> Result<usize> {
if !self.has_local_indexer(worker_id) {
tracing::debug!(
"Worker {worker_id} does not have local indexer enabled, skipping recovery"
);
return Ok(0);
}
tracing::debug!(
"Attempting recovery from worker {worker_id} dp_rank {dp_rank}, \
start_event_id: {start_event_id:?}, end_event_id: {end_event_id:?}"
);
// Query worker with retry logic for transient failures
let mut response = None;
let mut last_error = None;
for attempt in 0..RECOVERY_MAX_RETRIES {
match self
.query_worker(worker_id, dp_rank, start_event_id, end_event_id)
.await
{
Ok(resp) => {
if attempt > 0 {
tracing::info!(
"Worker {worker_id} dp_rank {dp_rank} query succeeded after retry {attempt}"
);
}
response = Some(resp);
break;
}
Err(e) => {
last_error = Some(e);
if attempt < RECOVERY_MAX_RETRIES - 1 {
let backoff_ms = RECOVERY_INITIAL_BACKOFF_MS * 2_u64.pow(attempt);
tracing::warn!(
"Worker {worker_id} dp_rank {dp_rank} query failed on attempt {attempt}, \
retrying after {backoff_ms}ms"
);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
}
}
}
}
let response = match response {
Some(r) => r,
None => return Err(last_error.unwrap_or_else(|| anyhow::anyhow!("No response"))),
};
// Handle response variants
let events = match response {
WorkerKvQueryResponse::Events(events) => {
tracing::debug!(
"Got {count} buffered events from worker {worker_id} dp_rank {dp_rank}",
count = events.len()
);
events
}
WorkerKvQueryResponse::TreeDump(events) => {
tracing::info!(
"Got tree dump from worker {worker_id} dp_rank {dp_rank} \
(range too old or unspecified), count: {count}",
count = events.len()
);
events
}
WorkerKvQueryResponse::TooNew {
requested_start,
requested_end,
newest_available,
} => {
tracing::warn!(
"Requested range [{requested_start:?}, {requested_end:?}] is newer than \
available (newest: {newest_available}) for worker {worker_id} dp_rank {dp_rank}"
);
return Ok(0);
}
WorkerKvQueryResponse::InvalidRange { start_id, end_id } => {
anyhow::bail!(
"Invalid range for worker {worker_id} dp_rank {dp_rank}: \
end_id ({end_id}) < start_id ({start_id})"
);
}
WorkerKvQueryResponse::Error(msg) => {
anyhow::bail!("Worker {worker_id} dp_rank {dp_rank} query error: {msg}");
}
};
// Send recovered events to the indexer
let count = events.len();
if count == 0 {
tracing::debug!("No events to recover from worker {worker_id} dp_rank {dp_rank}");
return Ok(0);
}
tracing::info!("Recovered {count} events from worker {worker_id} dp_rank {dp_rank}");
for event in events {
if let Err(e) = event_tx.send(event).await {
tracing::error!(
"Failed to send recovered event to indexer for worker {worker_id} dp_rank {dp_rank}: {e}"
);
anyhow::bail!("Failed to send recovered event: {e}");
}
}
Ok(count)
}
} }
// Worker-side endpoint registration for Router -> LocalKvIndexer query service // Worker-side endpoint registration for Router -> LocalKvIndexer query service
pub(crate) async fn start_worker_kv_query_endpoint( pub(crate) async fn start_worker_kv_query_endpoint(
component: Component, component: Component,
worker_id: u64, worker_id: u64,
dp_rank: DpRank,
local_indexer: Arc<LocalKvIndexer>, local_indexer: Arc<LocalKvIndexer>,
) { ) {
let engine = Arc::new(WorkerKvQueryEngine { let engine = Arc::new(WorkerKvQueryEngine {
...@@ -121,26 +316,28 @@ pub(crate) async fn start_worker_kv_query_endpoint( ...@@ -121,26 +316,28 @@ pub(crate) async fn start_worker_kv_query_endpoint(
Ok(ingress) => ingress, Ok(ingress) => ingress,
Err(e) => { Err(e) => {
tracing::error!( tracing::error!(
"Failed to build WorkerKvQuery endpoint handler for worker {worker_id}: {e}" "Failed to build WorkerKvQuery endpoint handler for worker {worker_id} dp_rank {dp_rank}: {e}"
); );
return; return;
} }
}; };
let endpoint_name = worker_kv_indexer_query_endpoint(dp_rank);
tracing::info!( tracing::info!(
"WorkerKvQuery endpoint starting for worker {worker_id} on endpoint '{}'", "WorkerKvQuery endpoint starting for worker {worker_id} dp_rank {dp_rank} on endpoint '{endpoint_name}'"
WORKER_KV_INDEXER_QUERY_ENDPOINT
); );
if let Err(e) = component if let Err(e) = component
.endpoint(WORKER_KV_INDEXER_QUERY_ENDPOINT) .endpoint(&endpoint_name)
.endpoint_builder() .endpoint_builder()
.handler(ingress) .handler(ingress)
.graceful_shutdown(true) .graceful_shutdown(true)
.start() .start()
.await .await
{ {
tracing::error!("WorkerKvQuery endpoint failed for worker {worker_id}: {e}"); tracing::error!(
"WorkerKvQuery endpoint failed for worker {worker_id} dp_rank {dp_rank}: {e}"
);
} }
} }
......
...@@ -91,7 +91,7 @@ impl KvManager { ...@@ -91,7 +91,7 @@ impl KvManager {
"Initializing KV event publisher for DP rank {dp_rank} with block_size {block_size}, enable_local_indexer={enable_local_indexer}" "Initializing KV event publisher for DP rank {dp_rank} with block_size {block_size}, enable_local_indexer={enable_local_indexer}"
); );
Arc::new( Arc::new(
KvEventPublisher::new_with_local_indexer(comp, block_size as u32, None, enable_local_indexer) KvEventPublisher::new_with_local_indexer(comp, block_size as u32, None, enable_local_indexer, dp_rank)
.expect("Failed to create KV event publisher"), .expect("Failed to create KV event publisher"),
) )
}); });
......
...@@ -1940,24 +1940,8 @@ def _test_router_decisions( ...@@ -1940,24 +1940,8 @@ def _test_router_decisions(
# Use async to manage the test flow # Use async to manage the test flow
async def test_sync(): async def test_sync():
# Calculate expected number of instances # Workers register one instance per process (not per dp_rank)
# With data parallelism: expected_num_instances = engine_workers.num_workers
# - vLLM/SGLang: each DP rank registers as a separate instance
# - Mockers: all DP ranks share the same worker instance ID (instance_ids returns worker IDs)
if test_dp_rank:
if (
hasattr(engine_workers, "data_parallel_size")
and engine_workers.data_parallel_size is not None
):
# vLLM/SGLang: each DP rank registers as a separate instance
expected_num_instances = (
engine_workers.num_workers * engine_workers.data_parallel_size
)
else:
# Mockers with dp_size or no DP: instance_ids() returns worker IDs
expected_num_instances = engine_workers.num_workers
else:
expected_num_instances = engine_workers.num_workers
# Wait for workers to be ready and get their instance IDs # Wait for workers to be ready and get their instance IDs
worker_ids = await wait_for_workers_ready( worker_ids = await wait_for_workers_ready(
......
...@@ -187,6 +187,8 @@ class MockerProcess: ...@@ -187,6 +187,8 @@ class MockerProcess:
mocker_args = mocker_args or {} mocker_args = mocker_args or {}
# Store dp_size for DP-aware test functions # Store dp_size for DP-aware test functions
self.dp_size = mocker_args.get("dp_size") self.dp_size = mocker_args.get("dp_size")
# Alias for consistency with vLLM/SGLang workers
self.data_parallel_size = self.dp_size
command = _build_mocker_command( command = _build_mocker_command(
endpoint=self.endpoint, endpoint=self.endpoint,
...@@ -586,15 +588,17 @@ def test_indexers_sync( ...@@ -586,15 +588,17 @@ def test_indexers_sync(
nats_process, _etcd_process = runtime_services_dynamic_ports nats_process, _etcd_process = runtime_services_dynamic_ports
# Create mocker args dictionary # Create mocker args dictionary
# Use 2 DP ranks to test per-dp_rank event ID tracking and recovery
mocker_args = { mocker_args = {
"speedup_ratio": SPEEDUP_RATIO, "speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE, "block_size": BLOCK_SIZE,
"enable_local_indexer": use_nats_core, "enable_local_indexer": use_nats_core,
"dp_size": 2,
} }
try: try:
# Start mocker instances # Start mocker instances (2 workers x 2 DP ranks = 4 independent event streams)
logger.info(f"Starting {NUM_MOCKERS} mocker instances") logger.info(f"Starting {NUM_MOCKERS} mocker instances with dp_size=2")
mockers = MockerProcess( mockers = MockerProcess(
request, request,
mocker_args=mocker_args, mocker_args=mocker_args,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment