Unverified Commit 008683d6 authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: adding kvbm-engine (#6773)


Signed-off-by: default avatarRyan Olson <rolson@nvidia.com>
parent cf79c4fc
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! BlockManager testing utilities.
//!
//! Contains core manager/registry builders, population helpers, and the
//! `MultiInstancePopulator` which bridges logical and physical layers.
//!
//! Note: Due to version boundaries between workspace crates and git-sourced
//! kvbm-logical/kvbm-physical, these utilities use workspace-local types directly
//! rather than re-exporting from kvbm-logical::testing.
use std::marker::PhantomData;
use std::ops::Range;
use std::sync::Arc;
use anyhow::Result;
use kvbm_logical::{
blocks::{BlockMetadata, BlockRegistry},
events::EventsManager,
manager::{BlockManager, FrequencyTrackingCapacity},
};
use crate::{BlockId, SequenceHash};
use kvbm_common::tokens::TokenBlockSequence;
use kvbm_logical::KvbmSequenceHashProvider;
use kvbm_physical::transfer::FillPattern;
use super::token_blocks;
/// Builder for creating test BlockRegistry with optional events integration.
#[derive(Default)]
pub struct TestRegistryBuilder {
events_manager: Option<Arc<EventsManager>>,
frequency_tracking: FrequencyTrackingCapacity,
}
impl TestRegistryBuilder {
pub fn new() -> Self {
Self {
events_manager: None,
frequency_tracking: FrequencyTrackingCapacity::Medium,
}
}
pub fn events_manager(mut self, manager: Arc<EventsManager>) -> Self {
self.events_manager = Some(manager);
self
}
pub fn frequency_tracking(mut self, capacity: FrequencyTrackingCapacity) -> Self {
self.frequency_tracking = capacity;
self
}
pub fn build(self) -> BlockRegistry {
let mut builder =
BlockRegistry::builder().frequency_tracker(self.frequency_tracking.create_tracker());
if let Some(events_manager) = self.events_manager {
builder = builder.event_manager(events_manager);
}
builder.build()
}
}
/// Builder for creating test BlockManagers.
pub struct TestManagerBuilder<T: BlockMetadata> {
block_count: Option<usize>,
block_size: Option<usize>,
registry: Option<BlockRegistry>,
events_manager: Option<Arc<EventsManager>>,
frequency_tracking: FrequencyTrackingCapacity,
_phantom: PhantomData<T>,
}
impl<T: BlockMetadata> Default for TestManagerBuilder<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: BlockMetadata> TestManagerBuilder<T> {
pub fn new() -> Self {
Self {
block_count: None,
block_size: None,
registry: None,
events_manager: None,
frequency_tracking: FrequencyTrackingCapacity::Medium,
_phantom: PhantomData,
}
}
pub fn block_count(mut self, count: usize) -> Self {
self.block_count = Some(count);
self
}
pub fn block_size(mut self, size: usize) -> Self {
self.block_size = Some(size);
self
}
pub fn registry(mut self, registry: BlockRegistry) -> Self {
self.registry = Some(registry);
self
}
pub fn events_manager(mut self, manager: Arc<EventsManager>) -> Self {
self.events_manager = Some(manager);
self
}
pub fn frequency_tracking(mut self, capacity: FrequencyTrackingCapacity) -> Self {
self.frequency_tracking = capacity;
self
}
pub fn build(self) -> BlockManager<T> {
let block_count = self.block_count.expect("block_count is required");
let block_size = self.block_size.expect("block_size is required");
let registry = self.registry.unwrap_or_else(|| {
let mut builder =
TestRegistryBuilder::new().frequency_tracking(self.frequency_tracking);
if let Some(events_manager) = self.events_manager {
builder = builder.events_manager(events_manager);
}
builder.build()
});
BlockManager::<T>::builder()
.block_count(block_count)
.block_size(block_size)
.registry(registry)
.with_lru_backend()
.build()
.expect("Should build test manager")
}
}
/// Populate a BlockManager with token blocks and return their sequence hashes.
pub fn populate_manager_with_blocks<T: BlockMetadata>(
manager: &BlockManager<T>,
token_blocks: &[kvbm_common::tokens::TokenBlock],
) -> Result<Vec<SequenceHash>> {
let blocks = manager
.allocate_blocks(token_blocks.len())
.ok_or_else(|| anyhow::anyhow!("Failed to allocate {} blocks", token_blocks.len()))?;
let complete_blocks: Vec<_> = blocks
.into_iter()
.zip(token_blocks.iter())
.map(|(block, token_block)| {
block
.complete(token_block)
.map_err(|e| anyhow::anyhow!("Failed to complete block: {:?}", e))
})
.collect::<Result<Vec<_>>>()?;
let seq_hashes: Vec<SequenceHash> = complete_blocks.iter().map(|b| b.sequence_hash()).collect();
let immutable_blocks = manager.register_blocks(complete_blocks);
drop(immutable_blocks);
Ok(seq_hashes)
}
/// Quick setup: create manager and populate with sequential token blocks.
pub fn create_and_populate_manager<T: BlockMetadata>(
block_count: usize,
block_size: usize,
start_token: u32,
registry: BlockRegistry,
) -> Result<(BlockManager<T>, Vec<SequenceHash>)> {
let manager = TestManagerBuilder::<T>::new()
.block_count(block_count)
.block_size(block_size)
.registry(registry)
.build();
let token_sequence = token_blocks::create_token_sequence(block_count, block_size, start_token);
let seq_hashes = populate_manager_with_blocks(&manager, token_sequence.blocks())?;
Ok((manager, seq_hashes))
}
// =============================================================================
// Multi-Instance Population Helper
// =============================================================================
/// Specification for a single instance's population.
pub struct InstancePopulationSpec<'a, M: BlockMetadata> {
pub manager: &'a BlockManager<M>,
pub block_range: Range<usize>,
pub fill_pattern: Option<FillPattern>,
}
/// Result of populating a single instance.
pub struct InstancePopulationResult {
pub instance_index: usize,
pub block_ids: Vec<BlockId>,
pub hashes: Vec<SequenceHash>,
}
/// Results from populating multiple instances.
pub struct PopulatedInstances {
token_sequence: TokenBlockSequence,
all_hashes: Vec<SequenceHash>,
instance_results: Vec<InstancePopulationResult>,
}
impl PopulatedInstances {
pub fn all_hashes(&self) -> &[SequenceHash] {
&self.all_hashes
}
pub fn token_sequence(&self) -> &TokenBlockSequence {
&self.token_sequence
}
pub fn instance_block_ids(&self, instance_index: usize) -> Option<&[BlockId]> {
self.instance_results
.get(instance_index)
.map(|r| r.block_ids.as_slice())
}
pub fn instance_hashes(&self, instance_index: usize) -> Option<&[SequenceHash]> {
self.instance_results
.get(instance_index)
.map(|r| r.hashes.as_slice())
}
pub fn instance_count(&self) -> usize {
self.instance_results.len()
}
pub fn instance_results(&self) -> &[InstancePopulationResult] {
&self.instance_results
}
}
/// Builder for populating multiple instances with blocks from a shared token sequence.
pub struct MultiInstancePopulatorBuilder<'a, M: BlockMetadata> {
total_blocks: Option<usize>,
block_size: Option<usize>,
start_token: u32,
instances: Vec<InstancePopulationSpec<'a, M>>,
}
impl<'a, M: BlockMetadata> Default for MultiInstancePopulatorBuilder<'a, M> {
fn default() -> Self {
Self::new()
}
}
impl<'a, M: BlockMetadata> MultiInstancePopulatorBuilder<'a, M> {
pub fn new() -> Self {
Self {
total_blocks: None,
block_size: None,
start_token: 0,
instances: Vec::new(),
}
}
pub fn total_blocks(mut self, count: usize) -> Self {
self.total_blocks = Some(count);
self
}
pub fn block_size(mut self, size: usize) -> Self {
self.block_size = Some(size);
self
}
pub fn start_token(mut self, token: u32) -> Self {
self.start_token = token;
self
}
pub fn add_instance(mut self, manager: &'a BlockManager<M>, block_range: Range<usize>) -> Self {
self.instances.push(InstancePopulationSpec {
manager,
block_range,
fill_pattern: None,
});
self
}
pub fn add_instance_with_pattern(
mut self,
manager: &'a BlockManager<M>,
block_range: Range<usize>,
fill_pattern: FillPattern,
) -> Self {
self.instances.push(InstancePopulationSpec {
manager,
block_range,
fill_pattern: Some(fill_pattern),
});
self
}
pub fn build(self) -> Result<PopulatedInstances> {
let total_blocks = self.total_blocks.expect("total_blocks is required");
let block_size = self.block_size.expect("block_size is required");
let token_sequence =
token_blocks::create_token_sequence(total_blocks, block_size, self.start_token);
let full_blocks = token_sequence.blocks();
let all_hashes: Vec<SequenceHash> =
full_blocks.iter().map(|b| b.kvbm_sequence_hash()).collect();
let mut instance_results = Vec::with_capacity(self.instances.len());
for (idx, spec) in self.instances.into_iter().enumerate() {
if spec.block_range.end > total_blocks {
anyhow::bail!(
"Instance {} block_range {:?} exceeds total_blocks {}",
idx,
spec.block_range,
total_blocks
);
}
let instance_blocks: Vec<_> = full_blocks[spec.block_range.clone()].to_vec();
let hashes = populate_manager_with_blocks(spec.manager, &instance_blocks)?;
let matched = spec.manager.match_blocks(&hashes);
let block_ids: Vec<BlockId> = matched.into_iter().map(|b| b.block_id()).collect();
instance_results.push(InstancePopulationResult {
instance_index: idx,
block_ids,
hashes,
});
}
Ok(PopulatedInstances {
token_sequence,
all_hashes,
instance_results,
})
}
}
/// Convenience type alias for the builder.
pub type MultiInstancePopulator<'a, M> = MultiInstancePopulatorBuilder<'a, M>;
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone, Debug)]
struct TestMetadata;
#[test]
fn test_create_test_manager() {
let manager = TestManagerBuilder::<TestMetadata>::new()
.block_count(100)
.block_size(16)
.build();
assert_eq!(manager.total_blocks(), 100);
assert_eq!(manager.block_size(), 16);
assert_eq!(manager.available_blocks(), 100);
}
#[test]
fn test_populate_manager_with_blocks() {
let manager = TestManagerBuilder::<TestMetadata>::new()
.block_count(50)
.block_size(4)
.build();
let token_seq = token_blocks::create_token_sequence(10, 4, 0);
let seq_hashes =
populate_manager_with_blocks(&manager, token_seq.blocks()).expect("Should populate");
assert_eq!(seq_hashes.len(), 10);
assert_eq!(manager.available_blocks(), 50);
}
#[test]
fn test_create_and_populate_manager() {
let registry = TestRegistryBuilder::new().build();
let (manager, hashes) = create_and_populate_manager::<TestMetadata>(32, 4, 100, registry)
.expect("Should create");
assert_eq!(hashes.len(), 32);
assert_eq!(manager.total_blocks(), 32);
assert_eq!(manager.available_blocks(), 32);
let matched = manager.match_blocks(&hashes);
assert_eq!(matched.len(), 32);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Messenger instance setup utilities for testing.
use anyhow::Result;
use std::net::TcpListener;
use std::sync::Arc;
use tokio::time::Duration;
use velo::Messenger;
use velo::backend::Transport;
use velo::backend::tcp::TcpTransportBuilder;
/// Create a single Messenger instance with TCP transport on a random port.
///
/// # Returns
/// Messenger instance
///
/// # Example
/// ```ignore
/// let messenger = create_messenger_tcp().await?;
/// ```
pub async fn create_messenger_tcp() -> Result<Arc<Messenger>> {
let listener = TcpListener::bind("127.0.0.1:0")?;
let transport: Arc<dyn Transport> = Arc::new(
TcpTransportBuilder::new()
.from_listener(listener)?
.build()?,
);
let messenger = Messenger::builder()
.add_transport(transport)
.build()
.await?;
// Give transport a moment to bind
tokio::time::sleep(Duration::from_millis(100)).await;
Ok(messenger)
}
/// Container for a pair of connected Messenger instances.
pub struct MessengerPair {
pub messenger_a: Arc<Messenger>,
pub messenger_b: Arc<Messenger>,
}
/// Create a pair of Messenger instances with bidirectional peer registration.
///
/// Both instances:
/// - Use TCP transport on random ports
/// - Are registered as peers of each other
/// - Ready for communication
///
/// # Example
/// ```ignore
/// let pair = create_messenger_pair_tcp().await?;
///
/// // Can now send messages between messenger_a and messenger_b
/// pair.messenger_a.unary("handler")?
/// .instance(pair.messenger_b.instance_id())
/// .send().await?;
/// ```
pub async fn create_messenger_pair_tcp() -> Result<MessengerPair> {
// Create first Messenger instance
let messenger_a = create_messenger_tcp().await?;
// Create second Messenger instance
let messenger_b = create_messenger_tcp().await?;
// Register each as peer of the other
messenger_a.register_peer(messenger_b.peer_info())?;
messenger_b.register_peer(messenger_a.peer_info())?;
// Give time for peer registration to propagate
tokio::time::sleep(Duration::from_millis(200)).await;
Ok(MessengerPair {
messenger_a,
messenger_b,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_create_messenger_instance() {
let messenger = create_messenger_tcp()
.await
.expect("Should create Messenger");
let peer_info = messenger.peer_info();
assert_eq!(
peer_info.instance_id().worker_id(),
messenger.instance_id().worker_id()
);
assert!(!peer_info.worker_address().as_bytes().is_empty());
// Local handlers should include system entries
let handlers = messenger.list_local_handlers();
assert!(
handlers.contains(&"_list_handlers".to_string()),
"Expected _list_handlers in local handler list: {:?}",
handlers
);
assert!(
handlers.contains(&"_hello".to_string()),
"Expected _hello in local handler list: {:?}",
handlers
);
}
#[tokio::test]
async fn test_create_messenger_pair() {
let pair = create_messenger_pair_tcp()
.await
.expect("Should create pair");
// Verify both instances have different IDs
assert_ne!(
pair.messenger_a.instance_id(),
pair.messenger_b.instance_id()
);
// Verify worker addresses differ
assert_ne!(
pair.messenger_a.peer_info().worker_address().checksum(),
pair.messenger_b.peer_info().worker_address().checksum()
);
// Verify system handlers are discoverable across peers
let handlers_from_a = pair
.messenger_a
.available_handlers(pair.messenger_b.instance_id())
.await
.expect("Handlers from messenger_b should be available");
assert!(
handlers_from_a.contains(&"_list_handlers".to_string()),
"messenger_a should see _list_handlers on messenger_b: {:?}",
handlers_from_a
);
assert!(
handlers_from_a.contains(&"_hello".to_string()),
"messenger_a should see _hello on messenger_b: {:?}",
handlers_from_a
);
let handlers_from_b = pair
.messenger_b
.available_handlers(pair.messenger_a.instance_id())
.await
.expect("Handlers from messenger_a should be available");
assert!(
handlers_from_b.contains(&"_list_handlers".to_string()),
"messenger_b should see _list_handlers on messenger_a: {:?}",
handlers_from_b
);
assert!(
handlers_from_b.contains(&"_hello".to_string()),
"messenger_b should see _hello on messenger_a: {:?}",
handlers_from_b
);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
#![doc = include_str!("../../docs/testing.md")]
pub mod distributed;
pub mod events;
pub mod managers;
pub mod messenger;
pub mod offloading;
pub mod physical;
pub mod token_blocks;
// Re-export commonly used testing utilities
pub use distributed::TestSession;
pub use events::{EventsPipelineConfig, EventsPipelineConfigBuilder, EventsPipelineFixture};
pub use managers::{
InstancePopulationResult, InstancePopulationSpec, MultiInstancePopulator,
MultiInstancePopulatorBuilder, PopulatedInstances, TestManagerBuilder, TestRegistryBuilder,
create_and_populate_manager, populate_manager_with_blocks,
};
pub use messenger::{MessengerPair, create_messenger_pair_tcp, create_messenger_tcp};
pub use physical::{TestAgent, TestAgentBuilder, TransferChecksums};
pub use token_blocks::*;
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Only compile proto files if grpc feature is enabled
#[cfg(feature = "grpc")]
{
tonic_build::compile_protos("proto/velo.proto")?;
}
Ok(())
}
//! End-to-end tests for the offload engine.
mod object_flow;
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! End-to-end test for G2->G4 (object storage) offload flow with distributed locking.
//!
//! This test demonstrates:
//! - Using the locking mechanism (.lock/.meta files) for distributed coordination
//! - Verifying lock acquisition and release
//! - Verifying meta file creation marks blocks as offloaded
//! - Verifying re-offload is skipped for blocks that already have meta files
//!
//! Note: Uses a mock in-memory object storage implementation for testing without
//! requiring a real S3/MinIO backend.
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{SystemTime, UNIX_EPOCH};
use anyhow::Result;
use bytes::Bytes;
use futures::future::BoxFuture;
use crate::LogicalLayoutHandle;
use crate::object::{LockFileContent, ObjectBlockOps, ObjectLockManager};
use crate::offload::{
BoxFuture as PolicyBoxFuture, EvalContext, ObjectLockPresenceFilter, OffloadPolicy,
PendingTracker,
};
use crate::{BlockId, G2, SequenceHash};
/// Create a test sequence hash from a simple integer.
fn test_hash(n: u64) -> SequenceHash {
SequenceHash::new(n, None, 0)
}
/// Get current time as seconds since epoch
fn now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
}
/// Format timestamp as RFC3339-like string
fn timestamp_to_string(secs: u64) -> String {
format!("{}", secs)
}
/// Parse timestamp from string
fn parse_timestamp(s: &str) -> Option<u64> {
s.parse().ok()
}
/// Check if deadline timestamp is expired
fn is_expired_timestamp(deadline_str: &str) -> bool {
if let Some(deadline) = parse_timestamp(deadline_str) {
now_secs() > deadline
} else {
true // Can't parse = treat as expired
}
}
// =========================================================================
// Mock Object Storage Implementation
// =========================================================================
/// In-memory mock object storage for testing.
#[derive(Debug, Default)]
struct MockObjectStorage {
objects: RwLock<HashMap<String, Bytes>>,
}
impl MockObjectStorage {
fn new() -> Self {
Self {
objects: RwLock::new(HashMap::new()),
}
}
fn has_object(&self, key: &str) -> bool {
self.objects.read().unwrap().contains_key(key)
}
fn get_object(&self, key: &str) -> Option<Bytes> {
self.objects.read().unwrap().get(key).cloned()
}
fn put_object(&self, key: &str, data: Bytes) {
self.objects.write().unwrap().insert(key.to_string(), data);
}
fn delete_object(&self, key: &str) -> bool {
self.objects.write().unwrap().remove(key).is_some()
}
fn put_if_not_exists(&self, key: &str, data: Bytes) -> bool {
let mut objects = self.objects.write().unwrap();
if objects.contains_key(key) {
false
} else {
objects.insert(key.to_string(), data);
true
}
}
#[allow(dead_code)]
fn list_keys(&self) -> Vec<String> {
self.objects.read().unwrap().keys().cloned().collect()
}
}
/// Mock ObjectBlockOps implementation using in-memory storage.
#[allow(dead_code)]
struct MockObjectBlockClient {
storage: Arc<MockObjectStorage>,
}
#[allow(dead_code)]
impl MockObjectBlockClient {
fn new(storage: Arc<MockObjectStorage>) -> Self {
Self { storage }
}
}
impl ObjectBlockOps for MockObjectBlockClient {
fn has_blocks(
&self,
keys: Vec<SequenceHash>,
) -> BoxFuture<'static, Vec<(SequenceHash, Option<usize>)>> {
let storage = self.storage.clone();
Box::pin(async move {
keys.into_iter()
.map(|hash| {
let key = format!("{:?}", hash);
let size = storage.get_object(&key).map(|b| b.len());
(hash, size)
})
.collect()
})
}
fn put_blocks(
&self,
keys: Vec<SequenceHash>,
_layout: LogicalLayoutHandle,
_block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
let storage = self.storage.clone();
Box::pin(async move {
keys.into_iter()
.map(|hash| {
let key = format!("{:?}", hash);
storage.put_object(&key, Bytes::from("mock_block_data"));
Ok(hash)
})
.collect()
})
}
fn get_blocks(
&self,
keys: Vec<SequenceHash>,
_layout: LogicalLayoutHandle,
_block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
let storage = self.storage.clone();
Box::pin(async move {
keys.into_iter()
.map(|hash| {
let key = format!("{:?}", hash);
if storage.has_object(&key) {
Ok(hash)
} else {
Err(hash)
}
})
.collect()
})
}
}
/// Mock ObjectLockManager implementation using in-memory storage.
struct MockLockManager {
storage: Arc<MockObjectStorage>,
instance_id: String,
lock_timeout_secs: u64,
}
impl MockLockManager {
fn new(storage: Arc<MockObjectStorage>, instance_id: String) -> Self {
Self {
storage,
instance_id,
lock_timeout_secs: 300,
}
}
fn lock_key(hash: &SequenceHash) -> String {
format!("{:?}.lock", hash)
}
fn meta_key(hash: &SequenceHash) -> String {
format!("{:?}.meta", hash)
}
fn create_lock_content(&self) -> LockFileContent {
LockFileContent {
instance_id: self.instance_id.clone(),
acquired_at: timestamp_to_string(now_secs()),
deadline: timestamp_to_string(now_secs() + self.lock_timeout_secs),
}
}
}
impl ObjectLockManager for MockLockManager {
fn has_meta(&self, hash: SequenceHash) -> PolicyBoxFuture<'static, Result<bool>> {
let storage = self.storage.clone();
Box::pin(async move {
let meta_key = Self::meta_key(&hash);
Ok(storage.has_object(&meta_key))
})
}
fn try_acquire_lock(&self, hash: SequenceHash) -> PolicyBoxFuture<'static, Result<bool>> {
let storage = self.storage.clone();
let lock_content = self.create_lock_content();
let our_instance_id = self.instance_id.clone();
Box::pin(async move {
let lock_key = Self::lock_key(&hash);
let lock_data =
serde_json::to_vec(&lock_content).expect("Failed to serialize lock content");
// Try conditional put
if storage.put_if_not_exists(&lock_key, Bytes::from(lock_data.clone())) {
return Ok(true); // Acquired lock
}
// Lock exists, check if we own it or if it's expired
if let Some(existing_data) = storage.get_object(&lock_key)
&& let Ok(existing_lock) =
serde_json::from_slice::<LockFileContent>(&existing_data)
{
// Check if we own the lock
if existing_lock.instance_id == our_instance_id {
return Ok(true);
}
// Check if expired
if is_expired_timestamp(&existing_lock.deadline) {
// Expired, overwrite
storage.put_object(&lock_key, Bytes::from(lock_data));
return Ok(true);
}
}
Ok(false) // Lock held by another instance
})
}
fn create_meta(&self, hash: SequenceHash) -> PolicyBoxFuture<'static, Result<()>> {
let storage = self.storage.clone();
Box::pin(async move {
let meta_key = Self::meta_key(&hash);
storage.put_object(&meta_key, Bytes::new());
Ok(())
})
}
fn release_lock(&self, hash: SequenceHash) -> PolicyBoxFuture<'static, Result<()>> {
let storage = self.storage.clone();
Box::pin(async move {
let lock_key = Self::lock_key(&hash);
storage.delete_object(&lock_key);
Ok(())
})
}
}
// =========================================================================
// Lock Manager Tests
// =========================================================================
/// Test basic lock acquisition and release.
#[tokio::test]
async fn test_lock_manager_acquire_and_release() -> Result<()> {
let storage = Arc::new(MockObjectStorage::new());
let lock_manager = MockLockManager::new(storage.clone(), "test-instance".to_string());
let hash = test_hash(12345);
// Initially no lock or meta
assert!(!storage.has_object(&MockLockManager::lock_key(&hash)));
assert!(!storage.has_object(&MockLockManager::meta_key(&hash)));
// Acquire lock
let acquired = lock_manager.try_acquire_lock(hash).await?;
assert!(acquired, "Should acquire lock");
assert!(storage.has_object(&MockLockManager::lock_key(&hash)));
// Verify lock content
let lock_data = storage
.get_object(&MockLockManager::lock_key(&hash))
.unwrap();
let lock_content: LockFileContent = serde_json::from_slice(&lock_data)?;
assert_eq!(lock_content.instance_id, "test-instance");
// Create meta
lock_manager.create_meta(hash).await?;
assert!(storage.has_object(&MockLockManager::meta_key(&hash)));
// Release lock
lock_manager.release_lock(hash).await?;
assert!(!storage.has_object(&MockLockManager::lock_key(&hash)));
// Meta should still exist
assert!(storage.has_object(&MockLockManager::meta_key(&hash)));
eprintln!("✓ Lock acquisition and release test passed");
Ok(())
}
/// Test that same instance can re-acquire its own lock.
#[tokio::test]
async fn test_lock_manager_reacquire_own_lock() -> Result<()> {
let storage = Arc::new(MockObjectStorage::new());
let lock_manager = MockLockManager::new(storage.clone(), "test-instance".to_string());
let hash = test_hash(12345);
// Acquire lock
let acquired1 = lock_manager.try_acquire_lock(hash).await?;
assert!(acquired1, "Should acquire lock first time");
// Re-acquire same lock (same instance)
let acquired2 = lock_manager.try_acquire_lock(hash).await?;
assert!(
acquired2,
"Same instance should be able to re-acquire its lock"
);
eprintln!("✓ Lock re-acquisition test passed");
Ok(())
}
/// Test that different instance cannot acquire a valid lock.
#[tokio::test]
async fn test_lock_manager_contention() -> Result<()> {
let storage = Arc::new(MockObjectStorage::new());
let lock_manager1 = MockLockManager::new(storage.clone(), "instance-1".to_string());
let lock_manager2 = MockLockManager::new(storage.clone(), "instance-2".to_string());
let hash = test_hash(12345);
// Instance 1 acquires lock
let acquired1 = lock_manager1.try_acquire_lock(hash).await?;
assert!(acquired1, "Instance 1 should acquire lock");
// Instance 2 tries to acquire same lock
let acquired2 = lock_manager2.try_acquire_lock(hash).await?;
assert!(
!acquired2,
"Instance 2 should NOT acquire lock held by instance 1"
);
// Instance 1 can still re-acquire its own lock
let acquired1_again = lock_manager1.try_acquire_lock(hash).await?;
assert!(acquired1_again, "Instance 1 should re-acquire its own lock");
eprintln!("✓ Lock contention test passed");
Ok(())
}
/// Test that expired locks can be overwritten.
#[tokio::test]
async fn test_lock_manager_expired_lock_overwrite() -> Result<()> {
let storage = Arc::new(MockObjectStorage::new());
let lock_manager = MockLockManager::new(storage.clone(), "new-instance".to_string());
let hash = test_hash(12345);
// Pre-populate an expired lock from another instance
let expired_lock = LockFileContent {
instance_id: "old-instance".to_string(),
acquired_at: timestamp_to_string(0), // Ancient time
deadline: timestamp_to_string(1), // Long expired
};
let lock_key = MockLockManager::lock_key(&hash);
storage.put_object(&lock_key, Bytes::from(serde_json::to_vec(&expired_lock)?));
// New instance should be able to overwrite expired lock
let acquired = lock_manager.try_acquire_lock(hash).await?;
assert!(acquired, "Should acquire expired lock");
// Verify new instance owns the lock
let lock_data = storage.get_object(&lock_key).unwrap();
let lock_content: LockFileContent = serde_json::from_slice(&lock_data)?;
assert_eq!(lock_content.instance_id, "new-instance");
eprintln!("✓ Expired lock overwrite test passed");
Ok(())
}
/// Test has_meta checks correctly.
#[tokio::test]
async fn test_lock_manager_has_meta() -> Result<()> {
let storage = Arc::new(MockObjectStorage::new());
let lock_manager = MockLockManager::new(storage.clone(), "test-instance".to_string());
let hash = test_hash(12345);
// Initially no meta
let has_meta_before = lock_manager.has_meta(hash).await?;
assert!(!has_meta_before, "Should not have meta initially");
// Create meta
lock_manager.create_meta(hash).await?;
// Now has meta
let has_meta_after = lock_manager.has_meta(hash).await?;
assert!(has_meta_after, "Should have meta after creation");
eprintln!("✓ Has meta test passed");
Ok(())
}
// =========================================================================
// Policy Tests
// =========================================================================
/// Test ObjectLockPresenceFilter passes blocks without meta/lock.
#[tokio::test]
async fn test_policy_passes_new_blocks() -> Result<()> {
let storage = Arc::new(MockObjectStorage::new());
let lock_manager: Arc<dyn ObjectLockManager> = Arc::new(MockLockManager::new(
storage.clone(),
"test-instance".to_string(),
));
let filter = ObjectLockPresenceFilter::<G2>::new(lock_manager);
let hash = test_hash(12345);
let ctx = EvalContext::<G2>::from_weak(BlockId::default(), hash);
// Evaluate policy
let result = match filter.evaluate(&ctx) {
futures::future::Either::Left(ready) => ready.await,
futures::future::Either::Right(boxed) => boxed.await,
};
assert!(result?, "Policy should pass for new block");
// Lock should be acquired during evaluation
assert!(storage.has_object(&MockLockManager::lock_key(&hash)));
eprintln!("✓ Policy passes new blocks test passed");
Ok(())
}
/// Test ObjectLockPresenceFilter filters blocks with existing meta.
#[tokio::test]
async fn test_policy_filters_blocks_with_meta() -> Result<()> {
let storage = Arc::new(MockObjectStorage::new());
// Pre-populate meta file
let hash = test_hash(12345);
let meta_key = MockLockManager::meta_key(&hash);
storage.put_object(&meta_key, Bytes::new());
let lock_manager: Arc<dyn ObjectLockManager> = Arc::new(MockLockManager::new(
storage.clone(),
"test-instance".to_string(),
));
let filter = ObjectLockPresenceFilter::<G2>::new(lock_manager);
let ctx = EvalContext::<G2>::from_weak(BlockId::default(), hash);
// Evaluate policy
let result = match filter.evaluate(&ctx) {
futures::future::Either::Left(ready) => ready.await,
futures::future::Either::Right(boxed) => boxed.await,
};
assert!(!result?, "Policy should filter block with existing meta");
eprintln!("✓ Policy filters blocks with meta test passed");
Ok(())
}
/// Test ObjectLockPresenceFilter filters blocks with valid lock from another instance.
#[tokio::test]
async fn test_policy_filters_locked_blocks() -> Result<()> {
let storage = Arc::new(MockObjectStorage::new());
// Pre-populate lock from another instance
let hash = test_hash(12345);
let lock_content = LockFileContent {
instance_id: "other-instance".to_string(),
acquired_at: timestamp_to_string(now_secs()),
deadline: timestamp_to_string(now_secs() + 300), // 5 min in future
};
let lock_key = MockLockManager::lock_key(&hash);
storage.put_object(&lock_key, Bytes::from(serde_json::to_vec(&lock_content)?));
let lock_manager: Arc<dyn ObjectLockManager> = Arc::new(MockLockManager::new(
storage.clone(),
"test-instance".to_string(),
));
let filter = ObjectLockPresenceFilter::<G2>::new(lock_manager);
let ctx = EvalContext::<G2>::from_weak(BlockId::default(), hash);
// Evaluate policy
let result = match filter.evaluate(&ctx) {
futures::future::Either::Left(ready) => ready.await,
futures::future::Either::Right(boxed) => boxed.await,
};
assert!(
!result?,
"Policy should filter block locked by another instance"
);
eprintln!("✓ Policy filters locked blocks test passed");
Ok(())
}
/// Test ObjectLockPresenceFilter with pending tracker.
#[tokio::test]
async fn test_policy_filters_pending_blocks() -> Result<()> {
let storage = Arc::new(MockObjectStorage::new());
let lock_manager: Arc<dyn ObjectLockManager> = Arc::new(MockLockManager::new(
storage.clone(),
"test-instance".to_string(),
));
let pending_tracker = Arc::new(PendingTracker::new());
let filter = ObjectLockPresenceFilter::<G2>::new(lock_manager)
.with_pending_tracker(pending_tracker.clone());
let hash = test_hash(12345);
// Mark as pending
let _guard = pending_tracker.guard(hash);
let ctx = EvalContext::<G2>::from_weak(BlockId::default(), hash);
// Evaluate policy
let result = match filter.evaluate(&ctx) {
futures::future::Either::Left(ready) => ready.await,
futures::future::Either::Right(boxed) => boxed.await,
};
assert!(!result?, "Policy should filter pending block");
eprintln!("✓ Policy filters pending blocks test passed");
Ok(())
}
/// Test batch evaluation filters correctly.
#[tokio::test]
async fn test_policy_batch_evaluation() -> Result<()> {
let storage = Arc::new(MockObjectStorage::new());
// Pre-populate meta for some blocks
let hash1 = test_hash(1);
let hash2 = test_hash(2);
let hash3 = test_hash(3);
storage.put_object(&MockLockManager::meta_key(&hash1), Bytes::new()); // Has meta
let lock_manager: Arc<dyn ObjectLockManager> = Arc::new(MockLockManager::new(
storage.clone(),
"test-instance".to_string(),
));
let filter = ObjectLockPresenceFilter::<G2>::new(lock_manager);
let contexts = vec![
EvalContext::<G2>::from_weak(0, hash1), // Has meta - should filter
EvalContext::<G2>::from_weak(1, hash2), // New - should pass
EvalContext::<G2>::from_weak(2, hash3), // New - should pass
];
// Evaluate batch
let result = match filter.evaluate_batch(&contexts) {
futures::future::Either::Left(ready) => ready.await,
futures::future::Either::Right(boxed) => boxed.await,
};
let results = result?;
assert_eq!(results.len(), 3);
assert!(!results[0], "Block with meta should be filtered");
assert!(results[1], "New block should pass");
assert!(results[2], "New block should pass");
eprintln!("✓ Policy batch evaluation test passed");
Ok(())
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
#![doc = include_str!("../README.md")]
//! Physical layout and transfer testing utilities.
//!
//! Note: These are local implementations using workspace-local types.
//! When kvbm-physical moves to a workspace path dependency, these can
//! be replaced with re-exports from `kvbm_physical::testing`.
mod address;
mod identity;
mod transport;
// Re-export all public types
pub use address::{PeerInfo, WorkerAddress, WorkerAddressError};
pub use identity::{InstanceId, WorkerId};
pub use transport::TransportKey;
pub use kvbm_physical::testing::*;
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Token block creation utilities for testing.
//!
//! Note: These are local implementations using workspace-local `dynamo-tokens`
//! types. When kvbm-logical moves to a workspace path dependency, these can
//! be replaced with re-exports from `kvbm_logical::testing`.
use crate::SequenceHash;
use kvbm_common::tokens::{TokenBlock, TokenBlockSequence, compute_hash_v2};
use kvbm_logical::KvbmSequenceHashProvider;
/// Compute the default salt hash for requests with no salt and no lora.
pub fn default_request_salt_hash() -> u64 {
compute_hash_v2(b"{}", 0)
}
/// Create a token block from a slice of tokens.
pub fn create_token_block(tokens: &[u32]) -> TokenBlock {
let salt = default_request_salt_hash();
let token_sequence = TokenBlockSequence::from_slice(tokens, tokens.len() as u32, Some(salt));
if let Some(block) = token_sequence.blocks().first() {
block.clone()
} else {
let mut partial = token_sequence.into_parts().1;
partial.commit().expect("Should be able to commit")
}
}
/// Create a token block with sequential tokens starting from `start`.
pub fn create_sequential_block(start: u32, count: usize) -> TokenBlock {
let tokens: Vec<u32> = (start..start + count as u32).collect();
create_token_block(&tokens)
}
/// Create a token sequence with multiple blocks.
pub fn create_token_sequence(
num_blocks: usize,
block_size: usize,
start_token: u32,
) -> TokenBlockSequence {
let salt = default_request_salt_hash();
let total_tokens = num_blocks * block_size;
let tokens: Vec<u32> = (start_token..start_token + total_tokens as u32).collect();
TokenBlockSequence::from_slice(&tokens, block_size as u32, Some(salt))
}
/// Generate sequence hashes from a token sequence.
pub fn generate_sequence_hashes(token_sequence: &TokenBlockSequence) -> Vec<SequenceHash> {
token_sequence
.blocks()
.iter()
.map(|block| block.kvbm_sequence_hash())
.collect()
}
/// Create multiple disjoint token sequences with gaps between them.
pub fn create_disjoint_sequences(
segments: Vec<(usize, u32)>,
block_size: usize,
) -> (Vec<TokenBlock>, Vec<SequenceHash>) {
let mut all_blocks = Vec::new();
let mut all_hashes = Vec::new();
for (num_blocks, start_token) in segments {
let token_sequence = create_token_sequence(num_blocks, block_size, start_token);
let blocks = token_sequence.blocks().to_vec();
let hashes = generate_sequence_hashes(&token_sequence);
all_blocks.extend(blocks);
all_hashes.extend(hashes);
}
let mut combined: Vec<_> = all_blocks.into_iter().zip(all_hashes).collect();
combined.sort_by_key(|(_, hash)| hash.position());
combined.into_iter().unzip()
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! CoordinatedWorker - Leader's view of a worker with coordination state.
//!
//! This module provides a wrapper around the Worker trait that adds coordination
//! state needed by the leader, including local layout handles and remote handle
//! mappings for cross-leader transfers.
use std::collections::HashMap;
use std::sync::{OnceLock, RwLock};
use anyhow::Result;
use futures::future::BoxFuture;
use crate::object::ObjectBlockOps;
use crate::{BlockId, InstanceId, SequenceHash};
use kvbm_physical::manager::LayoutHandle;
use kvbm_physical::transfer::TransferOptions;
use super::{
LogicalLayoutHandle, RemoteDescriptor, SerializedLayout, TransferCompleteNotification, Worker,
WorkerLayoutResponse,
};
/// Leader's view of a worker with coordination state.
///
/// # Coordination State vs Execution State
///
/// CoordinatedWorker maintains **coordination state** - the leader's view of what
/// handles a worker has and how to route transfers. This is distinct from
/// **execution state** which [`DirectWorker`] maintains for actual transfer execution.
///
/// | State Type | Owner | Purpose |
/// |------------|-------|---------|
/// | Execution | DirectWorker | Handles needed by TransferManager to execute |
/// | Coordination | CoordinatedWorker | Leader's tracking for routing decisions |
///
/// When the inner worker is a DirectWorker, handles exist in both places. This
/// duplication is intentional:
/// - DirectWorker needs handles to call TransferManager
/// - CoordinatedWorker provides uniform API for local AND remote workers
/// - VeloWorkerClient is stateless, so leader must track handles somewhere
///
/// # Usage
///
/// ```ignore
/// // Leader creates CoordinatedWorker wrapping actual worker
/// let worker = CoordinatedWorker::new(
/// Box::new(direct_worker),
/// rank,
/// host_instance,
/// );
///
/// // After configure_layouts RPC, populate coordination state
/// worker.apply_layout_response(&response)?;
///
/// // Leader can now query handles for routing
/// if let Some(g2) = worker.local_g2() {
/// // Route G2 transfers through this worker
/// }
/// ```
///
/// # Remote Handle Mappings
///
/// For cross-leader transfers (e.g., Prefill pulling from Decode), the leader
/// imports remote worker metadata and stores rank-aware mappings:
///
/// ```ignore
/// // Prefill leader imports Decode workers' metadata
/// worker.import_remote_metadata(decode_leader_id, decode_rank, metadata).await?;
///
/// // Later, execute transfer using stored mapping
/// worker.transfer_from_remote(
/// decode_leader_id,
/// decode_rank,
/// LogicalLayoutHandle::G2, // source
/// src_block_ids,
/// LogicalLayoutHandle::G2, // destination
/// dst_block_ids,
/// options,
/// )?;
/// ```
///
/// [`DirectWorker`]: super::DirectWorker
pub struct CoordinatedWorker {
/// The actual worker (local DirectWorker or remote VeloWorkerClient).
/// CoordinatedWorker delegates execution to this inner worker.
inner: Box<dyn Worker>,
/// This worker's rank under its leader (0-indexed).
/// Used for asymmetric TP routing between leaders with different worker counts.
rank: usize,
/// Instance ID of the process hosting this worker.
/// For DirectWorker: same as leader's instance.
/// For VeloWorkerClient: the remote worker's instance.
host_instance: InstanceId,
// =========================================================================
// Coordination State - leader's view of this worker's handles
// =========================================================================
/// G1 (GPU KV cache) layout handle.
/// Populated from WorkerLayoutResponse after configure_layouts RPC.
local_g1: OnceLock<LayoutHandle>,
/// G2 (Host/pinned cache) layout handle.
/// Populated from WorkerLayoutResponse after configure_layouts RPC.
local_g2: OnceLock<LayoutHandle>,
/// G3 (Disk cache) layout handle.
/// Populated from WorkerLayoutResponse after configure_layouts RPC.
local_g3: OnceLock<LayoutHandle>,
/// Remote handle mappings for cross-leader transfers.
/// Key: (remote_leader_id, remote_rank, logical_type) → physical_handle
///
/// Unlike DirectWorker's remote_handles (keyed by instance only), this
/// includes rank for asymmetric TP routing. When Prefill (TP=4) pulls from
/// Decode (TP=2), each Prefill worker needs to know which Decode worker(s)
/// to pull from.
remote_handles: RwLock<HashMap<(InstanceId, usize, LogicalLayoutHandle), LayoutHandle>>,
}
impl CoordinatedWorker {
/// Create a new CoordinatedWorker wrapping an existing Worker.
pub fn new(inner: Box<dyn Worker>, rank: usize, host_instance: InstanceId) -> Self {
Self {
inner,
rank,
host_instance,
local_g1: OnceLock::new(),
local_g2: OnceLock::new(),
local_g3: OnceLock::new(),
remote_handles: RwLock::new(HashMap::new()),
}
}
/// Get this worker's rank.
pub fn rank(&self) -> usize {
self.rank
}
/// Get the instance ID of the process hosting this worker.
pub fn host_instance(&self) -> InstanceId {
self.host_instance
}
/// Get a reference to the underlying Worker.
pub fn inner(&self) -> &dyn Worker {
&*self.inner
}
/// Set the local G1 (GPU KV) handle.
///
/// # Arguments
/// * `handle` - G1 layout handle
///
/// # Errors
/// Returns error if G1 handle was already set.
pub fn set_local_g1(&self, handle: LayoutHandle) -> Result<()> {
self.local_g1
.set(handle)
.map_err(|_| anyhow::anyhow!("G1 handle already set"))
}
/// Set the local G2 (Host) handle.
///
/// # Arguments
/// * `handle` - G2 layout handle
///
/// # Errors
/// Returns error if G2 handle was already set.
pub fn set_local_g2(&self, handle: LayoutHandle) -> Result<()> {
self.local_g2
.set(handle)
.map_err(|_| anyhow::anyhow!("G2 handle already set"))
}
/// Set the local G3 (Disk) handle.
///
/// # Arguments
/// * `handle` - G3 layout handle
///
/// # Errors
/// Returns error if G3 handle was already set.
pub fn set_local_g3(&self, handle: LayoutHandle) -> Result<()> {
self.local_g3
.set(handle)
.map_err(|_| anyhow::anyhow!("G3 handle already set"))
}
/// Apply layout response from configure_layouts RPC.
///
/// This is the primary way to populate coordination state. After the leader
/// sends a configure_layouts RPC to the worker, the response contains the
/// handles that were created. This method extracts those handles from the
/// serialized metadata.
///
/// # Arguments
/// * `response` - The WorkerLayoutResponse from configure_layouts RPC
///
/// # Example
/// ```ignore
/// // Leader calls configure_layouts on worker
/// let response = worker_client.configure_layouts(config).await?;
///
/// // Populate coordination state from response
/// coordinated_worker.apply_layout_response(&response)?;
/// ```
pub fn apply_layout_response(&self, response: &WorkerLayoutResponse) -> Result<()> {
// Extract handles from the metadata
let unpacked = response.metadata.unpack()?;
for descriptor in &unpacked.layouts {
match descriptor.logical_type {
LogicalLayoutHandle::G1 => {
let _ = self.local_g1.set(descriptor.handle);
}
LogicalLayoutHandle::G2 => {
let _ = self.local_g2.set(descriptor.handle);
}
LogicalLayoutHandle::G3 => {
let _ = self.local_g3.set(descriptor.handle);
}
LogicalLayoutHandle::G4 => {
// G4 (object store) not tracked locally
}
}
}
Ok(())
}
/// Get the local G1 handle if set.
pub fn local_g1(&self) -> Option<LayoutHandle> {
self.local_g1.get().copied()
}
/// Get the local G2 handle if set.
pub fn local_g2(&self) -> Option<LayoutHandle> {
self.local_g2.get().copied()
}
/// Get the local G3 handle if set.
pub fn local_g3(&self) -> Option<LayoutHandle> {
self.local_g3.get().copied()
}
/// Import metadata from a remote worker and store handle mappings.
///
/// This is called when the leader receives metadata from another leader's
/// workers during cross-leader coordination (e.g., prefill→decode).
///
/// # Arguments
/// * `remote_leader_id` - Instance ID of the remote leader
/// * `remote_rank` - Rank of the remote worker under its leader
/// * `metadata` - Serialized layout metadata from the remote worker
pub async fn import_remote_metadata(
&self,
remote_leader_id: InstanceId,
remote_rank: usize,
metadata: SerializedLayout,
) -> Result<()> {
// Unpack metadata to get logical type info
let unpacked = metadata.unpack()?;
// Import into the underlying worker so NIXL knows about the remote
let repacked = SerializedLayout::pack(
unpacked.worker_address.clone(),
unpacked.nixl_metadata.clone(),
unpacked.layouts.clone(),
)?;
let response = self.inner.import_metadata(repacked)?;
let _handles = response.await?;
// Store mappings for later lookups
let mut mapping = self.remote_handles.write().unwrap();
for descriptor in &unpacked.layouts {
mapping.insert(
(remote_leader_id, remote_rank, descriptor.logical_type),
descriptor.handle,
);
}
Ok(())
}
/// Look up physical handle for a remote transfer.
///
/// # Arguments
/// * `remote_leader_id` - Instance ID of the remote leader
/// * `remote_rank` - Rank of the remote worker
/// * `logical_type` - Logical layout type (G1/G2/G3)
pub fn resolve_remote_handle(
&self,
remote_leader_id: InstanceId,
remote_rank: usize,
logical_type: LogicalLayoutHandle,
) -> Option<LayoutHandle> {
self.remote_handles
.read()
.unwrap()
.get(&(remote_leader_id, remote_rank, logical_type))
.copied()
}
/// Check if remote metadata has been imported for a specific remote worker.
pub fn has_remote_metadata(&self, remote_leader_id: InstanceId, remote_rank: usize) -> bool {
let handles = self.remote_handles.read().unwrap();
handles
.keys()
.any(|(leader, rank, _)| *leader == remote_leader_id && *rank == remote_rank)
}
/// Execute transfer from a remote worker.
///
/// This method looks up the remote handle from stored mappings and
/// executes an RDMA transfer to pull data from the remote worker.
///
/// # Arguments
/// * `remote_leader_id` - Instance ID of the remote leader
/// * `remote_rank` - Rank of the source worker under its leader
/// * `src_logical` - Source logical layout type (e.g., G2)
/// * `src_block_ids` - Block IDs on the remote to pull
/// * `dst_logical` - Destination logical layout type on this worker
/// * `dst_block_ids` - Destination block IDs
/// * `options` - Transfer options
#[allow(clippy::too_many_arguments)]
pub fn transfer_from_remote(
&self,
remote_leader_id: InstanceId,
remote_rank: usize,
src_logical: LogicalLayoutHandle,
src_block_ids: Vec<BlockId>,
dst_logical: LogicalLayoutHandle,
dst_block_ids: std::sync::Arc<[BlockId]>,
options: TransferOptions,
) -> Result<TransferCompleteNotification> {
let src_handle = self
.resolve_remote_handle(remote_leader_id, remote_rank, src_logical)
.ok_or_else(|| {
anyhow::anyhow!(
"No mapping for remote ({}, rank {}, {:?})",
remote_leader_id,
remote_rank,
src_logical
)
})?;
let src = RemoteDescriptor::Layout {
handle: src_handle,
block_ids: src_block_ids,
};
self.inner
.execute_remote_onboard(src, dst_logical, dst_block_ids, options)
}
}
impl ObjectBlockOps for CoordinatedWorker {
fn has_blocks(
&self,
keys: Vec<SequenceHash>,
) -> BoxFuture<'static, Vec<(SequenceHash, Option<usize>)>> {
// Delegate to inner worker - Worker trait now extends ObjectBlockOps
self.inner.has_blocks(keys)
}
fn put_blocks(
&self,
keys: Vec<SequenceHash>,
src_layout: LogicalLayoutHandle,
block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
// Delegate to inner worker - inner worker resolves logical handle
self.inner.put_blocks(keys, src_layout, block_ids)
}
fn get_blocks(
&self,
keys: Vec<SequenceHash>,
dst_layout: LogicalLayoutHandle,
block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
// Delegate to inner worker - inner worker resolves logical handle
self.inner.get_blocks(keys, dst_layout, block_ids)
}
}
#[cfg(test)]
mod tests {
// TODO: Add tests with mock Worker implementation
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Distributed Workers
//!
//! This module provides the interface for how the leader will drive multiple workers.
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
mod spmd;
use std::sync::Arc;
use super::{
ImportMetadataResponse, SerializedLayout, SerializedLayoutResponse, Worker, WorkerTransfers, *,
};
use crate::object::ObjectBlockOps;
use anyhow::Result;
pub use spmd::SpmdParallelWorkers;
/// A cohort of parallel workers.
///
/// This trait is used to drive one or more parallel workers.
pub trait ParallelWorkers: WorkerTransfers + ObjectBlockOps + Send + Sync {
/// Export the local metadata for a set of workers.
///
/// Layouts will be returned in rank order.
///
/// # Returns
/// A [`kvbm_physical::manager::SerializedLayout`] containing the local metadata
fn export_metadata(&self) -> Result<Vec<SerializedLayoutResponse>>;
/// Import the remote metadata for this worker.
///
/// Handles will be returned in rank order.
///
/// # Arguments
/// * `metadata` - A [`kvbm_physical::manager::SerializedLayout`] containing the remote metadata
///
/// # Returns
/// A vector of [`kvbm_physical::manager::LayoutHandle`] for the imported remote layouts
fn import_metadata(
&self,
metadata: Vec<SerializedLayout>,
) -> Result<Vec<ImportMetadataResponse>>;
/// Get the number of workers.
fn worker_count(&self) -> usize;
/// Get access to the underlying workers for metadata/handle queries.
///
/// This is useful for operations that need to query individual workers
/// (e.g., collecting layout handles) without executing transfers.
fn workers(&self) -> &[Arc<dyn Worker>];
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::*;
use crate::object::ObjectBlockOps;
use anyhow::Result;
// velo event types used via fully-qualified paths (::velo::Event, ::velo::EventManager)
use futures::future::BoxFuture;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
/// SPMD (Single Program, Multiple Data) parallel worker group.
///
/// Wraps a set of rank-indexed [`Worker`]s and executes every operation on
/// all of them in parallel. Each worker has its own rank, physical layout
/// handles, and `TransferManager`, but they all receive the same logical
/// commands (transfer, connect, import/export metadata).
///
/// Transfer completion notifications from individual workers are aggregated
/// into a single notification via the event system, so callers see one
/// completion event per logical operation regardless of worker count.
///
/// Remote handle mappings are stored per `(InstanceId, worker_idx,
/// LogicalLayoutHandle)` so that each rank resolves to its own peer handle
/// during RDMA transfers.
pub struct SpmdParallelWorkers {
workers: Vec<Arc<dyn Worker>>,
events: Arc<::velo::EventManager>,
runtime: tokio::runtime::Handle,
/// Remote handle mappings: (InstanceId, worker_idx, LogicalLayoutHandle) -> remote LayoutHandle.
/// Populated by `connect_remote` for later use by `execute_remote_onboard_for_instance`.
remote_handles: RwLock<HashMap<(InstanceId, usize, LogicalLayoutHandle), LayoutHandle>>,
}
impl SpmdParallelWorkers {
/// Create a new SpmdParallelWorkers.
///
/// # Arguments
/// * `workers` - The underlying workers (one per rank)
/// * `events` - The event system for aggregating completion notifications
/// * `runtime` - The tokio runtime handle for spawning aggregation tasks
pub fn new(
workers: Vec<Arc<dyn Worker>>,
events: Arc<::velo::EventManager>,
runtime: tokio::runtime::Handle,
) -> Self {
Self {
workers,
events,
runtime,
remote_handles: RwLock::new(HashMap::new()),
}
}
/// Get the number of workers.
pub fn worker_count(&self) -> usize {
self.workers.len()
}
}
impl WorkerTransfers for SpmdParallelWorkers {
fn execute_local_transfer(
&self,
src: LogicalLayoutHandle,
dst: LogicalLayoutHandle,
src_block_ids: Arc<[BlockId]>,
dst_block_ids: Arc<[BlockId]>,
options: kvbm_physical::transfer::TransferOptions,
) -> Result<TransferCompleteNotification> {
let notifications = self
.workers
.iter()
.map(|worker| {
worker.execute_local_transfer(
src,
dst,
src_block_ids.clone(),
dst_block_ids.clone(),
options.clone(),
)
})
.collect::<Result<Vec<_>>>()?;
TransferCompleteNotification::aggregate(notifications, &self.events, &self.runtime)
}
fn execute_remote_onboard(
&self,
src: RemoteDescriptor,
dst: LogicalLayoutHandle,
dst_block_ids: Arc<[BlockId]>,
options: kvbm_physical::transfer::TransferOptions,
) -> Result<TransferCompleteNotification> {
let notifications = self
.workers
.iter()
.map(|worker| {
worker.execute_remote_onboard(
src.clone(),
dst,
dst_block_ids.clone(),
options.clone(),
)
})
.collect::<Result<Vec<_>>>()?;
TransferCompleteNotification::aggregate(notifications, &self.events, &self.runtime)
}
fn execute_remote_offload(
&self,
src: LogicalLayoutHandle,
src_block_ids: Arc<[BlockId]>,
dst: RemoteDescriptor,
options: kvbm_physical::transfer::TransferOptions,
) -> Result<TransferCompleteNotification> {
let notifications = self
.workers
.iter()
.map(|worker| {
worker.execute_remote_offload(
src,
src_block_ids.clone(),
dst.clone(),
options.clone(),
)
})
.collect::<Result<Vec<_>>>()?;
TransferCompleteNotification::aggregate(notifications, &self.events, &self.runtime)
}
fn connect_remote(
&self,
instance_id: InstanceId,
metadata: Vec<SerializedLayout>,
) -> Result<ConnectRemoteResponse> {
// Validate metadata count matches worker count
if metadata.len() != self.workers.len() {
anyhow::bail!(
"Metadata count ({}) doesn't match worker count ({})",
metadata.len(),
self.workers.len()
);
}
// Collect handles to store and responses to await
let mut new_handles = Vec::new();
let mut import_responses = Vec::new();
for (worker_idx, (worker, meta)) in
self.workers.iter().zip(metadata.into_iter()).enumerate()
{
// Unpack to extract logical type info
let unpacked = meta.unpack()?;
// Collect handle mappings
for descriptor in &unpacked.layouts {
new_handles.push((
(instance_id, worker_idx, descriptor.logical_type),
descriptor.handle,
));
}
// Repack for the underlying worker's import_metadata
let repacked = SerializedLayout::pack(
unpacked.worker_address,
unpacked.nixl_metadata,
unpacked.layouts,
)?;
// Call underlying worker's import_metadata
import_responses.push(worker.import_metadata(repacked)?);
}
// Store all handle mappings
{
let mut handles = self.remote_handles.write().unwrap();
for (key, value) in new_handles {
handles.insert(key, value);
}
}
// If all responses are ready (synchronous), return immediately
if import_responses.iter().all(|r| !r.could_yield()) {
return Ok(ConnectRemoteResponse::ready());
}
// Create an event to aggregate all import completions
let event = self.events.new_event()?;
let awaiter = self.events.awaiter(event.handle())?;
// Spawn task to await all import responses and signal completion
self.runtime
.spawn(await_import_responses(import_responses, event));
Ok(ConnectRemoteResponse::from_awaiter(awaiter))
}
fn has_remote_metadata(&self, instance_id: InstanceId) -> bool {
let handles = self.remote_handles.read().unwrap();
handles.keys().any(|(id, _, _)| *id == instance_id)
}
fn execute_remote_onboard_for_instance(
&self,
instance_id: InstanceId,
remote_logical_type: LogicalLayoutHandle,
src_block_ids: Vec<BlockId>,
dst: LogicalLayoutHandle,
dst_block_ids: Arc<[BlockId]>,
options: kvbm_physical::transfer::TransferOptions,
) -> Result<TransferCompleteNotification> {
let handles = self.remote_handles.read().unwrap();
let mut notifications = Vec::with_capacity(self.workers.len());
// SPMD: Execute SAME transfer on EVERY worker, each with its own remote handle
for (worker_idx, worker) in self.workers.iter().enumerate() {
let remote_handle = handles
.get(&(instance_id, worker_idx, remote_logical_type))
.ok_or_else(|| {
anyhow::anyhow!(
"No remote {:?} handle for instance {} worker {}",
remote_logical_type,
instance_id,
worker_idx
)
})?;
let descriptor = RemoteDescriptor::Layout {
handle: *remote_handle,
block_ids: src_block_ids.clone(),
};
notifications.push(worker.execute_remote_onboard(
descriptor,
dst,
dst_block_ids.clone(),
options.clone(),
)?);
}
TransferCompleteNotification::aggregate(notifications, &self.events, &self.runtime)
}
}
/// Helper to await all import metadata responses and signal completion via an event.
/// Helper to await all import metadata responses and signal completion via an event.
async fn await_import_responses(responses: Vec<ImportMetadataResponse>, event: ::velo::Event) {
let results: Vec<Result<Vec<LayoutHandle>>> =
futures::future::join_all(responses.into_iter().map(|r| r.into_future())).await;
// Check for any failures
let errors: Vec<_> = results.into_iter().filter_map(|r| r.err()).collect();
if errors.is_empty() {
let _ = event.trigger();
} else {
let error_msg = errors
.iter()
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join("; ");
let _ = event.poison(error_msg);
}
}
impl ParallelWorkers for SpmdParallelWorkers {
fn export_metadata(&self) -> Result<Vec<SerializedLayoutResponse>> {
let metadata = self
.workers
.iter()
.map(|worker| worker.export_metadata())
.collect::<Result<Vec<_>>>()?;
Ok(metadata)
}
fn import_metadata(
&self,
metadata: Vec<SerializedLayout>,
) -> Result<Vec<ImportMetadataResponse>> {
// validate the size of the metadata is the same as the number of workers
if metadata.len() != self.workers.len() {
return Err(anyhow::anyhow!(
"Metadata size does not match number of workers"
));
}
let results = self
.workers
.iter()
.zip(metadata.iter())
.map(|(worker, metadata)| worker.import_metadata(metadata.clone()))
.collect::<Result<Vec<_>>>()?;
Ok(results)
}
fn worker_count(&self) -> usize {
self.workers.len()
}
fn workers(&self) -> &[Arc<dyn Worker>] {
&self.workers
}
}
impl ObjectBlockOps for SpmdParallelWorkers {
fn has_blocks(
&self,
keys: Vec<SequenceHash>,
) -> BoxFuture<'static, Vec<(SequenceHash, Option<usize>)>> {
// For has_blocks, we query all workers and verify consistency.
// All workers should agree on block presence for SPMD semantics.
// We return the results from worker 0 but verify all workers agree.
let workers = self.workers.clone();
let _runtime = self.runtime.clone();
Box::pin(async move {
if workers.is_empty() {
return keys.into_iter().map(|k| (k, None)).collect();
}
// Query all workers in parallel
let futures: Vec<_> = workers
.iter()
.map(|worker| worker.has_blocks(keys.clone()))
.collect();
let results: Vec<Vec<(SequenceHash, Option<usize>)>> =
futures::future::join_all(futures).await;
// Return results from first worker (all should agree in SPMD)
// In debug mode, we could verify consistency across workers
results.into_iter().next().unwrap_or_default()
})
}
fn put_blocks(
&self,
keys: Vec<SequenceHash>,
src_layout: LogicalLayoutHandle,
block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
// For put_blocks, each worker writes with its own rank-prefixed key.
// Each worker resolves the logical handle to its own physical layout.
// All workers must succeed for the operation to be considered successful.
let workers = self.workers.clone();
Box::pin(async move {
if workers.is_empty() {
return keys.into_iter().map(Err).collect();
}
// Execute put on all workers in parallel
// Each worker resolves src_layout to its own physical layout
let futures: Vec<_> = workers
.iter()
.map(|worker| worker.put_blocks(keys.clone(), src_layout, block_ids.clone()))
.collect();
let results: Vec<Vec<Result<SequenceHash, SequenceHash>>> =
futures::future::join_all(futures).await;
// Aggregate: a key succeeded only if ALL workers succeeded
let num_keys = keys.len();
let mut aggregated = Vec::with_capacity(num_keys);
for (key_idx, key) in keys.iter().enumerate() {
let all_succeeded = results.iter().all(|worker_results| {
worker_results
.get(key_idx)
.map(|r| r.is_ok())
.unwrap_or(false)
});
if all_succeeded {
aggregated.push(Ok(*key));
} else {
aggregated.push(Err(*key));
}
}
aggregated
})
}
fn get_blocks(
&self,
keys: Vec<SequenceHash>,
dst_layout: LogicalLayoutHandle,
block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
// For get_blocks, each worker reads from its own rank-prefixed key.
// Each worker resolves the logical handle to its own physical layout.
// All workers must succeed for the operation to be considered successful.
let workers = self.workers.clone();
Box::pin(async move {
if workers.is_empty() {
return keys.into_iter().map(Err).collect();
}
// Execute get on all workers in parallel
// Each worker resolves dst_layout to its own physical layout
let futures: Vec<_> = workers
.iter()
.map(|worker| worker.get_blocks(keys.clone(), dst_layout, block_ids.clone()))
.collect();
let results: Vec<Vec<Result<SequenceHash, SequenceHash>>> =
futures::future::join_all(futures).await;
// Aggregate: a key succeeded only if ALL workers succeeded
let num_keys = keys.len();
let mut aggregated = Vec::with_capacity(num_keys);
for (key_idx, key) in keys.iter().enumerate() {
let all_succeeded = results.iter().all(|worker_results| {
worker_results
.get(key_idx)
.map(|r| r.is_ok())
.unwrap_or(false)
});
if all_succeeded {
aggregated.push(Ok(*key));
} else {
aggregated.push(Err(*key));
}
}
aggregated
})
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
mod coordinated;
#[doc = include_str!("../../docs/worker-group.md")]
pub mod group;
mod physical;
mod protocol;
pub mod velo;
pub use coordinated::CoordinatedWorker;
pub use physical::{PhysicalWorker, PhysicalWorkerBuilder};
/// Compatibility alias for [`PhysicalWorker`].
pub use physical::PhysicalWorker as DirectWorker;
use anyhow::Result;
use std::{pin::Pin, sync::Arc};
use crate::object::ObjectBlockOps;
pub use crate::{BlockId, InstanceId, SequenceHash};
pub use kvbm_common::LogicalLayoutHandle;
pub use kvbm_physical::{
manager::{LayoutHandle, SerializedLayout},
transfer::TransferCompleteNotification,
};
pub use velo::{VeloWorkerClient, VeloWorkerService, VeloWorkerServiceBuilder};
/// Boxed future for serialized layout responses - allows both typed_unary and raw unary results
pub type SerializedResponseAwaiter = Pin<Box<dyn Future<Output = Result<SerializedLayout>> + Send>>;
/// Boxed future for import metadata responses
pub type ImportMetadataResponseAwaiter =
Pin<Box<dyn Future<Output = Result<Vec<LayoutHandle>>> + Send>>;
pub use protocol::*;
pub trait WorkerTransfers: Send + Sync {
/// Execute a local transfer between two logical layouts.
///
/// # Arguments
/// * `src` - The source layout handle
/// * `dst` - The destination layout handle
/// * `src_block_ids` - The source block IDs
/// * `dst_block_ids` - The destination block IDs
/// * `options` - Transfer options (layer range, bounce buffers, etc.)
///
/// # Returns
/// A future that completes when the transfer is complete
fn execute_local_transfer(
&self,
src: LogicalLayoutHandle,
dst: LogicalLayoutHandle,
src_block_ids: Arc<[BlockId]>,
dst_block_ids: Arc<[BlockId]>,
options: kvbm_physical::transfer::TransferOptions,
) -> Result<TransferCompleteNotification>;
/// Execute a remote transfer from a remote layout to a local logical layout.
///
/// This represents a NIXL transfer.
///
/// # Arguments
/// * `src` - Remote sources can take several forms, see [`RemoteDescriptor`]
/// * `dst` - The destination layout handle
/// * `dst_block_ids` - The destination block IDs
/// * `options` - Transfer options (layer range, bounce buffers, etc.)
///
/// # Returns
/// A future that completes when the transfer is complete
fn execute_remote_onboard(
&self,
src: RemoteDescriptor,
dst: LogicalLayoutHandle,
dst_block_ids: Arc<[BlockId]>,
options: kvbm_physical::transfer::TransferOptions,
) -> Result<TransferCompleteNotification>;
/// Execute a remote offload from a local logical layout to a remote descriptor.
///
/// This represents a NIXL offload.
///
/// # Arguments
/// * `src` - The source layout handle
/// * `dst` - The destination remote descriptor
/// * `src_block_ids` - The source block IDs
/// * `options` - Transfer options (layer range, bounce buffers, etc.)
///
/// # Returns
/// A future that completes when the offload is complete
fn execute_remote_offload(
&self,
src: LogicalLayoutHandle,
src_block_ids: Arc<[BlockId]>,
dst: RemoteDescriptor,
options: kvbm_physical::transfer::TransferOptions,
) -> Result<TransferCompleteNotification>;
/// Connect to a remote instance by importing its metadata and storing handle mappings.
///
/// This method stores the handle mappings internally for later use by
/// `execute_remote_onboard_for_instance`. The metadata is also imported into
/// the underlying transfer manager so NIXL knows about the remote.
///
/// # Arguments
/// * `instance_id` - The unique identifier of the remote instance
/// * `metadata` - Serialized layout metadata from the remote instance.
/// For DirectWorker, expects exactly 1 element.
/// For ReplicatedWorker, expects one element per worker (in rank order).
///
/// # Returns
/// A response that completes when the metadata has been imported and mappings stored.
fn connect_remote(
&self,
instance_id: InstanceId,
metadata: Vec<SerializedLayout>,
) -> Result<ConnectRemoteResponse>;
/// Check if remote metadata has been imported for an instance.
///
/// Returns true if `connect_remote` has been successfully called for this instance.
fn has_remote_metadata(&self, instance_id: InstanceId) -> bool;
/// Execute a remote onboard transfer using stored handle mapping.
///
/// This method looks up the remote handle from the stored mapping
/// (established via `connect_remote`) and executes the transfer.
///
/// # Arguments
/// * `instance_id` - The remote instance to pull from
/// * `remote_logical_type` - The logical layout type on the remote (e.g., G2)
/// * `src_block_ids` - Block IDs on the remote to pull
/// * `dst` - Local destination logical layout
/// * `dst_block_ids` - Local destination block IDs
/// * `options` - Transfer options
///
/// # Errors
/// Returns error if remote metadata hasn't been imported for this instance.
fn execute_remote_onboard_for_instance(
&self,
instance_id: InstanceId,
remote_logical_type: LogicalLayoutHandle,
src_block_ids: Vec<BlockId>,
dst: LogicalLayoutHandle,
dst_block_ids: Arc<[BlockId]>,
options: kvbm_physical::transfer::TransferOptions,
) -> Result<TransferCompleteNotification>;
}
pub trait Worker: WorkerTransfers + ObjectBlockOps + Send + Sync {
/// Get the G1 layout handle for this worker (if configured).
///
/// Returns None if no G1 layout has been registered with this worker.
fn g1_handle(&self) -> Option<LayoutHandle>;
/// Get the G2 layout handle for this worker (if configured).
///
/// Returns None if no G2 layout has been registered with this worker.
fn g2_handle(&self) -> Option<LayoutHandle>;
/// Get the G3 layout handle for this worker (if configured).
///
/// Returns None if no G3 layout has been registered with this worker.
fn g3_handle(&self) -> Option<LayoutHandle>;
/// Export the local metadata for this worker.
///
/// # Returns
/// A [`kvbm_physical::manager::SerializedLayout`] containing the local metadata
fn export_metadata(&self) -> Result<SerializedLayoutResponse>;
/// Import the remote metadata for this worker.
///
/// # Arguments
/// * `metadata` - A [`kvbm_physical::manager::SerializedLayout`] containing the remote metadata
///
/// # Returns
/// A vector of [`kvbm_physical::manager::LayoutHandle`] for the imported remote layouts
fn import_metadata(&self, metadata: SerializedLayout) -> Result<ImportMetadataResponse>;
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Base worker implementation for single-worker transfer execution.
//!
//! This module provides the [`PhysicalWorker`] type which executes transfer operations
//! using a local [`TransferManager`]. It serves as the foundation for both standalone
//! worker scenarios and as a building block for parallel worker implementations.
#[cfg(feature = "collectives")]
mod replicated;
#[cfg(feature = "collectives")]
#[allow(unused_imports)]
pub use replicated::ReplicatedDataWorker;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[cfg(feature = "nccl")]
use cudarc::driver::CudaEvent;
use derive_builder::Builder;
use futures::future::BoxFuture;
use crate::object::ObjectBlockOps;
use kvbm_physical::layout::PhysicalLayout;
use kvbm_physical::{
manager::{SerializedLayout, TransferManager},
transfer::{BounceBuffer, TransferOptions, context::TransferCompleteNotification},
};
use super::*;
/// PhysicalWorker executes transfer operations using a local TransferManager.
///
/// This is the fundamental worker type that directly owns a `TransferManager` and
/// layout handles for executing data transfers. It implements the [`Worker`] and
/// [`WorkerTransfers`] traits for single-worker scenarios.
///
/// # Builder fields
///
/// | Field | Required | Description |
/// |-------|----------|-------------|
/// | `manager` | **yes** | `TransferManager` that executes actual data movement |
/// | `g1_handle` | no | GPU KV cache layout handle (for GPU transfers) |
/// | `g2_handle` | no | Host/pinned cache layout handle (for host transfers) |
/// | `g3_handle` | no | Disk cache layout handle (for disk-tier transfers) |
/// | `rank` | no | Worker rank for object-key prefixing in SPMD setups |
/// | `object_client` | no | Object storage client for G4 tier (S3, etc.) |
///
/// # Execution State vs Coordination State
///
/// PhysicalWorker maintains **execution state** -- the handles and manager needed
/// to actually perform RDMA/local transfers. This is distinct from
/// **coordination state** which the leader tracks in [`CoordinatedWorker`].
///
/// When a leader wraps a PhysicalWorker in a CoordinatedWorker:
/// - PhysicalWorker: owns handles for TransferManager execution
/// - CoordinatedWorker: tracks the same handles for leader coordination
///
/// This duplication is intentional -- PhysicalWorker needs handles to execute,
/// and CoordinatedWorker provides a uniform API regardless of whether the
/// inner worker is local (PhysicalWorker) or remote (VeloWorkerClient).
///
/// # Typical lifecycle
///
/// 1. Created via `PhysicalWorker::builder()` during deferred initialization
/// 2. Wrapped by [`VeloWorkerService`] to expose RPC handlers
/// 3. Wrapped by [`CoordinatedWorker`] for leader coordination
/// 4. Used as a building block in parallel workers (e.g., `SpmdParallelWorkers`)
///
/// [`CoordinatedWorker`]: super::CoordinatedWorker
/// [`VeloWorkerService`]: super::VeloWorkerService
#[derive(Builder)]
#[builder(pattern = "owned")]
pub struct PhysicalWorker {
// =========================================================================
// Execution State - needed by TransferManager to perform operations
// =========================================================================
/// The transfer manager that executes actual data movement.
manager: TransferManager,
/// G1 (GPU KV cache) layout handle - set during initialization.
/// Required for GPU-to-GPU and GPU-to-Host transfers.
#[builder(default, setter(strip_option))]
g1_handle: Option<LayoutHandle>,
/// G2 (Host/pinned cache) layout handle - set during initialization.
/// Required for Host-to-GPU and Host-to-Disk transfers.
#[builder(default, setter(strip_option))]
g2_handle: Option<LayoutHandle>,
/// G3 (Disk cache) layout handle - set during initialization if disk tier enabled.
/// Required for Disk-to-Host transfers.
#[builder(default, setter(strip_option))]
g3_handle: Option<LayoutHandle>,
/// Remote handle mappings for peer-to-peer transfers.
/// Key: (InstanceId, LogicalLayoutHandle) → remote LayoutHandle
///
/// Populated by `connect_remote` when this worker imports metadata from
/// a peer instance. Used by `execute_remote_onboard_for_instance` to
/// resolve logical handles to physical handles for RDMA transfers.
///
/// Note: This is per-instance mapping (no rank), suitable for single-worker
/// scenarios. For multi-worker asymmetric TP, use CoordinatedWorker's
/// rank-aware remote_handles instead.
#[builder(default = "RwLock::new(HashMap::new())")]
remote_handles: RwLock<HashMap<(InstanceId, LogicalLayoutHandle), LayoutHandle>>,
// =========================================================================
// Object Storage State
// =========================================================================
/// Worker rank (set during initialization from LeaderLayoutConfig).
/// Used to augment object keys for unique storage across SPMD workers.
#[builder(default, setter(strip_option))]
rank: Option<usize>,
/// Optional object storage client for G4 tier operations.
/// Set during initialization if object storage is enabled.
#[builder(default, setter(strip_option))]
object_client: Option<Arc<dyn ObjectBlockOps>>,
}
impl PhysicalWorker {
/// Create a new builder for PhysicalWorker.
///
/// # Example
/// ```rust,ignore
/// let worker = PhysicalWorker::builder()
/// .manager(manager)
/// .g1_handle(g1_handle)
/// .g2_handle(g2_handle)
/// .g3_handle(g3_handle)
/// .build();
/// ```
pub fn builder() -> PhysicalWorkerBuilder {
PhysicalWorkerBuilder::default()
}
/// Get the worker rank (if set).
pub fn rank(&self) -> Option<usize> {
self.rank
}
/// Get the object storage client (if set).
pub fn object_client(&self) -> Option<&Arc<dyn ObjectBlockOps>> {
self.object_client.as_ref()
}
/// Get the G1 layout handle (if set).
pub fn g1_handle(&self) -> Option<LayoutHandle> {
self.g1_handle
}
/// Get the G2 layout handle (if set).
pub fn g2_handle(&self) -> Option<LayoutHandle> {
self.g2_handle
}
/// Get the G3 layout handle (if set).
pub fn g3_handle(&self) -> Option<LayoutHandle> {
self.g3_handle
}
/// Get a reference to the TransferManager.
pub fn transfer_manager(&self) -> &TransferManager {
&self.manager
}
/// Resolve a logical layout handle to a physical layout.
///
/// # Arguments
/// * `logical` - The logical layout handle (G1, G2, G3)
///
/// # Returns
/// The physical layout for the given logical handle, or an error if not found.
pub fn resolve_layout(&self, logical: LogicalLayoutHandle) -> Result<PhysicalLayout> {
use LogicalLayoutHandle::*;
let physical_handle = match logical {
G1 => self.g1_handle(),
G2 => self.g2_handle(),
G3 => self.g3_handle(),
_ => None,
}
.ok_or_else(|| anyhow::anyhow!("No layout registered for {:?}", logical))?;
self.manager
.get_physical_layout(physical_handle)
.ok_or_else(|| {
anyhow::anyhow!(
"Layout handle {:?} not found in TransferManager",
physical_handle
)
})
}
/// Create a bounce buffer specification from a layout handle and block IDs.
pub fn create_bounce_buffer(
&self,
handle: LayoutHandle,
block_ids: Vec<BlockId>,
) -> Result<BounceBuffer> {
Ok(BounceBuffer::from_handle(handle, block_ids))
}
/// Export serialized layout metadata with proper logical type mappings.
///
/// This exports layouts with their logical types (G1, G2, G3) so that
/// remote instances can correctly identify which handle corresponds to
/// which tier during RDMA transfers.
pub fn export_metadata(&self) -> Result<SerializedLayout> {
self.export_metadata_with_logical_types()
}
/// Export metadata with logical type annotations for each registered handle.
fn export_metadata_with_logical_types(&self) -> Result<SerializedLayout> {
let mut descriptors = Vec::new();
// Build descriptors for each registered logical handle
if let Some(handle) = self.g1_handle() {
descriptors.push(
self.manager
.build_logical_descriptor(handle, LogicalLayoutHandle::G1)?,
);
}
if let Some(handle) = self.g2_handle() {
descriptors.push(
self.manager
.build_logical_descriptor(handle, LogicalLayoutHandle::G2)?,
);
}
if let Some(handle) = self.g3_handle() {
descriptors.push(
self.manager
.build_logical_descriptor(handle, LogicalLayoutHandle::G3)?,
);
}
// Pack with worker address and NIXL metadata
let worker_address = self.manager.worker_address();
let nixl_metadata = self.manager.get_nixl_metadata()?;
SerializedLayout::pack(worker_address, nixl_metadata, descriptors)
}
/// Import serialized layout metadata into the transfer manager.
pub fn import_metadata(&self, metadata: SerializedLayout) -> Result<Vec<LayoutHandle>> {
self.manager.import_metadata(metadata)
}
/// Execute layer-wise local transfer from G2 to G1.
///
/// This method transfers blocks from the host cache (G2) to the GPU cache (G1)
/// one layer at a time, recording an event after each layer's transfer completes.
/// All transfers execute on the same CUDA stream to ensure proper ordering.
///
/// The caller provides pre-allocated events that are reused across iterations.
/// After calling this method, the caller can use `cudaStreamWaitEvent` on the
/// torch stream to synchronize each layer's load before attention computation.
///
/// # Arguments
/// * `src_block_ids` - Source block IDs in G2 (host cache)
/// * `dst_block_ids` - Destination block IDs in G1 (GPU cache)
/// * `layer_events` - Pre-allocated CUDA events, one per layer. Must have length == num_layers.
///
/// # Returns
/// `Ok(())` on success. The caller owns synchronization via the recorded events.
///
/// # Errors
/// Returns an error if:
/// - src_block_ids and dst_block_ids have different lengths
/// - layer_events length doesn't match num_layers
/// - G1 or G2 handles are not registered
/// - Any layer transfer fails
#[cfg(feature = "nccl")]
pub fn execute_local_layerwise_onboard(
&self,
src_block_ids: &[BlockId],
dst_block_ids: &[BlockId],
layer_events: &[Arc<CudaEvent>],
) -> Result<()> {
// Validate block ID lengths match
if src_block_ids.len() != dst_block_ids.len() {
return Err(anyhow::anyhow!(
"Block ID length mismatch: src={}, dst={}",
src_block_ids.len(),
dst_block_ids.len()
));
}
// Get layout handles
let g2_handle = self
.g2_handle()
.ok_or_else(|| anyhow::anyhow!("G2 layout not registered"))?;
let g1_handle = self
.g1_handle()
.ok_or_else(|| anyhow::anyhow!("G1 layout not registered"))?;
// Get num_layers from layout config
let g2_config = self.manager.get_layout_config(g2_handle)?;
let num_layers = g2_config.num_layers;
// Validate layer_events length
if layer_events.len() != num_layers {
return Err(anyhow::anyhow!(
"layer_events length ({}) doesn't match num_layers ({})",
layer_events.len(),
num_layers
));
}
// Acquire a dedicated stream for all layer transfers
let stream = self.manager.context().acquire_h2d_stream();
tracing::debug!(
num_layers,
num_blocks = src_block_ids.len(),
"Starting layer-wise onboard from G2 to G1"
);
// Execute transfer for each layer and record event
for layer in 0..num_layers {
// Execute single-layer transfer on our dedicated stream
let options = TransferOptions::builder()
.layer_range(layer..layer + 1)
.cuda_stream(stream.clone())
.build()?;
self.manager.execute_transfer(
g2_handle,
src_block_ids,
g1_handle,
dst_block_ids,
options,
)?;
// Record event on the stream for this layer
layer_events[layer].record(stream.as_ref())?;
}
tracing::debug!(num_layers, "Layer-wise onboard complete - events recorded");
Ok(())
}
}
impl WorkerTransfers for PhysicalWorker {
fn execute_local_transfer(
&self,
src: LogicalLayoutHandle,
dst: LogicalLayoutHandle,
src_block_ids: Arc<[BlockId]>,
dst_block_ids: Arc<[BlockId]>,
options: TransferOptions,
) -> Result<TransferCompleteNotification> {
use LogicalLayoutHandle::*;
let src_layout = match &src {
G1 => self.g1_handle(),
G2 => self.g2_handle(),
G3 => self.g3_handle(),
G4 => return Err(anyhow::anyhow!("G4 is not supported for local transfers")),
}
.ok_or_else(|| anyhow::anyhow!("Source layout not registered: {:?}", src))?;
let dst_layout = match &dst {
G1 => self.g1_handle(),
G2 => self.g2_handle(),
G3 => self.g3_handle(),
G4 => return Err(anyhow::anyhow!("G4 is not supported for local transfers")),
}
.ok_or_else(|| anyhow::anyhow!("Destination layout not registered: {:?}", dst))?;
self.manager.execute_transfer(
src_layout,
&src_block_ids,
dst_layout,
&dst_block_ids,
options,
)
}
fn execute_remote_onboard(
&self,
src: RemoteDescriptor,
dst: LogicalLayoutHandle,
dst_block_ids: Arc<[BlockId]>,
options: TransferOptions,
) -> Result<TransferCompleteNotification> {
use LogicalLayoutHandle::*;
let dst_layout = match &dst {
G1 => self.g1_handle(),
G2 => self.g2_handle(),
G3 => self.g3_handle(),
G4 => return Err(anyhow::anyhow!("G4 is not supported for remote transfers")),
}
.ok_or_else(|| anyhow::anyhow!("Destination layout not registered: {:?}", dst))?;
match src {
RemoteDescriptor::Layout { handle, block_ids } => {
// RDMA onboard from remote layout
let block_ids_arc: Arc<[BlockId]> = block_ids.into();
self.manager.execute_transfer(
handle,
&block_ids_arc,
dst_layout,
&dst_block_ids,
options,
)
}
RemoteDescriptor::Object { keys } => {
// Object storage onboard (e.g., S3 → G2)
let object_client = self
.object_client
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Object client not configured"))?
.clone();
// Resolve destination physical layout
let dst_physical = self.resolve_layout(dst)?;
let block_ids_vec: Vec<BlockId> = dst_block_ids.to_vec();
// Create event for completion notification
let ctx = self.manager.context();
let event = ctx.event_system().new_event()?;
let handle = event.handle();
let awaiter = ctx.event_system().awaiter(handle)?;
// Spawn task to execute object storage read
ctx.tokio().spawn(async move {
let results = object_client
.get_blocks_with_layout(keys.clone(), dst_physical, block_ids_vec)
.await;
// Check if any failed
let failed: Vec<_> = results.iter().filter(|r| r.is_err()).collect();
if failed.is_empty() {
let _ = event.trigger();
} else {
let error_msg = format!(
"{} of {} blocks failed to download",
failed.len(),
results.len()
);
let _ = event.poison(error_msg);
}
});
Ok(TransferCompleteNotification::from_awaiter(awaiter))
}
}
}
fn execute_remote_offload(
&self,
src: LogicalLayoutHandle,
src_block_ids: Arc<[BlockId]>,
dst: RemoteDescriptor,
_options: TransferOptions,
) -> Result<TransferCompleteNotification> {
match dst {
RemoteDescriptor::Layout { handle, block_ids } => {
// RDMA offload to remote layout
let src_layout = match &src {
LogicalLayoutHandle::G1 => self.g1_handle(),
LogicalLayoutHandle::G2 => self.g2_handle(),
LogicalLayoutHandle::G3 => self.g3_handle(),
LogicalLayoutHandle::G4 => {
return Err(anyhow::anyhow!("G4 cannot be used as source for offload"));
}
}
.ok_or_else(|| anyhow::anyhow!("Source layout not registered: {:?}", src))?;
let block_ids_arc: Arc<[BlockId]> = block_ids.into();
self.manager.execute_transfer(
src_layout,
&src_block_ids,
handle,
&block_ids_arc,
_options,
)
}
RemoteDescriptor::Object { keys } => {
// Object storage offload (e.g., G2 → S3)
let object_client = self
.object_client
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Object client not configured"))?
.clone();
// Resolve source physical layout
let src_physical = self.resolve_layout(src)?;
let block_ids_vec: Vec<BlockId> = src_block_ids.to_vec();
// Create event for completion notification
let ctx = self.manager.context();
let event = ctx.event_system().new_event()?;
let handle = event.handle();
let awaiter = ctx.event_system().awaiter(handle)?;
// Spawn task to execute object storage write
ctx.tokio().spawn(async move {
let results = object_client
.put_blocks_with_layout(keys.clone(), src_physical, block_ids_vec)
.await;
// Check if any failed
let failed: Vec<_> = results.iter().filter(|r| r.is_err()).collect();
if failed.is_empty() {
let _ = event.trigger();
} else {
let error_msg = format!(
"{} of {} blocks failed to upload",
failed.len(),
results.len()
);
let _ = event.poison(error_msg);
}
});
Ok(TransferCompleteNotification::from_awaiter(awaiter))
}
}
}
fn connect_remote(
&self,
instance_id: InstanceId,
metadata: Vec<SerializedLayout>,
) -> Result<ConnectRemoteResponse> {
// PhysicalWorker expects exactly 1 metadata item
if metadata.len() != 1 {
anyhow::bail!(
"PhysicalWorker expects exactly 1 metadata item, got {}",
metadata.len()
);
}
let meta = metadata.into_iter().next().unwrap();
// Unpack to extract logical type info
let unpacked = meta.unpack()?;
// Store mappings
{
let mut handles = self.remote_handles.write().unwrap();
for descriptor in &unpacked.layouts {
handles.insert((instance_id, descriptor.logical_type), descriptor.handle);
}
}
// Import so NIXL knows about the remote (repack to pass ownership)
let repacked = SerializedLayout::pack(
unpacked.worker_address,
unpacked.nixl_metadata,
unpacked.layouts,
)?;
self.manager.import_metadata(repacked)?;
Ok(ConnectRemoteResponse::ready())
}
fn has_remote_metadata(&self, instance_id: InstanceId) -> bool {
let handles = self.remote_handles.read().unwrap();
handles.keys().any(|(id, _)| *id == instance_id)
}
fn execute_remote_onboard_for_instance(
&self,
instance_id: InstanceId,
remote_logical_type: LogicalLayoutHandle,
src_block_ids: Vec<BlockId>,
dst: LogicalLayoutHandle,
dst_block_ids: Arc<[BlockId]>,
options: TransferOptions,
) -> Result<TransferCompleteNotification> {
let handles = self.remote_handles.read().unwrap();
let remote_handle = handles
.get(&(instance_id, remote_logical_type))
.ok_or_else(|| {
anyhow::anyhow!(
"No remote {:?} handle for instance {}",
remote_logical_type,
instance_id
)
})?;
let descriptor = RemoteDescriptor::Layout {
handle: *remote_handle,
block_ids: src_block_ids,
};
self.execute_remote_onboard(descriptor, dst, dst_block_ids, options)
}
}
impl Worker for PhysicalWorker {
fn g1_handle(&self) -> Option<LayoutHandle> {
self.g1_handle
}
fn g2_handle(&self) -> Option<LayoutHandle> {
self.g2_handle
}
fn g3_handle(&self) -> Option<LayoutHandle> {
self.g3_handle
}
fn export_metadata(&self) -> Result<SerializedLayoutResponse> {
// Use the logical-type-aware export
self.export_metadata_with_logical_types()
.map(SerializedLayoutResponse::ready)
}
fn import_metadata(&self, metadata: SerializedLayout) -> Result<ImportMetadataResponse> {
self.manager
.import_metadata(metadata)
.map(ImportMetadataResponse::ready)
}
}
impl ObjectBlockOps for PhysicalWorker {
fn has_blocks(
&self,
keys: Vec<SequenceHash>,
) -> BoxFuture<'static, Vec<(SequenceHash, Option<usize>)>> {
// Object client handles rank-based key prefixing internally
if let Some(client) = self.object_client.as_ref() {
client.has_blocks(keys)
} else {
// No object client configured - return all keys as not found
Box::pin(async move { keys.into_iter().map(|k| (k, None)).collect() })
}
}
fn put_blocks(
&self,
keys: Vec<SequenceHash>,
src_layout: LogicalLayoutHandle,
block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
// Resolve logical handle to physical layout
let physical_layout = match self.resolve_layout(src_layout) {
Ok(layout) => layout,
Err(e) => {
tracing::error!(?src_layout, error = %e, "Failed to resolve layout for put_blocks");
return Box::pin(async move { keys.into_iter().map(Err).collect() });
}
};
// Object client handles rank-based key prefixing internally
if let Some(client) = self.object_client.as_ref() {
client.put_blocks_with_layout(keys, physical_layout, block_ids)
} else {
// No object client configured - return all keys as failed
tracing::warn!("put_blocks called but no object client configured");
Box::pin(async move { keys.into_iter().map(Err).collect() })
}
}
fn get_blocks(
&self,
keys: Vec<SequenceHash>,
dst_layout: LogicalLayoutHandle,
block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
// Resolve logical handle to physical layout
let physical_layout = match self.resolve_layout(dst_layout) {
Ok(layout) => layout,
Err(e) => {
tracing::error!(?dst_layout, error = %e, "Failed to resolve layout for get_blocks");
return Box::pin(async move { keys.into_iter().map(Err).collect() });
}
};
// Object client handles rank-based key prefixing internally
if let Some(client) = self.object_client.as_ref() {
client.get_blocks_with_layout(keys, physical_layout, block_ids)
} else {
// No object client configured - return all keys as failed
tracing::warn!("get_blocks called but no object client configured");
Box::pin(async move { keys.into_iter().map(Err).collect() })
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Replicated data worker for MLA (Multi-head Latent Attention) scenarios.
//!
//! In MLA architectures, KV blocks are replicated across all workers rather than
//! sharded. This means only rank 0 needs G2/G3 storage - other ranks receive
//! data via broadcast from rank 0 after it loads from G2/G3.
//!
//! # Architecture
//!
//! ```text
//! Rank 0: G3 (disk) ←→ G2 (host) ←→ G1 (GPU) ───broadcast──→ Other ranks G1
//! Rank 1-N: [no G2/G3] G1 (GPU) ←──────────────────────┘
//! ```
//!
//! # Transfer Semantics
//!
//! | Operation | Behavior |
//! |-----------|----------|
//! | G2/G3 → G1 (onboard) | Rank 0 transfers, then broadcasts to all ranks |
//! | G1 → G2/G3 (offload) | Rank 0 only (other ranks don't have G2/G3) |
//! | G2 ↔ G3 | Rank 0 only |
//! | G1 → G1 (local) | All ranks execute (data is replicated) |
use super::*;
use crate::KvbmRuntime;
use crate::collectives::CollectiveOps;
use anyhow::Result;
use std::sync::Arc;
/// Replicated data worker for MLA scenarios.
///
/// Only rank 0 has G2/G3 storage. When loading data to G1, rank 0 transfers
/// from G2/G3 and then broadcasts to all other ranks via collective operations.
///
/// # Requirements
///
/// - Workers must be initialized such that only rank 0 has G2/G3 handles
/// - A [`CollectiveOps`] implementation must be provided for broadcasting
///
/// # Trait Implementations
///
/// - [`WorkerTransfers`]: Specialized routing based on source/destination tiers
/// - [`ParallelWorker`]: Delegates to inner SpmdWorker
/// - [`ObjectBlockOps`]: Routes to rank 0 only (it has the G2 layout for resolution)
#[allow(dead_code)]
pub struct ReplicatedDataWorker {
inner: Arc<PhysicalWorker>,
runtime: Arc<KvbmRuntime>,
collective: Arc<dyn CollectiveOps>,
}
#[allow(dead_code)]
impl ReplicatedDataWorker {
/// Create a new ReplicatedDataWorker.
///
/// # Arguments
/// * `workers` - The underlying workers (one per rank). Only workers[0] should have G2/G3.
/// * `events` - The event system for aggregating completion notifications
/// * `runtime` - The tokio runtime handle for spawning aggregation tasks
/// * `collective` - The collective ops implementation for broadcasting
///
/// # Panics
///
/// Debug builds will panic if workers.len() < 1.
pub fn new(
worker: Arc<PhysicalWorker>, // perhaps use a trait to abstract this
runtime: Arc<KvbmRuntime>,
collective: Arc<dyn CollectiveOps>,
) -> Self {
// todo: ensure worker has a rank
Self {
inner: worker,
runtime,
collective,
}
}
/// Get access to the underlying SpmdWorker.
pub fn inner(&self) -> &PhysicalWorker {
&self.inner
}
/// Get the rank of the underlying worker.
pub fn rank(&self) -> usize {
self.inner.rank().expect("Worker must have a rank")
}
#[expect(unused_variables)]
fn broadcast(
&self,
xfer_completion: TransferCompleteNotification,
dst: LogicalLayoutHandle,
dst_block_ids: Arc<[BlockId]>,
options: kvbm_physical::transfer::TransferOptions,
) -> Result<TransferCompleteNotification> {
unimplemented!()
}
}
impl WorkerTransfers for ReplicatedDataWorker {
fn execute_local_transfer(
&self,
src: LogicalLayoutHandle,
dst: LogicalLayoutHandle,
src_block_ids: Arc<[BlockId]>,
dst_block_ids: Arc<[BlockId]>,
options: kvbm_physical::transfer::TransferOptions,
) -> Result<TransferCompleteNotification> {
let is_rank0 = self.rank() == 0;
let use_bcast = dst == LogicalLayoutHandle::G1;
if src == LogicalLayoutHandle::G1 && dst == LogicalLayoutHandle::G1 {
return self.inner.execute_local_transfer(
src,
dst,
src_block_ids,
dst_block_ids.clone(),
options,
);
}
if !is_rank0 && !use_bcast {
return Ok(TransferCompleteNotification::completed());
} else if is_rank0 {
let xfer_completion = self.inner.execute_local_transfer(
src,
dst,
src_block_ids,
dst_block_ids.clone(),
options.clone(),
)?;
if use_bcast {
self.broadcast(xfer_completion, dst, dst_block_ids, options)
} else {
Ok(xfer_completion)
}
} else {
let xfer_completion = TransferCompleteNotification::completed();
self.broadcast(xfer_completion, dst, dst_block_ids, options)
}
}
#[expect(unused_variables)]
fn execute_remote_onboard(
&self,
src: RemoteDescriptor,
dst: LogicalLayoutHandle,
dst_block_ids: Arc<[BlockId]>,
options: kvbm_physical::transfer::TransferOptions,
) -> Result<TransferCompleteNotification> {
unimplemented!()
}
#[expect(unused_variables)]
fn execute_remote_offload(
&self,
src: LogicalLayoutHandle,
src_block_ids: Arc<[BlockId]>,
dst: RemoteDescriptor,
options: kvbm_physical::transfer::TransferOptions,
) -> Result<TransferCompleteNotification> {
unimplemented!()
}
fn connect_remote(
&self,
instance_id: InstanceId,
metadata: Vec<SerializedLayout>,
) -> Result<ConnectRemoteResponse> {
// Use the shared implementation
self.inner.connect_remote(instance_id, metadata)
}
fn has_remote_metadata(&self, instance_id: InstanceId) -> bool {
self.inner.has_remote_metadata(instance_id)
}
#[expect(unused_variables)]
fn execute_remote_onboard_for_instance(
&self,
instance_id: InstanceId,
remote_logical_type: LogicalLayoutHandle,
src_block_ids: Vec<BlockId>,
dst: LogicalLayoutHandle,
dst_block_ids: Arc<[BlockId]>,
options: kvbm_physical::transfer::TransferOptions,
) -> Result<TransferCompleteNotification> {
unimplemented!()
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::Result;
use futures::future::{Either, Ready, ready};
use serde::{Deserialize, Serialize};
use std::{
pin::Pin,
task::{Context, Poll},
};
pub use crate::worker::{ImportMetadataResponseAwaiter, SerializedResponseAwaiter};
pub use crate::{BlockId, SequenceHash};
pub use kvbm_common::LogicalLayoutHandle;
pub use kvbm_physical::manager::{LayoutHandle, SerializedLayout};
pub struct SerializedLayoutResponse {
awaiter: Either<Ready<Result<SerializedLayout>>, SerializedResponseAwaiter>,
}
impl SerializedLayoutResponse {
pub fn ready(layout: SerializedLayout) -> Self {
Self {
awaiter: Either::Left(ready(Ok(layout))),
}
}
pub fn from_boxed(awaiter: SerializedResponseAwaiter) -> Self {
Self {
awaiter: Either::Right(awaiter),
}
}
pub fn could_yield(&self) -> bool {
matches!(self.awaiter, Either::Right(_))
}
}
impl std::future::IntoFuture for SerializedLayoutResponse {
type Output = Result<SerializedLayout>;
type IntoFuture = Either<Ready<Result<SerializedLayout>>, SerializedResponseAwaiter>;
fn into_future(self) -> Self::IntoFuture {
self.awaiter
}
}
pub struct ImportMetadataResponse {
awaiter: Either<Ready<Result<Vec<LayoutHandle>>>, ImportMetadataResponseAwaiter>,
}
impl ImportMetadataResponse {
pub fn ready(handles: Vec<LayoutHandle>) -> Self {
Self {
awaiter: Either::Left(ready(Ok(handles))),
}
}
pub fn from_boxed(awaiter: ImportMetadataResponseAwaiter) -> Self {
Self {
awaiter: Either::Right(awaiter),
}
}
pub fn could_yield(&self) -> bool {
matches!(self.awaiter, Either::Right(_))
}
}
impl std::future::IntoFuture for ImportMetadataResponse {
type Output = Result<Vec<LayoutHandle>>;
type IntoFuture = Either<Ready<Result<Vec<LayoutHandle>>>, ImportMetadataResponseAwaiter>;
fn into_future(self) -> Self::IntoFuture {
self.awaiter
}
}
/// Response type for `connect_remote` operations.
///
/// This type represents the completion state of a remote metadata import
/// with handle mapping storage. Like other response types, it can be awaited.
///
/// For direct workers, this is typically ready immediately.
/// For replicated workers, this aggregates multiple underlying imports.
pub struct ConnectRemoteResponse {
awaiter: ConnectRemoteAwaiter,
}
pub enum ConnectRemoteAwaiter {
Ready(Ready<Result<()>>),
Event(::velo::EventAwaiter),
}
impl std::future::Future for ConnectRemoteAwaiter {
type Output = Result<()>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.get_mut() {
Self::Ready(ready) => Pin::new(ready).poll(cx),
Self::Event(waiter) => Pin::new(waiter).poll(cx),
}
}
}
impl ConnectRemoteResponse {
/// Create a response that is already completed.
///
/// This is used when the connect operation completes synchronously,
/// such as for DirectWorker with local metadata import.
pub fn ready() -> Self {
Self {
awaiter: ConnectRemoteAwaiter::Ready(ready(Ok(()))),
}
}
/// Create a response from an event waiter.
///
/// This is used when the connect operation requires waiting for
/// multiple underlying operations to complete (e.g., ReplicatedWorker).
pub fn from_awaiter(awaiter: ::velo::EventAwaiter) -> Self {
Self {
awaiter: ConnectRemoteAwaiter::Event(awaiter),
}
}
/// Check if the response can yield the current task.
pub fn could_yield(&self) -> bool {
matches!(self.awaiter, ConnectRemoteAwaiter::Event(_))
}
}
impl std::future::IntoFuture for ConnectRemoteResponse {
type Output = Result<()>;
type IntoFuture = ConnectRemoteAwaiter;
fn into_future(self) -> Self::IntoFuture {
self.awaiter
}
}
/// Remote descriptor for transfer operations.
#[derive(Serialize, Deserialize, Clone)]
pub enum RemoteDescriptor {
Layout {
handle: LayoutHandle,
block_ids: Vec<BlockId>,
},
Object {
keys: Vec<SequenceHash>,
},
}
/// Configuration sent from leader to workers for G2/G3/G4 layout creation.
///
/// This message is sent via Nova RPC during Phase 3 coordination.
/// Workers use this to create additional cache tiers beyond G1 (GPU KV).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LeaderLayoutConfig {
/// Leader provided rank of this worker
///
/// The Connector framework provides us with an ordered list of workers. To ensure
/// leaders and workers are all in-sync on this information, the leader will send
/// each worker the rank provided by the Connector framework.
pub rank: usize,
/// Number of host/pinned blocks for G2 tier.
pub host_block_count: usize,
/// Number of disk blocks for G3 tier (None = no disk tier).
pub disk_block_count: Option<usize>,
/// Object storage configuration for G4 tier (None = no object tier).
///
/// When present, workers should instantiate object clients for storing
/// blocks in external object storage (S3/MinIO).
#[serde(default)]
pub object: Option<kvbm_config::ObjectConfig>,
/// Parallelism mode for this worker.
///
/// When `ReplicatedData` and rank > 0, the worker skips G2/G3 creation
/// since only rank 0 has host/disk storage in replicated mode.
#[serde(default)]
pub parallelism: kvbm_config::ParallelismMode,
}
/// Worker's response after configuring additional layouts (G2, G3).
///
/// Returned in response to a `LeaderLayoutConfig` request.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerLayoutResponse {
/// Full exported metadata including all registered layouts (G1, G2, G3).
pub metadata: SerializedLayout,
/// Which logical layouts were successfully created in this operation.
pub created_layouts: Vec<LogicalLayoutHandle>,
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::*;
use crate::object::ObjectBlockOps;
use futures::future::BoxFuture;
use parking_lot::RwLock;
use std::collections::HashSet;
use std::sync::OnceLock;
#[derive(Clone)]
pub struct VeloWorkerClient {
messenger: Arc<Messenger>,
remote: InstanceId,
g1_handle: Arc<OnceLock<LayoutHandle>>,
g2_handle: Arc<OnceLock<LayoutHandle>>,
g3_handle: Arc<OnceLock<LayoutHandle>>,
/// Track which remote instances we've connected to for has_remote_metadata()
connected_instances: Arc<RwLock<HashSet<InstanceId>>>,
}
impl WorkerTransfers for VeloWorkerClient {
fn execute_local_transfer(
&self,
src: LogicalLayoutHandle,
dst: LogicalLayoutHandle,
src_block_ids: Arc<[BlockId]>,
dst_block_ids: Arc<[BlockId]>,
options: TransferOptions,
) -> Result<TransferCompleteNotification> {
// Create a single local event for this operation
let event = self.messenger.events().new_event()?;
let awaiter = self.messenger.events().awaiter(event.handle())?;
// Convert to serializable options
// TODO: Extract bounce buffer handle if present in options.bounce_buffer
let options = SerializableTransferOptions {
layer_range: options.layer_range,
nixl_write_notification: options.nixl_write_notification,
bounce_buffer_handle: None,
bounce_buffer_block_ids: None,
};
// Create the message
let message = LocalTransferMessage {
src,
dst,
src_block_ids: src_block_ids.to_vec(),
dst_block_ids: dst_block_ids.to_vec(),
options,
};
let bytes = Bytes::from(serde_json::to_vec(&message)?);
// Spawn a task for the remote instance
let nova = self.messenger.clone();
let remote_instance = self.remote;
// Use unary (not am_sync) to wait for transfer completion
self.messenger.tracker().spawn_on(
async move {
let result = nova
.unary("kvbm.worker.local_transfer")?
.raw_payload(bytes)
.instance(remote_instance)
.send()
.await;
match result {
Ok(_) => event.trigger(),
Err(e) => event.poison(e.to_string()),
}
},
self.messenger.runtime(),
);
Ok(TransferCompleteNotification::from_awaiter(awaiter))
}
fn execute_remote_onboard(
&self,
src: RemoteDescriptor,
dst: LogicalLayoutHandle,
dst_block_ids: Arc<[BlockId]>,
options: TransferOptions,
) -> Result<TransferCompleteNotification> {
let event = self.messenger.events().new_event()?;
let awaiter = self.messenger.events().awaiter(event.handle())?;
let options = SerializableTransferOptions {
layer_range: options.layer_range,
nixl_write_notification: options.nixl_write_notification,
bounce_buffer_handle: None,
bounce_buffer_block_ids: None,
};
let message = RemoteOnboardMessage {
src,
dst,
dst_block_ids: dst_block_ids.to_vec(),
options,
};
let bytes = Bytes::from(serde_json::to_vec(&message)?);
let nova = self.messenger.clone();
let remote_instance = self.remote;
self.messenger.tracker().spawn_on(
async move {
// Use unary instead of am_sync for explicit response handling
let result = nova
.unary("kvbm.worker.remote_onboard")?
.raw_payload(bytes)
.instance(remote_instance)
.send()
.await;
match result {
Ok(_) => event.trigger(),
Err(e) => event.poison(e.to_string()),
}
},
self.messenger.runtime(),
);
Ok(TransferCompleteNotification::from_awaiter(awaiter))
}
fn execute_remote_offload(
&self,
src: LogicalLayoutHandle,
src_block_ids: Arc<[BlockId]>,
dst: RemoteDescriptor,
options: TransferOptions,
) -> Result<TransferCompleteNotification> {
let event = self.messenger.events().new_event()?;
let awaiter = self.messenger.events().awaiter(event.handle())?;
let options = SerializableTransferOptions {
layer_range: options.layer_range,
nixl_write_notification: options.nixl_write_notification,
bounce_buffer_handle: None,
bounce_buffer_block_ids: None,
};
let message = RemoteOffloadMessage {
src,
dst,
src_block_ids: src_block_ids.to_vec(),
options,
};
let bytes = Bytes::from(serde_json::to_vec(&message)?);
let nova = self.messenger.clone();
let remote_instance = self.remote;
self.messenger.tracker().spawn_on(
async move {
// Use unary instead of am_sync for explicit response handling
let result = nova
.unary("kvbm.worker.remote_offload")?
.raw_payload(bytes)
.instance(remote_instance)
.send()
.await;
match result {
Ok(_) => event.trigger(),
Err(e) => event.poison(e.to_string()),
}
},
self.messenger.runtime(),
);
Ok(TransferCompleteNotification::from_awaiter(awaiter))
}
fn connect_remote(
&self,
instance_id: InstanceId,
metadata: Vec<SerializedLayout>,
) -> Result<ConnectRemoteResponse> {
// Serialize metadata to bytes (SerializedLayout uses bincode internally)
let serialized_metadata: Vec<Vec<u8>> =
metadata.iter().map(|m| m.as_bytes().to_vec()).collect();
let message = ConnectRemoteMessage {
instance_id,
metadata: serialized_metadata,
};
let bytes = Bytes::from(serde_json::to_vec(&message)?);
// Create event for completion tracking
let event = self.messenger.events().new_event()?;
let awaiter = self.messenger.events().awaiter(event.handle())?;
let nova = self.messenger.clone();
let remote_instance = self.remote;
let connected = self.connected_instances.clone();
let target_instance = instance_id;
self.messenger.tracker().spawn_on(
async move {
let result = nova
.unary("kvbm.worker.connect_remote")?
.raw_payload(bytes)
.instance(remote_instance)
.send()
.await;
match result {
Ok(_) => {
// Track that we've connected to this instance
connected.write().insert(target_instance);
event.trigger()
}
Err(e) => event.poison(e.to_string()),
}
},
self.messenger.runtime(),
);
Ok(ConnectRemoteResponse::from_awaiter(awaiter))
}
fn has_remote_metadata(&self, instance_id: InstanceId) -> bool {
// Check if we've successfully connected to this instance
self.connected_instances.read().contains(&instance_id)
}
fn execute_remote_onboard_for_instance(
&self,
instance_id: InstanceId,
remote_logical_type: LogicalLayoutHandle,
src_block_ids: Vec<BlockId>,
dst: LogicalLayoutHandle,
dst_block_ids: Arc<[BlockId]>,
options: TransferOptions,
) -> Result<TransferCompleteNotification> {
let message = ExecuteRemoteOnboardForInstanceMessage {
instance_id,
remote_logical_type,
src_block_ids,
dst,
dst_block_ids: dst_block_ids.to_vec(),
options: SerializableTransferOptions::from(options),
};
let bytes = Bytes::from(serde_json::to_vec(&message)?);
// Create event for completion tracking
let event = self.messenger.events().new_event()?;
let awaiter = self.messenger.events().awaiter(event.handle())?;
let nova = self.messenger.clone();
let remote_instance = self.remote;
self.messenger.tracker().spawn_on(
async move {
let result = nova
.unary("kvbm.worker.remote_onboard_for_instance")?
.raw_payload(bytes)
.instance(remote_instance)
.send()
.await;
match result {
Ok(_) => event.trigger(),
Err(e) => event.poison(e.to_string()),
}
},
self.messenger.runtime(),
);
Ok(TransferCompleteNotification::from_awaiter(awaiter))
}
}
impl Worker for VeloWorkerClient {
fn g1_handle(&self) -> Option<LayoutHandle> {
self.g1_handle.get().copied()
}
fn g2_handle(&self) -> Option<LayoutHandle> {
self.g2_handle.get().copied()
}
fn g3_handle(&self) -> Option<LayoutHandle> {
self.g3_handle.get().copied()
}
fn export_metadata(&self) -> Result<SerializedLayoutResponse> {
// Use unary (not typed_unary) to avoid JSON serialization of bincode data
let unary_result = self
.messenger
.unary("kvbm.worker.export_metadata")?
.instance(self.remote)
.send();
// Wrap UnaryResult to convert Bytes to SerializedLayout
let future = async move {
let bytes = unary_result.await?;
Ok(SerializedLayout::from_bytes(bytes.to_vec()))
};
Ok(SerializedLayoutResponse::from_boxed(Box::pin(future)))
}
fn import_metadata(&self, metadata: SerializedLayout) -> Result<ImportMetadataResponse> {
// Use raw_payload to avoid JSON serialization of bincode data
let unary_result = self
.messenger
.unary("kvbm.worker.import_metadata")?
.raw_payload(Bytes::from(metadata.as_bytes().to_vec()))
.instance(self.remote)
.send();
// Response is JSON-serialized Vec<LayoutHandle>
let future = async move {
let bytes = unary_result.await?;
serde_json::from_slice(&bytes).map_err(|e| {
anyhow::anyhow!("Failed to deserialize import_metadata response: {}", e)
})
};
Ok(ImportMetadataResponse::from_boxed(Box::pin(future)))
}
}
impl VeloWorkerClient {
/// Create a new VeloWorkerClient for communicating with a remote worker.
pub fn new(messenger: Arc<Messenger>, remote: InstanceId) -> Self {
Self {
messenger,
remote,
g1_handle: Arc::new(OnceLock::new()),
g2_handle: Arc::new(OnceLock::new()),
g3_handle: Arc::new(OnceLock::new()),
connected_instances: Arc::new(RwLock::new(HashSet::new())),
}
}
/// Configure layout handles from serialized metadata.
///
/// Call this after worker initialization when handles are known from WorkerLayoutResponse.
/// This allows the VeloWorkerClient to provide layout handles like DirectWorker does.
///
/// # Arguments
/// * `metadata` - SerializedLayout from WorkerLayoutResponse.metadata
///
/// # Example
/// ```ignore
/// let response: WorkerLayoutResponse = worker.initialize(config).await?;
/// worker_client.configure_layout_handles(&response.metadata)?;
/// ```
pub fn configure_layout_handles(&self, metadata: &SerializedLayout) -> Result<()> {
let unpacked = metadata.unpack()?;
for desc in &unpacked.layouts {
match desc.logical_type {
LogicalLayoutHandle::G1 => {
self.g1_handle.set(desc.handle).ok();
}
LogicalLayoutHandle::G2 => {
self.g2_handle.set(desc.handle).ok();
}
LogicalLayoutHandle::G3 => {
self.g3_handle.set(desc.handle).ok();
}
_ => {}
}
}
Ok(())
}
/// Get the layout configuration from the remote worker.
///
/// This calls the `kvbm.worker.get_layout_config` handler on the remote worker.
/// Used by the leader during Phase 3 to gather G1 layout configs from all workers
/// and validate they match before creating G2/G3 layouts.
///
/// # Returns
/// A typed unary result that resolves to the layout configuration
pub fn get_layout_config(&self) -> Result<::velo::TypedUnaryResult<LayoutConfig>> {
let instance = self.remote;
let awaiter = self
.messenger
.typed_unary::<LayoutConfig>("kvbm.worker.get_layout_config")?
.instance(instance)
.send();
Ok(awaiter)
}
/// Configure additional layouts (G2, G3) on the remote worker.
///
/// This calls the `kvbm.worker.configure_layouts` handler on the remote worker.
/// The worker will create host/pinned cache (G2) and optionally disk cache (G3)
/// based on the provided configuration.
///
/// # Arguments
/// * `config` - Leader-provided configuration specifying block counts and backends
///
/// # Returns
/// A typed unary result that resolves to the worker's response with updated metadata
pub fn configure_layouts(
&self,
config: LeaderLayoutConfig,
) -> Result<::velo::TypedUnaryResult<WorkerLayoutResponse>> {
let instance = self.remote;
let awaiter = self
.messenger
.typed_unary::<WorkerLayoutResponse>("kvbm.worker.configure_layouts")?
.payload(config)?
.instance(instance)
.send();
Ok(awaiter)
}
}
impl ObjectBlockOps for VeloWorkerClient {
fn has_blocks(
&self,
keys: Vec<SequenceHash>,
) -> BoxFuture<'static, Vec<(SequenceHash, Option<usize>)>> {
let message = ObjectHasBlocksMessage { keys: keys.clone() };
let bytes = match serde_json::to_vec(&message) {
Ok(b) => Bytes::from(b),
Err(_) => {
return Box::pin(async move { keys.into_iter().map(|k| (k, None)).collect() });
}
};
let nova = self.messenger.clone();
let remote = self.remote;
Box::pin(async move {
let result = nova
.unary("kvbm.worker.object_has_blocks")
.ok()
.map(|u| u.raw_payload(bytes).instance(remote).send());
match result {
Some(unary_result) => match unary_result.await {
Ok(response_bytes) => {
match serde_json::from_slice::<ObjectHasBlocksResponse>(&response_bytes) {
Ok(response) => response.results,
Err(_) => keys.into_iter().map(|k| (k, None)).collect(),
}
}
Err(_) => keys.into_iter().map(|k| (k, None)).collect(),
},
None => keys.into_iter().map(|k| (k, None)).collect(),
}
})
}
fn put_blocks(
&self,
keys: Vec<SequenceHash>,
src_layout: LogicalLayoutHandle,
block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
// For remote workers, we send the logical layout handle - they resolve it locally
let message = ObjectPutBlocksMessage {
keys: keys.clone(),
layout: src_layout,
block_ids,
};
let bytes = match serde_json::to_vec(&message) {
Ok(b) => Bytes::from(b),
Err(_) => return Box::pin(async move { keys.into_iter().map(Err).collect() }),
};
let nova = self.messenger.clone();
let remote = self.remote;
Box::pin(async move {
let result = nova
.unary("kvbm.worker.object_put_blocks")
.ok()
.map(|u| u.raw_payload(bytes).instance(remote).send());
match result {
Some(unary_result) => match unary_result.await {
Ok(response_bytes) => {
match serde_json::from_slice::<ObjectPutGetBlocksResponse>(&response_bytes)
{
Ok(response) => response.into_results(),
Err(_) => keys.into_iter().map(Err).collect(),
}
}
Err(_) => keys.into_iter().map(Err).collect(),
},
None => keys.into_iter().map(Err).collect(),
}
})
}
fn get_blocks(
&self,
keys: Vec<SequenceHash>,
dst_layout: LogicalLayoutHandle,
block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
// For remote workers, we send the logical layout handle - they resolve it locally
let message = ObjectGetBlocksMessage {
keys: keys.clone(),
layout: dst_layout,
block_ids,
};
let bytes = match serde_json::to_vec(&message) {
Ok(b) => Bytes::from(b),
Err(_) => return Box::pin(async move { keys.into_iter().map(Err).collect() }),
};
let nova = self.messenger.clone();
let remote = self.remote;
Box::pin(async move {
let result = nova
.unary("kvbm.worker.object_get_blocks")
.ok()
.map(|u| u.raw_payload(bytes).instance(remote).send());
match result {
Some(unary_result) => match unary_result.await {
Ok(response_bytes) => {
match serde_json::from_slice::<ObjectPutGetBlocksResponse>(&response_bytes)
{
Ok(response) => response.into_results(),
Err(_) => keys.into_iter().map(Err).collect(),
}
}
Err(_) => keys.into_iter().map(Err).collect(),
},
None => keys.into_iter().map(Err).collect(),
}
})
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Nova-based RPC implementation for distributed worker communication.
//!
//! # RPC Pattern Guidelines
//!
//! This module uses only two Nova RPC patterns:
//!
//! 1. **`am_send` (fire-and-forget)**: Use when no response is needed.
//! - Client sends message and returns immediately
//! - Handler processes asynchronously, no response sent back
//! - Use `Handler::am_handler` or `am_handler_async`
//!
//! 2. **`unary` (request-response)**: Use when waiting for completion.
//! - Client sends request and awaits response
//! - Handler returns `Ok(Some(Bytes))` or `Ok(None)` which is sent back
//! - Use `Handler::unary_handler` or `unary_handler_async`
//!
//! # Why Not `am_sync`?
//!
//! We avoid `am_sync` due to observed issues where it does not reliably
//! receive completion signals when paired with `am_handler_async`. While
//! `am_sync` should theoretically behave like `unary` (both await completion),
//! in practice pairing `am_sync` client with `am_handler_async` handler caused
//! indefinite blocking during RDMA transfer tests.
//!
//! The root cause appears to be a mismatch in how responses are routed:
//! - `am_handler_async` returns `Result<()>` - the return value is NOT sent back
//! - `unary_handler_async` returns `Result<Option<Bytes>>` - the return value IS sent back
//!
//! Until the `am_sync` completion path is validated, prefer the simpler and
//! more predictable patterns: `am_send` for fire-and-forget, `unary` for
//! request-response.
mod client;
mod service;
pub use client::VeloWorkerClient;
pub use service::{VeloWorkerService, VeloWorkerServiceBuilder};
use super::DirectWorker;
use super::*;
use kvbm_physical::layout::LayoutConfig;
use kvbm_physical::transfer::TransferOptions;
use ::velo::Messenger;
use bytes::Bytes;
use serde::{Deserialize, Serialize};
// Serializable transfer options for remote operations
#[derive(Serialize, Deserialize, Clone)]
struct SerializableTransferOptions {
layer_range: Option<std::ops::Range<usize>>,
nixl_write_notification: Option<u64>,
bounce_buffer_handle: Option<LayoutHandle>,
bounce_buffer_block_ids: Option<Vec<BlockId>>,
}
impl From<SerializableTransferOptions> for TransferOptions {
fn from(opts: SerializableTransferOptions) -> Self {
TransferOptions {
layer_range: opts.layer_range,
nixl_write_notification: opts.nixl_write_notification,
// bounce_buffer requires TransportManager to resolve handle to layout
bounce_buffer: None,
cuda_stream: None,
// KV layout overrides are not serialized; they must be set locally
src_kv_layout: None,
dst_kv_layout: None,
}
}
}
impl SerializableTransferOptions {
/// Extract bounce buffer handle and block IDs if present
fn bounce_buffer_parts(&self) -> Option<(LayoutHandle, Vec<BlockId>)> {
match (&self.bounce_buffer_handle, &self.bounce_buffer_block_ids) {
(Some(handle), Some(block_ids)) => Some((*handle, block_ids.clone())),
_ => None,
}
}
}
impl From<TransferOptions> for SerializableTransferOptions {
fn from(opts: TransferOptions) -> Self {
// Extract bounce buffer parts if present using into_parts()
let (bounce_buffer_handle, bounce_buffer_block_ids) = opts
.bounce_buffer
.map(|bb| {
let (handle, block_ids) = bb.into_parts();
(Some(handle), Some(block_ids))
})
.unwrap_or((None, None));
Self {
layer_range: opts.layer_range,
nixl_write_notification: opts.nixl_write_notification,
bounce_buffer_handle,
bounce_buffer_block_ids,
}
}
}
// Message types for remote worker operations
#[derive(Serialize, Deserialize)]
struct LocalTransferMessage {
src: LogicalLayoutHandle,
dst: LogicalLayoutHandle,
src_block_ids: Vec<BlockId>,
dst_block_ids: Vec<BlockId>,
options: SerializableTransferOptions,
}
#[derive(Serialize, Deserialize)]
struct RemoteOnboardMessage {
src: RemoteDescriptor,
dst: LogicalLayoutHandle,
dst_block_ids: Vec<BlockId>,
options: SerializableTransferOptions,
}
#[derive(Serialize, Deserialize)]
struct RemoteOffloadMessage {
src: LogicalLayoutHandle,
dst: RemoteDescriptor,
src_block_ids: Vec<BlockId>,
options: SerializableTransferOptions,
}
/// Message for connect_remote RPC - stores remote instance metadata in local worker
#[derive(Serialize, Deserialize)]
struct ConnectRemoteMessage {
instance_id: InstanceId,
/// Metadata serialized as raw bytes (SerializedLayout uses bincode internally)
metadata: Vec<Vec<u8>>,
}
/// Message for execute_remote_onboard_for_instance RPC - pulls from remote using instance ID
#[derive(Serialize, Deserialize)]
struct ExecuteRemoteOnboardForInstanceMessage {
instance_id: InstanceId,
remote_logical_type: LogicalLayoutHandle,
src_block_ids: Vec<BlockId>,
dst: LogicalLayoutHandle,
dst_block_ids: Vec<BlockId>,
options: SerializableTransferOptions,
}
// ============================================================================
// Object Storage Message Types
// ============================================================================
/// Message for object_has_blocks RPC - check if blocks exist in object storage
#[derive(Serialize, Deserialize)]
struct ObjectHasBlocksMessage {
keys: Vec<SequenceHash>,
}
/// Response for object_has_blocks RPC
#[derive(Serialize, Deserialize)]
struct ObjectHasBlocksResponse {
results: Vec<(SequenceHash, Option<usize>)>,
}
/// Message for object_put_blocks RPC - upload blocks to object storage
#[derive(Serialize, Deserialize)]
struct ObjectPutBlocksMessage {
keys: Vec<SequenceHash>,
layout: LogicalLayoutHandle,
block_ids: Vec<BlockId>,
}
/// Message for object_get_blocks RPC - download blocks from object storage
#[derive(Serialize, Deserialize)]
struct ObjectGetBlocksMessage {
keys: Vec<SequenceHash>,
layout: LogicalLayoutHandle,
block_ids: Vec<BlockId>,
}
/// Response for object put/get operations
#[derive(Serialize, Deserialize)]
struct ObjectPutGetBlocksResponse {
/// Ok(key) for success, Err(key) for failure - serialized as (bool, key)
results: Vec<(bool, SequenceHash)>,
}
impl ObjectPutGetBlocksResponse {
fn from_results(results: Vec<Result<SequenceHash, SequenceHash>>) -> Self {
Self {
results: results
.into_iter()
.map(|r| match r {
Ok(k) => (true, k),
Err(k) => (false, k),
})
.collect(),
}
}
fn into_results(self) -> Vec<Result<SequenceHash, SequenceHash>> {
self.results
.into_iter()
.map(|(ok, k)| if ok { Ok(k) } else { Err(k) })
.collect()
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use kvbm_physical::manager::SerializedLayout;
use super::{
Arc, ConnectRemoteMessage, DirectWorker, ExecuteRemoteOnboardForInstanceMessage,
LocalTransferMessage, ObjectGetBlocksMessage, ObjectHasBlocksMessage, ObjectHasBlocksResponse,
ObjectPutBlocksMessage, ObjectPutGetBlocksResponse, RemoteOffloadMessage, RemoteOnboardMessage,
Result, TransferOptions, WorkerTransfers,
};
use crate::object::ObjectBlockOps;
use bytes::Bytes;
use derive_builder::Builder;
use ::velo::{Handler, Messenger};
/// Builder for VeloWorkerService - provides flexibility in construction.
///
/// Use this builder when you need to:
/// - Pass a pre-built DirectWorker (when caller manages layout registration)
/// - Pass a pre-built TransferManager (service creates DirectWorker)
/// - Have more control over worker configuration
#[derive(Builder)]
#[builder(pattern = "owned")]
pub struct VeloWorkerService {
messenger: Arc<Messenger>,
worker: Arc<DirectWorker>,
}
impl VeloWorkerService {
pub fn new(messenger: Arc<Messenger>, worker: Arc<DirectWorker>) -> Result<Self> {
let service = Self { messenger, worker };
service.register_handlers()?;
Ok(service)
}
/// Access the underlying DirectWorker.
///
/// This is useful for:
/// - Registering additional layouts after service creation
/// - Exporting metadata for handshake
/// - Accessing the TransferManager
pub fn worker(&self) -> &Arc<DirectWorker> {
&self.worker
}
/// Register all worker handlers with Nova
fn register_handlers(&self) -> Result<()> {
self.register_local_transfer_handler()?;
self.register_remote_onboard_handler()?;
self.register_remote_offload_handler()?;
self.register_import_metadata_handler()?;
self.register_export_metadata_handler()?;
self.register_connect_remote_handler()?;
self.register_execute_remote_onboard_for_instance_handler()?;
// Object storage handlers
self.register_object_has_blocks_handler()?;
self.register_object_put_blocks_handler()?;
self.register_object_get_blocks_handler()?;
Ok(())
}
fn register_local_transfer_handler(&self) -> Result<()> {
let worker = self.worker.clone();
// Use unary_handler_async for explicit response (client waits for transfer completion)
let handler = Handler::unary_handler_async("kvbm.worker.local_transfer", move |ctx| {
let worker = worker.clone();
async move {
// Deserialize the message
let message: LocalTransferMessage = serde_json::from_slice(&ctx.payload)?;
// Convert options and resolve bounce buffer if present
let bounce_buffer_parts = message.options.bounce_buffer_parts();
let mut options: TransferOptions = message.options.into();
if let Some((handle, block_ids)) = bounce_buffer_parts {
options.bounce_buffer = Some(worker.create_bounce_buffer(handle, block_ids)?);
}
let notification = worker.execute_local_transfer(
message.src,
message.dst,
Arc::from(message.src_block_ids),
Arc::from(message.dst_block_ids),
options,
)?;
// Await the transfer completion
notification.await?;
// Return empty response to signal success
Ok(Some(Bytes::new()))
}
})
.build();
self.messenger.register_handler(handler)?;
Ok(())
}
fn register_remote_onboard_handler(&self) -> Result<()> {
let worker = self.worker.clone();
// Use unary_handler_async for explicit response (works with unary client)
let handler = Handler::unary_handler_async("kvbm.worker.remote_onboard", move |ctx| {
let worker = worker.clone();
async move {
let message: RemoteOnboardMessage = serde_json::from_slice(&ctx.payload)?;
// Convert options and resolve bounce buffer if present
let bounce_buffer_parts = message.options.bounce_buffer_parts();
let mut options: TransferOptions = message.options.into();
if let Some((handle, block_ids)) = bounce_buffer_parts {
options.bounce_buffer = Some(worker.create_bounce_buffer(handle, block_ids)?);
}
let notification = worker.execute_remote_onboard(
message.src,
message.dst,
Arc::from(message.dst_block_ids),
options,
)?;
notification.await?;
Ok(Some(Bytes::new()))
}
})
.build();
self.messenger.register_handler(handler)?;
Ok(())
}
fn register_remote_offload_handler(&self) -> Result<()> {
let worker = self.worker.clone();
// Use unary_handler_async for explicit response (works with unary client)
let handler = Handler::unary_handler_async("kvbm.worker.remote_offload", move |ctx| {
let worker = worker.clone();
async move {
let message: RemoteOffloadMessage = serde_json::from_slice(&ctx.payload)?;
// Convert options and resolve bounce buffer if present
let bounce_buffer_parts = message.options.bounce_buffer_parts();
let mut options: TransferOptions = message.options.into();
if let Some((handle, block_ids)) = bounce_buffer_parts {
options.bounce_buffer = Some(worker.create_bounce_buffer(handle, block_ids)?);
}
let notification = worker.execute_remote_offload(
message.src,
Arc::from(message.src_block_ids),
message.dst,
options,
)?;
notification.await?;
Ok(Some(Bytes::new()))
}
})
.build();
self.messenger.register_handler(handler)?;
Ok(())
}
fn register_import_metadata_handler(&self) -> Result<()> {
let worker = self.worker.clone();
let handler = Handler::unary_handler("kvbm.worker.import_metadata", move |ctx| {
let metadata = SerializedLayout::from_bytes(ctx.payload.to_vec());
let handles = worker.import_metadata(metadata)?;
Ok(Some(Bytes::from(serde_json::to_vec(&handles)?)))
})
.build();
self.messenger.register_handler(handler)?;
Ok(())
}
fn register_export_metadata_handler(&self) -> Result<()> {
let worker = self.worker.clone();
let handler = Handler::unary_handler("kvbm.worker.export_metadata", move |_ctx| {
let response = worker.export_metadata()?;
Ok(Some(Bytes::from(response.as_bytes().to_vec())))
})
.build();
self.messenger.register_handler(handler)?;
Ok(())
}
/// Register handler for connect_remote - stores remote instance metadata in local worker
fn register_connect_remote_handler(&self) -> Result<()> {
let worker = self.worker.clone();
let handler = Handler::unary_handler("kvbm.worker.connect_remote", move |ctx| {
let message: ConnectRemoteMessage = serde_json::from_slice(&ctx.payload)?;
// Deserialize metadata (SerializedLayout stored as raw bytes)
let metadata: Vec<SerializedLayout> = message
.metadata
.into_iter()
.map(SerializedLayout::from_bytes)
.collect();
// Call DirectWorker.connect_remote()
worker.connect_remote(message.instance_id, metadata)?;
// Return empty response to signal success
Ok(Some(Bytes::new()))
})
.build();
self.messenger.register_handler(handler)?;
Ok(())
}
/// Register handler for execute_remote_onboard_for_instance - pulls from remote using instance ID
fn register_execute_remote_onboard_for_instance_handler(&self) -> Result<()> {
let worker = self.worker.clone();
let handler =
Handler::unary_handler_async("kvbm.worker.remote_onboard_for_instance", move |ctx| {
let worker = worker.clone();
async move {
let message: ExecuteRemoteOnboardForInstanceMessage =
serde_json::from_slice(&ctx.payload)?;
// Convert options and resolve bounce buffer if present
let bounce_buffer_parts = message.options.bounce_buffer_parts();
let mut options: TransferOptions = message.options.into();
if let Some((handle, block_ids)) = bounce_buffer_parts {
options.bounce_buffer =
Some(worker.create_bounce_buffer(handle, block_ids)?);
}
let notification = worker.execute_remote_onboard_for_instance(
message.instance_id,
message.remote_logical_type,
message.src_block_ids,
message.dst,
Arc::from(message.dst_block_ids),
options,
)?;
notification.await?;
Ok(Some(Bytes::new()))
}
})
.build();
self.messenger.register_handler(handler)?;
Ok(())
}
// ========================================================================
// Object Storage Handlers
// ========================================================================
/// Register handler for object_has_blocks - check if blocks exist in object storage
fn register_object_has_blocks_handler(&self) -> Result<()> {
let worker = self.worker.clone();
let handler = Handler::unary_handler_async("kvbm.worker.object_has_blocks", move |ctx| {
let worker = worker.clone();
async move {
let message: ObjectHasBlocksMessage = serde_json::from_slice(&ctx.payload)?;
// Call DirectWorker's ObjectBlockOps implementation
let results = worker.has_blocks(message.keys).await;
let response = ObjectHasBlocksResponse { results };
Ok(Some(Bytes::from(serde_json::to_vec(&response)?)))
}
})
.build();
self.messenger.register_handler(handler)?;
Ok(())
}
/// Register handler for object_put_blocks - upload blocks to object storage
fn register_object_put_blocks_handler(&self) -> Result<()> {
let worker = self.worker.clone();
let handler = Handler::unary_handler_async("kvbm.worker.object_put_blocks", move |ctx| {
let worker = worker.clone();
async move {
let message: ObjectPutBlocksMessage = serde_json::from_slice(&ctx.payload)?;
// Call DirectWorker's ObjectBlockOps implementation
// DirectWorker resolves logical handle to physical layout internally
let results = worker
.put_blocks(message.keys, message.layout, message.block_ids)
.await;
let response = ObjectPutGetBlocksResponse::from_results(results);
Ok(Some(Bytes::from(serde_json::to_vec(&response)?)))
}
})
.build();
self.messenger.register_handler(handler)?;
Ok(())
}
/// Register handler for object_get_blocks - download blocks from object storage
fn register_object_get_blocks_handler(&self) -> Result<()> {
let worker = self.worker.clone();
let handler = Handler::unary_handler_async("kvbm.worker.object_get_blocks", move |ctx| {
let worker = worker.clone();
async move {
let message: ObjectGetBlocksMessage = serde_json::from_slice(&ctx.payload)?;
// Call DirectWorker's ObjectBlockOps implementation
// DirectWorker resolves logical handle to physical layout internally
let results = worker
.get_blocks(message.keys, message.layout, message.block_ids)
.await;
let response = ObjectPutGetBlocksResponse::from_results(results);
Ok(Some(Bytes::from(serde_json::to_vec(&response)?)))
}
})
.build();
self.messenger.register_handler(handler)?;
Ok(())
}
}
......@@ -3,7 +3,7 @@
[package]
name = "kvbm-kernels"
version = "0.1.0"
version = "1.0.0"
edition.workspace = true
authors.workspace = true
license.workspace = true
......
......@@ -3,7 +3,7 @@
[package]
name = "kvbm-logical"
version = "0.1.0"
version = "1.0.0"
edition.workspace = true
authors.workspace = true
license.workspace = true
......
......@@ -3,7 +3,7 @@
[package]
name = "kvbm-physical"
version = "0.1.0"
version = "1.0.0"
edition.workspace = true
authors.workspace = true
license.workspace = true
......@@ -13,17 +13,16 @@ repository.workspace = true
dynamo-memory = { workspace = true }
kvbm-common = { workspace = true }
kvbm-kernels = { workspace = true }
velo-events = { workspace = true }
velo = { workspace = true }
aligned-vec = "0.6.4"
anyhow = { workspace = true }
bincode = { version = "2.0.0", features = ["serde", "derive"] }
blake3 = { version = "1" }
cudarc = { workspace = true }
cudarc = { workspace = true, features = ["nccl"] }
derive_builder = { workspace = true }
futures = { workspace = true }
derive-getters = { version = "0.5" }
parking_lot = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment