Unverified Commit bb07b2f4 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(kv-router): ZMQ gap detection + replay for standalone indexer [LLM-126] (#7209)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 1e668608
...@@ -1887,6 +1887,7 @@ dependencies = [ ...@@ -1887,6 +1887,7 @@ dependencies = [
"anyhow", "anyhow",
"async-trait", "async-trait",
"axum 0.8.4", "axum 0.8.4",
"bytes",
"clap 4.5.60", "clap 4.5.60",
"dashmap 6.1.0", "dashmap 6.1.0",
"derive-getters", "derive-getters",
......
...@@ -385,6 +385,15 @@ def parse_args() -> argparse.Namespace: ...@@ -385,6 +385,15 @@ def parse_args() -> argparse.Namespace:
"Each worker's DP ranks bind on base_port + dp_rank. A KvEventPublisher relay " "Each worker's DP ranks bind on base_port + dp_rank. A KvEventPublisher relay "
"subscribes and forwards events to NATS. (default: None, disabled)", "subscribes and forwards events to NATS. (default: None, disabled)",
) )
parser.add_argument(
"--zmq-replay-ports",
type=str,
default=None,
help="Comma-separated list of ZMQ ROUTER base ports for KV event replay. "
"One port per worker (must match --num-workers). "
"Each worker's DP ranks bind on base_port + dp_rank. "
"Used alongside --zmq-kv-events-ports for gap recovery. (default: None, disabled)",
)
parser.add_argument( parser.add_argument(
"--bootstrap-ports", "--bootstrap-ports",
type=str, type=str,
...@@ -479,6 +488,17 @@ def parse_args() -> argparse.Namespace: ...@@ -479,6 +488,17 @@ def parse_args() -> argparse.Namespace:
f"got {len(args.zmq_kv_events_ports_list)}: {args.zmq_kv_events_ports_list}" f"got {len(args.zmq_kv_events_ports_list)}: {args.zmq_kv_events_ports_list}"
) )
# Parse and validate zmq_replay_ports
args.zmq_replay_ports_list = parse_bootstrap_ports(args.zmq_replay_ports)
if args.zmq_replay_ports_list:
if not args.zmq_kv_events_ports_list:
raise ValueError("--zmq-replay-ports requires --zmq-kv-events-ports")
if len(args.zmq_replay_ports_list) != args.num_workers:
raise ValueError(
f"--zmq-replay-ports must have exactly --num-workers ({args.num_workers}) ports, "
f"got {len(args.zmq_replay_ports_list)}: {args.zmq_replay_ports_list}"
)
# Set endpoint default based on worker type if not explicitly provided # Set endpoint default based on worker type if not explicitly provided
if args.endpoint is None: if args.endpoint is None:
if args.is_prefill_worker: if args.is_prefill_worker:
......
...@@ -218,7 +218,9 @@ async def launch_workers(args: argparse.Namespace, extra_engine_args_path: Path) ...@@ -218,7 +218,9 @@ async def launch_workers(args: argparse.Namespace, extra_engine_args_path: Path)
base_engine_args = json.load(f) base_engine_args = json.load(f)
needs_per_worker_args = bool( needs_per_worker_args = bool(
args.bootstrap_ports_list or args.zmq_kv_events_ports_list args.bootstrap_ports_list
or args.zmq_kv_events_ports_list
or args.zmq_replay_ports_list
) )
for worker_id in range(args.num_workers): for worker_id in range(args.num_workers):
...@@ -242,6 +244,8 @@ async def launch_workers(args: argparse.Namespace, extra_engine_args_path: Path) ...@@ -242,6 +244,8 @@ async def launch_workers(args: argparse.Namespace, extra_engine_args_path: Path)
worker_args["zmq_kv_events_port"] = args.zmq_kv_events_ports_list[ worker_args["zmq_kv_events_port"] = args.zmq_kv_events_ports_list[
worker_id worker_id
] ]
if args.zmq_replay_ports_list:
worker_args["zmq_replay_port"] = args.zmq_replay_ports_list[worker_id]
with tempfile.NamedTemporaryFile( with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False mode="w", suffix=".json", delete=False
) as tmp: ) as tmp:
......
...@@ -135,6 +135,7 @@ curl -X POST http://localhost:8090/register \ ...@@ -135,6 +135,7 @@ curl -X POST http://localhost:8090/register \
| `block_size` | yes | — | KV cache block size (must match the engine) | | `block_size` | yes | — | KV cache block size (must match the engine) |
| `tenant_id` | no | `"default"` | Tenant identifier for isolation | | `tenant_id` | no | `"default"` | Tenant identifier for isolation |
| `dp_rank` | no | `0` | Data parallel rank | | `dp_rank` | no | `0` | Data parallel rank |
| `replay_endpoint` | no | — | ZMQ ROUTER address for gap replay (e.g. `tcp://host:5560`) |
### `POST /unregister` — Deregister an instance ### `POST /unregister` — Deregister an instance
...@@ -270,6 +271,24 @@ Returns: ...@@ -270,6 +271,24 @@ Returns:
["http://peer:8091"] ["http://peer:8091"]
``` ```
## DP Rank Handling
When a worker registers with the standalone KV indexer (`/register`), it provides an `instance_id`, a ZMQ `endpoint`, and an optional `dp_rank` (defaults to 0). The service spawns one ZMQ listener per registration.
Each incoming `KvEventBatch` may carry an optional `data_parallel_rank` field. If present, it **overrides** the statically-registered `dp_rank` for that batch. This allows a single ZMQ port to multiplex events from multiple DP ranks.
**Caveat**: the registry only tracks dp_ranks from explicit `/register` calls. If an engine dynamically emits batches with a dp_rank that was never registered, the indexer will store those blocks correctly (under the dynamic `WorkerWithDpRank` key), but per-dp_rank deregistration (`/unregister` with `dp_rank`) will not find them. Full-instance deregistration (`/unregister` without `dp_rank`) still cleans up all dp_ranks for a given `worker_id` in the tree via `remove_worker`.
## Gap Detection and Replay
ZMQ PUB/SUB is lossy — messages can be dropped under backpressure or brief disconnects. The indexer detects gaps by tracking the sequence number of each batch: if `seq > last_seq + 1`, a gap is detected.
When a `replay_endpoint` is provided during `/register`, the indexer connects a DEALER socket to the engine's ROUTER socket and requests the missing batches by sequence number. The engine streams back buffered `(seq, payload)` pairs from its ring buffer until an empty-payload sentinel.
If no `replay_endpoint` is configured, gaps are logged as warnings but not recovered.
The sequence counter (`last_seq`) persists across unregister/register cycles, so re-registering a worker after a gap will trigger replay on the first batch received by the new listener.
## Limitations ## Limitations
- **ZMQ only**: Workers must publish KV events via ZMQ PUB sockets. The standalone indexer does not subscribe to NATS event streams. - **ZMQ only**: Workers must publish KV events via ZMQ PUB sockets. The standalone indexer does not subscribe to NATS event streams.
......
...@@ -14,7 +14,8 @@ repository.workspace = true ...@@ -14,7 +14,8 @@ repository.workspace = true
default = [] default = []
metrics = [] metrics = []
bench = ["dep:clap", "dep:indicatif", "dep:serde_json", "dep:plotters"] bench = ["dep:clap", "dep:indicatif", "dep:serde_json", "dep:plotters"]
indexer-bin = ["metrics", "dep:axum", "dep:clap", "dep:zeromq", "dep:tracing-subscriber", "dep:serde_json", "dep:reqwest"] indexer-bin = ["metrics", "dep:axum", "dep:bytes", "dep:clap", "dep:zeromq", "dep:tracing-subscriber", "dep:serde_json", "dep:reqwest"]
test-endpoints = ["indexer-bin"]
[dependencies] [dependencies]
# repo # repo
...@@ -52,6 +53,7 @@ rustc-hash = "2.1.1" ...@@ -52,6 +53,7 @@ rustc-hash = "2.1.1"
# indexer-bin (optional) # indexer-bin (optional)
axum = { workspace = true, optional = true } axum = { workspace = true, optional = true }
bytes = { workspace = true, optional = true }
reqwest = { workspace = true, optional = true } reqwest = { workspace = true, optional = true }
zeromq = { version = "0.4.1", optional = true } zeromq = { version = "0.4.1", optional = true }
tracing-subscriber = { workspace = true, optional = true } tracing-subscriber = { workspace = true, optional = true }
......
...@@ -2,13 +2,14 @@ ...@@ -2,13 +2,14 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::AtomicU32; use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::Duration; use std::time::Duration;
use bytes::Bytes;
use rmp_serde as rmps; use rmp_serde as rmps;
use tokio::sync::watch; use tokio::sync::watch;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use zeromq::{Socket, SocketRecv, SubSocket}; use zeromq::{Socket, SocketRecv, SocketSend, SubSocket};
use dynamo_kv_router::protocols::{RouterEvent, WorkerId}; use dynamo_kv_router::protocols::{RouterEvent, WorkerId};
use dynamo_kv_router::zmq_wire::{KvEventBatch, convert_event}; use dynamo_kv_router::zmq_wire::{KvEventBatch, convert_event};
...@@ -27,27 +28,113 @@ fn calculate_backoff_ms(consecutive_errors: u32) -> u64 { ...@@ -27,27 +28,113 @@ fn calculate_backoff_ms(consecutive_errors: u32) -> u64 {
) )
} }
// TODO: Gap detection for missed ZMQ messages /// Sentinel value for `watermark`: indicates no batch has been processed yet.
// const WATERMARK_UNSET: u64 = u64::MAX;
// ZMQ PUB/SUB is lossy — if the subscriber is slow or disconnects briefly,
// messages can be dropped. The `zeromq` 0.4 crate uses bounded internal /// Replay missed batches from the engine's ROUTER socket.
// channels between the PUB and SUB sockets (via `try_send` with a noop ///
// waker), so messages are silently dropped when the write buffer is full /// Uses a DEALER socket (no send/recv lockstep) to send one request and
// (per ZMQ spec RFC 29). /// receive multiple response frames. Each response is `[empty, seq, payload]`;
// /// an empty payload signals end of replay.
// For P2P recovery, the ready signal delays `recv()` only briefly (the #[expect(clippy::too_many_arguments)]
// duration of the HTTP dump fetch), which is well within the crate's async fn replay_gap(
// internal channel capacity. For longer delays or high-throughput scenarios, replay_socket: &mut zeromq::DealerSocket,
// messages could be lost. start_seq: u64,
// end_seq: u64,
// Easy win: hook up the vLLM replay endpoint — workers already expose worker_id: WorkerId,
// `LocalKvIndexer` with event buffering and range queries (see dp_rank: u32,
// `lib/llm/src/kv_router/worker_query.rs`), just need to query it from block_size: u32,
// the standalone indexer on gap detection. indexer: &Indexer,
// warning_count: &Arc<AtomicU32>,
// Alternative future approach: switch to an explicit `mpsc` channel as the watermark: &Arc<AtomicU64>,
// buffer (unbounded, no drops) instead of relying on ZMQ's internal buffer. ) -> u64 {
tracing::info!(
worker_id,
dp_rank,
start_seq,
end_seq,
"Requesting replay from engine"
);
// DEALER must manually prepend the empty delimiter that REQ adds automatically.
let req_frames = vec![Bytes::new(), Bytes::from(start_seq.to_be_bytes().to_vec())];
let Ok(req_msg) = zeromq::ZmqMessage::try_from(req_frames) else {
tracing::error!(worker_id, dp_rank, "Failed to build replay request");
return 0;
};
if let Err(e) = replay_socket.send(req_msg).await {
tracing::error!(worker_id, dp_rank, error = %e, "Failed to send replay request");
return 0;
}
let mut replayed = 0u64;
loop {
let Ok(msg) = replay_socket.recv().await else {
tracing::error!(worker_id, dp_rank, "Replay recv error");
break;
};
// ROUTER sends [identity, empty, seq, payload]; DEALER strips identity,
// so we receive [empty, seq, payload].
if msg.len() < 3 {
tracing::warn!(
worker_id,
dp_rank,
"Unexpected replay frame count: {}",
msg.len()
);
break;
}
let payload = msg.get(2).unwrap();
if payload.is_empty() {
break;
}
let seq_bytes = msg.get(1).unwrap();
if seq_bytes.len() != 8 {
tracing::warn!(
worker_id,
dp_rank,
"Invalid replay seq length: {}",
seq_bytes.len()
);
break;
}
let seq = u64::from_be_bytes(seq_bytes[..8].try_into().unwrap());
let Ok(batch) = rmps::from_slice::<KvEventBatch>(payload) else {
tracing::warn!(worker_id, dp_rank, seq, "Failed to decode replayed batch");
continue;
};
let effective_dp_rank = batch
.data_parallel_rank
.map_or(dp_rank, |r| r.cast_unsigned());
for raw_event in batch.events {
let kv_event =
convert_event(raw_event, seq, block_size, effective_dp_rank, warning_count);
let router_event = RouterEvent::new(worker_id, kv_event);
indexer.apply_event(router_event).await;
}
watermark.store(seq, Ordering::Release);
replayed += 1;
}
tracing::info!(worker_id, dp_rank, replayed, "Replay complete");
replayed
}
// TODO: assumes one dp_rank per ZMQ socket. Seq counter is per-socket so gap
// detection works regardless, but replay semantics may differ if a single
// socket multiplexes dp_ranks.
/// Connect the ZMQ SUB socket, then spawn a background task that waits for
/// the ready signal before entering the recv loop.
///
/// Returns once the SUB socket is connected (subscription handshake begins
/// immediately in the background). The ready gate and recv loop run in a
/// spawned task so `register()` is never blocked waiting for `signal_ready()`.
#[expect(clippy::too_many_arguments)]
pub async fn run_zmq_listener( pub async fn run_zmq_listener(
worker_id: WorkerId, worker_id: WorkerId,
dp_rank: u32, dp_rank: u32,
...@@ -55,7 +142,9 @@ pub async fn run_zmq_listener( ...@@ -55,7 +142,9 @@ pub async fn run_zmq_listener(
block_size: u32, block_size: u32,
indexer: Indexer, indexer: Indexer,
cancel: CancellationToken, cancel: CancellationToken,
mut ready: watch::Receiver<bool>, ready: watch::Receiver<bool>,
replay_endpoint: Option<String>,
watermark: Arc<AtomicU64>,
) { ) {
tracing::info!(worker_id, dp_rank, zmq_address, "ZMQ listener starting"); tracing::info!(worker_id, dp_rank, zmq_address, "ZMQ listener starting");
...@@ -71,6 +160,35 @@ pub async fn run_zmq_listener( ...@@ -71,6 +160,35 @@ pub async fn run_zmq_listener(
return; return;
} }
// Spawn the ready-wait + recv loop so the caller returns immediately.
// The ZMQ subscription handshake proceeds in the background while P2P
// recovery runs; once signal_ready() fires the recv loop starts draining
// any buffered messages.
tokio::spawn(zmq_wait_ready_then_recv(
worker_id,
dp_rank,
block_size,
indexer,
cancel,
ready,
socket,
replay_endpoint,
watermark,
));
}
#[expect(clippy::too_many_arguments)]
async fn zmq_wait_ready_then_recv(
worker_id: WorkerId,
dp_rank: u32,
block_size: u32,
indexer: Indexer,
cancel: CancellationToken,
mut ready: watch::Receiver<bool>,
socket: SubSocket,
replay_endpoint: Option<String>,
watermark: Arc<AtomicU64>,
) {
// Wait for the ready signal before entering the recv loop. // Wait for the ready signal before entering the recv loop.
// During P2P recovery, this delay lets the recovery code fetch the dump // During P2P recovery, this delay lets the recovery code fetch the dump
// from a peer while ZMQ subscription handshakes complete in the background. // from a peer while ZMQ subscription handshakes complete in the background.
...@@ -90,7 +208,48 @@ pub async fn run_zmq_listener( ...@@ -90,7 +208,48 @@ pub async fn run_zmq_listener(
tracing::info!(worker_id, dp_rank, "ZMQ listener ready, starting recv loop"); tracing::info!(worker_id, dp_rank, "ZMQ listener ready, starting recv loop");
let mut next_event_id = 0u64; // Connect DEALER socket once if replay_endpoint is configured.
// DEALER (not REQ) because we send one request and receive multiple responses.
let mut replay_socket = None;
if let Some(ref ep) = replay_endpoint {
let mut sock = zeromq::DealerSocket::new();
if let Err(e) = sock.connect(ep).await {
tracing::error!(worker_id, dp_rank, error = %e, "Failed to connect replay socket to {ep}");
} else {
tracing::info!(
worker_id,
dp_rank,
replay_endpoint = ep,
"Replay socket connected"
);
replay_socket = Some(sock);
}
}
zmq_recv_loop(
worker_id,
dp_rank,
block_size,
indexer,
cancel,
socket,
replay_socket,
watermark,
)
.await;
}
#[expect(clippy::too_many_arguments)]
async fn zmq_recv_loop(
worker_id: WorkerId,
dp_rank: u32,
block_size: u32,
indexer: Indexer,
cancel: CancellationToken,
mut socket: SubSocket,
mut replay_socket: Option<zeromq::DealerSocket>,
watermark: Arc<AtomicU64>,
) {
let warning_count = Arc::new(AtomicU32::new(0)); let warning_count = Arc::new(AtomicU32::new(0));
let mut consecutive_errors = 0u32; let mut consecutive_errors = 0u32;
#[expect(unused_assignments)] #[expect(unused_assignments)]
...@@ -147,6 +306,41 @@ pub async fn run_zmq_listener( ...@@ -147,6 +306,41 @@ pub async fn run_zmq_listener(
continue; continue;
} }
let seq = u64::from_be_bytes(seq_bytes[..8].try_into().unwrap());
// Gap detection
let prev = watermark.load(Ordering::Acquire);
if prev != WATERMARK_UNSET && seq > prev + 1 {
let gap_start = prev + 1;
tracing::warn!(
worker_id, dp_rank,
expected = gap_start, got = seq,
"Gap detected: expected seq {gap_start}, got {seq}"
);
match replay_socket.as_mut() {
Some(sock) => {
replay_gap(
sock, gap_start, seq, worker_id, dp_rank,
block_size, &indexer, &warning_count, &watermark,
).await;
}
None => tracing::warn!(
worker_id, dp_rank,
gap_size = seq - gap_start,
"No replay endpoint configured, {gap_size} batches lost",
gap_size = seq - gap_start,
),
}
}
// After replay, watermark may have advanced past the current
// batch — skip to avoid double-apply. Exclude the sentinel
// (WATERMARK_UNSET) so the very first message is not skipped.
let current_wm = watermark.load(Ordering::Acquire);
if current_wm != WATERMARK_UNSET && current_wm >= seq {
continue;
}
let payload = msg.get(2).unwrap(); let payload = msg.get(2).unwrap();
let batch_result = rmps::from_slice::<KvEventBatch>(payload); let batch_result = rmps::from_slice::<KvEventBatch>(payload);
let Ok(batch) = batch_result else { let Ok(batch) = batch_result else {
...@@ -155,14 +349,15 @@ pub async fn run_zmq_listener( ...@@ -155,14 +349,15 @@ pub async fn run_zmq_listener(
}; };
let effective_dp_rank = batch.data_parallel_rank.map_or(dp_rank, |r| r.cast_unsigned()); let effective_dp_rank = batch.data_parallel_rank.map_or(dp_rank, |r| r.cast_unsigned());
// Use the engine's ZMQ sequence number as event_id so downstream
// consumers can detect gaps and request replay.
for raw_event in batch.events { for raw_event in batch.events {
let event_id = next_event_id; let kv_event = convert_event(raw_event, seq, block_size, effective_dp_rank, &warning_count);
next_event_id += 1;
let kv_event = convert_event(raw_event, event_id, block_size, effective_dp_rank, &warning_count);
let router_event = RouterEvent::new(worker_id, kv_event); let router_event = RouterEvent::new(worker_id, kv_event);
indexer.apply_event(router_event).await; indexer.apply_event(router_event).await;
messages_processed += 1; messages_processed += 1;
} }
watermark.store(seq, Ordering::Release);
} }
} }
} }
......
...@@ -97,23 +97,26 @@ async fn main() -> anyhow::Result<()> { ...@@ -97,23 +97,26 @@ async fn main() -> anyhow::Result<()> {
let registry = WorkerRegistry::new(cli.threads); let registry = WorkerRegistry::new(cli.threads);
// Register initial workers — connects ZMQ sockets but listeners wait // Register initial workers — connects ZMQ SUB sockets (subscription
// for the ready signal. This ensures ZMQ subscription handshakes begin // handshakes begin immediately) and spawns listener tasks that wait for
// before P2P recovery fetches the dump from a peer. // the ready signal. register() returns as soon as the socket is connected.
if let Some(ref workers_str) = cli.workers { if let Some(ref workers_str) = cli.workers {
let block_size = cli.block_size.ok_or_else(|| { let block_size = cli.block_size.ok_or_else(|| {
anyhow::anyhow!("--block-size is required when --workers is specified") anyhow::anyhow!("--block-size is required when --workers is specified")
})?; })?;
for (instance_id, dp_rank, endpoint) in parse_workers(workers_str) { for (instance_id, dp_rank, endpoint) in parse_workers(workers_str) {
tracing::info!(instance_id, dp_rank, endpoint, "Registering initial worker"); tracing::info!(instance_id, dp_rank, endpoint, "Registering initial worker");
registry.register( registry
instance_id, .register(
endpoint, instance_id,
dp_rank, endpoint,
cli.model_name.clone(), dp_rank,
cli.tenant_id.clone(), cli.model_name.clone(),
block_size, cli.tenant_id.clone(),
)?; block_size,
None,
)
.await?;
} }
} }
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use anyhow::{Result, bail}; use anyhow::{Result, bail};
use dashmap::DashMap; use dashmap::DashMap;
...@@ -27,13 +29,27 @@ pub struct IndexerEntry { ...@@ -27,13 +29,27 @@ pub struct IndexerEntry {
pub struct WorkerEntry { pub struct WorkerEntry {
pub endpoints: HashMap<u32, String>, pub endpoints: HashMap<u32, String>,
pub replay_endpoints: HashMap<u32, String>,
cancels: HashMap<u32, CancellationToken>, cancels: HashMap<u32, CancellationToken>,
} }
/// State needed to restart a paused ZMQ listener.
struct ListenerState {
endpoint: String,
replay_endpoint: Option<String>,
block_size: u32,
indexer: Indexer,
watermark: Arc<AtomicU64>,
}
pub struct WorkerRegistry { pub struct WorkerRegistry {
workers: DashMap<WorkerId, WorkerEntry>, workers: DashMap<WorkerId, WorkerEntry>,
indexers: DashMap<IndexerKey, IndexerEntry>, indexers: DashMap<IndexerKey, IndexerEntry>,
peers: DashMap<String, ()>, peers: DashMap<String, ()>,
/// Persists across unregister/register cycles so gap detection works after re-registration.
watermarks: DashMap<(WorkerId, u32), Arc<AtomicU64>>,
/// Saved listener state for pause/resume. Populated on register, kept on pause.
listener_states: DashMap<(WorkerId, u32), ListenerState>,
num_threads: usize, num_threads: usize,
ready_tx: watch::Sender<bool>, ready_tx: watch::Sender<bool>,
ready_rx: watch::Receiver<bool>, ready_rx: watch::Receiver<bool>,
...@@ -46,6 +62,8 @@ impl WorkerRegistry { ...@@ -46,6 +62,8 @@ impl WorkerRegistry {
workers: DashMap::new(), workers: DashMap::new(),
indexers: DashMap::new(), indexers: DashMap::new(),
peers: DashMap::new(), peers: DashMap::new(),
watermarks: DashMap::new(),
listener_states: DashMap::new(),
num_threads, num_threads,
ready_tx, ready_tx,
ready_rx, ready_rx,
...@@ -72,7 +90,8 @@ impl WorkerRegistry { ...@@ -72,7 +90,8 @@ impl WorkerRegistry {
self.peers.iter().map(|entry| entry.key().clone()).collect() self.peers.iter().map(|entry| entry.key().clone()).collect()
} }
pub fn register( #[expect(clippy::too_many_arguments)]
pub async fn register(
&self, &self,
instance_id: WorkerId, instance_id: WorkerId,
endpoint: String, endpoint: String,
...@@ -80,6 +99,7 @@ impl WorkerRegistry { ...@@ -80,6 +99,7 @@ impl WorkerRegistry {
model_name: String, model_name: String,
tenant_id: String, tenant_id: String,
block_size: u32, block_size: u32,
replay_endpoint: Option<String>,
) -> Result<()> { ) -> Result<()> {
let key = IndexerKey { let key = IndexerKey {
model_name, model_name,
...@@ -115,27 +135,68 @@ impl WorkerRegistry { ...@@ -115,27 +135,68 @@ impl WorkerRegistry {
let bs = indexer_entry.block_size; let bs = indexer_entry.block_size;
drop(indexer_entry); drop(indexer_entry);
let mut entry = self // Check for duplicate and insert replay endpoint while holding the lock briefly.
.workers {
.entry(instance_id) let mut entry = self
.or_insert_with(|| WorkerEntry { .workers
endpoints: HashMap::new(), .entry(instance_id)
cancels: HashMap::new(), .or_insert_with(|| WorkerEntry {
}); endpoints: HashMap::new(),
replay_endpoints: HashMap::new(),
if entry.endpoints.contains_key(&dp_rank) { cancels: HashMap::new(),
bail!("instance {instance_id} dp_rank {dp_rank} already registered"); });
if entry.endpoints.contains_key(&dp_rank) {
bail!("instance {instance_id} dp_rank {dp_rank} already registered");
}
if let Some(rep) = &replay_endpoint {
entry.replay_endpoints.insert(dp_rank, rep.clone());
}
} }
// Reuse watermark if it survived a previous unregister (preserves gap detection).
let watermark = self
.watermarks
.entry((instance_id, dp_rank))
.or_insert_with(|| Arc::new(AtomicU64::new(u64::MAX)))
.clone();
self.listener_states.insert(
(instance_id, dp_rank),
ListenerState {
endpoint: endpoint.clone(),
replay_endpoint: replay_endpoint.clone(),
block_size: bs,
indexer: indexer.clone(),
watermark: watermark.clone(),
},
);
let cancel = CancellationToken::new(); let cancel = CancellationToken::new();
let child_cancel = cancel.child_token(); let child_cancel = cancel.child_token();
let addr = endpoint.clone(); let addr = endpoint.clone();
let ready = self.ready_rx(); let ready = self.ready_rx();
tokio::spawn(async move { // Connect the ZMQ socket and spawn the listener task (non-blocking).
run_zmq_listener(instance_id, dp_rank, addr, bs, indexer, child_cancel, ready).await; run_zmq_listener(
}); instance_id,
dp_rank,
addr,
bs,
indexer,
child_cancel,
ready,
replay_endpoint,
watermark,
)
.await;
// Re-acquire to store the endpoint and cancel token.
let mut entry = self
.workers
.get_mut(&instance_id)
.expect("worker entry disappeared during listener setup");
entry.endpoints.insert(dp_rank, endpoint); entry.endpoints.insert(dp_rank, endpoint);
entry.cancels.insert(dp_rank, cancel); entry.cancels.insert(dp_rank, cancel);
Ok(()) Ok(())
...@@ -251,6 +312,71 @@ impl WorkerRegistry { ...@@ -251,6 +312,71 @@ impl WorkerRegistry {
Ok(()) Ok(())
} }
#[expect(dead_code)]
pub fn pause_listener(&self, instance_id: WorkerId, dp_rank: u32) -> Result<()> {
let mut entry = self
.workers
.get_mut(&instance_id)
.ok_or_else(|| anyhow::anyhow!("instance {instance_id} not found"))?;
let cancel = entry.cancels.remove(&dp_rank).ok_or_else(|| {
anyhow::anyhow!("instance {instance_id} dp_rank {dp_rank} not active")
})?;
cancel.cancel();
tracing::info!(instance_id, dp_rank, "Paused ZMQ listener");
Ok(())
}
#[expect(dead_code)]
pub async fn resume_listener(&self, instance_id: WorkerId, dp_rank: u32) -> Result<()> {
{
let entry = self
.workers
.get(&instance_id)
.ok_or_else(|| anyhow::anyhow!("instance {instance_id} not found"))?;
if entry.cancels.contains_key(&dp_rank) {
bail!("instance {instance_id} dp_rank {dp_rank} already running");
}
}
let state = self
.listener_states
.get(&(instance_id, dp_rank))
.ok_or_else(|| anyhow::anyhow!("no saved state for {instance_id} dp_rank {dp_rank}"))?;
let cancel = CancellationToken::new();
let child_cancel = cancel.child_token();
let ready = self.ready_rx();
let addr = state.endpoint.clone();
let bs = state.block_size;
let indexer = state.indexer.clone();
let replay_ep = state.replay_endpoint.clone();
let watermark = state.watermark.clone();
drop(state);
run_zmq_listener(
instance_id,
dp_rank,
addr,
bs,
indexer,
child_cancel,
ready,
replay_ep,
watermark,
)
.await;
let mut entry = self
.workers
.get_mut(&instance_id)
.expect("worker entry disappeared during listener resume");
entry.cancels.insert(dp_rank, cancel);
Ok(())
}
pub fn list(&self) -> Vec<(WorkerId, HashMap<u32, String>)> { pub fn list(&self) -> Vec<(WorkerId, HashMap<u32, String>)> {
self.workers self.workers
.iter() .iter()
......
...@@ -33,6 +33,8 @@ pub struct RegisterRequest { ...@@ -33,6 +33,8 @@ pub struct RegisterRequest {
pub block_size: u32, pub block_size: u32,
#[serde(default)] #[serde(default)]
pub dp_rank: Option<u32>, pub dp_rank: Option<u32>,
#[serde(default)]
pub replay_endpoint: Option<String>,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
...@@ -85,14 +87,19 @@ async fn register( ...@@ -85,14 +87,19 @@ async fn register(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Json(req): Json<RegisterRequest>, Json(req): Json<RegisterRequest>,
) -> impl IntoResponse { ) -> impl IntoResponse {
match state.registry.register( match state
req.instance_id, .registry
req.endpoint, .register(
req.dp_rank.unwrap_or(0), req.instance_id,
req.model_name, req.endpoint,
req.tenant_id, req.dp_rank.unwrap_or(0),
req.block_size, req.model_name,
) { req.tenant_id,
req.block_size,
req.replay_endpoint,
)
.await
{
Ok(()) => ( Ok(()) => (
StatusCode::CREATED, StatusCode::CREATED,
Json(serde_json::json!({"status": "ok"})), Json(serde_json::json!({"status": "ok"})),
...@@ -248,6 +255,49 @@ async fn query_by_hash( ...@@ -248,6 +255,49 @@ async fn query_by_hash(
} }
} }
#[cfg(feature = "test-endpoints")]
#[derive(Deserialize)]
struct ListenerControlRequest {
instance_id: WorkerId,
#[serde(default)]
dp_rank: Option<u32>,
}
#[cfg(feature = "test-endpoints")]
async fn test_pause_listener(
State(state): State<Arc<AppState>>,
Json(req): Json<ListenerControlRequest>,
) -> impl IntoResponse {
match state
.registry
.pause_listener(req.instance_id, req.dp_rank.unwrap_or(0))
{
Ok(()) => (StatusCode::OK, Json(serde_json::json!({"status": "ok"}))),
Err(e) => (
StatusCode::NOT_FOUND,
Json(serde_json::json!({"error": e.to_string()})),
),
}
}
#[cfg(feature = "test-endpoints")]
async fn test_resume_listener(
State(state): State<Arc<AppState>>,
Json(req): Json<ListenerControlRequest>,
) -> impl IntoResponse {
match state
.registry
.resume_listener(req.instance_id, req.dp_rank.unwrap_or(0))
.await
{
Ok(()) => (StatusCode::OK, Json(serde_json::json!({"status": "ok"}))),
Err(e) => (
StatusCode::CONFLICT,
Json(serde_json::json!({"error": e.to_string()})),
),
}
}
#[derive(Deserialize)] #[derive(Deserialize)]
struct PeerRequest { struct PeerRequest {
url: String, url: String,
...@@ -319,7 +369,7 @@ async fn dump_events(State(state): State<Arc<AppState>>) -> impl IntoResponse { ...@@ -319,7 +369,7 @@ async fn dump_events(State(state): State<Arc<AppState>>) -> impl IntoResponse {
} }
pub fn create_router(state: Arc<AppState>) -> Router { pub fn create_router(state: Arc<AppState>) -> Router {
Router::new() let router = Router::new()
.route("/register", post(register)) .route("/register", post(register))
.route("/unregister", post(unregister)) .route("/unregister", post(unregister))
.route("/workers", get(list_workers)) .route("/workers", get(list_workers))
...@@ -328,6 +378,12 @@ pub fn create_router(state: Arc<AppState>) -> Router { ...@@ -328,6 +378,12 @@ pub fn create_router(state: Arc<AppState>) -> Router {
.route("/dump", get(dump_events)) .route("/dump", get(dump_events))
.route("/register_peer", post(register_peer)) .route("/register_peer", post(register_peer))
.route("/deregister_peer", post(deregister_peer)) .route("/deregister_peer", post(deregister_peer))
.route("/peers", get(list_peers)) .route("/peers", get(list_peers));
.with_state(state)
#[cfg(feature = "test-endpoints")]
let router = router
.route("/test/pause_listener", post(test_pause_listener))
.route("/test/resume_listener", post(test_resume_listener));
router.with_state(state)
} }
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
//! The core mocker logic lives in the `dynamo-mocker` crate. //! The core mocker logic lives in the `dynamo-mocker` crate.
//! This module provides the runtime-dependent engine wrapper. //! This module provides the runtime-dependent engine wrapper.
use std::collections::VecDeque;
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::time::{Duration, SystemTime, UNIX_EPOCH};
...@@ -38,7 +39,7 @@ use tokio::sync::{Notify, OnceCell, mpsc}; ...@@ -38,7 +39,7 @@ use tokio::sync::{Notify, OnceCell, mpsc};
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use uuid::Uuid; use uuid::Uuid;
use zeromq::{Socket, SocketSend}; use zeromq::{Socket, SocketRecv, SocketSend};
pub const MOCKER_COMPONENT: &str = "mocker"; pub const MOCKER_COMPONENT: &str = "mocker";
...@@ -84,8 +85,16 @@ struct ZmqKvEventSink { ...@@ -84,8 +85,16 @@ struct ZmqKvEventSink {
tx: mpsc::UnboundedSender<ZmqKvEventMsg>, tx: mpsc::UnboundedSender<ZmqKvEventMsg>,
} }
/// Maximum number of entries in the replay ring buffer.
const REPLAY_BUFFER_CAPACITY: usize = 10_000;
impl ZmqKvEventSink { impl ZmqKvEventSink {
async fn new(port: u16, dp_rank: u32, block_size: u32) -> Result<Self> { async fn new(
port: u16,
replay_port: Option<u16>,
dp_rank: u32,
block_size: u32,
) -> Result<Self> {
let (tx, mut rx) = mpsc::unbounded_channel::<ZmqKvEventMsg>(); let (tx, mut rx) = mpsc::unbounded_channel::<ZmqKvEventMsg>();
// Bind the PUB socket before returning so that any SUB connect() // Bind the PUB socket before returning so that any SUB connect()
...@@ -98,44 +107,139 @@ impl ZmqKvEventSink { ...@@ -98,44 +107,139 @@ impl ZmqKvEventSink {
.map_err(|e| anyhow::anyhow!("ZMQ PUB bind to {endpoint} failed: {e}"))?; .map_err(|e| anyhow::anyhow!("ZMQ PUB bind to {endpoint} failed: {e}"))?;
tracing::info!("ZmqKvEventSink bound to {endpoint} for dp_rank {dp_rank}"); tracing::info!("ZmqKvEventSink bound to {endpoint} for dp_rank {dp_rank}");
// Optionally bind ROUTER socket for replay
let mut router_socket = if let Some(rp) = replay_port {
let mut sock = zeromq::RouterSocket::new();
let replay_ep = format!("tcp://0.0.0.0:{rp}");
sock.bind(&replay_ep)
.await
.map_err(|e| anyhow::anyhow!("ZMQ ROUTER bind to {replay_ep} failed: {e}"))?;
tracing::info!(
"ZmqKvEventSink replay ROUTER bound to {replay_ep} for dp_rank {dp_rank}"
);
Some(sock)
} else {
None
};
tokio::spawn(async move { tokio::spawn(async move {
let mut seq_num: u64 = 0; let mut seq_num: u64 = 0;
// Store Bytes (ref-counted) to avoid memcpy on both PUB and buffer paths.
let mut ring_buffer: VecDeque<(u64, Bytes)> = VecDeque::new();
while let Some(msg) = rx.recv().await { loop {
let events = tokio::select! {
convert_to_zmq_events(&msg.event, msg.block_token_ids.as_deref(), block_size); biased;
if events.is_empty() {
continue; // Replay requests are rare but latency-sensitive — poll first
} // to prevent starvation under sustained KV event load.
replay_result = async {
match router_socket.as_mut() {
Some(sock) => sock.recv().await,
None => std::future::pending().await,
}
} => {
let Ok(req_msg) = replay_result else {
tracing::warn!("Replay ROUTER recv error");
continue;
};
if req_msg.len() < 3 {
tracing::warn!("Unexpected replay request frame count: {}", req_msg.len());
continue;
}
let identity: Bytes = Bytes::copy_from_slice(req_msg.get(0).unwrap());
let start_seq_bytes = req_msg.get(2).unwrap();
if start_seq_bytes.len() != 8 {
tracing::warn!("Invalid replay start_seq length: {}", start_seq_bytes.len());
continue;
}
let start_seq = u64::from_be_bytes(start_seq_bytes[..8].try_into().unwrap());
tracing::debug!(dp_rank, start_seq, buffer_len = ring_buffer.len(), "Replay request received");
// Compute start index directly — sequences are monotonic.
let start_idx = ring_buffer.front()
.map(|(first_seq, _)| start_seq.saturating_sub(*first_seq) as usize)
.unwrap_or(0)
.min(ring_buffer.len());
let sock = router_socket.as_mut().unwrap();
for (seq, payload) in ring_buffer.iter().skip(start_idx) {
let frames = vec![
identity.clone(),
Bytes::new(),
Bytes::from(seq.to_be_bytes().to_vec()),
payload.clone(), // ref-count bump
];
let reply = zeromq::ZmqMessage::try_from(frames)
.expect("replay frame");
if let Err(e) = sock.send(reply).await {
tracing::warn!("Replay send error: {e}");
break;
}
}
// Sentinel: empty payload signals end of replay
let sentinel_frames = vec![
identity,
Bytes::new(),
Bytes::from((-1i64).to_be_bytes().to_vec()),
Bytes::new(),
];
let sentinel = zeromq::ZmqMessage::try_from(sentinel_frames)
.expect("sentinel frame");
let _ = sock.send(sentinel).await;
}
msg_opt = rx.recv() => {
let Some(msg) = msg_opt else { break };
let events = convert_to_zmq_events(
&msg.event,
msg.block_token_ids.as_deref(),
block_size,
);
if events.is_empty() {
continue;
}
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64();
let batch: (f64, Vec<ZmqRawKvEvent>, Option<i32>) =
(timestamp, events, Some(dp_rank as i32));
let payload: Bytes = match rmp_serde::to_vec(&batch) {
Ok(p) => p.into(),
Err(e) => {
tracing::warn!("Failed to serialize ZMQ KV event: {e}");
continue;
}
};
let frames = vec![
Bytes::from(""),
Bytes::from(seq_num.to_be_bytes().to_vec()),
payload.clone(), // ref-count bump, not memcpy
];
let zmq_msg = zeromq::ZmqMessage::try_from(frames)
.expect("Failed to create ZMQ multipart message");
let timestamp = SystemTime::now() if let Err(e) = pub_socket.send(zmq_msg).await {
.duration_since(UNIX_EPOCH) tracing::warn!("Failed to send ZMQ KV event: {e}");
.unwrap_or_default() }
.as_secs_f64();
if router_socket.is_some() {
let batch: (f64, Vec<ZmqRawKvEvent>, Option<i32>) = if ring_buffer.len() >= REPLAY_BUFFER_CAPACITY {
(timestamp, events, Some(dp_rank as i32)); ring_buffer.pop_front();
let payload = match rmp_serde::to_vec(&batch) { }
Ok(p) => p, ring_buffer.push_back((seq_num, payload));
Err(e) => { }
tracing::warn!("Failed to serialize ZMQ KV event: {e}");
continue; seq_num += 1;
} }
};
let frames = vec![
Bytes::from(""),
Bytes::from(seq_num.to_be_bytes().to_vec()),
Bytes::from(payload),
];
let zmq_msg = zeromq::ZmqMessage::try_from(frames)
.expect("Failed to create ZMQ multipart message");
if let Err(e) = pub_socket.send(zmq_msg).await {
tracing::warn!("Failed to send ZMQ KV event: {e}");
} }
seq_num += 1;
} }
}); });
...@@ -312,7 +416,15 @@ impl MockVllmEngine { ...@@ -312,7 +416,15 @@ impl MockVllmEngine {
) = match component { ) = match component {
Some(comp) if args.zmq_kv_events_port.is_some() => { Some(comp) if args.zmq_kv_events_port.is_some() => {
let zmq_port = args.zmq_kv_events_port.unwrap() + dp_rank as u16; let zmq_port = args.zmq_kv_events_port.unwrap() + dp_rank as u16;
match ZmqKvEventSink::new(zmq_port, dp_rank, args.block_size as u32).await { let replay_port = args.zmq_replay_port.map(|p| p + dp_rank as u16);
match ZmqKvEventSink::new(
zmq_port,
replay_port,
dp_rank,
args.block_size as u32,
)
.await
{
Ok(sink) => { Ok(sink) => {
let source_config = Some(KvEventSourceConfig::Zmq { let source_config = Some(KvEventSourceConfig::Zmq {
endpoint: format!("tcp://127.0.0.1:{zmq_port}"), endpoint: format!("tcp://127.0.0.1:{zmq_port}"),
......
...@@ -212,6 +212,13 @@ pub struct MockEngineArgs { ...@@ -212,6 +212,13 @@ pub struct MockEngineArgs {
#[builder(default = "None")] #[builder(default = "None")]
pub zmq_kv_events_port: Option<u16>, pub zmq_kv_events_port: Option<u16>,
/// ZMQ ROUTER port for replay of buffered KV event batches.
/// When set alongside `zmq_kv_events_port`, the mocker binds a ROUTER socket
/// that streams back buffered batches by sequence number on request.
/// Port is offset by dp_rank (replay_port + dp_rank).
#[builder(default = "None")]
pub zmq_replay_port: Option<u16>,
/// Preemption mode for decode eviction under memory pressure. /// Preemption mode for decode eviction under memory pressure.
/// Lifo (default) evicts the newest request; Fifo evicts the oldest. /// Lifo (default) evicts the newest request; Fifo evicts the oldest.
#[builder(default)] #[builder(default)]
...@@ -271,6 +278,7 @@ impl MockEngineArgs { ...@@ -271,6 +278,7 @@ impl MockEngineArgs {
"kv_transfer_bandwidth", "kv_transfer_bandwidth",
"reasoning", "reasoning",
"zmq_kv_events_port", "zmq_kv_events_port",
"zmq_replay_port",
"preemption_mode", "preemption_mode",
] ]
.iter() .iter()
...@@ -383,6 +391,12 @@ impl MockEngineArgs { ...@@ -383,6 +391,12 @@ impl MockEngineArgs {
builder = builder.zmq_kv_events_port(Some(port as u16)); builder = builder.zmq_kv_events_port(Some(port as u16));
} }
if let Some(value) = extra_args.get("zmq_replay_port")
&& let Some(port) = value.as_u64()
{
builder = builder.zmq_replay_port(Some(port as u16));
}
if let Some(value) = extra_args.get("preemption_mode") if let Some(value) = extra_args.get("preemption_mode")
&& let Some(mode_str) = value.as_str() && let Some(mode_str) = value.as_str()
{ {
......
...@@ -701,6 +701,61 @@ def _test_router_overload_503( ...@@ -701,6 +701,61 @@ def _test_router_overload_503(
logger.info("Successfully verified 503 response when all workers are busy") logger.info("Successfully verified 503 response when all workers are busy")
async def _zmq_replay_cycle(
phase: int,
router,
router_name: str,
endpoint,
indexer_url: str,
engine_workers,
send_requests_to_router,
):
"""Pause indexer listeners → send gap requests → resume → send to trigger replay."""
await asyncio.sleep(1)
worker_ids = list(engine_workers.worker_id_to_zmq_ports.keys())
dp_size = getattr(engine_workers, "dp_size", None) or 1
logger.info(f"=== ZMQ REPLAY TEST: Phase {phase} ({router_name}) ===")
async with aiohttp.ClientSession() as session:
for wid in worker_ids:
for dp_rank in range(dp_size):
async with session.post(
f"{indexer_url}/test/pause_listener",
json={"instance_id": wid, "dp_rank": dp_rank},
) as resp:
assert (
resp.status == 200
), f"Pause {wid}:{dp_rank} failed: {await resp.text()}"
logger.info("Sending 10 requests while indexer listeners are paused")
successful_gap = await send_requests_to_router(
router, 10, f"{router_name} (indexer paused)", endpoint
)
assert (
successful_gap == 10
), f"Expected 10 requests while paused, got {successful_gap}"
async with aiohttp.ClientSession() as session:
for wid in worker_ids:
for dp_rank in range(dp_size):
async with session.post(
f"{indexer_url}/test/resume_listener",
json={"instance_id": wid, "dp_rank": dp_rank},
) as resp:
assert (
resp.status == 200
), f"Resume {wid}:{dp_rank} failed: {await resp.text()}"
logger.info("Sending 5 requests after resume (triggers gap detection + replay)")
successful_post = await send_requests_to_router(
router, 5, f"{router_name} (post-resume)", endpoint
)
assert (
successful_post == 5
), f"Expected 5 requests post-resume, got {successful_post}"
await asyncio.sleep(2)
def _test_router_indexers_sync( def _test_router_indexers_sync(
engine_workers, engine_workers,
block_size: int, block_size: int,
...@@ -714,6 +769,7 @@ def _test_router_indexers_sync( ...@@ -714,6 +769,7 @@ def _test_router_indexers_sync(
router_event_threads: int = 4, router_event_threads: int = 4,
standalone_indexer_url: Optional[str] = None, standalone_indexer_url: Optional[str] = None,
standalone_indexer_b_url: Optional[str] = None, standalone_indexer_b_url: Optional[str] = None,
test_zmq_replay: bool = False,
): ):
"""Test that two KV routers have synchronized indexer states after processing requests. """Test that two KV routers have synchronized indexer states after processing requests.
...@@ -854,6 +910,17 @@ def _test_router_indexers_sync( ...@@ -854,6 +910,17 @@ def _test_router_indexers_sync(
await asyncio.sleep(5) await asyncio.sleep(5)
if test_zmq_replay and standalone_indexer_url:
await _zmq_replay_cycle(
1,
kv_router1,
"Router 1",
endpoint1,
standalone_indexer_url,
engine_workers,
send_requests_to_router,
)
# Wait for snapshot to be available before creating second router. # Wait for snapshot to be available before creating second router.
# In JetStream mode, the background task may purge acknowledged messages # In JetStream mode, the background task may purge acknowledged messages
# from the stream before the snapshot upload completes. Poll the object # from the stream before the snapshot upload completes. Poll the object
...@@ -945,6 +1012,17 @@ def _test_router_indexers_sync( ...@@ -945,6 +1012,17 @@ def _test_router_indexers_sync(
successful_recovery == 5 successful_recovery == 5
), f"Expected 5 successful requests post-recovery, got {successful_recovery}" ), f"Expected 5 successful requests post-recovery, got {successful_recovery}"
if test_zmq_replay and standalone_indexer_url:
await _zmq_replay_cycle(
2,
kv_router2,
"Router 2",
endpoint2,
standalone_indexer_url,
engine_workers,
send_requests_to_router,
)
# Wait for internal synchronization and ZMQ event propagation # Wait for internal synchronization and ZMQ event propagation
logger.info("Waiting for final synchronization") logger.info("Waiting for final synchronization")
await asyncio.sleep(2) await asyncio.sleep(2)
......
...@@ -168,6 +168,8 @@ def _build_mocker_command( ...@@ -168,6 +168,8 @@ def _build_mocker_command(
command.extend(["--bootstrap-ports", mocker_args["bootstrap_ports"]]) command.extend(["--bootstrap-ports", mocker_args["bootstrap_ports"]])
if "zmq_kv_events_ports" in mocker_args: if "zmq_kv_events_ports" in mocker_args:
command.extend(["--zmq-kv-events-ports", mocker_args["zmq_kv_events_ports"]]) command.extend(["--zmq-kv-events-ports", mocker_args["zmq_kv_events_ports"]])
if "zmq_replay_ports" in mocker_args:
command.extend(["--zmq-replay-ports", mocker_args["zmq_replay_ports"]])
return command return command
...@@ -190,6 +192,7 @@ class MockerProcess: ...@@ -190,6 +192,7 @@ class MockerProcess:
zmq_kv_events: bool = False, zmq_kv_events: bool = False,
standalone_indexer: bool = False, standalone_indexer: bool = False,
model_name: str = "mocker", model_name: str = "mocker",
zmq_replay: bool = False,
): ):
namespace_suffix = generate_random_suffix() namespace_suffix = generate_random_suffix()
self.namespace = f"test-namespace-{namespace_suffix}" self.namespace = f"test-namespace-{namespace_suffix}"
...@@ -198,6 +201,7 @@ class MockerProcess: ...@@ -198,6 +201,7 @@ class MockerProcess:
self.endpoint = f"dyn://{self.namespace}.{self.component_name}.generate" self.endpoint = f"dyn://{self.namespace}.{self.component_name}.generate"
self.num_workers = num_mockers self.num_workers = num_mockers
self._zmq_kv_events_ports: list[int] = [] self._zmq_kv_events_ports: list[int] = []
self._zmq_replay_ports: list[int] = []
self._standalone_indexer = standalone_indexer self._standalone_indexer = standalone_indexer
self._standalone_indexer_port: Optional[int] = None self._standalone_indexer_port: Optional[int] = None
self._standalone_indexer_b_port: Optional[int] = None self._standalone_indexer_b_port: Optional[int] = None
...@@ -233,6 +237,22 @@ class MockerProcess: ...@@ -233,6 +237,22 @@ class MockerProcess:
f"(bases: {bases}) for {num_mockers} workers" f"(bases: {bases}) for {num_mockers} workers"
) )
# Allocate ZMQ replay ports (same layout as event ports)
if zmq_replay and zmq_kv_events:
dp_size = mocker_args.get("dp_size", 1)
self._zmq_replay_ports = allocate_ports(
num_mockers * dp_size, BASE_PORT_ZMQ + 1000
)
replay_bases = [
self._zmq_replay_ports[i * dp_size] for i in range(num_mockers)
]
if not standalone_indexer:
mocker_args["zmq_replay_ports"] = ",".join(str(p) for p in replay_bases)
logger.info(
f"Allocated ZMQ replay ports {self._zmq_replay_ports} "
f"(bases: {replay_bases}) for {num_mockers} workers"
)
if standalone_indexer: if standalone_indexer:
# Allocate ports for standalone indexer A and B (P2P recovery peer) # Allocate ports for standalone indexer A and B (P2P recovery peer)
indexer_ports = allocate_ports(2, BASE_PORT) indexer_ports = allocate_ports(2, BASE_PORT)
...@@ -289,7 +309,7 @@ class MockerProcess: ...@@ -289,7 +309,7 @@ class MockerProcess:
"-p", "-p",
"dynamo-kv-router", "dynamo-kv-router",
"--features", "--features",
"indexer-bin", "indexer-bin,test-endpoints",
"--bin", "--bin",
"dynamo-kv-indexer", "dynamo-kv-indexer",
"--", "--",
...@@ -338,6 +358,9 @@ class MockerProcess: ...@@ -338,6 +358,9 @@ class MockerProcess:
mocker_args = self._mocker_args_orig.copy() mocker_args = self._mocker_args_orig.copy()
base_port = self._zmq_kv_events_ports[i * dp_size] base_port = self._zmq_kv_events_ports[i * dp_size]
mocker_args["zmq_kv_events_ports"] = str(base_port) mocker_args["zmq_kv_events_ports"] = str(base_port)
if self._zmq_replay_ports:
replay_base = self._zmq_replay_ports[i * dp_size]
mocker_args["zmq_replay_ports"] = str(replay_base)
command = _build_mocker_command( command = _build_mocker_command(
endpoint=self.endpoint, endpoint=self.endpoint,
...@@ -398,6 +421,9 @@ class MockerProcess: ...@@ -398,6 +421,9 @@ class MockerProcess:
"block_size", BLOCK_SIZE "block_size", BLOCK_SIZE
), ),
} }
if self._zmq_replay_ports:
replay_port = self._zmq_replay_ports[i * dp_size + dp_rank]
payload["replay_endpoint"] = f"tcp://127.0.0.1:{replay_port}"
async with session.post(register_url, json=payload) as resp: async with session.post(register_url, json=payload) as resp:
if resp.status != 201: if resp.status != 201:
body = await resp.text() body = await resp.text()
...@@ -444,7 +470,7 @@ class MockerProcess: ...@@ -444,7 +470,7 @@ class MockerProcess:
"-p", "-p",
"dynamo-kv-router", "dynamo-kv-router",
"--features", "--features",
"indexer-bin", "indexer-bin,test-endpoints",
"--bin", "--bin",
"dynamo-kv-indexer", "dynamo-kv-indexer",
"--", "--",
...@@ -505,6 +531,10 @@ class MockerProcess: ...@@ -505,6 +531,10 @@ class MockerProcess:
deallocate_ports(self._zmq_kv_events_ports) deallocate_ports(self._zmq_kv_events_ports)
logger.info(f"Deallocated ZMQ KV event ports {self._zmq_kv_events_ports}") logger.info(f"Deallocated ZMQ KV event ports {self._zmq_kv_events_ports}")
self._zmq_kv_events_ports = [] self._zmq_kv_events_ports = []
if self._zmq_replay_ports:
deallocate_ports(self._zmq_replay_ports)
logger.info(f"Deallocated ZMQ replay ports {self._zmq_replay_ports}")
self._zmq_replay_ports = []
class DisaggMockerProcess: class DisaggMockerProcess:
...@@ -862,6 +892,7 @@ def test_indexers_sync( ...@@ -862,6 +892,7 @@ def test_indexers_sync(
store_backend=store_backend, store_backend=store_backend,
request_plane=request_plane, request_plane=request_plane,
zmq_kv_events=True, zmq_kv_events=True,
zmq_replay=True,
standalone_indexer=True, standalone_indexer=True,
model_name=MODEL_NAME, model_name=MODEL_NAME,
) as mockers: ) as mockers:
...@@ -884,6 +915,7 @@ def test_indexers_sync( ...@@ -884,6 +915,7 @@ def test_indexers_sync(
durable_kv_events=durable_kv_events, durable_kv_events=durable_kv_events,
standalone_indexer_url=mockers.standalone_indexer_url, standalone_indexer_url=mockers.standalone_indexer_url,
standalone_indexer_b_url=mockers.standalone_indexer_b_url, standalone_indexer_b_url=mockers.standalone_indexer_b_url,
test_zmq_replay=True,
) )
logger.info("Indexers sync test completed successfully") logger.info("Indexers sync test completed successfully")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment