Unverified Commit bb07b2f4 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(kv-router): ZMQ gap detection + replay for standalone indexer [LLM-126] (#7209)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 1e668608
......@@ -1887,6 +1887,7 @@ dependencies = [
"anyhow",
"async-trait",
"axum 0.8.4",
"bytes",
"clap 4.5.60",
"dashmap 6.1.0",
"derive-getters",
......
......@@ -385,6 +385,15 @@ def parse_args() -> argparse.Namespace:
"Each worker's DP ranks bind on base_port + dp_rank. A KvEventPublisher relay "
"subscribes and forwards events to NATS. (default: None, disabled)",
)
parser.add_argument(
"--zmq-replay-ports",
type=str,
default=None,
help="Comma-separated list of ZMQ ROUTER base ports for KV event replay. "
"One port per worker (must match --num-workers). "
"Each worker's DP ranks bind on base_port + dp_rank. "
"Used alongside --zmq-kv-events-ports for gap recovery. (default: None, disabled)",
)
parser.add_argument(
"--bootstrap-ports",
type=str,
......@@ -479,6 +488,17 @@ def parse_args() -> argparse.Namespace:
f"got {len(args.zmq_kv_events_ports_list)}: {args.zmq_kv_events_ports_list}"
)
# Parse and validate zmq_replay_ports
args.zmq_replay_ports_list = parse_bootstrap_ports(args.zmq_replay_ports)
if args.zmq_replay_ports_list:
if not args.zmq_kv_events_ports_list:
raise ValueError("--zmq-replay-ports requires --zmq-kv-events-ports")
if len(args.zmq_replay_ports_list) != args.num_workers:
raise ValueError(
f"--zmq-replay-ports must have exactly --num-workers ({args.num_workers}) ports, "
f"got {len(args.zmq_replay_ports_list)}: {args.zmq_replay_ports_list}"
)
# Set endpoint default based on worker type if not explicitly provided
if args.endpoint is None:
if args.is_prefill_worker:
......
......@@ -218,7 +218,9 @@ async def launch_workers(args: argparse.Namespace, extra_engine_args_path: Path)
base_engine_args = json.load(f)
needs_per_worker_args = bool(
args.bootstrap_ports_list or args.zmq_kv_events_ports_list
args.bootstrap_ports_list
or args.zmq_kv_events_ports_list
or args.zmq_replay_ports_list
)
for worker_id in range(args.num_workers):
......@@ -242,6 +244,8 @@ async def launch_workers(args: argparse.Namespace, extra_engine_args_path: Path)
worker_args["zmq_kv_events_port"] = args.zmq_kv_events_ports_list[
worker_id
]
if args.zmq_replay_ports_list:
worker_args["zmq_replay_port"] = args.zmq_replay_ports_list[worker_id]
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as tmp:
......
......@@ -135,6 +135,7 @@ curl -X POST http://localhost:8090/register \
| `block_size` | yes | — | KV cache block size (must match the engine) |
| `tenant_id` | no | `"default"` | Tenant identifier for isolation |
| `dp_rank` | no | `0` | Data parallel rank |
| `replay_endpoint` | no | — | ZMQ ROUTER address for gap replay (e.g. `tcp://host:5560`) |
### `POST /unregister` — Deregister an instance
......@@ -270,6 +271,24 @@ Returns:
["http://peer:8091"]
```
## DP Rank Handling
When a worker registers with the standalone KV indexer (`/register`), it provides an `instance_id`, a ZMQ `endpoint`, and an optional `dp_rank` (defaults to 0). The service spawns one ZMQ listener per registration.
Each incoming `KvEventBatch` may carry an optional `data_parallel_rank` field. If present, it **overrides** the statically-registered `dp_rank` for that batch. This allows a single ZMQ port to multiplex events from multiple DP ranks.
**Caveat**: the registry only tracks dp_ranks from explicit `/register` calls. If an engine dynamically emits batches with a dp_rank that was never registered, the indexer will store those blocks correctly (under the dynamic `WorkerWithDpRank` key), but per-dp_rank deregistration (`/unregister` with `dp_rank`) will not find them. Full-instance deregistration (`/unregister` without `dp_rank`) still cleans up all dp_ranks for a given `worker_id` in the tree via `remove_worker`.
## Gap Detection and Replay
ZMQ PUB/SUB is lossy — messages can be dropped under backpressure or brief disconnects. The indexer detects gaps by tracking the sequence number of each batch: if `seq > last_seq + 1`, a gap is detected.
When a `replay_endpoint` is provided during `/register`, the indexer connects a DEALER socket to the engine's ROUTER socket and requests the missing batches by sequence number. The engine streams back buffered `(seq, payload)` pairs from its ring buffer until an empty-payload sentinel.
If no `replay_endpoint` is configured, gaps are logged as warnings but not recovered.
The sequence counter (`last_seq`) persists across unregister/register cycles, so re-registering a worker after a gap will trigger replay on the first batch received by the new listener.
## Limitations
- **ZMQ only**: Workers must publish KV events via ZMQ PUB sockets. The standalone indexer does not subscribe to NATS event streams.
......
......@@ -14,7 +14,8 @@ repository.workspace = true
default = []
metrics = []
bench = ["dep:clap", "dep:indicatif", "dep:serde_json", "dep:plotters"]
indexer-bin = ["metrics", "dep:axum", "dep:clap", "dep:zeromq", "dep:tracing-subscriber", "dep:serde_json", "dep:reqwest"]
indexer-bin = ["metrics", "dep:axum", "dep:bytes", "dep:clap", "dep:zeromq", "dep:tracing-subscriber", "dep:serde_json", "dep:reqwest"]
test-endpoints = ["indexer-bin"]
[dependencies]
# repo
......@@ -52,6 +53,7 @@ rustc-hash = "2.1.1"
# indexer-bin (optional)
axum = { workspace = true, optional = true }
bytes = { workspace = true, optional = true }
reqwest = { workspace = true, optional = true }
zeromq = { version = "0.4.1", optional = true }
tracing-subscriber = { workspace = true, optional = true }
......
......@@ -2,13 +2,14 @@
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::sync::atomic::AtomicU32;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::Duration;
use bytes::Bytes;
use rmp_serde as rmps;
use tokio::sync::watch;
use tokio_util::sync::CancellationToken;
use zeromq::{Socket, SocketRecv, SubSocket};
use zeromq::{Socket, SocketRecv, SocketSend, SubSocket};
use dynamo_kv_router::protocols::{RouterEvent, WorkerId};
use dynamo_kv_router::zmq_wire::{KvEventBatch, convert_event};
......@@ -27,27 +28,113 @@ fn calculate_backoff_ms(consecutive_errors: u32) -> u64 {
)
}
// TODO: Gap detection for missed ZMQ messages
//
// ZMQ PUB/SUB is lossy — if the subscriber is slow or disconnects briefly,
// messages can be dropped. The `zeromq` 0.4 crate uses bounded internal
// channels between the PUB and SUB sockets (via `try_send` with a noop
// waker), so messages are silently dropped when the write buffer is full
// (per ZMQ spec RFC 29).
//
// For P2P recovery, the ready signal delays `recv()` only briefly (the
// duration of the HTTP dump fetch), which is well within the crate's
// internal channel capacity. For longer delays or high-throughput scenarios,
// messages could be lost.
//
// Easy win: hook up the vLLM replay endpoint — workers already expose
// `LocalKvIndexer` with event buffering and range queries (see
// `lib/llm/src/kv_router/worker_query.rs`), just need to query it from
// the standalone indexer on gap detection.
//
// Alternative future approach: switch to an explicit `mpsc` channel as the
// buffer (unbounded, no drops) instead of relying on ZMQ's internal buffer.
/// Sentinel value for `watermark`: indicates no batch has been processed yet.
const WATERMARK_UNSET: u64 = u64::MAX;
/// Replay missed batches from the engine's ROUTER socket.
///
/// Uses a DEALER socket (no send/recv lockstep) to send one request and
/// receive multiple response frames. Each response is `[empty, seq, payload]`;
/// an empty payload signals end of replay.
#[expect(clippy::too_many_arguments)]
async fn replay_gap(
replay_socket: &mut zeromq::DealerSocket,
start_seq: u64,
end_seq: u64,
worker_id: WorkerId,
dp_rank: u32,
block_size: u32,
indexer: &Indexer,
warning_count: &Arc<AtomicU32>,
watermark: &Arc<AtomicU64>,
) -> u64 {
tracing::info!(
worker_id,
dp_rank,
start_seq,
end_seq,
"Requesting replay from engine"
);
// DEALER must manually prepend the empty delimiter that REQ adds automatically.
let req_frames = vec![Bytes::new(), Bytes::from(start_seq.to_be_bytes().to_vec())];
let Ok(req_msg) = zeromq::ZmqMessage::try_from(req_frames) else {
tracing::error!(worker_id, dp_rank, "Failed to build replay request");
return 0;
};
if let Err(e) = replay_socket.send(req_msg).await {
tracing::error!(worker_id, dp_rank, error = %e, "Failed to send replay request");
return 0;
}
let mut replayed = 0u64;
loop {
let Ok(msg) = replay_socket.recv().await else {
tracing::error!(worker_id, dp_rank, "Replay recv error");
break;
};
// ROUTER sends [identity, empty, seq, payload]; DEALER strips identity,
// so we receive [empty, seq, payload].
if msg.len() < 3 {
tracing::warn!(
worker_id,
dp_rank,
"Unexpected replay frame count: {}",
msg.len()
);
break;
}
let payload = msg.get(2).unwrap();
if payload.is_empty() {
break;
}
let seq_bytes = msg.get(1).unwrap();
if seq_bytes.len() != 8 {
tracing::warn!(
worker_id,
dp_rank,
"Invalid replay seq length: {}",
seq_bytes.len()
);
break;
}
let seq = u64::from_be_bytes(seq_bytes[..8].try_into().unwrap());
let Ok(batch) = rmps::from_slice::<KvEventBatch>(payload) else {
tracing::warn!(worker_id, dp_rank, seq, "Failed to decode replayed batch");
continue;
};
let effective_dp_rank = batch
.data_parallel_rank
.map_or(dp_rank, |r| r.cast_unsigned());
for raw_event in batch.events {
let kv_event =
convert_event(raw_event, seq, block_size, effective_dp_rank, warning_count);
let router_event = RouterEvent::new(worker_id, kv_event);
indexer.apply_event(router_event).await;
}
watermark.store(seq, Ordering::Release);
replayed += 1;
}
tracing::info!(worker_id, dp_rank, replayed, "Replay complete");
replayed
}
// TODO: assumes one dp_rank per ZMQ socket. Seq counter is per-socket so gap
// detection works regardless, but replay semantics may differ if a single
// socket multiplexes dp_ranks.
/// Connect the ZMQ SUB socket, then spawn a background task that waits for
/// the ready signal before entering the recv loop.
///
/// Returns once the SUB socket is connected (subscription handshake begins
/// immediately in the background). The ready gate and recv loop run in a
/// spawned task so `register()` is never blocked waiting for `signal_ready()`.
#[expect(clippy::too_many_arguments)]
pub async fn run_zmq_listener(
worker_id: WorkerId,
dp_rank: u32,
......@@ -55,7 +142,9 @@ pub async fn run_zmq_listener(
block_size: u32,
indexer: Indexer,
cancel: CancellationToken,
mut ready: watch::Receiver<bool>,
ready: watch::Receiver<bool>,
replay_endpoint: Option<String>,
watermark: Arc<AtomicU64>,
) {
tracing::info!(worker_id, dp_rank, zmq_address, "ZMQ listener starting");
......@@ -71,6 +160,35 @@ pub async fn run_zmq_listener(
return;
}
// Spawn the ready-wait + recv loop so the caller returns immediately.
// The ZMQ subscription handshake proceeds in the background while P2P
// recovery runs; once signal_ready() fires the recv loop starts draining
// any buffered messages.
tokio::spawn(zmq_wait_ready_then_recv(
worker_id,
dp_rank,
block_size,
indexer,
cancel,
ready,
socket,
replay_endpoint,
watermark,
));
}
#[expect(clippy::too_many_arguments)]
async fn zmq_wait_ready_then_recv(
worker_id: WorkerId,
dp_rank: u32,
block_size: u32,
indexer: Indexer,
cancel: CancellationToken,
mut ready: watch::Receiver<bool>,
socket: SubSocket,
replay_endpoint: Option<String>,
watermark: Arc<AtomicU64>,
) {
// Wait for the ready signal before entering the recv loop.
// During P2P recovery, this delay lets the recovery code fetch the dump
// from a peer while ZMQ subscription handshakes complete in the background.
......@@ -90,7 +208,48 @@ pub async fn run_zmq_listener(
tracing::info!(worker_id, dp_rank, "ZMQ listener ready, starting recv loop");
let mut next_event_id = 0u64;
// Connect DEALER socket once if replay_endpoint is configured.
// DEALER (not REQ) because we send one request and receive multiple responses.
let mut replay_socket = None;
if let Some(ref ep) = replay_endpoint {
let mut sock = zeromq::DealerSocket::new();
if let Err(e) = sock.connect(ep).await {
tracing::error!(worker_id, dp_rank, error = %e, "Failed to connect replay socket to {ep}");
} else {
tracing::info!(
worker_id,
dp_rank,
replay_endpoint = ep,
"Replay socket connected"
);
replay_socket = Some(sock);
}
}
zmq_recv_loop(
worker_id,
dp_rank,
block_size,
indexer,
cancel,
socket,
replay_socket,
watermark,
)
.await;
}
#[expect(clippy::too_many_arguments)]
async fn zmq_recv_loop(
worker_id: WorkerId,
dp_rank: u32,
block_size: u32,
indexer: Indexer,
cancel: CancellationToken,
mut socket: SubSocket,
mut replay_socket: Option<zeromq::DealerSocket>,
watermark: Arc<AtomicU64>,
) {
let warning_count = Arc::new(AtomicU32::new(0));
let mut consecutive_errors = 0u32;
#[expect(unused_assignments)]
......@@ -147,6 +306,41 @@ pub async fn run_zmq_listener(
continue;
}
let seq = u64::from_be_bytes(seq_bytes[..8].try_into().unwrap());
// Gap detection
let prev = watermark.load(Ordering::Acquire);
if prev != WATERMARK_UNSET && seq > prev + 1 {
let gap_start = prev + 1;
tracing::warn!(
worker_id, dp_rank,
expected = gap_start, got = seq,
"Gap detected: expected seq {gap_start}, got {seq}"
);
match replay_socket.as_mut() {
Some(sock) => {
replay_gap(
sock, gap_start, seq, worker_id, dp_rank,
block_size, &indexer, &warning_count, &watermark,
).await;
}
None => tracing::warn!(
worker_id, dp_rank,
gap_size = seq - gap_start,
"No replay endpoint configured, {gap_size} batches lost",
gap_size = seq - gap_start,
),
}
}
// After replay, watermark may have advanced past the current
// batch — skip to avoid double-apply. Exclude the sentinel
// (WATERMARK_UNSET) so the very first message is not skipped.
let current_wm = watermark.load(Ordering::Acquire);
if current_wm != WATERMARK_UNSET && current_wm >= seq {
continue;
}
let payload = msg.get(2).unwrap();
let batch_result = rmps::from_slice::<KvEventBatch>(payload);
let Ok(batch) = batch_result else {
......@@ -155,14 +349,15 @@ pub async fn run_zmq_listener(
};
let effective_dp_rank = batch.data_parallel_rank.map_or(dp_rank, |r| r.cast_unsigned());
// Use the engine's ZMQ sequence number as event_id so downstream
// consumers can detect gaps and request replay.
for raw_event in batch.events {
let event_id = next_event_id;
next_event_id += 1;
let kv_event = convert_event(raw_event, event_id, block_size, effective_dp_rank, &warning_count);
let kv_event = convert_event(raw_event, seq, block_size, effective_dp_rank, &warning_count);
let router_event = RouterEvent::new(worker_id, kv_event);
indexer.apply_event(router_event).await;
messages_processed += 1;
}
watermark.store(seq, Ordering::Release);
}
}
}
......
......@@ -97,23 +97,26 @@ async fn main() -> anyhow::Result<()> {
let registry = WorkerRegistry::new(cli.threads);
// Register initial workers — connects ZMQ sockets but listeners wait
// for the ready signal. This ensures ZMQ subscription handshakes begin
// before P2P recovery fetches the dump from a peer.
// Register initial workers — connects ZMQ SUB sockets (subscription
// handshakes begin immediately) and spawns listener tasks that wait for
// the ready signal. register() returns as soon as the socket is connected.
if let Some(ref workers_str) = cli.workers {
let block_size = cli.block_size.ok_or_else(|| {
anyhow::anyhow!("--block-size is required when --workers is specified")
})?;
for (instance_id, dp_rank, endpoint) in parse_workers(workers_str) {
tracing::info!(instance_id, dp_rank, endpoint, "Registering initial worker");
registry.register(
instance_id,
endpoint,
dp_rank,
cli.model_name.clone(),
cli.tenant_id.clone(),
block_size,
)?;
registry
.register(
instance_id,
endpoint,
dp_rank,
cli.model_name.clone(),
cli.tenant_id.clone(),
block_size,
None,
)
.await?;
}
}
......
......@@ -2,6 +2,8 @@
// SPDX-License-Identifier: Apache-2.0
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use anyhow::{Result, bail};
use dashmap::DashMap;
......@@ -27,13 +29,27 @@ pub struct IndexerEntry {
pub struct WorkerEntry {
pub endpoints: HashMap<u32, String>,
pub replay_endpoints: HashMap<u32, String>,
cancels: HashMap<u32, CancellationToken>,
}
/// State needed to restart a paused ZMQ listener.
struct ListenerState {
endpoint: String,
replay_endpoint: Option<String>,
block_size: u32,
indexer: Indexer,
watermark: Arc<AtomicU64>,
}
pub struct WorkerRegistry {
workers: DashMap<WorkerId, WorkerEntry>,
indexers: DashMap<IndexerKey, IndexerEntry>,
peers: DashMap<String, ()>,
/// Persists across unregister/register cycles so gap detection works after re-registration.
watermarks: DashMap<(WorkerId, u32), Arc<AtomicU64>>,
/// Saved listener state for pause/resume. Populated on register, kept on pause.
listener_states: DashMap<(WorkerId, u32), ListenerState>,
num_threads: usize,
ready_tx: watch::Sender<bool>,
ready_rx: watch::Receiver<bool>,
......@@ -46,6 +62,8 @@ impl WorkerRegistry {
workers: DashMap::new(),
indexers: DashMap::new(),
peers: DashMap::new(),
watermarks: DashMap::new(),
listener_states: DashMap::new(),
num_threads,
ready_tx,
ready_rx,
......@@ -72,7 +90,8 @@ impl WorkerRegistry {
self.peers.iter().map(|entry| entry.key().clone()).collect()
}
pub fn register(
#[expect(clippy::too_many_arguments)]
pub async fn register(
&self,
instance_id: WorkerId,
endpoint: String,
......@@ -80,6 +99,7 @@ impl WorkerRegistry {
model_name: String,
tenant_id: String,
block_size: u32,
replay_endpoint: Option<String>,
) -> Result<()> {
let key = IndexerKey {
model_name,
......@@ -115,27 +135,68 @@ impl WorkerRegistry {
let bs = indexer_entry.block_size;
drop(indexer_entry);
let mut entry = self
.workers
.entry(instance_id)
.or_insert_with(|| WorkerEntry {
endpoints: HashMap::new(),
cancels: HashMap::new(),
});
if entry.endpoints.contains_key(&dp_rank) {
bail!("instance {instance_id} dp_rank {dp_rank} already registered");
// Check for duplicate and insert replay endpoint while holding the lock briefly.
{
let mut entry = self
.workers
.entry(instance_id)
.or_insert_with(|| WorkerEntry {
endpoints: HashMap::new(),
replay_endpoints: HashMap::new(),
cancels: HashMap::new(),
});
if entry.endpoints.contains_key(&dp_rank) {
bail!("instance {instance_id} dp_rank {dp_rank} already registered");
}
if let Some(rep) = &replay_endpoint {
entry.replay_endpoints.insert(dp_rank, rep.clone());
}
}
// Reuse watermark if it survived a previous unregister (preserves gap detection).
let watermark = self
.watermarks
.entry((instance_id, dp_rank))
.or_insert_with(|| Arc::new(AtomicU64::new(u64::MAX)))
.clone();
self.listener_states.insert(
(instance_id, dp_rank),
ListenerState {
endpoint: endpoint.clone(),
replay_endpoint: replay_endpoint.clone(),
block_size: bs,
indexer: indexer.clone(),
watermark: watermark.clone(),
},
);
let cancel = CancellationToken::new();
let child_cancel = cancel.child_token();
let addr = endpoint.clone();
let ready = self.ready_rx();
tokio::spawn(async move {
run_zmq_listener(instance_id, dp_rank, addr, bs, indexer, child_cancel, ready).await;
});
// Connect the ZMQ socket and spawn the listener task (non-blocking).
run_zmq_listener(
instance_id,
dp_rank,
addr,
bs,
indexer,
child_cancel,
ready,
replay_endpoint,
watermark,
)
.await;
// Re-acquire to store the endpoint and cancel token.
let mut entry = self
.workers
.get_mut(&instance_id)
.expect("worker entry disappeared during listener setup");
entry.endpoints.insert(dp_rank, endpoint);
entry.cancels.insert(dp_rank, cancel);
Ok(())
......@@ -251,6 +312,71 @@ impl WorkerRegistry {
Ok(())
}
#[expect(dead_code)]
pub fn pause_listener(&self, instance_id: WorkerId, dp_rank: u32) -> Result<()> {
let mut entry = self
.workers
.get_mut(&instance_id)
.ok_or_else(|| anyhow::anyhow!("instance {instance_id} not found"))?;
let cancel = entry.cancels.remove(&dp_rank).ok_or_else(|| {
anyhow::anyhow!("instance {instance_id} dp_rank {dp_rank} not active")
})?;
cancel.cancel();
tracing::info!(instance_id, dp_rank, "Paused ZMQ listener");
Ok(())
}
#[expect(dead_code)]
pub async fn resume_listener(&self, instance_id: WorkerId, dp_rank: u32) -> Result<()> {
{
let entry = self
.workers
.get(&instance_id)
.ok_or_else(|| anyhow::anyhow!("instance {instance_id} not found"))?;
if entry.cancels.contains_key(&dp_rank) {
bail!("instance {instance_id} dp_rank {dp_rank} already running");
}
}
let state = self
.listener_states
.get(&(instance_id, dp_rank))
.ok_or_else(|| anyhow::anyhow!("no saved state for {instance_id} dp_rank {dp_rank}"))?;
let cancel = CancellationToken::new();
let child_cancel = cancel.child_token();
let ready = self.ready_rx();
let addr = state.endpoint.clone();
let bs = state.block_size;
let indexer = state.indexer.clone();
let replay_ep = state.replay_endpoint.clone();
let watermark = state.watermark.clone();
drop(state);
run_zmq_listener(
instance_id,
dp_rank,
addr,
bs,
indexer,
child_cancel,
ready,
replay_ep,
watermark,
)
.await;
let mut entry = self
.workers
.get_mut(&instance_id)
.expect("worker entry disappeared during listener resume");
entry.cancels.insert(dp_rank, cancel);
Ok(())
}
pub fn list(&self) -> Vec<(WorkerId, HashMap<u32, String>)> {
self.workers
.iter()
......
......@@ -33,6 +33,8 @@ pub struct RegisterRequest {
pub block_size: u32,
#[serde(default)]
pub dp_rank: Option<u32>,
#[serde(default)]
pub replay_endpoint: Option<String>,
}
#[derive(Deserialize)]
......@@ -85,14 +87,19 @@ async fn register(
State(state): State<Arc<AppState>>,
Json(req): Json<RegisterRequest>,
) -> impl IntoResponse {
match state.registry.register(
req.instance_id,
req.endpoint,
req.dp_rank.unwrap_or(0),
req.model_name,
req.tenant_id,
req.block_size,
) {
match state
.registry
.register(
req.instance_id,
req.endpoint,
req.dp_rank.unwrap_or(0),
req.model_name,
req.tenant_id,
req.block_size,
req.replay_endpoint,
)
.await
{
Ok(()) => (
StatusCode::CREATED,
Json(serde_json::json!({"status": "ok"})),
......@@ -248,6 +255,49 @@ async fn query_by_hash(
}
}
#[cfg(feature = "test-endpoints")]
#[derive(Deserialize)]
struct ListenerControlRequest {
instance_id: WorkerId,
#[serde(default)]
dp_rank: Option<u32>,
}
#[cfg(feature = "test-endpoints")]
async fn test_pause_listener(
State(state): State<Arc<AppState>>,
Json(req): Json<ListenerControlRequest>,
) -> impl IntoResponse {
match state
.registry
.pause_listener(req.instance_id, req.dp_rank.unwrap_or(0))
{
Ok(()) => (StatusCode::OK, Json(serde_json::json!({"status": "ok"}))),
Err(e) => (
StatusCode::NOT_FOUND,
Json(serde_json::json!({"error": e.to_string()})),
),
}
}
#[cfg(feature = "test-endpoints")]
async fn test_resume_listener(
State(state): State<Arc<AppState>>,
Json(req): Json<ListenerControlRequest>,
) -> impl IntoResponse {
match state
.registry
.resume_listener(req.instance_id, req.dp_rank.unwrap_or(0))
.await
{
Ok(()) => (StatusCode::OK, Json(serde_json::json!({"status": "ok"}))),
Err(e) => (
StatusCode::CONFLICT,
Json(serde_json::json!({"error": e.to_string()})),
),
}
}
#[derive(Deserialize)]
struct PeerRequest {
url: String,
......@@ -319,7 +369,7 @@ async fn dump_events(State(state): State<Arc<AppState>>) -> impl IntoResponse {
}
pub fn create_router(state: Arc<AppState>) -> Router {
Router::new()
let router = Router::new()
.route("/register", post(register))
.route("/unregister", post(unregister))
.route("/workers", get(list_workers))
......@@ -328,6 +378,12 @@ pub fn create_router(state: Arc<AppState>) -> Router {
.route("/dump", get(dump_events))
.route("/register_peer", post(register_peer))
.route("/deregister_peer", post(deregister_peer))
.route("/peers", get(list_peers))
.with_state(state)
.route("/peers", get(list_peers));
#[cfg(feature = "test-endpoints")]
let router = router
.route("/test/pause_listener", post(test_pause_listener))
.route("/test/resume_listener", post(test_resume_listener));
router.with_state(state)
}
......@@ -6,6 +6,7 @@
//! The core mocker logic lives in the `dynamo-mocker` crate.
//! This module provides the runtime-dependent engine wrapper.
use std::collections::VecDeque;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
......@@ -38,7 +39,7 @@ use tokio::sync::{Notify, OnceCell, mpsc};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use zeromq::{Socket, SocketSend};
use zeromq::{Socket, SocketRecv, SocketSend};
pub const MOCKER_COMPONENT: &str = "mocker";
......@@ -84,8 +85,16 @@ struct ZmqKvEventSink {
tx: mpsc::UnboundedSender<ZmqKvEventMsg>,
}
/// Maximum number of entries in the replay ring buffer.
const REPLAY_BUFFER_CAPACITY: usize = 10_000;
impl ZmqKvEventSink {
async fn new(port: u16, dp_rank: u32, block_size: u32) -> Result<Self> {
async fn new(
port: u16,
replay_port: Option<u16>,
dp_rank: u32,
block_size: u32,
) -> Result<Self> {
let (tx, mut rx) = mpsc::unbounded_channel::<ZmqKvEventMsg>();
// Bind the PUB socket before returning so that any SUB connect()
......@@ -98,44 +107,139 @@ impl ZmqKvEventSink {
.map_err(|e| anyhow::anyhow!("ZMQ PUB bind to {endpoint} failed: {e}"))?;
tracing::info!("ZmqKvEventSink bound to {endpoint} for dp_rank {dp_rank}");
// Optionally bind ROUTER socket for replay
let mut router_socket = if let Some(rp) = replay_port {
let mut sock = zeromq::RouterSocket::new();
let replay_ep = format!("tcp://0.0.0.0:{rp}");
sock.bind(&replay_ep)
.await
.map_err(|e| anyhow::anyhow!("ZMQ ROUTER bind to {replay_ep} failed: {e}"))?;
tracing::info!(
"ZmqKvEventSink replay ROUTER bound to {replay_ep} for dp_rank {dp_rank}"
);
Some(sock)
} else {
None
};
tokio::spawn(async move {
let mut seq_num: u64 = 0;
// Store Bytes (ref-counted) to avoid memcpy on both PUB and buffer paths.
let mut ring_buffer: VecDeque<(u64, Bytes)> = VecDeque::new();
while let Some(msg) = rx.recv().await {
let events =
convert_to_zmq_events(&msg.event, msg.block_token_ids.as_deref(), block_size);
if events.is_empty() {
continue;
}
loop {
tokio::select! {
biased;
// Replay requests are rare but latency-sensitive — poll first
// to prevent starvation under sustained KV event load.
replay_result = async {
match router_socket.as_mut() {
Some(sock) => sock.recv().await,
None => std::future::pending().await,
}
} => {
let Ok(req_msg) = replay_result else {
tracing::warn!("Replay ROUTER recv error");
continue;
};
if req_msg.len() < 3 {
tracing::warn!("Unexpected replay request frame count: {}", req_msg.len());
continue;
}
let identity: Bytes = Bytes::copy_from_slice(req_msg.get(0).unwrap());
let start_seq_bytes = req_msg.get(2).unwrap();
if start_seq_bytes.len() != 8 {
tracing::warn!("Invalid replay start_seq length: {}", start_seq_bytes.len());
continue;
}
let start_seq = u64::from_be_bytes(start_seq_bytes[..8].try_into().unwrap());
tracing::debug!(dp_rank, start_seq, buffer_len = ring_buffer.len(), "Replay request received");
// Compute start index directly — sequences are monotonic.
let start_idx = ring_buffer.front()
.map(|(first_seq, _)| start_seq.saturating_sub(*first_seq) as usize)
.unwrap_or(0)
.min(ring_buffer.len());
let sock = router_socket.as_mut().unwrap();
for (seq, payload) in ring_buffer.iter().skip(start_idx) {
let frames = vec![
identity.clone(),
Bytes::new(),
Bytes::from(seq.to_be_bytes().to_vec()),
payload.clone(), // ref-count bump
];
let reply = zeromq::ZmqMessage::try_from(frames)
.expect("replay frame");
if let Err(e) = sock.send(reply).await {
tracing::warn!("Replay send error: {e}");
break;
}
}
// Sentinel: empty payload signals end of replay
let sentinel_frames = vec![
identity,
Bytes::new(),
Bytes::from((-1i64).to_be_bytes().to_vec()),
Bytes::new(),
];
let sentinel = zeromq::ZmqMessage::try_from(sentinel_frames)
.expect("sentinel frame");
let _ = sock.send(sentinel).await;
}
msg_opt = rx.recv() => {
let Some(msg) = msg_opt else { break };
let events = convert_to_zmq_events(
&msg.event,
msg.block_token_ids.as_deref(),
block_size,
);
if events.is_empty() {
continue;
}
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64();
let batch: (f64, Vec<ZmqRawKvEvent>, Option<i32>) =
(timestamp, events, Some(dp_rank as i32));
let payload: Bytes = match rmp_serde::to_vec(&batch) {
Ok(p) => p.into(),
Err(e) => {
tracing::warn!("Failed to serialize ZMQ KV event: {e}");
continue;
}
};
let frames = vec![
Bytes::from(""),
Bytes::from(seq_num.to_be_bytes().to_vec()),
payload.clone(), // ref-count bump, not memcpy
];
let zmq_msg = zeromq::ZmqMessage::try_from(frames)
.expect("Failed to create ZMQ multipart message");
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64();
let batch: (f64, Vec<ZmqRawKvEvent>, Option<i32>) =
(timestamp, events, Some(dp_rank as i32));
let payload = match rmp_serde::to_vec(&batch) {
Ok(p) => p,
Err(e) => {
tracing::warn!("Failed to serialize ZMQ KV event: {e}");
continue;
if let Err(e) = pub_socket.send(zmq_msg).await {
tracing::warn!("Failed to send ZMQ KV event: {e}");
}
if router_socket.is_some() {
if ring_buffer.len() >= REPLAY_BUFFER_CAPACITY {
ring_buffer.pop_front();
}
ring_buffer.push_back((seq_num, payload));
}
seq_num += 1;
}
};
let frames = vec![
Bytes::from(""),
Bytes::from(seq_num.to_be_bytes().to_vec()),
Bytes::from(payload),
];
let zmq_msg = zeromq::ZmqMessage::try_from(frames)
.expect("Failed to create ZMQ multipart message");
if let Err(e) = pub_socket.send(zmq_msg).await {
tracing::warn!("Failed to send ZMQ KV event: {e}");
}
seq_num += 1;
}
});
......@@ -312,7 +416,15 @@ impl MockVllmEngine {
) = match component {
Some(comp) if args.zmq_kv_events_port.is_some() => {
let zmq_port = args.zmq_kv_events_port.unwrap() + dp_rank as u16;
match ZmqKvEventSink::new(zmq_port, dp_rank, args.block_size as u32).await {
let replay_port = args.zmq_replay_port.map(|p| p + dp_rank as u16);
match ZmqKvEventSink::new(
zmq_port,
replay_port,
dp_rank,
args.block_size as u32,
)
.await
{
Ok(sink) => {
let source_config = Some(KvEventSourceConfig::Zmq {
endpoint: format!("tcp://127.0.0.1:{zmq_port}"),
......
......@@ -212,6 +212,13 @@ pub struct MockEngineArgs {
#[builder(default = "None")]
pub zmq_kv_events_port: Option<u16>,
/// ZMQ ROUTER port for replay of buffered KV event batches.
/// When set alongside `zmq_kv_events_port`, the mocker binds a ROUTER socket
/// that streams back buffered batches by sequence number on request.
/// Port is offset by dp_rank (replay_port + dp_rank).
#[builder(default = "None")]
pub zmq_replay_port: Option<u16>,
/// Preemption mode for decode eviction under memory pressure.
/// Lifo (default) evicts the newest request; Fifo evicts the oldest.
#[builder(default)]
......@@ -271,6 +278,7 @@ impl MockEngineArgs {
"kv_transfer_bandwidth",
"reasoning",
"zmq_kv_events_port",
"zmq_replay_port",
"preemption_mode",
]
.iter()
......@@ -383,6 +391,12 @@ impl MockEngineArgs {
builder = builder.zmq_kv_events_port(Some(port as u16));
}
if let Some(value) = extra_args.get("zmq_replay_port")
&& let Some(port) = value.as_u64()
{
builder = builder.zmq_replay_port(Some(port as u16));
}
if let Some(value) = extra_args.get("preemption_mode")
&& let Some(mode_str) = value.as_str()
{
......
......@@ -701,6 +701,61 @@ def _test_router_overload_503(
logger.info("Successfully verified 503 response when all workers are busy")
async def _zmq_replay_cycle(
phase: int,
router,
router_name: str,
endpoint,
indexer_url: str,
engine_workers,
send_requests_to_router,
):
"""Pause indexer listeners → send gap requests → resume → send to trigger replay."""
await asyncio.sleep(1)
worker_ids = list(engine_workers.worker_id_to_zmq_ports.keys())
dp_size = getattr(engine_workers, "dp_size", None) or 1
logger.info(f"=== ZMQ REPLAY TEST: Phase {phase} ({router_name}) ===")
async with aiohttp.ClientSession() as session:
for wid in worker_ids:
for dp_rank in range(dp_size):
async with session.post(
f"{indexer_url}/test/pause_listener",
json={"instance_id": wid, "dp_rank": dp_rank},
) as resp:
assert (
resp.status == 200
), f"Pause {wid}:{dp_rank} failed: {await resp.text()}"
logger.info("Sending 10 requests while indexer listeners are paused")
successful_gap = await send_requests_to_router(
router, 10, f"{router_name} (indexer paused)", endpoint
)
assert (
successful_gap == 10
), f"Expected 10 requests while paused, got {successful_gap}"
async with aiohttp.ClientSession() as session:
for wid in worker_ids:
for dp_rank in range(dp_size):
async with session.post(
f"{indexer_url}/test/resume_listener",
json={"instance_id": wid, "dp_rank": dp_rank},
) as resp:
assert (
resp.status == 200
), f"Resume {wid}:{dp_rank} failed: {await resp.text()}"
logger.info("Sending 5 requests after resume (triggers gap detection + replay)")
successful_post = await send_requests_to_router(
router, 5, f"{router_name} (post-resume)", endpoint
)
assert (
successful_post == 5
), f"Expected 5 requests post-resume, got {successful_post}"
await asyncio.sleep(2)
def _test_router_indexers_sync(
engine_workers,
block_size: int,
......@@ -714,6 +769,7 @@ def _test_router_indexers_sync(
router_event_threads: int = 4,
standalone_indexer_url: Optional[str] = None,
standalone_indexer_b_url: Optional[str] = None,
test_zmq_replay: bool = False,
):
"""Test that two KV routers have synchronized indexer states after processing requests.
......@@ -854,6 +910,17 @@ def _test_router_indexers_sync(
await asyncio.sleep(5)
if test_zmq_replay and standalone_indexer_url:
await _zmq_replay_cycle(
1,
kv_router1,
"Router 1",
endpoint1,
standalone_indexer_url,
engine_workers,
send_requests_to_router,
)
# Wait for snapshot to be available before creating second router.
# In JetStream mode, the background task may purge acknowledged messages
# from the stream before the snapshot upload completes. Poll the object
......@@ -945,6 +1012,17 @@ def _test_router_indexers_sync(
successful_recovery == 5
), f"Expected 5 successful requests post-recovery, got {successful_recovery}"
if test_zmq_replay and standalone_indexer_url:
await _zmq_replay_cycle(
2,
kv_router2,
"Router 2",
endpoint2,
standalone_indexer_url,
engine_workers,
send_requests_to_router,
)
# Wait for internal synchronization and ZMQ event propagation
logger.info("Waiting for final synchronization")
await asyncio.sleep(2)
......
......@@ -168,6 +168,8 @@ def _build_mocker_command(
command.extend(["--bootstrap-ports", mocker_args["bootstrap_ports"]])
if "zmq_kv_events_ports" in mocker_args:
command.extend(["--zmq-kv-events-ports", mocker_args["zmq_kv_events_ports"]])
if "zmq_replay_ports" in mocker_args:
command.extend(["--zmq-replay-ports", mocker_args["zmq_replay_ports"]])
return command
......@@ -190,6 +192,7 @@ class MockerProcess:
zmq_kv_events: bool = False,
standalone_indexer: bool = False,
model_name: str = "mocker",
zmq_replay: bool = False,
):
namespace_suffix = generate_random_suffix()
self.namespace = f"test-namespace-{namespace_suffix}"
......@@ -198,6 +201,7 @@ class MockerProcess:
self.endpoint = f"dyn://{self.namespace}.{self.component_name}.generate"
self.num_workers = num_mockers
self._zmq_kv_events_ports: list[int] = []
self._zmq_replay_ports: list[int] = []
self._standalone_indexer = standalone_indexer
self._standalone_indexer_port: Optional[int] = None
self._standalone_indexer_b_port: Optional[int] = None
......@@ -233,6 +237,22 @@ class MockerProcess:
f"(bases: {bases}) for {num_mockers} workers"
)
# Allocate ZMQ replay ports (same layout as event ports)
if zmq_replay and zmq_kv_events:
dp_size = mocker_args.get("dp_size", 1)
self._zmq_replay_ports = allocate_ports(
num_mockers * dp_size, BASE_PORT_ZMQ + 1000
)
replay_bases = [
self._zmq_replay_ports[i * dp_size] for i in range(num_mockers)
]
if not standalone_indexer:
mocker_args["zmq_replay_ports"] = ",".join(str(p) for p in replay_bases)
logger.info(
f"Allocated ZMQ replay ports {self._zmq_replay_ports} "
f"(bases: {replay_bases}) for {num_mockers} workers"
)
if standalone_indexer:
# Allocate ports for standalone indexer A and B (P2P recovery peer)
indexer_ports = allocate_ports(2, BASE_PORT)
......@@ -289,7 +309,7 @@ class MockerProcess:
"-p",
"dynamo-kv-router",
"--features",
"indexer-bin",
"indexer-bin,test-endpoints",
"--bin",
"dynamo-kv-indexer",
"--",
......@@ -338,6 +358,9 @@ class MockerProcess:
mocker_args = self._mocker_args_orig.copy()
base_port = self._zmq_kv_events_ports[i * dp_size]
mocker_args["zmq_kv_events_ports"] = str(base_port)
if self._zmq_replay_ports:
replay_base = self._zmq_replay_ports[i * dp_size]
mocker_args["zmq_replay_ports"] = str(replay_base)
command = _build_mocker_command(
endpoint=self.endpoint,
......@@ -398,6 +421,9 @@ class MockerProcess:
"block_size", BLOCK_SIZE
),
}
if self._zmq_replay_ports:
replay_port = self._zmq_replay_ports[i * dp_size + dp_rank]
payload["replay_endpoint"] = f"tcp://127.0.0.1:{replay_port}"
async with session.post(register_url, json=payload) as resp:
if resp.status != 201:
body = await resp.text()
......@@ -444,7 +470,7 @@ class MockerProcess:
"-p",
"dynamo-kv-router",
"--features",
"indexer-bin",
"indexer-bin,test-endpoints",
"--bin",
"dynamo-kv-indexer",
"--",
......@@ -505,6 +531,10 @@ class MockerProcess:
deallocate_ports(self._zmq_kv_events_ports)
logger.info(f"Deallocated ZMQ KV event ports {self._zmq_kv_events_ports}")
self._zmq_kv_events_ports = []
if self._zmq_replay_ports:
deallocate_ports(self._zmq_replay_ports)
logger.info(f"Deallocated ZMQ replay ports {self._zmq_replay_ports}")
self._zmq_replay_ports = []
class DisaggMockerProcess:
......@@ -862,6 +892,7 @@ def test_indexers_sync(
store_backend=store_backend,
request_plane=request_plane,
zmq_kv_events=True,
zmq_replay=True,
standalone_indexer=True,
model_name=MODEL_NAME,
) as mockers:
......@@ -884,6 +915,7 @@ def test_indexers_sync(
durable_kv_events=durable_kv_events,
standalone_indexer_url=mockers.standalone_indexer_url,
standalone_indexer_b_url=mockers.standalone_indexer_b_url,
test_zmq_replay=True,
)
logger.info("Indexers sync test completed successfully")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment