"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "4c1bc4ee61b550c10436b82a7c5309efeb30fbd9"
Unverified Commit 94aa2a7b authored by Richard Huo's avatar Richard Huo Committed by GitHub
Browse files

refactor: kvbm modularity DIS-657 Eliminate ETCD from the leader-worker initialization (#3202)


Signed-off-by: default avatarrichardhuo-nv <rihuo@nvidia.com>
parent 5d90e530
...@@ -2136,6 +2136,7 @@ dependencies = [ ...@@ -2136,6 +2136,7 @@ dependencies = [
"async_zmq", "async_zmq",
"axum 0.8.4", "axum 0.8.4",
"axum-server", "axum-server",
"bincode",
"bitflags 2.9.4", "bitflags 2.9.4",
"blake3", "blake3",
"bs62", "bs62",
......
...@@ -54,10 +54,6 @@ spec: ...@@ -54,10 +54,6 @@ spec:
envs: envs:
- name: DYN_KVBM_CPU_CACHE_GB - name: DYN_KVBM_CPU_CACHE_GB
value: "100" value: "100"
- name: DYN_KVBM_BARRIER_ID_PREFIX
valueFrom:
fieldRef:
fieldPath: metadata.name
extraPodSpec: extraPodSpec:
mainContainer: mainContainer:
image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:my-tag image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:my-tag
......
...@@ -58,10 +58,6 @@ spec: ...@@ -58,10 +58,6 @@ spec:
envs: envs:
- name: DYN_KVBM_CPU_CACHE_GB - name: DYN_KVBM_CPU_CACHE_GB
value: "100" value: "100"
- name: DYN_KVBM_BARRIER_ID_PREFIX
valueFrom:
fieldRef:
fieldPath: metadata.name
extraPodSpec: extraPodSpec:
mainContainer: mainContainer:
image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:my-tag image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:my-tag
......
...@@ -15,7 +15,6 @@ CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --connecto ...@@ -15,7 +15,6 @@ CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --connecto
# run prefill workers on GPU 2 and 3 with KVBM enabled using 20GB of CPU cache # run prefill workers on GPU 2 and 3 with KVBM enabled using 20GB of CPU cache
# NOTE: use different barrier id prefixes for each prefill worker to avoid conflicts # NOTE: use different barrier id prefixes for each prefill worker to avoid conflicts
# NOTE: remove --enforce-eager for production use # NOTE: remove --enforce-eager for production use
DYN_KVBM_BARRIER_ID_PREFIX=kvbm_0 \
DYN_KVBM_CPU_CACHE_GB=20 \ DYN_KVBM_CPU_CACHE_GB=20 \
CUDA_VISIBLE_DEVICES=2 \ CUDA_VISIBLE_DEVICES=2 \
python3 -m dynamo.vllm \ python3 -m dynamo.vllm \
...@@ -24,7 +23,8 @@ CUDA_VISIBLE_DEVICES=2 \ ...@@ -24,7 +23,8 @@ CUDA_VISIBLE_DEVICES=2 \
--connector kvbm nixl \ --connector kvbm nixl \
--enforce-eager & --enforce-eager &
DYN_KVBM_BARRIER_ID_PREFIX=kvbm_1 \ DYN_KVBM_LEADER_ZMQ_PUB_PORT=56003 \
DYN_KVBM_LEADER_ZMQ_ACK_PORT=56004 \
DYN_KVBM_CPU_CACHE_GB=20 \ DYN_KVBM_CPU_CACHE_GB=20 \
CUDA_VISIBLE_DEVICES=3 \ CUDA_VISIBLE_DEVICES=3 \
python3 -m dynamo.vllm \ python3 -m dynamo.vllm \
......
...@@ -1466,6 +1466,7 @@ dependencies = [ ...@@ -1466,6 +1466,7 @@ dependencies = [
"async_zmq", "async_zmq",
"axum", "axum",
"axum-server", "axum-server",
"bincode",
"bitflags 2.9.3", "bitflags 2.9.3",
"blake3", "blake3",
"bs62", "bs62",
......
...@@ -8,5 +8,5 @@ mod utils; ...@@ -8,5 +8,5 @@ mod utils;
mod worker; mod worker;
pub use leader::KvbmLeader; pub use leader::KvbmLeader;
pub use utils::get_barrier_id_prefix; pub use utils::{get_leader_zmq_ack_url, get_leader_zmq_pub_url};
pub use worker::{KvbmWorker, PyLayoutType, VllmTensor}; pub use worker::{KvbmWorker, PyLayoutType, VllmTensor};
...@@ -2,12 +2,11 @@ ...@@ -2,12 +2,11 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::*; use super::*;
use utils::get_barrier_id_prefix;
use derive_getters::Dissolve; use derive_getters::Dissolve;
use llm_rs::block_manager::distributed::{ use llm_rs::block_manager::distributed::{
KvbmLeader as KvbmLeaderImpl, KvbmLeaderConfig, KvbmLeaderNumBlocksConfig, KvbmLeader as KvbmLeaderImpl, KvbmLeaderConfig, KvbmLeaderNumBlocksConfig,
}; };
use utils::{get_leader_zmq_ack_url, get_leader_zmq_pub_url};
const CPU_CACHE: &str = "DYN_KVBM_CPU_CACHE_GB"; const CPU_CACHE: &str = "DYN_KVBM_CPU_CACHE_GB";
const CPU_CACHE_OVERRIDE: &str = "DYN_KVBM_CPU_CACHE_OVERRIDE_NUM_BLOCKS"; const CPU_CACHE_OVERRIDE: &str = "DYN_KVBM_CPU_CACHE_OVERRIDE_NUM_BLOCKS";
...@@ -72,17 +71,16 @@ impl KvbmLeader { ...@@ -72,17 +71,16 @@ impl KvbmLeader {
#[new] #[new]
#[pyo3(signature = (world_size, drt))] #[pyo3(signature = (world_size, drt))]
fn new(world_size: usize, drt: DistributedRuntime) -> PyResult<Self> { fn new(world_size: usize, drt: DistributedRuntime) -> PyResult<Self> {
let barrier_id_prefix = get_barrier_id_prefix();
let leader_init_timeout_sec: u64 = let leader_init_timeout_sec: u64 =
get_leader_init_timeout_secs(LEADER_WORKER_INIT_TIMEOUT_SECS); get_leader_init_timeout_secs(LEADER_WORKER_INIT_TIMEOUT_SECS);
let config = KvbmLeaderConfig::builder() let config = KvbmLeaderConfig::builder()
.barrier_id_prefix(barrier_id_prefix)
.world_size(world_size) .world_size(world_size)
.leader_init_timeout_secs(leader_init_timeout_sec) .leader_init_timeout_secs(leader_init_timeout_sec)
.drt(drt.inner().clone())
.host_blocks_config(get_blocks_config(CPU_CACHE, CPU_CACHE_OVERRIDE)) .host_blocks_config(get_blocks_config(CPU_CACHE, CPU_CACHE_OVERRIDE))
.disk_blocks_config(get_blocks_config(DISK_CACHE, DISK_CACHE_OVERRIDE)) .disk_blocks_config(get_blocks_config(DISK_CACHE, DISK_CACHE_OVERRIDE))
.leader_pub_url(get_leader_zmq_pub_url())
.leader_ack_url(get_leader_zmq_ack_url())
.build() .build()
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::env;
pub fn get_barrier_id_prefix() -> String { const DEFAULT_LEADER_ZMQ_HOST: &str = "127.0.0.1";
std::env::var("DYN_KVBM_BARRIER_ID_PREFIX") const DEFAULT_LEADER_ZMQ_PUB_PORT: u16 = 56001;
const DEFAULT_LEADER_ZMQ_ACK_PORT: u16 = 56002;
fn read_env_trimmed(key: &str) -> Option<String> {
env::var(key)
.ok() .ok()
.filter(|s| !s.trim().is_empty()) .map(|s| s.trim().to_string())
.unwrap_or_else(|| "kvbm".to_string()) .filter(|s| !s.is_empty())
}
fn parse_port_u16(s: &str) -> Option<u16> {
match s.parse::<u32>() {
Ok(v) if (1..=65535).contains(&v) => Some(v as u16),
_ => None,
}
}
fn validated_port_from_env(key: &str, default_port: u16) -> u16 {
if let Some(val) = read_env_trimmed(key) {
if let Some(p) = parse_port_u16(&val) {
if p < 1024 {
tracing::warn!("{key} is a privileged port ({p}); binding may require extra caps");
}
return p;
} else {
tracing::warn!("{key} invalid value '{val}', falling back to default {default_port}");
}
}
default_port
}
fn get_leader_zmq_host() -> String {
read_env_trimmed("DYN_KVBM_LEADER_ZMQ_HOST")
.unwrap_or_else(|| DEFAULT_LEADER_ZMQ_HOST.to_string())
}
fn get_leader_zmq_pub_port() -> String {
validated_port_from_env("DYN_KVBM_LEADER_ZMQ_PUB_PORT", DEFAULT_LEADER_ZMQ_PUB_PORT).to_string()
}
fn get_leader_zmq_ack_port() -> String {
validated_port_from_env("DYN_KVBM_LEADER_ZMQ_ACK_PORT", DEFAULT_LEADER_ZMQ_ACK_PORT).to_string()
}
pub fn get_leader_zmq_pub_url() -> String {
format!(
"tcp://{}:{}",
get_leader_zmq_host(),
get_leader_zmq_pub_port()
)
}
pub fn get_leader_zmq_ack_url() -> String {
format!(
"tcp://{}:{}",
get_leader_zmq_host(),
get_leader_zmq_ack_port()
)
} }
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use utils::{get_leader_zmq_ack_url, get_leader_zmq_pub_url};
use super::*; use super::*;
use std::sync::Arc; use std::sync::Arc;
use utils::get_barrier_id_prefix;
use llm_rs::block_manager::distributed::{ use llm_rs::block_manager::distributed::{
BlockTransferHandler as RustBlockTransferHandler, KvbmWorker as KvbmWorkerImpl, BlockTransferHandler as RustBlockTransferHandler, KvbmWorker as KvbmWorkerImpl,
...@@ -171,8 +172,6 @@ impl KvbmWorker { ...@@ -171,8 +172,6 @@ impl KvbmWorker {
vllm_tensors.push(Arc::new(vllm_tensor)); vllm_tensors.push(Arc::new(vllm_tensor));
} }
let barrier_id_prefix = get_barrier_id_prefix();
let config = KvbmWorkerConfig::builder() let config = KvbmWorkerConfig::builder()
.drt(drt) .drt(drt)
.num_device_blocks(num_device_blocks) .num_device_blocks(num_device_blocks)
...@@ -180,7 +179,6 @@ impl KvbmWorker { ...@@ -180,7 +179,6 @@ impl KvbmWorker {
.tensors(vllm_tensors) .tensors(vllm_tensors)
.device_id(device_id) .device_id(device_id)
.dtype_width_bytes(dtype_width_bytes) .dtype_width_bytes(dtype_width_bytes)
.barrier_id_prefix(barrier_id_prefix)
.device_layout_type( .device_layout_type(
device_layout_type device_layout_type
.map(|py_layout| py_layout.into()) .map(|py_layout| py_layout.into())
...@@ -196,6 +194,8 @@ impl KvbmWorker { ...@@ -196,6 +194,8 @@ impl KvbmWorker {
.map(|py_layout| py_layout.into()) .map(|py_layout| py_layout.into())
.unwrap_or(LayoutType::FullyContiguous), .unwrap_or(LayoutType::FullyContiguous),
) )
.leader_pub_url(get_leader_zmq_pub_url())
.leader_ack_url(get_leader_zmq_ack_url())
.build() .build()
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
......
...@@ -150,9 +150,6 @@ impl KvConnectorLeader { ...@@ -150,9 +150,6 @@ impl KvConnectorLeader {
let _ = slot_manager_cell.set(sm); let _ = slot_manager_cell.set(sm);
// another barrier sync to make sure worker init won't return before leader is ready
let _ = leader.run_leader_readiness_barrier_blocking(drt);
if leader_ready_tx.send("finished".to_string()).is_err() { if leader_ready_tx.send("finished".to_string()).is_err() {
tracing::error!("main routine receiver dropped before result was sent"); tracing::error!("main routine receiver dropped before result was sent");
} }
......
...@@ -166,9 +166,6 @@ impl KvConnectorLeaderRecorder { ...@@ -166,9 +166,6 @@ impl KvConnectorLeaderRecorder {
let _ = slot_manager_cell.set(sm); let _ = slot_manager_cell.set(sm);
// another barrier sync to make sure worker init won't return before leader is ready
leader.spawn_leader_readiness_barrier(drt);
if leader_ready_tx.send("finished".to_string()).is_err() { if leader_ready_tx.send("finished".to_string()).is_err() {
tracing::error!("main routine receiver dropped before result was sent"); tracing::error!("main routine receiver dropped before result was sent");
} }
......
...@@ -126,9 +126,6 @@ impl KvConnectorLeader { ...@@ -126,9 +126,6 @@ impl KvConnectorLeader {
let _ = slot_manager_cell.set(sm); let _ = slot_manager_cell.set(sm);
// another barrier sync to make sure worker init won't return before leader is ready
leader.spawn_leader_readiness_barrier(drt);
tracing::info!("KvConnectorLeader init complete."); tracing::info!("KvConnectorLeader init complete.");
}); });
} }
......
...@@ -10,7 +10,7 @@ use std::collections::HashSet; ...@@ -10,7 +10,7 @@ use std::collections::HashSet;
use std::sync::{Arc, OnceLock}; use std::sync::{Arc, OnceLock};
use super::*; use super::*;
use crate::llm::block_manager::distributed::get_barrier_id_prefix; use crate::llm::block_manager::distributed::{get_leader_zmq_ack_url, get_leader_zmq_pub_url};
use crate::llm::block_manager::vllm::connector::worker::event_sync_blocking; use crate::llm::block_manager::vllm::connector::worker::event_sync_blocking;
use crate::{ use crate::{
DistributedRuntime as PyDistributedRuntime, llm::block_manager::distributed::VllmTensor, DistributedRuntime as PyDistributedRuntime, llm::block_manager::distributed::VllmTensor,
...@@ -138,7 +138,8 @@ impl Worker for KvConnectorWorker { ...@@ -138,7 +138,8 @@ impl Worker for KvConnectorWorker {
.device_layout_type(LayoutType::FullyContiguous) .device_layout_type(LayoutType::FullyContiguous)
.host_layout_type(LayoutType::FullyContiguous) .host_layout_type(LayoutType::FullyContiguous)
.disk_layout_type(LayoutType::FullyContiguous) .disk_layout_type(LayoutType::FullyContiguous)
.barrier_id_prefix(get_barrier_id_prefix()) .leader_pub_url(get_leader_zmq_pub_url())
.leader_ack_url(get_leader_zmq_ack_url())
.scheduler_client(Some(self.transfer_client.clone())) .scheduler_client(Some(self.transfer_client.clone()))
.build()?; .build()?;
......
...@@ -10,7 +10,7 @@ use std::collections::HashSet; ...@@ -10,7 +10,7 @@ use std::collections::HashSet;
use std::sync::{Arc, OnceLock}; use std::sync::{Arc, OnceLock};
use super::*; use super::*;
use crate::llm::block_manager::distributed::get_barrier_id_prefix; use crate::llm::block_manager::distributed::{get_leader_zmq_ack_url, get_leader_zmq_pub_url};
use crate::{ use crate::{
DistributedRuntime as PyDistributedRuntime, llm::block_manager::distributed::VllmTensor, DistributedRuntime as PyDistributedRuntime, llm::block_manager::distributed::VllmTensor,
to_pyerr, to_pyerr,
...@@ -200,7 +200,8 @@ impl Worker for KvConnectorWorker { ...@@ -200,7 +200,8 @@ impl Worker for KvConnectorWorker {
.tensors(vllm_tensors) .tensors(vllm_tensors)
.device_id(device_id) .device_id(device_id)
.dtype_width_bytes(dtype_width_bytes) .dtype_width_bytes(dtype_width_bytes)
.barrier_id_prefix(get_barrier_id_prefix()) .leader_pub_url(get_leader_zmq_pub_url())
.leader_ack_url(get_leader_zmq_ack_url())
.scheduler_client(Some(self.transfer_client.clone())) .scheduler_client(Some(self.transfer_client.clone()))
.device_layout_type(detected_device_layout_type) .device_layout_type(detected_device_layout_type)
.host_layout_type(host_layout_type.unwrap_or(LayoutType::FullyContiguous)) .host_layout_type(host_layout_type.unwrap_or(LayoutType::FullyContiguous))
......
...@@ -85,6 +85,7 @@ offset-allocator = "0.2" ...@@ -85,6 +85,7 @@ offset-allocator = "0.2"
regex = "1" regex = "1"
rayon = "1" rayon = "1"
dashmap = { version = "5.5.3" } dashmap = { version = "5.5.3" }
bincode = "1"
# input/text # input/text
dialoguer = { version = "0.11", default-features = false, features = [ dialoguer = { version = "0.11", default-features = false, features = [
......
...@@ -123,14 +123,12 @@ mod tests { ...@@ -123,14 +123,12 @@ mod tests {
async fn build_leader_and_workers(num_workers: usize) -> Result<(KvbmLeader, Vec<KvbmWorker>)> { async fn build_leader_and_workers(num_workers: usize) -> Result<(KvbmLeader, Vec<KvbmWorker>)> {
let mut workers = Vec::new(); let mut workers = Vec::new();
let barrier_id = get_unique_barrier_id();
for i in 0..num_workers { for i in 0..num_workers {
let tensors: Vec<Arc<dyn TorchTensor>> = let tensors: Vec<Arc<dyn TorchTensor>> =
vec![Arc::new(MockTensor::new(vec![2, NUM_BLOCKS, 4096]))]; vec![Arc::new(MockTensor::new(vec![2, NUM_BLOCKS, 4096]))];
let config = KvbmWorkerConfig::builder() let config = KvbmWorkerConfig::builder()
.barrier_id_prefix(barrier_id.clone())
.num_device_blocks(NUM_BLOCKS) .num_device_blocks(NUM_BLOCKS)
.tensors(tensors) .tensors(tensors)
.device_id(i) .device_id(i)
...@@ -151,7 +149,6 @@ mod tests { ...@@ -151,7 +149,6 @@ mod tests {
}; };
let leader_config = KvbmLeaderConfig::builder() let leader_config = KvbmLeaderConfig::builder()
.barrier_id_prefix(barrier_id)
.world_size(num_workers) .world_size(num_workers)
.host_blocks_config(host_blocks) .host_blocks_config(host_blocks)
.disk_blocks_config(disk_blocks) .disk_blocks_config(disk_blocks)
......
...@@ -3,15 +3,10 @@ ...@@ -3,15 +3,10 @@
use super::*; use super::*;
use dynamo_runtime::DistributedRuntime;
use utils::*; use utils::*;
use zmq::*; use zmq::*;
use dynamo_runtime::utils::leader_worker_barrier::LeaderBarrier;
use anyhow::{Context, anyhow};
use derive_builder::Builder; use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::time::Duration; use std::time::Duration;
...@@ -20,15 +15,6 @@ use tokio::sync::OnceCell; ...@@ -20,15 +15,6 @@ use tokio::sync::OnceCell;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tokio::time::sleep; use tokio::time::sleep;
/// Data that is sent to workers over ETCD to establish a ZMQ connection.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KvbmLeaderData {
pub pub_url: String,
pub ack_url: String,
pub num_host_blocks: usize,
pub num_disk_blocks: usize,
}
#[derive(Builder, Clone, Debug, Default)] #[derive(Builder, Clone, Debug, Default)]
pub struct KvbmLeaderNumBlocksConfig { pub struct KvbmLeaderNumBlocksConfig {
#[builder(default = "0.0")] #[builder(default = "0.0")]
...@@ -51,10 +37,6 @@ fn compute_num_blocks( ...@@ -51,10 +37,6 @@ fn compute_num_blocks(
#[derive(Builder, Clone, Debug)] #[derive(Builder, Clone, Debug)]
pub struct KvbmLeaderConfig { pub struct KvbmLeaderConfig {
/// The barrier id to use for syncing with workers.
#[builder(default = "String::from(\"kvbm\")")]
barrier_id_prefix: String,
/// The world size. /// The world size.
#[builder(default = "1")] #[builder(default = "1")]
world_size: usize, world_size: usize,
...@@ -63,14 +45,17 @@ pub struct KvbmLeaderConfig { ...@@ -63,14 +45,17 @@ pub struct KvbmLeaderConfig {
#[builder(default = "120")] #[builder(default = "120")]
leader_init_timeout_secs: u64, leader_init_timeout_secs: u64,
#[builder(setter(strip_option))]
drt: Option<DistributedRuntime>,
#[builder(default = "KvbmLeaderNumBlocksConfig::default()")] #[builder(default = "KvbmLeaderNumBlocksConfig::default()")]
host_blocks_config: KvbmLeaderNumBlocksConfig, host_blocks_config: KvbmLeaderNumBlocksConfig,
#[builder(default = "KvbmLeaderNumBlocksConfig::default()")] #[builder(default = "KvbmLeaderNumBlocksConfig::default()")]
disk_blocks_config: KvbmLeaderNumBlocksConfig, disk_blocks_config: KvbmLeaderNumBlocksConfig,
#[builder(default = "String::from(\"tcp://127.0.0.1:56001\")")]
leader_pub_url: String,
#[builder(default = "String::from(\"tcp://127.0.0.1:56002\")")]
leader_ack_url: String,
} }
impl KvbmLeaderConfig { impl KvbmLeaderConfig {
...@@ -79,6 +64,11 @@ impl KvbmLeaderConfig { ...@@ -79,6 +64,11 @@ impl KvbmLeaderConfig {
} }
pub fn sanity_check(&self) -> anyhow::Result<()> { pub fn sanity_check(&self) -> anyhow::Result<()> {
if self.leader_pub_url == self.leader_ack_url {
anyhow::bail!(
"leader_pub_url and leader_ack_url must differ (same endpoint would fail to bind)."
);
}
let cpu = &self.host_blocks_config; let cpu = &self.host_blocks_config;
let disk = &self.disk_blocks_config; let disk = &self.disk_blocks_config;
let cpu_configured = cpu.num_blocks_overriden > 0 || cpu.cache_size_in_gb > 0.0; let cpu_configured = cpu.num_blocks_overriden > 0 || cpu.cache_size_in_gb > 0.0;
...@@ -121,166 +111,24 @@ pub struct KvbmLeader { ...@@ -121,166 +111,24 @@ pub struct KvbmLeader {
state: Arc<KvbmLeaderState>, state: Arc<KvbmLeaderState>,
zmq_leader: Arc<OnceCell<ZmqActiveMessageLeader>>, zmq_leader: Arc<OnceCell<ZmqActiveMessageLeader>>,
config: KvbmLeaderConfig, config: KvbmLeaderConfig,
//readiness flags
workers_sync_ready: Arc<AtomicBool>,
workers_sync_ready_notify: Arc<Notify>,
workers_sync_done: Arc<AtomicBool>,
} }
impl KvbmLeader { impl KvbmLeader {
pub async fn new(mut config: KvbmLeaderConfig) -> anyhow::Result<Self> { pub async fn new(config: KvbmLeaderConfig) -> anyhow::Result<Self> {
let drt = match config.drt.take() { let leader_sockets = new_leader_sockets(&config.leader_pub_url, &config.leader_ack_url)?;
Some(dtr) => dtr,
None => {
anyhow::bail!("No distributed runtime provided");
}
};
let leader_sockets = new_leader_sockets("tcp://127.0.0.1")?;
let leader = Self { let leader = Self {
state: Arc::new(KvbmLeaderState::default()), state: Arc::new(KvbmLeaderState::default()),
zmq_leader: Arc::new(tokio::sync::OnceCell::new()), zmq_leader: Arc::new(tokio::sync::OnceCell::new()),
config, config,
workers_sync_ready: Arc::new(AtomicBool::new(false)),
workers_sync_ready_notify: Arc::new(Notify::new()),
workers_sync_done: Arc::new(AtomicBool::new(false)),
}; };
let cancel_token = tokio_util::sync::CancellationToken::new(); let cancel_token = tokio_util::sync::CancellationToken::new();
// The leader_sockets struct cannot be cloned,
// so we use a tuple to "struct" the two urls
let leader_urls = (
leader_sockets.pub_url.clone(),
leader_sockets.ack_url.clone(),
);
leader.spawn_barrier_task(drt, leader_urls);
leader.spawn_zmq_task(leader_sockets, cancel_token); leader.spawn_zmq_task(leader_sockets, cancel_token);
Ok(leader) Ok(leader)
} }
fn spawn_barrier_task(&self, drt: DistributedRuntime, leader_urls: (String, String)) {
let state = self.state.clone();
let leader_config = self.config.clone();
let ready = Arc::clone(&self.workers_sync_ready);
let notify = Arc::clone(&self.workers_sync_ready_notify);
let done = Arc::clone(&self.workers_sync_done);
tokio::spawn(async move {
match KvbmLeader::run_barrier_sync(drt, leader_urls, leader_config).await {
Ok((num_device_blocks, num_host_blocks, num_disk_blocks)) => {
// write back results
state
.num_device_blocks
.store(num_device_blocks, Ordering::Release);
state
.num_host_blocks
.store(num_host_blocks, Ordering::Release);
state
.num_disk_blocks
.store(num_disk_blocks, Ordering::Release);
ready.store(true, Ordering::Release);
done.store(true, Ordering::Release);
notify.notify_waiters();
}
Err(e) => {
tracing::error!("Barrier sync failed: {e:?}");
done.store(true, Ordering::Release);
notify.notify_waiters();
}
}
});
}
async fn run_barrier_sync(
drt: DistributedRuntime,
leader_urls: (String, String),
leader_config: KvbmLeaderConfig,
) -> anyhow::Result<(usize, usize, usize)> {
let barrier_id_worker_to_leader =
format!("{}{}", leader_config.barrier_id_prefix, "-worker-to-leader");
tracing::info!(
"Syncing leader barrier with {} workers on barrier id {}",
leader_config.world_size,
barrier_id_worker_to_leader
);
// Build our leader barrier and publish the data.
// TODO: Use a separate timeout parameter from the ZMQ connection timeout
let worker_to_leader_barrier: LeaderBarrier<(), worker::KvbmWorkerData> =
LeaderBarrier::new(
barrier_id_worker_to_leader.clone(),
leader_config.world_size,
Some(Duration::from_secs(leader_config.leader_init_timeout_secs)),
);
let worker_data = worker_to_leader_barrier
.sync(&drt, &())
.await
.map_err(|e| anyhow::anyhow!("Failed to sync worker to leader barrier: {:?}", e))?;
let num_device_blocks = worker_data
.values()
.map(|data| data.num_device_blocks)
.min()
.unwrap();
// TODO: this works for TP, need to redefine bytes_per_block when we enable the DP/PP
let bytes_per_block: usize = worker_data.values().map(|d| d.bytes_per_block).sum();
assert!(
bytes_per_block > 0,
"bytes_per_block must be greater than 0"
);
tracing::info!(
"Worker to leader barrier synced with {} workers",
leader_config.world_size
);
tracing::debug!("Worker data: {:?}", worker_data);
let num_host_blocks =
compute_num_blocks(&leader_config.host_blocks_config, bytes_per_block);
let num_disk_blocks =
compute_num_blocks(&leader_config.disk_blocks_config, bytes_per_block);
// Start the second sync to transfer num_host_blocks and num_disk_blocks to worker
let barrier_id_leader_to_worker =
format!("{}{}", leader_config.barrier_id_prefix, "-leader-to-worker");
tracing::info!(
"Syncing leader barrier with {} workers on barrier id {}",
leader_config.world_size,
barrier_id_leader_to_worker
);
let (leader_pub_url, leader_ack_url) = leader_urls;
let zmq_data_leader_to_worker = Arc::new(KvbmLeaderData {
pub_url: leader_pub_url,
ack_url: leader_ack_url,
num_host_blocks,
num_disk_blocks,
});
let leader_to_worker_barrier: LeaderBarrier<KvbmLeaderData, ()> = LeaderBarrier::new(
barrier_id_leader_to_worker.clone(),
leader_config.world_size,
Some(Duration::from_secs(leader_config.leader_init_timeout_secs)),
);
let _worker_data = leader_to_worker_barrier
.sync(&drt, zmq_data_leader_to_worker.as_ref())
.await
.map_err(|e| anyhow::anyhow!("Failed to sync leader to worker barrier: {:?}", e))?;
tracing::info!(
"Worker to leader barrier synced with {} workers",
leader_config.world_size
);
Ok((num_device_blocks, num_host_blocks, num_disk_blocks))
}
fn spawn_zmq_task( fn spawn_zmq_task(
&self, &self,
leader_sockets: LeaderSockets, leader_sockets: LeaderSockets,
...@@ -290,133 +138,59 @@ impl KvbmLeader { ...@@ -290,133 +138,59 @@ impl KvbmLeader {
let state = self.state.clone(); let state = self.state.clone();
let world_size = self.config.world_size; let world_size = self.config.world_size;
let timeout = self.config.leader_init_timeout_secs; let timeout = self.config.leader_init_timeout_secs;
let host_cfg = self.config.host_blocks_config.clone();
let disk_cfg = self.config.disk_blocks_config.clone();
// capture num_device_blocks so we can set it inside the closure
let num_device_blocks_cell = state.num_device_blocks.clone();
let num_host_blocks_cell = state.num_host_blocks.clone();
let num_disk_blocks_cell = state.num_disk_blocks.clone();
tokio::spawn(async move { tokio::spawn(async move {
let res = ZmqActiveMessageLeader::new( let res = ZmqActiveMessageLeader::new_with_handshake(
leader_sockets, leader_sockets,
world_size, world_size,
std::time::Duration::from_secs(timeout), std::time::Duration::from_secs(timeout),
cancel, cancel.clone(),
move |workers: &[WorkerMetadata]| -> LeaderMetadata {
// Record device blocks: min across workers
if let Some(min_dev) = workers.iter().map(|w| w.num_device_blocks).min() {
num_device_blocks_cell.store(min_dev, Ordering::Release);
}
// For TP, sum bytes_per_block; adjust policy for DP/PP if needed.
let bytes_per_block: usize = workers.iter().map(|w| w.bytes_per_block).sum();
let num_host_blocks = compute_num_blocks(&host_cfg, bytes_per_block);
let num_disk_blocks = compute_num_blocks(&disk_cfg, bytes_per_block);
// store into leader state
num_host_blocks_cell.store(num_host_blocks, Ordering::Release);
num_disk_blocks_cell.store(num_disk_blocks, Ordering::Release);
LeaderMetadata {
num_host_blocks,
num_disk_blocks,
}
},
) )
.await; .await;
match res { match res {
Ok(zmq) => { Ok(zmq) => {
let _ = cell.set(zmq); let _ = cell.set(zmq);
// mark ready
state state
.workers_allocation_ready .workers_allocation_ready
.store(true, Ordering::Release); .store(true, Ordering::Release);
state.workers_ready_notify.notify_waiters(); state.workers_ready_notify.notify_waiters();
tracing::info!("ZMQ handshake complete; workers allocation ready");
} }
Err(e) => { Err(e) => {
tracing::error!("ZMQ init failed: {e:?}"); tracing::error!("ZMQ init/handshake failed: {e:?}");
}
}
});
}
// This is supposed to be used in non-blocking leader initialization
pub fn spawn_leader_readiness_barrier(&self, drt: DistributedRuntime) {
let timeout_secs = self.config.leader_init_timeout_secs;
let state = self.state.clone();
let leader_config = self.config.clone();
let handle = drt.runtime().primary();
handle.spawn(async move {
if !state.workers_allocation_ready.load(Ordering::Acquire) {
// Wait until ZMQ marks ready or we time out.
let waited = tokio::time::timeout(
Duration::from_secs(timeout_secs),
state.workers_ready_notify.notified(),
)
.await;
if waited.is_err() {
tracing::error!(
"leader readiness barrier wait timed out after {timeout_secs} seconds"
);
return;
}
// Double-check the flag (Acquire) after wakeup.
if !state.workers_allocation_ready.load(Ordering::Acquire) {
tracing::error!("leader readiness notify fired but flag not set; aborting");
return;
}
}
match KvbmLeader::run_leader_readiness(drt, leader_config).await {
Ok(()) => {
tracing::info!("leader readiness barrier synced!");
}
Err(e) => {
tracing::error!("leader readiness barrier failed: {e:?}");
} }
} }
}); });
} }
// This is supposed to be used in blocking leader initialization
pub fn run_leader_readiness_barrier_blocking(
&self,
drt: DistributedRuntime,
) -> anyhow::Result<()> {
let state = self.state.clone();
let timeout_secs = self.config.leader_init_timeout_secs;
let leader_config = self.config.clone();
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current()
.block_on(async move {
// Create the future *before* checking the flag to avoid a lost-notify race.
let notified = state.workers_ready_notify.notified();
if !state.workers_allocation_ready.load(Ordering::Acquire) {
// Wait (with timeout) until ZMQ task marks ready.
tokio::time::timeout(Duration::from_secs(timeout_secs), notified)
.await
.map_err(|_| anyhow!("timed out waiting for workers_allocation_ready after {timeout_secs} seconds"))?;
// Double-check after wake to ensure the flag is actually set.
if !state.workers_allocation_ready.load(Ordering::Acquire) {
return Err(anyhow!(
"notified but workers_allocation_ready is still false"
));
}
}
KvbmLeader::run_leader_readiness(drt, leader_config).await
})
.context("leader readiness barrier failed")
})
}
async fn run_leader_readiness(
drt: DistributedRuntime,
leader_config: KvbmLeaderConfig,
) -> anyhow::Result<()> {
let barrier_id_leader_ready =
format!("{}{}", leader_config.barrier_id_prefix, "-leader-ready");
tracing::info!(
"Syncing leader readiness barrier with {} workers on barrier id {}",
leader_config.world_size,
barrier_id_leader_ready
);
let leader_readiness_barrier: LeaderBarrier<(), ()> = LeaderBarrier::new(
barrier_id_leader_ready.clone(),
leader_config.world_size,
Some(Duration::from_secs(leader_config.leader_init_timeout_secs)),
);
let _ = leader_readiness_barrier
.sync(&drt, &())
.await
.map_err(|e| {
anyhow::anyhow!("Failed to sync leader readiness barrier on leader: {:?}", e)
})?;
Ok(())
}
pub async fn transfer_blocks_request( pub async fn transfer_blocks_request(
&self, &self,
request: BlockTransferRequest, request: BlockTransferRequest,
...@@ -429,14 +203,6 @@ impl KvbmLeader { ...@@ -429,14 +203,6 @@ impl KvbmLeader {
zmq.broadcast(ZMQ_TRANSFER_BLOCKS_MESSAGE, data).await zmq.broadcast(ZMQ_TRANSFER_BLOCKS_MESSAGE, data).await
} }
pub fn is_worker_sync_ready(&self) -> bool {
self.workers_sync_ready.load(Ordering::Acquire)
}
pub fn is_worker_sync_done(&self) -> bool {
self.workers_sync_done.load(Ordering::Acquire)
}
pub fn num_device_blocks(&self) -> usize { pub fn num_device_blocks(&self) -> usize {
self.state.num_device_blocks.load(Ordering::Acquire) self.state.num_device_blocks.load(Ordering::Acquire)
} }
...@@ -450,26 +216,12 @@ impl KvbmLeader { ...@@ -450,26 +216,12 @@ impl KvbmLeader {
} }
pub async fn wait_worker_sync_ready(&self) -> bool { pub async fn wait_worker_sync_ready(&self) -> bool {
if self.is_worker_sync_ready() { if self.state.workers_allocation_ready.load(Ordering::Acquire) {
return true;
}
if self.is_worker_sync_done() {
return false;
}
let notified = self.workers_sync_ready_notify.notified();
if self.is_worker_sync_ready() {
return true; return true;
} }
if self.is_worker_sync_done() { let notified = self.state.workers_ready_notify.notified();
return false;
}
// bounded wait
tokio::select! { tokio::select! {
_ = notified => { _ = notified => true,
self.is_worker_sync_ready()
}
_ = sleep(Duration::from_secs(self.config.leader_init_timeout_secs)) => false, _ = sleep(Duration::from_secs(self.config.leader_init_timeout_secs)) => false,
} }
} }
......
...@@ -7,8 +7,22 @@ use serde::{Deserialize, Serialize}; ...@@ -7,8 +7,22 @@ use serde::{Deserialize, Serialize};
use crate::block_manager::connector::protocol::LeaderTransferRequest; use crate::block_manager::connector::protocol::LeaderTransferRequest;
pub const ZMQ_PING_MESSAGE: &str = "ping"; pub const ZMQ_PING_MESSAGE: &str = "ping";
pub const ZMQ_WORKER_METADATA_MESSAGE: &str = "worker_metadata";
pub const ZMQ_LEADER_METADATA_MESSAGE: &str = "leader_metadata";
pub const ZMQ_TRANSFER_BLOCKS_MESSAGE: &str = "transfer_blocks"; pub const ZMQ_TRANSFER_BLOCKS_MESSAGE: &str = "transfer_blocks";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerMetadata {
pub num_device_blocks: usize,
pub bytes_per_block: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LeaderMetadata {
pub num_host_blocks: usize,
pub num_disk_blocks: usize,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Copy)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Copy)]
pub enum BlockTransferPool { pub enum BlockTransferPool {
Device, Device,
......
...@@ -22,11 +22,17 @@ use tmq::{ ...@@ -22,11 +22,17 @@ use tmq::{
use tokio::sync::{Mutex, oneshot}; use tokio::sync::{Mutex, oneshot};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use bincode;
use futures_util::{SinkExt, StreamExt}; use futures_util::{SinkExt, StreamExt};
use std::cmp::min;
struct PendingMessage { struct PendingMessage {
remaining_workers: usize, remaining_workers: usize,
completion_indicator: oneshot::Sender<()>, completion_indicator: Option<oneshot::Sender<()>>,
// If true, collect one payload (bytes) from each worker reply.
want_payload: bool,
// Collected raw payloads (one per worker), if want_payload == true
payloads: Option<Vec<Vec<u8>>>,
} }
pub struct LeaderSockets { pub struct LeaderSockets {
...@@ -36,18 +42,16 @@ pub struct LeaderSockets { ...@@ -36,18 +42,16 @@ pub struct LeaderSockets {
pub ack_url: String, pub ack_url: String,
} }
pub fn new_leader_sockets(url: &str) -> Result<LeaderSockets> { pub fn new_leader_sockets(pub_url: &str, ack_url: &str) -> Result<LeaderSockets> {
let url = format!("{}:0", url);
let context = Context::new(); let context = Context::new();
let pub_socket = publish(&context).bind(url.as_str())?; let pub_socket = publish(&context).bind(pub_url)?;
let pub_url = pub_socket let pub_url = pub_socket
.get_socket() .get_socket()
.get_last_endpoint() .get_last_endpoint()
.unwrap() .unwrap()
.unwrap(); .unwrap();
let ack_socket = pull(&context).bind(url.as_str())?; let ack_socket = pull(&context).bind(ack_url)?;
let ack_url = ack_socket let ack_url = ack_socket
.get_socket() .get_socket()
.get_last_endpoint() .get_last_endpoint()
...@@ -78,12 +82,18 @@ pub struct ZmqActiveMessageLeader { ...@@ -78,12 +82,18 @@ pub struct ZmqActiveMessageLeader {
} }
impl ZmqActiveMessageLeader { impl ZmqActiveMessageLeader {
pub async fn new( /// Handshake-first constructor: collects WorkerMetaData, broadcasts LeaderMetadata,
/// waits for allocation ACKs, then runs the final ping loop.
pub async fn new_with_handshake<F>(
leader_sockets: LeaderSockets, leader_sockets: LeaderSockets,
num_workers: usize, num_workers: usize,
timeout: Duration, overall_timeout: Duration,
cancel_token: CancellationToken, cancel_token: CancellationToken,
) -> Result<Self> { make_leader_meta: F,
) -> Result<Self>
where
F: Fn(&[WorkerMetadata]) -> LeaderMetadata + Send + Sync + 'static,
{
let pub_socket = Arc::new(Mutex::new(leader_sockets.pub_socket)); let pub_socket = Arc::new(Mutex::new(leader_sockets.pub_socket));
let pull_socket = leader_sockets.ack_socket; let pull_socket = leader_sockets.ack_socket;
...@@ -94,48 +104,128 @@ impl ZmqActiveMessageLeader { ...@@ -94,48 +104,128 @@ impl ZmqActiveMessageLeader {
); );
let pending_messages = Arc::new(Mutex::new(HashMap::new())); let pending_messages = Arc::new(Mutex::new(HashMap::new()));
let pending_messages_clone = pending_messages.clone(); let pending_messages_clone = pending_messages.clone();
CriticalTaskExecutionHandle::new( CriticalTaskExecutionHandle::new(
|cancel_token| Self::pull_worker(pull_socket, pending_messages_clone, cancel_token), |ct| Self::pull_worker(pull_socket, pending_messages_clone, ct),
cancel_token, cancel_token.clone(),
"ZmqActiveMessageLeader: Pull worker", "ZmqActiveMessageLeader: Pull worker",
)? )?
.detach(); .detach();
let self_ = Self { let this = Self {
pub_socket, pub_socket,
message_id: Arc::new(Mutex::new(0)), message_id: Arc::new(Mutex::new(0)),
pending_messages, pending_messages,
num_workers: Arc::new(num_workers), num_workers: Arc::new(num_workers),
}; };
// Ping our workers. let deadline = Instant::now() + overall_timeout;
let start = Instant::now();
loop { // 1) Collect KvbmWorkerData from ALL workers in a single round.
if start.elapsed() > timeout { // Keep rebroadcasting until we get exactly `num_workers` replies to the SAME broadcast.
return Err(anyhow::anyhow!("Timed out waiting for workers.")); let workers_payloads: Vec<Vec<u8>> = loop {
if Instant::now() >= deadline {
return Err(anyhow::anyhow!(
"Handshake timed out (device-config collection)."
));
} }
let remain = deadline.saturating_duration_since(Instant::now());
let round_to = min(Duration::from_secs(2), remain);
tracing::info!("Handshake: requesting worker device configs...");
match this
.broadcast_collect(
ZMQ_WORKER_METADATA_MESSAGE,
&[],
/* want_payload */ true,
round_to,
)
.await
{
Ok(payloads) if payloads.len() == num_workers => {
tracing::info!(
"Handshake: received {} worker metadata replies in this round.",
payloads.len()
);
break payloads;
}
Ok(payloads) => {
tracing::warn!(
"Handshake: got {} / {} worker metadata replies; rebroadcasting...",
payloads.len(),
num_workers
);
continue;
}
Err(e) => {
tracing::debug!(
"Handshake: worker metadata round timed out/failed: {e}; retrying..."
);
continue;
}
}
};
// Try to send a ping to all workers. let workers: Vec<WorkerMetadata> = workers_payloads
tracing::info!("ZmqActiveMessageLeader: Pinging workers..."); .into_iter()
let ping_receiver = self_.broadcast(ZMQ_PING_MESSAGE, vec![]).await?; .map(|b| bincode::deserialize::<WorkerMetadata>(&b))
.collect::<std::result::Result<_, _>>()?;
tokio::select! { // 2) Compute & broadcast LeaderMetadata; wait for ALL acks in the SAME round.
// If we receive an ACK from every worker, we're done. let leader_meta = make_leader_meta(&workers);
_ = ping_receiver => { let leader_meta_bytes = bincode::serialize(&leader_meta)?;
tracing::info!("ZmqActiveMessageLeader: Worker ping successful. Startup complete.");
loop {
if Instant::now() >= deadline {
return Err(anyhow::anyhow!(
"Handshake timed out (allocation-config broadcast)."
));
}
let remain = deadline.saturating_duration_since(Instant::now());
let round_to = min(Duration::from_secs(2), remain);
tracing::info!("Handshake: broadcasting allocation config to workers...");
match this
.broadcast_collect(
ZMQ_LEADER_METADATA_MESSAGE,
std::slice::from_ref(&leader_meta_bytes),
/* want_payload */ false,
round_to,
)
.await
{
Ok(_) => {
// Success: all workers acked in this round.
tracing::info!("Handshake: all workers acked allocation config.");
break; break;
} }
// Wait for 1 second before pinging again. Err(e) => {
_ = tokio::time::sleep(Duration::from_millis(1000)) => { tracing::warn!(
tracing::info!("ZmqActiveMessageLeader: Ping timed out. Retrying..."); "Handshake: allocation-config round incomplete: {e}; rebroadcasting..."
);
continue; continue;
} }
} }
} }
Ok(self_) // 3) Final readiness ping loop (workers only ACK after allocation ready)
let ping_deadline = deadline;
loop {
if Instant::now() >= ping_deadline {
return Err(anyhow::anyhow!(
"Timed out waiting for ping readiness after handshake."
));
}
tracing::info!("Handshake: final readiness ping...");
let ping = this.broadcast(ZMQ_PING_MESSAGE, vec![]).await?;
tokio::select! {
_ = ping => break,
_ = tokio::time::sleep(Duration::from_millis(500)) => continue,
_ = cancel_token.cancelled() => return Err(anyhow::anyhow!("Startup canceled")),
}
}
Ok(this)
} }
/// Broadcast a message to all workers. /// Broadcast a message to all workers.
...@@ -157,7 +247,9 @@ impl ZmqActiveMessageLeader { ...@@ -157,7 +247,9 @@ impl ZmqActiveMessageLeader {
let pending_message = PendingMessage { let pending_message = PendingMessage {
// We start with the number of workers we're waiting for. // We start with the number of workers we're waiting for.
remaining_workers: *self.num_workers, remaining_workers: *self.num_workers,
completion_indicator, completion_indicator: Some(completion_indicator),
want_payload: false,
payloads: None,
}; };
// Add the message to the pending messages map. // Add the message to the pending messages map.
...@@ -187,7 +279,68 @@ impl ZmqActiveMessageLeader { ...@@ -187,7 +279,68 @@ impl ZmqActiveMessageLeader {
Ok(completion_receiver) Ok(completion_receiver)
} }
/// Pull worker is responsible for receiving ACKs from workers. /// Generic broadcast that can collect one reply payload from each worker.
/// - `function`: handler name on workers
/// - `data_frames`: optional extra frames after [id, function]
/// - `want_payload`: if true, expects replies shaped as [id, function, payload]
/// Returns payloads (empty if want_payload == false).
pub async fn broadcast_collect(
&self,
function: &str,
data_frames: &[Vec<u8>],
want_payload: bool,
timeout: Duration,
) -> Result<Vec<Vec<u8>>> {
// Generate a unique id.
let id = {
let mut id = self.message_id.lock().await;
*id += 1;
*id
};
let (completion_indicator, completion_receiver) = oneshot::channel();
let pending_message = PendingMessage {
remaining_workers: *self.num_workers,
completion_indicator: Some(completion_indicator),
want_payload,
payloads: want_payload.then(|| Vec::with_capacity(*self.num_workers)),
};
self.pending_messages
.lock()
.await
.insert(id, pending_message);
// Build message: [id, function, ...data]
let mut message: VecDeque<Message> = VecDeque::with_capacity(2 + data_frames.len());
message.push_back(id.to_be_bytes().as_slice().into());
message.push_back(function.into());
for df in data_frames {
message.push_back(df.clone().into());
}
self.pub_socket
.lock()
.await
.send(Multipart(message))
.await?;
// Await all replies or timeout.
tokio::select! {
_ = completion_receiver => { /* done */ }
_ = tokio::time::sleep(timeout) => {
let mut map = self.pending_messages.lock().await;
map.remove(&id);
return Err(anyhow::anyhow!("Timed out waiting for '{}' responses", function));
}
}
// Extract payloads (if any).
let mut map = self.pending_messages.lock().await;
let entry = map
.remove(&id)
.ok_or_else(|| anyhow::anyhow!("pending entry missing"))?;
Ok(entry.payloads.unwrap_or_default())
}
async fn pull_worker( async fn pull_worker(
mut pull_socket: Pull, mut pull_socket: Pull,
pending_messages: Arc<Mutex<HashMap<usize, PendingMessage>>>, pending_messages: Arc<Mutex<HashMap<usize, PendingMessage>>>,
...@@ -196,65 +349,42 @@ impl ZmqActiveMessageLeader { ...@@ -196,65 +349,42 @@ impl ZmqActiveMessageLeader {
loop { loop {
tokio::select! { tokio::select! {
Some(Ok(message)) = pull_socket.next() => { Some(Ok(message)) = pull_socket.next() => {
// The leader should only ever receive ACKs. if message.is_empty() {
// ACKs have no data. tracing::error!("Leader PULL: empty message");
if message.len() != 1 { continue;
tracing::error!( }
"Received message with unexpected length: {:?}", let arr: [u8; std::mem::size_of::<usize>()] = (*message[0]).try_into()?;
message.len() let id = usize::from_be_bytes(arr);
);
continue; let mut map = pending_messages.lock().await;
}
// TODO: This looks ugly. if let Some(pm) = map.get_mut(&id) {
let arr: [u8; std::mem::size_of::<usize>()] = (*message[0]).try_into()?; // payload reply or pure ACK?
let id = usize::from_be_bytes(arr); if message.len() == 1 {
if pm.remaining_workers > 0 { pm.remaining_workers -= 1; }
let mut pending_messages = pending_messages.lock().await; } else {
// TODO: Should we error if we can't find the pending message? if pm.want_payload && message.len() >= 3
// if let std::collections::hash_map::Entry::Occupied(mut entry) = && let Some(bufs) = pm.payloads.as_mut() {
// pending_messages.entry(id) bufs.push((*message[2]).to_vec());
// {
// entry.get_mut().remaining_workers -= 1;
// tracing::debug!(
// "ZmqActiveMessageLeader: Received ACK for message with id: {}. There are {} remaining workers.",
// id,
// entry.get().remaining_workers
// );
// // If all workers have ACKed, notify the completion indicator.
// if entry.get().remaining_workers == 0 {
// let e = entry.remove();
// tracing::debug!(
// "ZmqActiveMessageLeader: Message with id: {} completed.",
// id
// );
// // It's possible that the receiver has already been dropped,
// // so ignore any send error here.
// let _ = e.completion_indicator.send(());
// }
// }
match pending_messages.entry(id) {
std::collections::hash_map::Entry::Occupied(mut entry) => {
let pending_message = entry.get_mut();
debug_assert!(pending_message.remaining_workers > 0);
pending_message.remaining_workers -= 1;
tracing::debug!(
"ZmqActiveMessageLeader: Received ACK for message with id: {}. There are {} remaining workers.",
id,
pending_message.remaining_workers
);
if pending_message.remaining_workers == 0 {
let e = entry.remove();
tracing::debug!("ZmqActiveMessageLeader: Message with id: {} completed.", id);
let _ = e.completion_indicator.send(());
} }
} if pm.remaining_workers > 0 { pm.remaining_workers -= 1; }
std::collections::hash_map::Entry::Vacant(_) => {
tracing::error!("Received ACK for unknown message with id: {}", id);
}
} }
tracing::debug!(
"Leader PULL: got {} for id {} (remaining={})",
if message.len()==1 { "ACK" } else { "REPLY" }, id, pm.remaining_workers
);
// IMPORTANT: do NOT remove here; just notify completion.
if pm.remaining_workers == 0
&& let Some(tx) = pm.completion_indicator.take() {
let _ = tx.send(());
}
} else {
// Late reply for a round we've already collected/removed.
tracing::debug!("Leader PULL: late/unknown id {}", id);
} }
}
_ = cancel_token.cancelled() => { _ = cancel_token.cancelled() => {
tracing::info!("ZmqActiveMessageLeader: Pull worker cancelled."); tracing::info!("ZmqActiveMessageLeader: Pull worker cancelled.");
break; break;
...@@ -269,10 +399,10 @@ impl ZmqActiveMessageLeader { ...@@ -269,10 +399,10 @@ impl ZmqActiveMessageLeader {
/// A message handle is used to track a message. /// A message handle is used to track a message.
/// It contains a way to ACK the message, as well as the data. /// It contains a way to ACK the message, as well as the data.
pub struct MessageHandle { pub struct MessageHandle {
message_id: usize, pub message_id: usize,
function: String, function: String,
pub data: Vec<Vec<u8>>, pub data: Vec<Vec<u8>>,
push_handle: Arc<Mutex<Push>>, pub push_handle: Arc<Mutex<Push>>,
acked: bool, acked: bool,
} }
...@@ -321,6 +451,36 @@ impl MessageHandle { ...@@ -321,6 +451,36 @@ impl MessageHandle {
tracing::debug!("ZmqActiveMessageWorker: ACKed message with id: {}", id); tracing::debug!("ZmqActiveMessageWorker: ACKed message with id: {}", id);
Ok(()) Ok(())
} }
/// Reply to the leader with arbitrary payload frames and mark as acked.
/// Frames shape: [id, function, payload_0, payload_1, ...]
pub async fn reply(
&mut self,
function: &str,
payload_frames: &[Vec<u8>],
) -> anyhow::Result<()> {
let mut frames: std::collections::VecDeque<tmq::Message> =
std::collections::VecDeque::with_capacity(2 + payload_frames.len());
frames.push_back(self.message_id.to_be_bytes().as_slice().into());
frames.push_back(function.into());
for p in payload_frames {
frames.push_back(p.clone().into());
}
self.push_handle
.lock()
.await
.send(tmq::Multipart(frames))
.await?;
// Mark as acked so Drop won't panic; leader treats the reply as the "ack".
self.acked = true;
Ok(())
}
/// Mark this message as handled locally without sending an ACK/reply.
/// Use when intentionally ignoring a message (e.g. ping before readiness).
pub fn mark_handled(&mut self) {
self.acked = true;
}
} }
/// We must always ACK a message. /// We must always ACK a message.
...@@ -340,21 +500,6 @@ pub trait Handler: Send + Sync { ...@@ -340,21 +500,6 @@ pub trait Handler: Send + Sync {
async fn handle(&self, message: MessageHandle) -> Result<()>; async fn handle(&self, message: MessageHandle) -> Result<()>;
} }
/// A super simple handler that responds to a ping.
/// This is used in the startup sequence to check worker liveness.
struct Ping;
#[async_trait]
impl Handler for Ping {
async fn handle(&self, mut message: MessageHandle) -> Result<()> {
if !message.data.is_empty() {
return Err(anyhow::anyhow!("Ping message should not have data."));
}
message.ack().await?;
Ok(())
}
}
type MessageHandlers = HashMap<String, Arc<dyn Handler>>; type MessageHandlers = HashMap<String, Arc<dyn Handler>>;
/// The ActiveMessageWorker receives commands from the leader, and ACKs them. /// The ActiveMessageWorker receives commands from the leader, and ACKs them.
...@@ -364,7 +509,7 @@ impl ZmqActiveMessageWorker { ...@@ -364,7 +509,7 @@ impl ZmqActiveMessageWorker {
pub fn new( pub fn new(
sub_url: &str, sub_url: &str,
push_url: &str, push_url: &str,
mut message_handlers: MessageHandlers, message_handlers: MessageHandlers,
cancel_token: CancellationToken, cancel_token: CancellationToken,
) -> Result<Self> { ) -> Result<Self> {
let context = Context::new(); let context = Context::new();
...@@ -380,8 +525,6 @@ impl ZmqActiveMessageWorker { ...@@ -380,8 +525,6 @@ impl ZmqActiveMessageWorker {
push_url push_url
); );
// Add our ping handler.
message_handlers.insert(ZMQ_PING_MESSAGE.to_string(), Arc::new(Ping));
let message_handlers = Arc::new(message_handlers); let message_handlers = Arc::new(message_handlers);
CriticalTaskExecutionHandle::new( CriticalTaskExecutionHandle::new(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment