// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use super::*; use utils::*; use zmq::*; use derive_builder::Builder; use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::time::Duration; use tokio::sync::Notify; use tokio::sync::OnceCell; use tokio::sync::oneshot; use tokio::time::sleep; #[derive(Builder, Clone, Debug, Default)] pub struct KvbmLeaderNumBlocksConfig { #[builder(default = "0.0")] pub cache_size_in_gb: f64, #[builder(default = "0")] pub num_blocks_overriden: usize, } fn compute_num_blocks( num_blocks_config: &KvbmLeaderNumBlocksConfig, bytes_per_block: usize, ) -> usize { if num_blocks_config.num_blocks_overriden > 0 { num_blocks_config.num_blocks_overriden } else { ((num_blocks_config.cache_size_in_gb * 1_000_000_000.0) / bytes_per_block as f64) as usize } } #[derive(Builder, Clone, Debug)] pub struct KvbmLeaderConfig { /// The world size. #[builder(default = "1")] world_size: usize, /// The leader-worker init connection timeout seconds. #[builder(default = "120")] leader_init_timeout_secs: u64, #[builder(default = "KvbmLeaderNumBlocksConfig::default()")] host_blocks_config: KvbmLeaderNumBlocksConfig, #[builder(default = "KvbmLeaderNumBlocksConfig::default()")] disk_blocks_config: KvbmLeaderNumBlocksConfig, #[builder(default = "String::from(\"tcp://127.0.0.1:56001\")")] leader_pub_url: String, #[builder(default = "String::from(\"tcp://127.0.0.1:56002\")")] leader_ack_url: String, } impl KvbmLeaderConfig { pub fn builder() -> KvbmLeaderConfigBuilder { KvbmLeaderConfigBuilder::default() } pub fn sanity_check(&self) -> anyhow::Result<()> { if self.leader_pub_url == self.leader_ack_url { anyhow::bail!( "leader_pub_url and leader_ack_url must differ (same endpoint would fail to bind)." ); } let cpu = &self.host_blocks_config; let disk = &self.disk_blocks_config; let cpu_configured = cpu.num_blocks_overriden > 0 || cpu.cache_size_in_gb > 0.0; let disk_configured = disk.num_blocks_overriden > 0 || disk.cache_size_in_gb > 0.0; if !cpu_configured && !disk_configured { panic!( "KVBM Configuration Error: At least one cache tier must be configured.\n\ \n\ Configure CPU cache (G2) for CPU memory offloading:\n\ • DYN_KVBM_CPU_CACHE_GB= (e.g., DYN_KVBM_CPU_CACHE_GB=4)\n\ • DYN_KVBM_CPU_CACHE_OVERRIDE_NUM_BLOCKS= (e.g., DYN_KVBM_CPU_CACHE_OVERRIDE_NUM_BLOCKS=1000)\n\ \n\ OR configure disk cache (G3) for direct GPU->Disk offloading:\n\ • DYN_KVBM_DISK_CACHE_GB= (e.g., DYN_KVBM_DISK_CACHE_GB=8)\n\ • DYN_KVBM_DISK_CACHE_OVERRIDE_NUM_BLOCKS=\n\ \n\ Note: If only disk cache is configured, KVBM will offload directly from GPU (G1) to Disk (G3), bypassing CPU memory (G2)." ); } Ok(()) } } #[derive(Debug, Default)] pub struct KvbmLeaderState { pub num_device_blocks: Arc, pub num_host_blocks: Arc, pub num_disk_blocks: Arc, pub workers_allocation_ready: Arc, pub workers_ready_notify: Arc, } /// The leader of the KVBM. /// /// This is responsible for: /// - Establishing a ZMQ connection with workers. /// - Syncing the leader barrier with workers. /// - Sending messages to workers. pub struct KvbmLeader { state: Arc, zmq_leader: Arc>, config: KvbmLeaderConfig, } impl KvbmLeader { pub async fn new(config: KvbmLeaderConfig) -> anyhow::Result { let leader_sockets = new_leader_sockets(&config.leader_pub_url, &config.leader_ack_url)?; let leader = Self { state: Arc::new(KvbmLeaderState::default()), zmq_leader: Arc::new(tokio::sync::OnceCell::new()), config, }; let cancel_token = tokio_util::sync::CancellationToken::new(); leader.spawn_zmq_task(leader_sockets, cancel_token); Ok(leader) } fn spawn_zmq_task( &self, leader_sockets: LeaderSockets, cancel: tokio_util::sync::CancellationToken, ) { let cell = self.zmq_leader.clone(); let state = self.state.clone(); let world_size = self.config.world_size; let timeout = self.config.leader_init_timeout_secs; let host_cfg = self.config.host_blocks_config.clone(); let disk_cfg = self.config.disk_blocks_config.clone(); // capture num_device_blocks so we can set it inside the closure let num_device_blocks_cell = state.num_device_blocks.clone(); let num_host_blocks_cell = state.num_host_blocks.clone(); let num_disk_blocks_cell = state.num_disk_blocks.clone(); tokio::spawn(async move { let res = ZmqActiveMessageLeader::new_with_handshake( leader_sockets, world_size, std::time::Duration::from_secs(timeout), cancel.clone(), move |workers: &[WorkerMetadata]| -> LeaderMetadata { // Record device blocks: min across workers if let Some(min_dev) = workers.iter().map(|w| w.num_device_blocks).min() { num_device_blocks_cell.store(min_dev, Ordering::Release); } // For TP, sum bytes_per_block; adjust policy for DP/PP if needed. let bytes_per_block: usize = workers.iter().map(|w| w.bytes_per_block).sum(); let num_host_blocks = compute_num_blocks(&host_cfg, bytes_per_block); let num_disk_blocks = compute_num_blocks(&disk_cfg, bytes_per_block); // store into leader state num_host_blocks_cell.store(num_host_blocks, Ordering::Release); num_disk_blocks_cell.store(num_disk_blocks, Ordering::Release); LeaderMetadata { num_host_blocks, num_disk_blocks, } }, ) .await; match res { Ok(zmq) => { let _ = cell.set(zmq); state .workers_allocation_ready .store(true, Ordering::Release); state.workers_ready_notify.notify_waiters(); tracing::info!("ZMQ handshake complete; workers allocation ready"); } Err(e) => { tracing::error!("ZMQ init/handshake failed: {e:?}"); } } }); } pub async fn transfer_blocks_request( &self, request: BlockTransferRequest, ) -> anyhow::Result> { let zmq = self .zmq_leader .get() .ok_or_else(|| anyhow::anyhow!("ZMQ leader not ready"))?; let data = vec![serde_json::to_vec(&request)?]; zmq.broadcast(ZMQ_TRANSFER_BLOCKS_MESSAGE, data).await } pub fn num_device_blocks(&self) -> usize { self.state.num_device_blocks.load(Ordering::Acquire) } pub fn num_host_blocks(&self) -> usize { self.state.num_host_blocks.load(Ordering::Acquire) } pub fn num_disk_blocks(&self) -> usize { self.state.num_disk_blocks.load(Ordering::Acquire) } pub async fn wait_worker_sync_ready(&self) -> bool { if self.state.workers_allocation_ready.load(Ordering::Acquire) { return true; } let notified = self.state.workers_ready_notify.notified(); tokio::select! { _ = notified => true, _ = sleep(Duration::from_secs(self.config.leader_init_timeout_secs)) => false, } } }