Unverified Commit 008683d6 authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: adding kvbm-engine (#6773)


Signed-off-by: default avatarRyan Olson <rolson@nvidia.com>
parent cf79c4fc
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
mod accessor;
mod instance;
mod onboarding;
#[doc = include_str!("../../docs/session.md")]
pub mod session;
mod state;
mod types;
pub mod velo;
pub use accessor::{BlockAccessor, PolicyContext, TieredBlock};
pub use instance::InstanceLeader;
pub use onboarding::*;
pub use session::{
ControllableSessionOptions, ControllableSessionResult, InitiatorSession, ResponderSession,
ServerSession, ServerSessionHandle, ServerSessionOptions, SessionId,
};
pub use state::{LeaderState, RemoteLeaderInfo, route_local_to_remote};
pub use types::*;
pub use velo::VeloLeaderService;
use anyhow::Result;
use crate::SequenceHash;
/// Leader trait for distributed block onboarding operations.
pub trait Leader: Send + Sync {
/// Find matching blocks with default options.
fn find_matches(&self, sequence_hashes: &[SequenceHash]) -> Result<FindMatchesResult> {
self.find_matches_with_options(sequence_hashes, FindMatchesOptions::default())
}
/// Find matching blocks with custom options.
fn find_matches_with_options(
&self,
sequence_hashes: &[SequenceHash],
options: FindMatchesOptions,
) -> Result<FindMatchesResult>;
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::Result;
use tokio::sync::mpsc;
use super::session::SessionId;
use super::types::StagingMode;
/// Status of an onboarding operation.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum OnboardingStatus {
/// Searching for blocks (local or remote).
Searching,
/// Holding blocks without staging (StagingMode::Hold).
/// Provides location breakdown for cost analysis.
/// - `local_g2`: number of blocks in local G2 (ready to use)
/// - `local_g3`: number of blocks in local G3 (needs local staging)
/// - `remote_g2`: number of blocks in remote G2 (needs RDMA pull)
/// - `remote_g3`: number of blocks in remote G3 (needs remote staging + RDMA)
/// - `pending_g4`: number of blocks with G4 load in progress
/// - `loaded_g4`: number of blocks successfully loaded from G4 (included in local_g2)
/// - `failed_g4`: number of blocks that failed to load from G4
Holding {
local_g2: usize,
local_g3: usize,
remote_g2: usize,
remote_g3: usize,
pending_g4: usize,
loaded_g4: usize,
failed_g4: usize,
},
/// Preparing: staging G3→G2 (StagingMode::Prepare or Full).
/// - `matched`: total number of blocks matched during search
/// - `staging_local`: number of local G3→G2 transfers in progress
/// - `staging_remote`: number of remote G3→G2 transfers in progress
Preparing {
matched: usize,
staging_local: usize,
staging_remote: usize,
},
/// Prepared: all blocks in G2, session still alive (StagingMode::Prepare).
/// - `local_g2`: number of blocks in local G2
/// - `remote_g2`: number of blocks in remote G2 instances
Prepared { local_g2: usize, remote_g2: usize },
/// Staging: full mode with RDMA pulls (StagingMode::Full).
/// - `matched`: total number of blocks matched
/// - `staging_local`: local G3→G2 in progress
/// - `staging_remote`: remote G3→G2 in progress
/// - `pulling`: remote G2→local G2 (RDMA) in progress
Staging {
matched: usize,
staging_local: usize,
staging_remote: usize,
pulling: usize,
},
/// Operation complete - all blocks are in initiator's G2 (StagingMode::Full).
/// Or terminal state for Hold/Prepare modes.
/// - `matched`: total number of blocks in local G2
Complete { matched_blocks: usize },
}
/// Control commands for managing live sessions.
#[derive(Debug)]
pub(crate) enum SessionControl {
/// Trigger prepare operation (Hold → Prepare): stage all G3→G2
Prepare,
/// Trigger pull operation (Prepare → Full): RDMA pull remote G2→local G2
Pull,
/// Cancel session and release all blocks
Cancel,
/// Shutdown session (normal completion)
Shutdown,
}
/// Handle to a live onboarding session for deferred operations.
///
/// Only available for StagingMode::Hold and StagingMode::Prepare.
#[derive(Debug)]
pub struct SessionHandle {
session_id: SessionId,
mode: StagingMode,
control_tx: mpsc::Sender<SessionControl>,
}
impl SessionHandle {
pub(crate) fn new(
session_id: SessionId,
mode: StagingMode,
control_tx: mpsc::Sender<SessionControl>,
) -> Self {
Self {
session_id,
mode,
control_tx,
}
}
/// Get the session ID.
pub fn session_id(&self) -> SessionId {
self.session_id
}
/// Get the current staging mode.
pub fn mode(&self) -> StagingMode {
self.mode
}
/// Trigger G3→G2 staging on all instances (Hold → Prepare).
///
/// The server validates that the session is in Hold mode before processing.
/// After this completes, the session transitions to Prepare mode internally.
pub async fn prepare(&self) -> Result<()> {
self.control_tx
.send(SessionControl::Prepare)
.await
.map_err(|_| anyhow::anyhow!("session task has exited"))
}
/// Trigger RDMA pull from remote G2→local G2 (Prepare → Complete).
///
/// The server validates that the session is in Prepare mode before processing.
/// After this completes, the session transitions to Complete status.
pub async fn pull(&self) -> Result<()> {
self.control_tx
.send(SessionControl::Pull)
.await
.map_err(|_| anyhow::anyhow!("session task has exited"))
}
/// Cancel session and release all held blocks.
pub async fn cancel(&self) -> Result<()> {
self.control_tx
.send(SessionControl::Cancel)
.await
.map_err(|_| anyhow::anyhow!("session task has exited"))
}
/// Shutdown session (used internally).
#[expect(dead_code)]
pub(crate) async fn shutdown(&self) -> Result<()> {
self.control_tx
.send(SessionControl::Shutdown)
.await
.map_err(|_| anyhow::anyhow!("session task has exited"))
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! RAII block holding for sessions.
//!
//! This module provides [`BlockHolder<T>`], a tier-agnostic container for
//! holding blocks during session operations. Blocks are automatically
//! released when the holder is dropped.
//!
//! # Design Philosophy
//!
//! `BlockHolder` is intentionally simple - it's pure RAII with no staging logic.
//! This allows flexibility for different staging patterns:
//! - G3→G2 staging
//! - G4→G2 staging
//! - G1→G2 staging
//! - G2→G3 offload
//!
//! The caller decides when and how to stage; `BlockHolder` just holds.
use crate::SequenceHash;
use kvbm_logical::blocks::{BlockMetadata, ImmutableBlock};
/// RAII block holder - tier-agnostic, just holds blocks.
///
/// # Type Parameter
///
/// `T` is the tier metadata type (e.g., `G2`, `G3`). It must implement
/// `BlockMetadata` which is `Clone + Send + Sync + 'static`.
///
/// # RAII Semantics
///
/// When `BlockHolder` is dropped, all held blocks are released. This ensures
/// blocks don't leak even if session handling panics.
///
/// # Example
///
/// ```ignore
/// // Create holder with searched blocks
/// let mut holder = BlockHolder::new(g2_blocks);
///
/// // Check what we have
/// println!("Holding {} blocks", holder.count());
///
/// // Release some blocks (e.g., after RDMA pull)
/// holder.release(&pulled_hashes);
///
/// // Holder drops here, releasing any remaining blocks
/// ```
#[derive(Debug)]
pub struct BlockHolder<T: BlockMetadata> {
blocks: Vec<ImmutableBlock<T>>,
}
impl<T: BlockMetadata> BlockHolder<T> {
/// Create a new `BlockHolder` with the given blocks.
pub fn new(blocks: Vec<ImmutableBlock<T>>) -> Self {
Self { blocks }
}
/// Create an empty `BlockHolder`.
pub fn empty() -> Self {
Self { blocks: Vec::new() }
}
/// Get a reference to the held blocks.
pub fn blocks(&self) -> &[ImmutableBlock<T>] {
&self.blocks
}
/// Get the number of held blocks.
pub fn count(&self) -> usize {
self.blocks.len()
}
/// Check if the holder is empty.
pub fn is_empty(&self) -> bool {
self.blocks.is_empty()
}
/// Add blocks to this holder.
pub fn extend(&mut self, blocks: impl IntoIterator<Item = ImmutableBlock<T>>) {
self.blocks.extend(blocks);
}
/// Release blocks matching the given sequence hashes.
///
/// Removes blocks from the holder whose sequence hash is in `hashes`.
/// The blocks are dropped, releasing their references.
pub fn release(&mut self, hashes: &[SequenceHash]) {
self.blocks.retain(|b| !hashes.contains(&b.sequence_hash()));
}
/// Retain only blocks matching the given sequence hashes.
///
/// Removes blocks from the holder whose sequence hash is NOT in `hashes`.
/// The removed blocks are dropped, releasing their references.
pub fn retain(&mut self, hashes: &[SequenceHash]) {
self.blocks.retain(|b| hashes.contains(&b.sequence_hash()));
}
/// Take all blocks out of this holder.
///
/// The holder becomes empty. Useful for transferring blocks to another
/// location or for processing before dropping.
pub fn take_all(&mut self) -> Vec<ImmutableBlock<T>> {
std::mem::take(&mut self.blocks)
}
/// Get sequence hashes of all held blocks.
pub fn sequence_hashes(&self) -> Vec<SequenceHash> {
self.blocks.iter().map(|b| b.sequence_hash()).collect()
}
/// Find a block by sequence hash.
pub fn find(&self, hash: &SequenceHash) -> Option<&ImmutableBlock<T>> {
self.blocks.iter().find(|b| &b.sequence_hash() == hash)
}
/// Check if a block with the given hash is held.
pub fn contains(&self, hash: &SequenceHash) -> bool {
self.blocks.iter().any(|b| &b.sequence_hash() == hash)
}
/// Iterate over held blocks.
pub fn iter(&self) -> impl Iterator<Item = &ImmutableBlock<T>> {
self.blocks.iter()
}
}
impl<T: BlockMetadata> Default for BlockHolder<T> {
fn default() -> Self {
Self::empty()
}
}
impl<T: BlockMetadata> FromIterator<ImmutableBlock<T>> for BlockHolder<T> {
fn from_iter<I: IntoIterator<Item = ImmutableBlock<T>>>(iter: I) -> Self {
Self::new(iter.into_iter().collect())
}
}
impl<T: BlockMetadata> IntoIterator for BlockHolder<T> {
type Item = ImmutableBlock<T>;
type IntoIter = std::vec::IntoIter<ImmutableBlock<T>>;
fn into_iter(self) -> Self::IntoIter {
self.blocks.into_iter()
}
}
impl<'a, T: BlockMetadata> IntoIterator for &'a BlockHolder<T> {
type Item = &'a ImmutableBlock<T>;
type IntoIter = std::slice::Iter<'a, ImmutableBlock<T>>;
fn into_iter(self) -> Self::IntoIter {
self.blocks.iter()
}
}
#[cfg(test)]
mod tests {
use super::*;
// Note: Full tests require test infrastructure to create ImmutableBlock instances.
// These tests verify the basic container operations.
#[test]
fn test_empty_holder() {
let holder: BlockHolder<()> = BlockHolder::empty();
assert!(holder.is_empty());
assert_eq!(holder.count(), 0);
}
#[test]
fn test_default_is_empty() {
let holder: BlockHolder<()> = BlockHolder::default();
assert!(holder.is_empty());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! SessionEndpoint: Point-to-point session primitive.
//!
//! This is the core building block for unified sessions. It handles:
//! - State machine (control role + attachment state + phase)
//! - Message receive channel for incoming [`SessionMessage`]
//! - State publication via watch channel for observers
//! - Transport for sending messages to peer
//!
//! It does NOT handle:
//! - Block holding (use [`BlockHolder`] for that)
//! - Staging logic (caller invokes staging)
//! - Multi-peer orchestration (that's [`InitiatorSession`]'s job)
//!
//! # Usage
//!
//! ```ignore
//! // Create an endpoint
//! let (tx, rx) = mpsc::channel(32);
//! let endpoint = SessionEndpoint::new(
//! SessionId::new_v4(),
//! my_instance_id,
//! transport,
//! rx,
//! );
//!
//! // Process messages
//! while let Some(msg) = endpoint.recv().await {
//! match msg {
//! SessionMessage::Attach { peer, as_role, .. } => {
//! endpoint.accept_attachment(peer, as_role);
//! // ... handle attachment
//! }
//! // ... other messages
//! }
//! }
//! ```
use std::sync::Arc;
use tokio::sync::{mpsc, watch};
use anyhow::Result;
use crate::InstanceId;
use super::{
SessionId,
messages::{BlockInfo, SessionMessage, SessionStateSnapshot},
state::{AttachmentState, ControlRole, SessionPhase},
transport::MessageTransport,
};
/// A point-to-point session endpoint.
///
/// This is the common building block for all session types. It encapsulates:
/// - Identity (session_id, instance_id)
/// - State machine (control role, attachment, phase)
/// - Communication (message receive, state publication)
///
/// The endpoint starts in `Neutral + Unattached` state by default.
pub struct SessionEndpoint {
session_id: SessionId,
instance_id: InstanceId,
// State
control_role: ControlRole,
attachment: AttachmentState,
phase: SessionPhase,
// Communication
transport: Arc<MessageTransport>,
msg_rx: mpsc::Receiver<SessionMessage>,
state_tx: watch::Sender<SessionStateSnapshot>,
}
impl SessionEndpoint {
/// Create a new endpoint in `Neutral + Unattached` state.
pub fn new(
session_id: SessionId,
instance_id: InstanceId,
transport: Arc<MessageTransport>,
msg_rx: mpsc::Receiver<SessionMessage>,
) -> Self {
let initial_state = SessionStateSnapshot {
phase: SessionPhase::default(),
control_role: ControlRole::default(),
g2_blocks: Vec::new(),
g3_pending: 0,
ready_layer_range: None,
};
let (state_tx, _) = watch::channel(initial_state);
Self {
session_id,
instance_id,
control_role: ControlRole::default(),
attachment: AttachmentState::default(),
phase: SessionPhase::default(),
transport,
msg_rx,
state_tx,
}
}
/// Create a new endpoint with pre-attached state.
///
/// Used when creating a session that is already attached to a peer
/// (e.g., ResponderSession which is pre-attached to the initiator).
pub fn new_attached(
session_id: SessionId,
instance_id: InstanceId,
peer: InstanceId,
role: ControlRole,
phase: SessionPhase,
transport: Arc<MessageTransport>,
msg_rx: mpsc::Receiver<SessionMessage>,
) -> Self {
let initial_state = SessionStateSnapshot {
phase,
control_role: role,
g2_blocks: Vec::new(),
g3_pending: 0,
ready_layer_range: None,
};
let (state_tx, _) = watch::channel(initial_state);
Self {
session_id,
instance_id,
control_role: role,
attachment: AttachmentState::Attached { peer },
phase,
transport,
msg_rx,
state_tx,
}
}
// =========================================================================
// State Accessors
// =========================================================================
/// Get the session ID.
pub fn session_id(&self) -> SessionId {
self.session_id
}
/// Get this endpoint's instance ID.
pub fn instance_id(&self) -> InstanceId {
self.instance_id
}
/// Get the current control role.
pub fn control_role(&self) -> ControlRole {
self.control_role
}
/// Check if a peer is attached.
pub fn is_attached(&self) -> bool {
self.attachment.is_attached()
}
/// Get the attached peer's instance ID.
pub fn peer(&self) -> Option<InstanceId> {
self.attachment.peer()
}
/// Get the current session phase.
pub fn phase(&self) -> SessionPhase {
self.phase
}
/// Check if the session is in a terminal state.
pub fn is_complete(&self) -> bool {
self.phase.is_terminal()
}
// =========================================================================
// State Transitions
// =========================================================================
/// Set the session phase.
pub fn set_phase(&mut self, phase: SessionPhase) {
self.phase = phase;
}
/// Set the control role.
///
/// Use this for direct role changes (e.g., when processing YieldControl
/// or AcquireControl messages).
pub fn set_control_role(&mut self, role: ControlRole) {
self.control_role = role;
}
/// Accept an attachment from a peer.
///
/// Transitions from `Unattached` to `Attached` and sets the control role.
pub fn accept_attachment(&mut self, peer: InstanceId, role: ControlRole) {
self.attachment = AttachmentState::Attached { peer };
self.control_role = role;
}
/// Detach from the current peer.
///
/// Returns the detached peer's instance ID if there was one.
pub fn detach(&mut self) -> Option<InstanceId> {
let peer = self.attachment.peer();
self.attachment = AttachmentState::Unattached;
self.control_role = ControlRole::Neutral;
peer
}
/// Yield control to peer.
///
/// Transitions from `Controller` to `Neutral`.
/// Returns `Err` if not currently `Controller`.
pub fn yield_control(&mut self) -> Result<()> {
if self.control_role != ControlRole::Controller {
anyhow::bail!("Cannot yield control: not currently Controller");
}
self.control_role = ControlRole::Neutral;
Ok(())
}
/// Acquire control from peer.
///
/// Transitions from `Neutral` or `Controllee` to `Controller`.
/// The peer must be in `Neutral` state for this to succeed.
pub fn acquire_control(&mut self) -> Result<()> {
if self.control_role == ControlRole::Controller {
// Already controller, no-op
return Ok(());
}
self.control_role = ControlRole::Controller;
Ok(())
}
/// Handle a peer yielding control to us.
///
/// Transitions to `Neutral` (peer has yielded, we can now acquire if we want).
pub fn peer_yielded_control(&mut self) {
// When peer yields, we stay in our current role or become neutral
// The peer is now Neutral, so we can acquire if desired
if self.control_role == ControlRole::Controllee {
self.control_role = ControlRole::Neutral;
}
}
/// Handle a peer acquiring control.
///
/// Transitions from `Neutral` to `Controllee`.
pub fn peer_acquired_control(&mut self) -> Result<()> {
if self.control_role == ControlRole::Controller {
anyhow::bail!("Cannot transition to Controllee: currently Controller");
}
self.control_role = ControlRole::Controllee;
Ok(())
}
// =========================================================================
// Message Handling
// =========================================================================
/// Receive the next message.
///
/// Returns `None` when the channel is closed.
pub async fn recv(&mut self) -> Option<SessionMessage> {
self.msg_rx.recv().await
}
/// Try to receive a message without blocking.
pub fn try_recv(&mut self) -> Result<SessionMessage, mpsc::error::TryRecvError> {
self.msg_rx.try_recv()
}
// =========================================================================
// Outbound Messages
// =========================================================================
/// Send an attach message to a peer.
pub async fn send_attach(&self, peer: InstanceId, as_role: ControlRole) -> Result<()> {
let msg = SessionMessage::Attach {
peer: self.instance_id,
session_id: self.session_id,
as_role,
};
self.send_to(peer, msg).await
}
/// Send a detach message to the current peer.
pub async fn send_detach(&self) -> Result<()> {
let peer = self
.peer()
.ok_or_else(|| anyhow::anyhow!("Cannot detach: not attached"))?;
let msg = SessionMessage::Detach {
peer: self.instance_id,
session_id: self.session_id,
};
self.send_to(peer, msg).await
}
/// Send yield control message to peer.
pub async fn send_yield_control(&self) -> Result<()> {
let peer = self
.peer()
.ok_or_else(|| anyhow::anyhow!("Cannot yield: not attached"))?;
let msg = SessionMessage::YieldControl {
peer: self.instance_id,
session_id: self.session_id,
};
self.send_to(peer, msg).await
}
/// Send acquire control message to peer.
pub async fn send_acquire_control(&self) -> Result<()> {
let peer = self
.peer()
.ok_or_else(|| anyhow::anyhow!("Cannot acquire: not attached"))?;
let msg = SessionMessage::AcquireControl {
peer: self.instance_id,
session_id: self.session_id,
};
self.send_to(peer, msg).await
}
/// Send a message to a specific peer.
pub async fn send_to(&self, peer: InstanceId, msg: SessionMessage) -> Result<()> {
self.transport.send_session(peer, msg).await
}
/// Send a message to the currently attached peer.
pub async fn send(&self, msg: SessionMessage) -> Result<()> {
let peer = self
.peer()
.ok_or_else(|| anyhow::anyhow!("Cannot send: not attached"))?;
self.send_to(peer, msg).await
}
// =========================================================================
// State Publication
// =========================================================================
/// Publish the current state snapshot.
///
/// This updates all watchers with the new state.
pub fn publish_state(&self, g2_blocks: Vec<BlockInfo>, g3_pending: usize) {
let _ = self.state_tx.send(SessionStateSnapshot {
phase: self.phase,
control_role: self.control_role,
g2_blocks,
g3_pending,
ready_layer_range: None,
});
}
/// Publish state with layer range information.
///
/// Used for layerwise transfer where specific layers are ready.
pub fn publish_state_with_layer_range(
&self,
g2_blocks: Vec<BlockInfo>,
g3_pending: usize,
layer_range: Option<std::ops::Range<usize>>,
) {
let _ = self.state_tx.send(SessionStateSnapshot {
phase: self.phase,
control_role: self.control_role,
g2_blocks,
g3_pending,
ready_layer_range: layer_range,
});
}
/// Get a receiver for state updates.
///
/// Returns a watch receiver that will receive state snapshots
/// whenever they are published.
pub fn state_rx(&self) -> watch::Receiver<SessionStateSnapshot> {
self.state_tx.subscribe()
}
/// Get the transport for direct access (for legacy interop).
pub fn transport(&self) -> &Arc<MessageTransport> {
&self.transport
}
}
/// Channel type for sending SessionMessages to an endpoint.
pub type SessionMessageTx = mpsc::Sender<SessionMessage>;
/// Create a new session message channel.
pub fn session_message_channel(
buffer: usize,
) -> (SessionMessageTx, mpsc::Receiver<SessionMessage>) {
mpsc::channel(buffer)
}
#[cfg(test)]
mod tests {
use super::*;
use dashmap::DashMap;
fn create_test_transport() -> Arc<MessageTransport> {
Arc::new(MessageTransport::local(
Arc::new(DashMap::new()),
Arc::new(DashMap::new()),
))
}
#[test]
fn test_endpoint_initial_state() {
let (_, rx) = mpsc::channel(32);
let transport = create_test_transport();
let session_id = SessionId::new_v4();
let instance_id = InstanceId::new_v4();
let endpoint = SessionEndpoint::new(session_id, instance_id, transport, rx);
assert_eq!(endpoint.session_id(), session_id);
assert_eq!(endpoint.instance_id(), instance_id);
assert_eq!(endpoint.control_role(), ControlRole::Neutral);
assert!(!endpoint.is_attached());
assert!(endpoint.peer().is_none());
assert_eq!(endpoint.phase(), SessionPhase::Searching);
assert!(!endpoint.is_complete());
}
#[test]
fn test_endpoint_attachment() {
let (_, rx) = mpsc::channel(32);
let transport = create_test_transport();
let session_id = SessionId::new_v4();
let instance_id = InstanceId::new_v4();
let peer_id = InstanceId::new_v4();
let mut endpoint = SessionEndpoint::new(session_id, instance_id, transport, rx);
// Accept attachment
endpoint.accept_attachment(peer_id, ControlRole::Controllee);
assert!(endpoint.is_attached());
assert_eq!(endpoint.peer(), Some(peer_id));
assert_eq!(endpoint.control_role(), ControlRole::Controllee);
// Detach
let detached = endpoint.detach();
assert_eq!(detached, Some(peer_id));
assert!(!endpoint.is_attached());
assert_eq!(endpoint.control_role(), ControlRole::Neutral);
}
#[test]
fn test_endpoint_pre_attached() {
let (_, rx) = mpsc::channel(32);
let transport = create_test_transport();
let session_id = SessionId::new_v4();
let instance_id = InstanceId::new_v4();
let peer_id = InstanceId::new_v4();
let endpoint = SessionEndpoint::new_attached(
session_id,
instance_id,
peer_id,
ControlRole::Controller,
SessionPhase::Holding,
transport,
rx,
);
assert!(endpoint.is_attached());
assert_eq!(endpoint.peer(), Some(peer_id));
assert_eq!(endpoint.control_role(), ControlRole::Controller);
assert_eq!(endpoint.phase(), SessionPhase::Holding);
}
#[test]
fn test_control_transitions() {
let (_, rx) = mpsc::channel(32);
let transport = create_test_transport();
let session_id = SessionId::new_v4();
let instance_id = InstanceId::new_v4();
let peer_id = InstanceId::new_v4();
let mut endpoint = SessionEndpoint::new(session_id, instance_id, transport, rx);
endpoint.accept_attachment(peer_id, ControlRole::Controller);
// Yield control
assert!(endpoint.yield_control().is_ok());
assert_eq!(endpoint.control_role(), ControlRole::Neutral);
// Can't yield again
assert!(endpoint.yield_control().is_err());
// Acquire control
assert!(endpoint.acquire_control().is_ok());
assert_eq!(endpoint.control_role(), ControlRole::Controller);
}
#[test]
fn test_phase_transitions() {
let (_, rx) = mpsc::channel(32);
let transport = create_test_transport();
let session_id = SessionId::new_v4();
let instance_id = InstanceId::new_v4();
let mut endpoint = SessionEndpoint::new(session_id, instance_id, transport, rx);
assert_eq!(endpoint.phase(), SessionPhase::Searching);
assert!(!endpoint.is_complete());
endpoint.set_phase(SessionPhase::Holding);
assert_eq!(endpoint.phase(), SessionPhase::Holding);
endpoint.set_phase(SessionPhase::Complete);
assert!(endpoint.is_complete());
}
#[tokio::test]
async fn test_state_publication() {
let (_, rx) = mpsc::channel(32);
let transport = create_test_transport();
let session_id = SessionId::new_v4();
let instance_id = InstanceId::new_v4();
let endpoint = SessionEndpoint::new(session_id, instance_id, transport, rx);
let mut state_rx = endpoint.state_rx();
// Initial state
let state = state_rx.borrow().clone();
assert_eq!(state.phase, SessionPhase::Searching);
assert_eq!(state.g2_blocks.len(), 0);
// Publish new state
endpoint.publish_state(vec![], 5);
// Wait for change
state_rx.changed().await.unwrap();
let state = state_rx.borrow().clone();
assert_eq!(state.g3_pending, 5);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! SessionHandle: Unified handle for controlling a remote session.
//!
//! This is the unified replacement for `RemoteSessionHandle` that uses the
//! new session model types (`SessionPhase`, `ControlRole`, `SessionStateSnapshot`).
//!
//! Key improvements over RemoteSessionHandle:
//! - Uses unified `SessionPhase` and `ControlRole` enums
//! - Supports bidirectional control transfer (yield/acquire)
//! - Uses `SessionStateSnapshot` for state observation
//! - Same RDMA support via `ParallelWorker`
use anyhow::Result;
use std::sync::Arc;
use tokio::sync::watch;
use crate::worker::group::ParallelWorkers;
use crate::{BlockId, InstanceId, SequenceHash};
use kvbm_common::LogicalLayoutHandle;
use kvbm_physical::transfer::{TransferCompleteNotification, TransferOptions};
use super::{
BlockInfo, ControlRole, SessionId, SessionMessage, SessionPhase, SessionStateSnapshot,
transport::MessageTransport,
};
/// Handle for controlling a remote session.
///
/// Created by attaching to a remote session. Provides methods to:
/// - Query and observe session state
/// - Issue control commands (trigger staging, release blocks)
/// - Transfer control bidirectionally (yield/acquire)
/// - Pull blocks via RDMA
///
/// ## Usage
///
/// ```ignore
/// // Attach to remote session
/// let mut handle = leader.attach_session(remote_id, session_id).await?;
///
/// // Wait for initial state
/// let state = handle.wait_for_ready().await?;
///
/// // Trigger staging if needed
/// if state.g3_pending > 0 {
/// handle.trigger_staging().await?;
/// handle.wait_for_ready().await?;
/// }
///
/// // Pull blocks via RDMA
/// let notification = handle.pull_blocks_rdma(&state.g2_blocks, &local_block_ids).await?;
/// notification.await?;
///
/// // Notify remote and detach
/// handle.mark_blocks_pulled(hashes).await?;
/// handle.detach().await?;
/// ```
pub struct SessionHandle {
session_id: SessionId,
remote_instance: InstanceId,
local_instance: InstanceId,
transport: Arc<MessageTransport>,
// State observation
state_rx: watch::Receiver<SessionStateSnapshot>,
// RDMA transfer support
parallel_worker: Option<Arc<dyn ParallelWorkers>>,
}
impl SessionHandle {
/// Create a new session handle.
///
/// Note: Currently unused during incremental migration. Will be used once
/// existing session implementations are fully migrated to the new model.
#[allow(dead_code)]
pub(crate) fn new(
session_id: SessionId,
remote_instance: InstanceId,
local_instance: InstanceId,
transport: Arc<MessageTransport>,
state_rx: watch::Receiver<SessionStateSnapshot>,
) -> Self {
Self {
session_id,
remote_instance,
local_instance,
transport,
state_rx,
parallel_worker: None,
}
}
/// Add RDMA support to this handle.
pub fn with_rdma_support(mut self, parallel_worker: Arc<dyn ParallelWorkers>) -> Self {
self.parallel_worker = Some(parallel_worker);
self
}
// =========================================================================
// Identity
// =========================================================================
/// Get the session ID.
pub fn session_id(&self) -> SessionId {
self.session_id
}
/// Get the remote instance ID.
pub fn remote_instance(&self) -> InstanceId {
self.remote_instance
}
/// Get the local instance ID.
pub fn local_instance(&self) -> InstanceId {
self.local_instance
}
// =========================================================================
// State Observation
// =========================================================================
/// Get the current state snapshot (non-blocking).
pub fn current_state(&self) -> SessionStateSnapshot {
self.state_rx.borrow().clone()
}
/// Get the current phase.
pub fn phase(&self) -> SessionPhase {
self.state_rx.borrow().phase
}
/// Get the current control role of the remote session.
pub fn remote_control_role(&self) -> ControlRole {
self.state_rx.borrow().control_role
}
/// Check if state has changed since last read.
pub fn has_changed(&self) -> bool {
self.state_rx.has_changed().unwrap_or(false)
}
/// Wait for state to change.
pub async fn wait_for_change(&mut self) -> Result<SessionStateSnapshot> {
self.state_rx
.changed()
.await
.map_err(|e| anyhow::anyhow!("State channel closed: {}", e))?;
Ok(self.state_rx.borrow().clone())
}
/// Wait for the session to reach Ready phase (all blocks in G2).
pub async fn wait_for_ready(&mut self) -> Result<SessionStateSnapshot> {
self.state_rx
.wait_for(|s| s.phase == SessionPhase::Ready || s.phase.is_terminal())
.await
.map_err(|e| anyhow::anyhow!("Failed waiting for ready: {}", e))?;
let state = self.state_rx.borrow().clone();
if state.phase == SessionPhase::Failed {
anyhow::bail!("Session failed while waiting for ready");
}
Ok(state)
}
/// Wait for the session to complete.
pub async fn wait_for_complete(&mut self) -> Result<SessionStateSnapshot> {
self.state_rx
.wait_for(|s| s.phase.is_terminal())
.await
.map_err(|e| anyhow::anyhow!("Failed waiting for complete: {}", e))?;
Ok(self.state_rx.borrow().clone())
}
/// Check if the session is complete.
pub fn is_complete(&self) -> bool {
self.state_rx.borrow().phase.is_terminal()
}
/// Check if the session is ready (all blocks in G2).
pub fn is_ready(&self) -> bool {
self.state_rx.borrow().phase == SessionPhase::Ready
}
/// Get G2 blocks from current state.
pub fn get_g2_blocks(&self) -> Vec<BlockInfo> {
self.state_rx.borrow().g2_blocks.clone()
}
/// Get count of G3 blocks pending staging.
pub fn g3_pending_count(&self) -> usize {
self.state_rx.borrow().g3_pending
}
/// Get the layer range that is ready for transfer.
///
/// Returns `None` if all layers are ready or layerwise tracking is not active.
/// Returns `Some(range)` if only specific layers are ready.
pub fn ready_layer_range(&self) -> Option<std::ops::Range<usize>> {
self.state_rx.borrow().ready_layer_range.clone()
}
// =========================================================================
// Control Commands
// =========================================================================
/// Trigger G3→G2 staging on the remote session.
///
/// Idempotent - no-op if already staging or staged.
pub async fn trigger_staging(&self) -> Result<()> {
let msg = SessionMessage::TriggerStaging {
session_id: self.session_id,
};
self.transport.send_session(self.remote_instance, msg).await
}
/// Notify remote that blocks have been pulled.
///
/// Call after successfully pulling blocks via RDMA.
pub async fn mark_blocks_pulled(&self, pulled_hashes: Vec<SequenceHash>) -> Result<()> {
let msg = SessionMessage::BlocksPulled {
session_id: self.session_id,
pulled_hashes,
};
self.transport.send_session(self.remote_instance, msg).await
}
/// Detach from the session.
///
/// Consumes the handle. The remote session will release remaining blocks.
pub async fn detach(self) -> Result<()> {
let msg = SessionMessage::Detach {
peer: self.local_instance,
session_id: self.session_id,
};
self.transport.send_session(self.remote_instance, msg).await
}
// =========================================================================
// Control Transfer (Bidirectional)
// =========================================================================
/// Yield control to the remote peer.
///
/// After yielding, this handle transitions to Neutral and the remote
/// can acquire control if desired.
pub async fn yield_control(&self) -> Result<()> {
let msg = SessionMessage::YieldControl {
peer: self.local_instance,
session_id: self.session_id,
};
self.transport.send_session(self.remote_instance, msg).await
}
/// Attempt to acquire control from the remote peer.
///
/// Valid when remote is in Neutral state.
pub async fn acquire_control(&self) -> Result<()> {
let msg = SessionMessage::AcquireControl {
peer: self.local_instance,
session_id: self.session_id,
};
self.transport.send_session(self.remote_instance, msg).await
}
// =========================================================================
// RDMA Transfer Methods
// =========================================================================
/// Check if remote metadata has been imported.
pub fn has_remote_metadata(&self) -> bool {
self.parallel_worker
.as_ref()
.map(|pw| pw.has_remote_metadata(self.remote_instance))
.unwrap_or(false)
}
/// Ensure remote metadata is imported (lazy loading).
pub async fn ensure_metadata_imported(&mut self) -> Result<()> {
let parallel_worker = self
.parallel_worker
.as_ref()
.ok_or_else(|| anyhow::anyhow!("RDMA support not configured"))?;
if parallel_worker.has_remote_metadata(self.remote_instance) {
return Ok(());
}
let remote_metadata = self
.transport
.request_metadata(self.remote_instance)
.await?;
parallel_worker
.connect_remote(self.remote_instance, remote_metadata)?
.await?;
Ok(())
}
/// Pull blocks from remote G2 to local G2 via RDMA.
///
/// This method:
/// 1. Ensures remote metadata is imported
/// 2. Executes SPMD-aware transfer (worker N pulls from remote worker N)
/// 3. Returns notification that completes when all transfers done
pub async fn pull_blocks_rdma(
&mut self,
blocks: &[BlockInfo],
local_dst_block_ids: &[BlockId],
) -> Result<TransferCompleteNotification> {
self.ensure_metadata_imported().await?;
self.pull_blocks_rdma_explicit(blocks, local_dst_block_ids)
}
/// Pull blocks with explicit metadata pre-import.
///
/// Caller must have already ensured metadata is imported.
pub fn pull_blocks_rdma_explicit(
&self,
blocks: &[BlockInfo],
local_dst_block_ids: &[BlockId],
) -> Result<TransferCompleteNotification> {
let parallel_worker = self
.parallel_worker
.as_ref()
.ok_or_else(|| anyhow::anyhow!("RDMA support not configured"))?;
if !parallel_worker.has_remote_metadata(self.remote_instance) {
anyhow::bail!(
"Remote metadata not imported for instance {}",
self.remote_instance
);
}
if blocks.len() != local_dst_block_ids.len() {
anyhow::bail!(
"Block count mismatch: source={}, destination={}",
blocks.len(),
local_dst_block_ids.len()
);
}
let src_block_ids: Vec<BlockId> = blocks.iter().map(|b| b.block_id).collect();
parallel_worker.execute_remote_onboard_for_instance(
self.remote_instance,
LogicalLayoutHandle::G2,
src_block_ids,
LogicalLayoutHandle::G2,
local_dst_block_ids.to_vec().into(),
Default::default(),
)
}
/// Pull blocks from remote G2 to local G2 via RDMA with transfer options.
///
/// This method allows specifying transfer options like layer range for
/// layerwise transfer. Use this when you only want to pull specific layers.
///
/// # Example
/// ```ignore
/// // Pull only layer 0
/// let notification = handle.pull_blocks_rdma_with_options(
/// &state.g2_blocks,
/// &local_block_ids,
/// TransferOptions::builder().layer_range(0..1).build(),
/// ).await?;
/// notification.await?;
/// ```
pub async fn pull_blocks_rdma_with_options(
&mut self,
blocks: &[BlockInfo],
local_dst_block_ids: &[BlockId],
options: TransferOptions,
) -> Result<TransferCompleteNotification> {
self.ensure_metadata_imported().await?;
self.pull_blocks_rdma_with_options_explicit(blocks, local_dst_block_ids, options)
}
/// Pull blocks with options and explicit metadata pre-import.
///
/// Caller must have already ensured metadata is imported.
pub fn pull_blocks_rdma_with_options_explicit(
&self,
blocks: &[BlockInfo],
local_dst_block_ids: &[BlockId],
options: TransferOptions,
) -> Result<TransferCompleteNotification> {
let parallel_worker = self
.parallel_worker
.as_ref()
.ok_or_else(|| anyhow::anyhow!("RDMA support not configured"))?;
if !parallel_worker.has_remote_metadata(self.remote_instance) {
anyhow::bail!(
"Remote metadata not imported for instance {}",
self.remote_instance
);
}
if blocks.len() != local_dst_block_ids.len() {
anyhow::bail!(
"Block count mismatch: source={}, destination={}",
blocks.len(),
local_dst_block_ids.len()
);
}
let src_block_ids: Vec<BlockId> = blocks.iter().map(|b| b.block_id).collect();
parallel_worker.execute_remote_onboard_for_instance(
self.remote_instance,
LogicalLayoutHandle::G2,
src_block_ids,
LogicalLayoutHandle::G2,
local_dst_block_ids.to_vec().into(),
options,
)
}
}
/// Sender for state updates to SessionHandle.
pub struct SessionHandleStateTx {
tx: watch::Sender<SessionStateSnapshot>,
}
impl SessionHandleStateTx {
/// Create a new state sender.
pub fn new(tx: watch::Sender<SessionStateSnapshot>) -> Self {
Self { tx }
}
/// Update state from a full snapshot.
pub fn update(&self, state: SessionStateSnapshot) {
let _ = self.tx.send(state);
}
/// Update phase only.
pub fn set_phase(&self, phase: SessionPhase) {
self.tx.send_modify(|state| {
state.phase = phase;
});
}
/// Update G2 blocks.
pub fn set_g2_blocks(&self, blocks: Vec<BlockInfo>) {
self.tx.send_modify(|state| {
state.g2_blocks = blocks;
});
}
/// Add newly staged blocks.
///
/// # Arguments
/// * `staged` - Blocks that have been staged
/// * `g3_remaining` - Count of G3 blocks still pending
/// * `layer_range` - Optional layer range that is ready for transfer
pub fn add_staged_blocks(
&self,
staged: Vec<BlockInfo>,
g3_remaining: usize,
layer_range: Option<std::ops::Range<usize>>,
) {
self.tx.send_modify(|state| {
state.g2_blocks.extend(staged);
state.g3_pending = g3_remaining;
state.ready_layer_range = layer_range;
if g3_remaining == 0 && state.ready_layer_range.is_none() {
// All blocks staged and no layer tracking = fully ready
state.phase = SessionPhase::Ready;
}
});
}
/// Set error/failed state.
pub fn set_failed(&self) {
self.tx.send_modify(|state| {
state.phase = SessionPhase::Failed;
});
}
}
/// Create a new session handle state channel.
pub fn session_handle_state_channel()
-> (SessionHandleStateTx, watch::Receiver<SessionStateSnapshot>) {
let initial = SessionStateSnapshot {
phase: SessionPhase::Searching,
control_role: ControlRole::Controllee,
g2_blocks: Vec::new(),
g3_pending: 0,
ready_layer_range: None,
};
let (tx, rx) = watch::channel(initial);
(SessionHandleStateTx::new(tx), rx)
}
#[cfg(test)]
mod tests {
use super::*;
use dashmap::DashMap;
fn create_test_transport() -> Arc<MessageTransport> {
Arc::new(MessageTransport::local(
Arc::new(DashMap::new()),
Arc::new(DashMap::new()),
))
}
#[test]
fn test_session_handle_state_channel() {
let (tx, rx) = session_handle_state_channel();
// Initial state
let state = rx.borrow().clone();
assert_eq!(state.phase, SessionPhase::Searching);
assert_eq!(state.control_role, ControlRole::Controllee);
assert!(state.g2_blocks.is_empty());
// Update state
tx.set_phase(SessionPhase::Ready);
let state = rx.borrow().clone();
assert_eq!(state.phase, SessionPhase::Ready);
}
#[test]
fn test_session_handle_creation() {
let (_, rx) = session_handle_state_channel();
let transport = create_test_transport();
let session_id = SessionId::new_v4();
let remote_id = InstanceId::new_v4();
let local_id = InstanceId::new_v4();
let handle = SessionHandle::new(session_id, remote_id, local_id, transport, rx);
assert_eq!(handle.session_id(), session_id);
assert_eq!(handle.remote_instance(), remote_id);
assert_eq!(handle.local_instance(), local_id);
assert_eq!(handle.phase(), SessionPhase::Searching);
assert!(!handle.is_ready());
assert!(!handle.is_complete());
assert!(!handle.has_remote_metadata());
}
#[tokio::test]
async fn test_wait_for_ready() {
let (tx, rx) = session_handle_state_channel();
let transport = create_test_transport();
let session_id = SessionId::new_v4();
let mut handle = SessionHandle::new(
session_id,
InstanceId::new_v4(),
InstanceId::new_v4(),
transport,
rx,
);
// Spawn task to update state
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
tx.set_phase(SessionPhase::Ready);
});
let state = handle.wait_for_ready().await.unwrap();
assert_eq!(state.phase, SessionPhase::Ready);
}
#[test]
fn test_add_staged_blocks() {
let (tx, rx) = session_handle_state_channel();
// Set initial g3 pending
tx.update(SessionStateSnapshot {
phase: SessionPhase::Staging,
control_role: ControlRole::Controllee,
g2_blocks: Vec::new(),
g3_pending: 5,
ready_layer_range: None,
});
let state = rx.borrow().clone();
assert_eq!(state.g3_pending, 5);
assert!(state.g2_blocks.is_empty());
// Add staged blocks with remaining = 0
let block = BlockInfo {
block_id: 42,
sequence_hash: crate::SequenceHash::new(1, None, 100),
layout_handle: kvbm_physical::manager::LayoutHandle::new(0, 1),
};
tx.add_staged_blocks(vec![block], 0, None);
let state = rx.borrow().clone();
assert_eq!(state.g2_blocks.len(), 1);
assert_eq!(state.g3_pending, 0);
// No layer range + g3_remaining == 0 → Ready
assert_eq!(state.phase, SessionPhase::Ready);
}
#[test]
fn test_set_failed() {
let (tx, rx) = session_handle_state_channel();
// Initially Searching
assert_eq!(rx.borrow().phase, SessionPhase::Searching);
tx.set_failed();
assert_eq!(rx.borrow().phase, SessionPhase::Failed);
}
#[tokio::test]
async fn test_wait_for_complete() {
let (tx, rx) = session_handle_state_channel();
let transport = create_test_transport();
let session_id = SessionId::new_v4();
let mut handle = SessionHandle::new(
session_id,
InstanceId::new_v4(),
InstanceId::new_v4(),
transport,
rx,
);
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
tx.set_phase(SessionPhase::Complete);
});
let state = handle.wait_for_complete().await.unwrap();
assert_eq!(state.phase, SessionPhase::Complete);
assert!(handle.is_complete());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::Result;
use tokio::sync::{Mutex, mpsc, watch};
use tokio::task::JoinHandle;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::{
BlockId, G2, G3, InstanceId, SequenceHash, object::ObjectBlockOps,
worker::group::ParallelWorkers,
};
use kvbm_common::LogicalLayoutHandle;
use kvbm_logical::{blocks::ImmutableBlock, manager::BlockManager};
use kvbm_physical::transfer::TransferOptions;
use super::staging;
use super::{
super::{OnboardingStatus, SessionControl, StagingMode},
BlockHolder, SessionId,
messages::OnboardMessage,
transport::MessageTransport,
};
/// Validate that sequence hashes have contiguous positions (X, X+1, X+2, ...).
///
/// The positions don't need to start at 0, but they must be monotonically
/// increasing with no gaps.
fn validate_contiguous_positions(seq_hashes: &[SequenceHash]) -> Result<()> {
if seq_hashes.len() <= 1 {
return Ok(());
}
// Collect and sort positions
let mut positions: Vec<u64> = seq_hashes.iter().map(|h| h.position()).collect();
positions.sort();
// Check monotonically increasing with no holes: X, X+1, X+2, ...
for window in positions.windows(2) {
if window[1] != window[0] + 1 {
anyhow::bail!(
"Position gap detected in remote blocks: {} -> {} (expected {}). \
This indicates a block ordering bug.",
window[0],
window[1],
window[0] + 1
);
}
}
Ok(())
}
/// Tracks G4/object storage search state for parallel search.
///
/// This state is used when G4 search runs in parallel with G2/G3 search.
/// The first responder (local, remote, or G4) wins for each hash.
#[derive(Default)]
struct G4SearchState {
/// Hashes won by G4 in the first-responder-wins race
won_hashes: HashSet<SequenceHash>,
/// Hashes currently pending load (get_blocks in progress)
pending_load: HashSet<SequenceHash>,
/// Hashes that failed to load with error messages
failed_hashes: HashMap<SequenceHash, String>,
/// Block IDs allocated for G4→G2 loading (sequence_hash → block_id)
allocated_blocks: HashMap<SequenceHash, BlockId>,
}
impl G4SearchState {
fn new() -> Self {
Self::default()
}
/// Clear all state.
#[expect(dead_code)]
fn clear(&mut self) {
self.won_hashes.clear();
self.pending_load.clear();
self.failed_hashes.clear();
self.allocated_blocks.clear();
}
}
/// Initiator-side session for coordinating distributed block search.
///
/// Supports three staging modes:
/// - Hold: Find and hold blocks (G2+G3), no staging
/// - Prepare: Stage G3→G2 everywhere, keep session alive
/// - Full: Stage G3→G2 + RDMA pull remote G2→local G2, session completes
pub struct InitiatorSession {
session_id: SessionId,
instance_id: InstanceId,
mode: StagingMode,
g2_manager: Arc<BlockManager<G2>>,
g3_manager: Option<Arc<BlockManager<G3>>>,
parallel_worker: Option<Arc<dyn ParallelWorkers>>,
transport: Arc<MessageTransport>,
status_tx: watch::Sender<OnboardingStatus>,
// Held blocks from local search using BlockHolder for RAII semantics
local_g2_blocks: BlockHolder<G2>,
local_g3_blocks: BlockHolder<G3>,
// Track remote blocks by tier
remote_g2_blocks: HashMap<InstanceId, Vec<BlockId>>, // G2: track block IDs
remote_g2_hashes: HashMap<InstanceId, Vec<SequenceHash>>, // G2: track sequence hashes (parallel to block_ids)
remote_g3_blocks: HashMap<InstanceId, Vec<SequenceHash>>, // G3: track sequence hashes
// Shared with FindMatchesResult for block access
all_g2_blocks: Arc<Mutex<Option<Vec<ImmutableBlock<G2>>>>>,
// Control channel for deferred operations
control_rx: mpsc::Receiver<SessionControl>,
// G4/Object storage fields
/// Object storage client for G4 search and load (leader-initiated)
object_client: Option<Arc<dyn ObjectBlockOps>>,
/// G4 search state tracking won hashes, pending loads, and failures
g4_state: G4SearchState,
/// Channel for receiving G4 search/load results
g4_rx: Option<mpsc::Receiver<OnboardMessage>>,
/// Handle for G4 search task (for cancellation on drop)
#[allow(dead_code)]
g4_task_handle: Option<JoinHandle<()>>,
}
impl InitiatorSession {
/// Create a new initiator session.
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
session_id: SessionId,
instance_id: InstanceId,
mode: StagingMode,
g2_manager: Arc<BlockManager<G2>>,
g3_manager: Option<Arc<BlockManager<G3>>>,
parallel_worker: Option<Arc<dyn ParallelWorkers>>,
transport: Arc<MessageTransport>,
status_tx: watch::Sender<OnboardingStatus>,
all_g2_blocks: Arc<Mutex<Option<Vec<ImmutableBlock<G2>>>>>,
control_rx: mpsc::Receiver<SessionControl>,
object_client: Option<Arc<dyn ObjectBlockOps>>,
) -> Self {
Self {
session_id,
instance_id,
mode,
g2_manager,
g3_manager,
parallel_worker,
transport,
status_tx,
local_g2_blocks: BlockHolder::empty(),
local_g3_blocks: BlockHolder::empty(),
remote_g2_blocks: HashMap::new(),
remote_g2_hashes: HashMap::new(),
remote_g3_blocks: HashMap::new(),
all_g2_blocks,
control_rx,
object_client,
g4_state: G4SearchState::new(),
g4_rx: None,
g4_task_handle: None,
}
}
/// Run the initiator session task.
pub async fn run(
mut self,
mut rx: mpsc::Receiver<OnboardMessage>,
remote_leaders: Vec<InstanceId>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<()> {
tracing::debug!(
session_id = %self.session_id,
mode = ?self.mode,
num_hashes = sequence_hashes.len(),
num_remotes = remote_leaders.len(),
"Starting initiator session"
);
// Phase 1: Search (local G2 and G3, then remote if needed)
self.search_phase(&mut rx, &remote_leaders, &sequence_hashes)
.await?;
// Phase 1.5: Apply find policy (first-hole detection)
// Trims results to first contiguous sequence from start
self.apply_find_policy(&sequence_hashes).await?;
tracing::debug!(
session_id = %self.session_id,
"search_phase complete, entering mode handler"
);
// Phase 2: Staging based on mode
match self.mode {
StagingMode::Hold => {
tracing::debug!(session_id = %self.session_id, "Calling hold_mode()");
self.hold_mode().await?;
// Wait for control commands or shutdown
self.await_commands(rx).await?;
}
StagingMode::Prepare => {
self.prepare_mode(&mut rx).await?;
// Wait for pull command or shutdown
self.await_commands(rx).await?;
}
StagingMode::Full => {
self.full_mode(&mut rx).await?;
// Completes and exits
}
}
Ok(())
}
/// Phase 1: Search for blocks locally and remotely.
async fn search_phase(
&mut self,
rx: &mut mpsc::Receiver<OnboardMessage>,
remote_leaders: &[InstanceId],
sequence_hashes: &[SequenceHash],
) -> Result<()> {
// Local G2 search
self.local_g2_blocks = BlockHolder::new(self.g2_manager.match_blocks(sequence_hashes));
let mut matched_hashes: HashSet<SequenceHash> =
self.local_g2_blocks.sequence_hashes().into_iter().collect();
// Local G3 search
if let Some(ref g3_manager) = self.g3_manager {
let remaining: Vec<_> = sequence_hashes
.iter()
.filter(|h| !matched_hashes.contains(h))
.copied()
.collect();
if !remaining.is_empty() {
self.local_g3_blocks = BlockHolder::new(g3_manager.match_blocks(&remaining));
for hash in self.local_g3_blocks.sequence_hashes() {
matched_hashes.insert(hash);
}
}
}
// Check if remote/G4 search needed
// Continue if: not all matched locally AND (remote leaders exist OR object_client configured)
let has_object_client = self.object_client.is_some();
if matched_hashes.len() == sequence_hashes.len()
|| (remote_leaders.is_empty() && !has_object_client)
{
return Ok(());
}
// Remote search
let remaining_hashes: Vec<_> = sequence_hashes
.iter()
.filter(|h| !matched_hashes.contains(h))
.copied()
.collect();
if remaining_hashes.is_empty() {
return Ok(());
}
self.status_tx.send(OnboardingStatus::Searching).ok();
// Send CreateSession to all remotes FIRST
for remote in remote_leaders {
let msg = OnboardMessage::CreateSession {
requester: self.instance_id,
session_id: self.session_id,
sequence_hashes: remaining_hashes.clone(),
};
self.transport.send(*remote, msg).await?;
}
// Then spawn G4 search task if object storage is configured and parallel_worker is available
// We use parallel_worker.has_blocks() which fans out to workers with rank-prefixed keys
let g4_tx = if self.object_client.is_some() && self.parallel_worker.is_some() {
let (tx, rx) = mpsc::channel(16);
self.g4_rx = Some(rx);
// Spawn G4 search - searches the same remaining hashes as remote search
let handle = self.spawn_g4_search(remaining_hashes.clone(), tx.clone());
self.g4_task_handle = Some(handle);
Some(tx)
} else {
None
};
// Process search responses (including G4 if configured)
self.process_search_responses(rx, remote_leaders, &mut matched_hashes, g4_tx)
.await?;
Ok(())
}
/// Process G2Results, G3Results, and G4 results from responders.
///
/// Uses `tokio::select!` to handle both remote messages and G4 results
/// in parallel, applying first-responder-wins logic across all tiers.
async fn process_search_responses(
&mut self,
rx: &mut mpsc::Receiver<OnboardMessage>,
remote_leaders: &[InstanceId],
matched_hashes: &mut HashSet<SequenceHash>,
g4_tx: Option<mpsc::Sender<OnboardMessage>>,
) -> Result<()> {
let mut pending_g2_responses = remote_leaders.len();
let mut pending_g3_responses: HashSet<InstanceId> =
remote_leaders.iter().copied().collect();
let mut pending_search_complete: HashSet<InstanceId> =
remote_leaders.iter().copied().collect();
let mut pending_acknowledgments: HashSet<InstanceId> = HashSet::new();
// G4 state tracking
let mut pending_g4_search = self.g4_rx.is_some();
let mut pending_g4_load = false;
// Helper to check if all responses are complete
let is_complete = |pending_g2: usize,
pending_g3: &HashSet<InstanceId>,
pending_ack: &HashSet<InstanceId>,
pending_search: &HashSet<InstanceId>,
pending_g4_s: bool,
pending_g4_l: bool| {
pending_g2 == 0
&& pending_g3.is_empty()
&& pending_ack.is_empty()
&& pending_search.is_empty()
&& !pending_g4_s
&& !pending_g4_l
};
loop {
// Check completion before waiting for more messages
if is_complete(
pending_g2_responses,
&pending_g3_responses,
&pending_acknowledgments,
&pending_search_complete,
pending_g4_search,
pending_g4_load,
) {
tracing::debug!(
session_id = %self.session_id,
"All responses received (including G4), exiting search_phase"
);
break;
}
tokio::select! {
// Handle G4 messages from internal channel
g4_msg = async {
if let Some(ref mut g4_rx) = self.g4_rx {
g4_rx.recv().await
} else {
std::future::pending::<Option<OnboardMessage>>().await
}
} => {
let Some(msg) = g4_msg else {
// Channel closed unexpectedly
pending_g4_search = false;
pending_g4_load = false;
continue;
};
tracing::debug!(
session_id = %self.session_id,
msg = msg.variant_name(),
"process_search_responses received G4"
);
match msg {
OnboardMessage::G4Results { found_hashes, .. } => {
pending_g4_search = false;
// Process G4 results with first-responder-wins
let won_hashes = self.process_g4_results(found_hashes, matched_hashes);
// If G4 won any hashes, start loading them
if !won_hashes.is_empty()
&& let Some(ref tx) = g4_tx {
self.load_g4_blocks(won_hashes, tx.clone()).await?;
pending_g4_load = true;
}
}
OnboardMessage::G4LoadComplete { success, failures, blocks, .. } => {
self.handle_g4_load_complete(success, failures, blocks);
pending_g4_load = false;
}
_ => {}
}
}
// Handle remote messages
remote_msg = rx.recv() => {
let Some(msg) = remote_msg else {
// Channel closed - exit loop
break;
};
tracing::debug!(
session_id = %self.session_id,
msg = msg.variant_name(),
"process_search_responses received"
);
match msg {
OnboardMessage::G2Results {
responder,
sequence_hashes,
block_ids,
..
} => {
tracing::debug!(
session_id = %self.session_id,
responder = %responder,
num_hashes = sequence_hashes.len(),
"Processing G2Results"
);
// First-responder-wins logic using sequence hashes
let mut hold_hashes = Vec::new();
let mut drop_hashes = Vec::new();
for (seq_hash, block_id) in sequence_hashes.iter().zip(block_ids.iter()) {
if matched_hashes.insert(*seq_hash) {
hold_hashes.push(*seq_hash);
self.remote_g2_blocks
.entry(responder)
.or_default()
.push(*block_id);
// Track sequence hash in parallel for block registration after RDMA pull
self.remote_g2_hashes
.entry(responder)
.or_default()
.push(*seq_hash);
} else {
drop_hashes.push(*seq_hash);
}
}
// Send HoldBlocks decision
self.transport
.send(
responder,
OnboardMessage::HoldBlocks {
requester: self.instance_id,
session_id: self.session_id,
hold_hashes,
drop_hashes,
},
)
.await?;
pending_acknowledgments.insert(responder);
pending_g2_responses -= 1;
}
OnboardMessage::G3Results {
responder,
sequence_hashes,
..
} => {
// Store G3 sequence hashes for later staging
for seq_hash in sequence_hashes {
if matched_hashes.insert(seq_hash) {
self.remote_g3_blocks
.entry(responder)
.or_default()
.push(seq_hash);
}
}
pending_g3_responses.remove(&responder);
}
OnboardMessage::SearchComplete { responder, .. } => {
pending_search_complete.remove(&responder);
// SearchComplete means responder is done with G2 AND G3 search
pending_g3_responses.remove(&responder);
tracing::debug!(
session_id = %self.session_id,
responder = %responder,
g2_pending = pending_g2_responses,
g3_pending = pending_g3_responses.len(),
ack_pending = pending_acknowledgments.len(),
search_pending = pending_search_complete.len(),
g4_search = pending_g4_search,
g4_load = pending_g4_load,
"SearchComplete"
);
}
OnboardMessage::Acknowledged { responder, .. } => {
pending_acknowledgments.remove(&responder);
}
_ => {}
}
}
}
}
Ok(())
}
/// Apply "first hole" policy: trim results to first contiguous sequence.
///
/// This implements the policy where we only return blocks from position 0
/// up to (but not including) the first missing block. Any blocks after the
/// first hole are released.
///
/// # Arguments
/// * `sequence_hashes` - The original query hashes in order (position 0 to N)
async fn apply_find_policy(&mut self, sequence_hashes: &[SequenceHash]) -> Result<()> {
// Build set of all matched hashes (local + remote)
let mut matched_hashes: HashSet<SequenceHash> = HashSet::new();
// Local G2 blocks
for hash in self.local_g2_blocks.sequence_hashes() {
matched_hashes.insert(hash);
}
// Local G3 blocks
for hash in self.local_g3_blocks.sequence_hashes() {
matched_hashes.insert(hash);
}
// Remote G2 hashes
for hashes in self.remote_g2_hashes.values() {
for hash in hashes {
matched_hashes.insert(*hash);
}
}
// Remote G3 hashes
for hashes in self.remote_g3_blocks.values() {
for hash in hashes {
matched_hashes.insert(*hash);
}
}
// G4 won hashes (blocks successfully loaded from object storage)
for hash in &self.g4_state.won_hashes {
matched_hashes.insert(*hash);
}
// Find the first hole: count contiguous matches from start
let mut keep_count = 0;
for hash in sequence_hashes {
if matched_hashes.contains(hash) {
keep_count += 1;
} else {
// First hole found - stop here
break;
}
}
// If all hashes matched or first hole is at position 0, nothing to trim
if keep_count == sequence_hashes.len() || keep_count == matched_hashes.len() {
tracing::debug!(
session_id = %self.session_id,
matched = keep_count,
total = sequence_hashes.len(),
"apply_find_policy: no trimming needed"
);
return Ok(());
}
// Get the hashes to keep
let keep_hashes: Vec<SequenceHash> = sequence_hashes[..keep_count].to_vec();
let keep_set: HashSet<&SequenceHash> = keep_hashes.iter().collect();
tracing::debug!(
session_id = %self.session_id,
from = matched_hashes.len(),
to = keep_count,
first_hole = keep_count,
"apply_find_policy: trimming blocks"
);
// Filter local blocks
self.local_g2_blocks.retain(&keep_hashes);
self.local_g3_blocks.retain(&keep_hashes);
// Filter remote G2 block tracking and send ReleaseBlocks messages
for (remote_instance, block_ids) in &mut self.remote_g2_blocks {
let hashes = self.remote_g2_hashes.get_mut(remote_instance);
if let Some(hashes) = hashes {
// Find indices of blocks to release
let mut release_indices = Vec::new();
for (i, hash) in hashes.iter().enumerate() {
if !keep_set.contains(hash) {
release_indices.push(i);
}
}
// Collect hashes to release for ReleaseBlocks message
let release_hashes: Vec<SequenceHash> =
release_indices.iter().map(|&i| hashes[i]).collect();
// Remove from tracking (reverse order to preserve indices)
for i in release_indices.into_iter().rev() {
hashes.remove(i);
block_ids.remove(i);
}
// Send ReleaseBlocks message if any blocks need releasing
if !release_hashes.is_empty() {
tracing::debug!(
session_id = %self.session_id,
count = release_hashes.len(),
instance = %remote_instance,
"Releasing G2 blocks beyond first hole"
);
self.transport
.send(
*remote_instance,
OnboardMessage::ReleaseBlocks {
requester: self.instance_id,
session_id: self.session_id,
release_hashes,
},
)
.await?;
}
}
}
// Filter remote G3 block tracking and send ReleaseBlocks messages
for (remote_instance, hashes) in &mut self.remote_g3_blocks {
// Find hashes to release
let release_hashes: Vec<SequenceHash> = hashes
.iter()
.filter(|h| !keep_set.contains(h))
.copied()
.collect();
// Remove from tracking
hashes.retain(|h| keep_set.contains(h));
// Send ReleaseBlocks message if any blocks need releasing
if !release_hashes.is_empty() {
tracing::debug!(
session_id = %self.session_id,
count = release_hashes.len(),
instance = %remote_instance,
"Releasing G3 blocks beyond first hole"
);
self.transport
.send(
*remote_instance,
OnboardMessage::ReleaseBlocks {
requester: self.instance_id,
session_id: self.session_id,
release_hashes,
},
)
.await?;
}
}
// Filter G4 state - release allocated blocks and remove from tracking for hashes beyond first hole
let g4_release_hashes: Vec<SequenceHash> = self
.g4_state
.won_hashes
.iter()
.filter(|h| !keep_set.contains(h))
.copied()
.collect();
if !g4_release_hashes.is_empty() {
tracing::debug!(
session_id = %self.session_id,
count = g4_release_hashes.len(),
"Releasing G4 blocks beyond first hole"
);
for hash in &g4_release_hashes {
// Remove from won_hashes
self.g4_state.won_hashes.remove(hash);
// Remove from pending_load (if still loading)
self.g4_state.pending_load.remove(hash);
// Remove allocated block (will be deallocated when dropped)
self.g4_state.allocated_blocks.remove(hash);
}
}
Ok(())
}
/// Hold mode: Just hold blocks without staging.
async fn hold_mode(&mut self) -> Result<()> {
let local_g2 = self.local_g2_blocks.count();
let local_g3 = self.local_g3_blocks.count();
let remote_g2: usize = self.remote_g2_blocks.values().map(|v| v.len()).sum();
let remote_g3: usize = self.remote_g3_blocks.values().map(|v| v.len()).sum();
// G4 state
let pending_g4 = self.g4_state.pending_load.len();
let loaded_g4 = self.g4_state.won_hashes.len();
let failed_g4 = self.g4_state.failed_hashes.len();
tracing::debug!(
session_id = %self.session_id,
local_g2,
local_g3,
remote_g2,
remote_g3,
pending_g4,
loaded_g4,
failed_g4,
"hold_mode"
);
self.status_tx
.send(OnboardingStatus::Holding {
local_g2,
local_g3,
remote_g2,
remote_g3,
pending_g4,
loaded_g4,
failed_g4,
})
.ok();
tracing::debug!(session_id = %self.session_id, "Sent Holding status");
Ok(())
}
/// Send StageBlocks to all remotes with G3 blocks and wait for BlocksReady responses.
///
/// After sending StageBlocks, waits for each remote to respond with BlocksReady,
/// which updates `remote_g2_blocks` and `remote_g2_hashes` with the newly staged blocks.
async fn send_stage_and_wait_for_ready(
&mut self,
rx: &mut mpsc::Receiver<OnboardMessage>,
) -> Result<()> {
if self.remote_g3_blocks.is_empty() {
return Ok(());
}
// Send StageBlocks to remotes for their G3 sequence hashes
let remotes_with_g3: Vec<(InstanceId, Vec<SequenceHash>)> = self
.remote_g3_blocks
.iter()
.map(|(k, v)| (*k, v.clone()))
.collect();
for (remote, stage_hashes) in &remotes_with_g3 {
self.transport
.send(
*remote,
OnboardMessage::StageBlocks {
requester: self.instance_id,
session_id: self.session_id,
stage_hashes: stage_hashes.clone(),
},
)
.await?;
}
// Wait for BlocksReady from all remotes that had G3 blocks
let mut pending: HashSet<InstanceId> = remotes_with_g3.iter().map(|(k, _)| *k).collect();
while !pending.is_empty() {
match rx.recv().await {
Some(OnboardMessage::BlocksReady {
responder,
sequence_hashes,
block_ids,
..
}) => {
tracing::debug!(
session_id = %self.session_id,
responder = %responder,
count = block_ids.len(),
"Received BlocksReady"
);
self.remote_g2_blocks
.entry(responder)
.or_default()
.extend(block_ids);
self.remote_g2_hashes
.entry(responder)
.or_default()
.extend(sequence_hashes);
pending.remove(&responder);
}
Some(other) => {
tracing::warn!(
session_id = %self.session_id,
msg = other.variant_name(),
"Unexpected message while waiting for BlocksReady"
);
}
None => {
tracing::warn!(
session_id = %self.session_id,
"Channel closed while waiting for BlocksReady"
);
break;
}
}
}
Ok(())
}
/// Prepare mode: Stage all G3→G2 but keep session alive.
async fn prepare_mode(&mut self, rx: &mut mpsc::Receiver<OnboardMessage>) -> Result<()> {
// Stage local G3→G2
self.stage_local_g3_to_g2().await?;
// Send StageBlocks to remotes and wait for BlocksReady
self.send_stage_and_wait_for_ready(rx).await?;
let local_g2 = self.local_g2_blocks.count();
let remote_g2: usize = self.remote_g2_blocks.values().map(|v| v.len()).sum();
self.status_tx
.send(OnboardingStatus::Prepared {
local_g2,
remote_g2,
})
.ok();
Ok(())
}
/// Full mode: Stage G3→G2 + pull remote G2→local G2.
async fn full_mode(&mut self, rx: &mut mpsc::Receiver<OnboardMessage>) -> Result<()> {
// Stage local G3→G2
self.stage_local_g3_to_g2().await?;
// Send StageBlocks to remotes and wait for BlocksReady before pulling
self.send_stage_and_wait_for_ready(rx).await?;
// Pull remote G2→local G2 via RDMA (both original G2 and newly staged from G3)
self.pull_remote_blocks().await?;
// Consolidate all blocks
self.consolidate_blocks().await;
// Send CloseSession to all remotes
let all_remotes: HashSet<InstanceId> = self
.remote_g2_blocks
.keys()
.chain(self.remote_g3_blocks.keys())
.copied()
.collect();
for remote in all_remotes {
self.transport
.send(
remote,
OnboardMessage::CloseSession {
requester: self.instance_id,
session_id: self.session_id,
},
)
.await?;
}
Ok(())
}
/// Stage local G3→G2.
async fn stage_local_g3_to_g2(&mut self) -> Result<()> {
if self.local_g3_blocks.is_empty() {
return Ok(());
}
let parallel_worker = self
.parallel_worker
.as_ref()
.ok_or_else(|| anyhow::anyhow!("ParallelWorker required for G3→G2 staging"))?;
let result =
staging::stage_g3_to_g2(&self.local_g3_blocks, &self.g2_manager, &**parallel_worker)
.await?;
let _ = self.local_g3_blocks.take_all();
self.local_g2_blocks.extend(result.new_g2_blocks);
Ok(())
}
/// Pull remote G2→local G2 via RDMA.
///
/// This method:
/// 1. Imports remote metadata for each instance (if not already imported)
/// 2. Allocates local G2 blocks as destinations
/// 3. Executes RDMA transfer via worker
/// 4. Registers pulled blocks with their sequence hashes
async fn pull_remote_blocks(&mut self) -> Result<()> {
let parallel_worker = self
.parallel_worker
.as_ref()
.ok_or_else(|| anyhow::anyhow!("ParallelWorker required for RDMA pull"))?;
// Process each remote instance that has G2 blocks to pull
for (remote_instance, block_ids) in self.remote_g2_blocks.clone() {
// Skip if no blocks to pull
if block_ids.is_empty() {
continue;
}
// Get the parallel sequence hashes for registration
let seq_hashes = self
.remote_g2_hashes
.get(&remote_instance)
.cloned()
.unwrap_or_default();
if seq_hashes.len() != block_ids.len() {
anyhow::bail!(
"Mismatch between block_ids ({}) and seq_hashes ({}) for instance {}",
block_ids.len(),
seq_hashes.len(),
remote_instance
);
}
// Sort (block_id, seq_hash) pairs by position to ensure correct transfer order
// This is a safety net in case responder sent blocks in wrong order
let mut pairs: Vec<(BlockId, SequenceHash)> =
block_ids.into_iter().zip(seq_hashes.into_iter()).collect();
pairs.sort_by_key(|(_, hash)| hash.position());
let block_ids: Vec<BlockId> = pairs.iter().map(|(id, _)| *id).collect();
let seq_hashes: Vec<SequenceHash> = pairs.iter().map(|(_, hash)| *hash).collect();
// Step 1: Import remote metadata if not already done
if !parallel_worker.has_remote_metadata(remote_instance) {
tracing::debug!(
session_id = %self.session_id,
instance = %remote_instance,
"Requesting metadata from instance"
);
let metadata = self.transport.request_metadata(remote_instance).await?;
parallel_worker
.connect_remote(remote_instance, metadata)?
.await?;
tracing::debug!(
session_id = %self.session_id,
instance = %remote_instance,
"Metadata imported for instance"
);
}
// Step 2: Allocate local G2 blocks as destinations
let dst_blocks = self
.g2_manager
.allocate_blocks(block_ids.len())
.ok_or_else(|| {
anyhow::anyhow!("Failed to allocate {} G2 blocks", block_ids.len())
})?;
let dst_ids: Vec<BlockId> = dst_blocks.iter().map(|b| b.block_id()).collect();
tracing::debug!(
session_id = %self.session_id,
count = block_ids.len(),
instance = %remote_instance,
"Pulling blocks via RDMA"
);
// Step 3: Execute RDMA transfer
// Uses execute_remote_onboard_for_instance which looks up the stored handle mapping
let notification = parallel_worker.execute_remote_onboard_for_instance(
remote_instance,
LogicalLayoutHandle::G2, // source is remote G2
block_ids,
LogicalLayoutHandle::G2, // destination is local G2
Arc::from(dst_ids),
TransferOptions::default(),
)?;
notification.await?;
tracing::debug!(
session_id = %self.session_id,
instance = %remote_instance,
"RDMA transfer complete"
);
// Step 4: Register pulled blocks with their sequence hashes
// We stage each block with the sequence hash from the remote,
// then register it to produce an immutable block.
let new_g2_blocks: Vec<ImmutableBlock<G2>> = dst_blocks
.into_iter()
.zip(seq_hashes.iter())
.map(|(dst, seq_hash)| {
let complete = dst
.stage(*seq_hash, self.g2_manager.block_size())
.expect("block size mismatch");
self.g2_manager.register_block(complete)
})
.collect();
// Add to local G2 blocks
self.local_g2_blocks.extend(new_g2_blocks);
}
Ok(())
}
/// Consolidate all G2 blocks into shared storage.
///
/// This method sorts blocks by sequence_hash position to ensure correct
/// positional correspondence for G2→G1 transfer. This is critical because
/// blocks from different sources (local G2, G3→G2, remote G2, G4) may arrive
/// in different orders, but the consumer expects them sorted by position.
async fn consolidate_blocks(&mut self) {
let mut all_blocks = self.local_g2_blocks.take_all();
// Sort blocks by sequence_hash position (lowest to highest)
// This ensures correct positional correspondence for G2→G1 transfer
all_blocks.sort_by_key(|b| b.sequence_hash().position());
// Validate contiguous positions - catches ordering bugs before data corruption.
// If validation fails, we still proceed with sorted blocks because:
// 1. Sorted order is strictly safer than unsorted for G2→G1 transfer
// 2. Non-contiguous positions indicate an upstream aggregation bug, not a
// sorting bug — failing here would discard valid cached data
// 3. The consumer (G1 transfer) handles sparse blocks correctly
let seq_hashes: Vec<SequenceHash> = all_blocks.iter().map(|b| b.sequence_hash()).collect();
if let Err(e) = validate_contiguous_positions(&seq_hashes) {
tracing::warn!(
session_id = %self.session_id,
error = %e,
"Block positions are not contiguous — proceeding with sorted order"
);
}
let matched_blocks = all_blocks.len();
*self.all_g2_blocks.lock().await = Some(all_blocks);
self.status_tx
.send(OnboardingStatus::Complete { matched_blocks })
.ok();
}
/// Wait for control commands (Hold/Prepare modes).
async fn await_commands(&mut self, mut rx: mpsc::Receiver<OnboardMessage>) -> Result<()> {
loop {
tokio::select! {
Some(cmd) = self.control_rx.recv() => {
match cmd {
SessionControl::Prepare => {
if self.mode == StagingMode::Hold {
self.prepare_mode(&mut rx).await?;
self.mode = StagingMode::Prepare;
}
}
SessionControl::Pull => {
if self.mode == StagingMode::Prepare {
self.pull_remote_blocks().await?;
self.consolidate_blocks().await;
// Send CloseSession to all remotes
let all_remotes: HashSet<InstanceId> = self
.remote_g2_blocks
.keys()
.chain(self.remote_g3_blocks.keys())
.copied()
.collect();
for remote in all_remotes {
self.transport.send(remote, OnboardMessage::CloseSession {
requester: self.instance_id,
session_id: self.session_id,
}).await?;
}
break;
}
}
SessionControl::Cancel => {
// Release all blocks and exit
let all_remotes: HashSet<InstanceId> = self
.remote_g2_blocks
.keys()
.chain(self.remote_g3_blocks.keys())
.copied()
.collect();
for remote in all_remotes {
self.transport.send(remote, OnboardMessage::CloseSession {
requester: self.instance_id,
session_id: self.session_id,
}).await?;
}
break;
}
SessionControl::Shutdown => {
break;
}
}
}
// Also drain any remaining messages from responders
Some(_msg) = rx.recv() => {
// Process any late messages if needed
}
}
}
Ok(())
}
// =========================================================================
// G4/Object Storage Methods
// =========================================================================
/// Spawn a G4 search task that runs in parallel with remote G2/G3 search.
///
/// This task calls `has_blocks` via parallel_worker which fans out to workers.
/// Workers use rank-prefixed keys, so we must query through them (not directly to S3).
fn spawn_g4_search(
&self,
sequence_hashes: Vec<SequenceHash>,
tx: mpsc::Sender<OnboardMessage>,
) -> JoinHandle<()> {
let session_id = self.session_id;
// Use parallel_worker for has_blocks - it fans out to workers who use rank-prefixed keys
let parallel_worker = self.parallel_worker.clone();
tokio::spawn(async move {
let Some(worker) = parallel_worker else {
// No parallel worker configured, send empty results
let _ = tx
.send(OnboardMessage::G4Results {
session_id,
found_hashes: vec![],
})
.await;
return;
};
// Call has_blocks via parallel_worker (fans out to workers with rank-prefixed keys)
let results = worker.has_blocks(sequence_hashes).await;
// Filter to only blocks that exist (Some(size))
let found_hashes: Vec<(SequenceHash, usize)> = results
.into_iter()
.filter_map(|(hash, size_opt)| size_opt.map(|size| (hash, size)))
.collect();
tracing::debug!(
session_id = %session_id,
count = found_hashes.len(),
"G4 search: found blocks in object storage"
);
// Send results back to initiator
let _ = tx
.send(OnboardMessage::G4Results {
session_id,
found_hashes,
})
.await;
})
}
/// Process G4 search results with first-responder-wins logic.
///
/// Returns the hashes that G4 won (not already claimed by G2/G3/remote).
fn process_g4_results(
&mut self,
found_hashes: Vec<(SequenceHash, usize)>,
matched_hashes: &mut HashSet<SequenceHash>,
) -> Vec<SequenceHash> {
let mut won_hashes = Vec::new();
for (hash, _size) in found_hashes {
// First-responder-wins: only claim if not already matched
if matched_hashes.insert(hash) {
won_hashes.push(hash);
self.g4_state.won_hashes.insert(hash);
}
}
tracing::debug!(
session_id = %self.session_id,
won_count = won_hashes.len(),
"G4 won hashes (first-responder-wins)"
);
won_hashes
}
/// Load G4 blocks into local G2 via workers.
///
/// Allocates G2 destination blocks and coordinates workers to download
/// from object storage via `get_blocks`. After successful download, blocks
/// are registered with the G2 manager and returned via G4LoadComplete message.
async fn load_g4_blocks(
&mut self,
won_hashes: Vec<SequenceHash>,
g4_tx: mpsc::Sender<OnboardMessage>,
) -> Result<()> {
if won_hashes.is_empty() {
return Ok(());
}
let parallel_worker = self
.parallel_worker
.as_ref()
.ok_or_else(|| anyhow::anyhow!("ParallelWorkers required for G4 load"))?;
// Mark hashes as pending load
for hash in &won_hashes {
self.g4_state.pending_load.insert(*hash);
}
// Allocate G2 destination blocks
let dst_blocks = self
.g2_manager
.allocate_blocks(won_hashes.len())
.ok_or_else(|| {
anyhow::anyhow!(
"Failed to allocate {} G2 blocks for G4 load",
won_hashes.len()
)
})?;
let dst_ids: Vec<BlockId> = dst_blocks.iter().map(|b| b.block_id()).collect();
// Track allocated blocks (for cleanup on failure)
for (hash, block_id) in won_hashes.iter().zip(dst_ids.iter()) {
self.g4_state.allocated_blocks.insert(*hash, *block_id);
}
tracing::debug!(
session_id = %self.session_id,
count = won_hashes.len(),
"Loading G4 blocks via workers"
);
// Clone values for the spawned task
let session_id = self.session_id;
let hashes = won_hashes.clone();
let parallel_worker = parallel_worker.clone();
let g2_manager = self.g2_manager.clone();
// Spawn load task so we can continue processing other messages
// IMPORTANT: dst_blocks is moved into the task to keep them alive during download
tokio::spawn(async move {
// Execute get_blocks via parallel worker
let results = parallel_worker
.get_blocks(hashes.clone(), LogicalLayoutHandle::G2, dst_ids.clone())
.await;
// Separate successes and failures, register successful blocks
let mut success = Vec::new();
let mut failures = Vec::new();
let mut blocks = Vec::new();
// Iterate over results alongside the dst_blocks and hashes
for ((result, dst_block), seq_hash) in results
.into_iter()
.zip(dst_blocks.into_iter())
.zip(hashes.iter())
{
match result {
Ok(hash) => {
// Register the block with its sequence hash
// This adds it to the BlockRegistry for presence filtering
let complete = dst_block
.stage(*seq_hash, g2_manager.block_size())
.expect("block size mismatch");
let immutable = g2_manager.register_block(complete);
blocks.push(immutable);
success.push(hash);
}
Err(hash) => {
// Block will be returned to pool when dst_block is dropped
failures.push((hash, "Failed to download block".to_string()));
}
}
}
tracing::debug!(
session_id = %session_id,
success_count = success.len(),
failure_count = failures.len(),
"G4 load complete"
);
// Send completion message with registered blocks
let _ = g4_tx
.send(OnboardMessage::G4LoadComplete {
session_id,
success,
failures,
blocks: std::sync::Arc::new(blocks),
})
.await;
});
Ok(())
}
/// Handle G4 load completion, updating state and adding blocks to local_g2_blocks.
///
/// The blocks have already been registered with the G2 manager in the spawned task,
/// so they are now visible in the BlockRegistry for presence filtering.
fn handle_g4_load_complete(
&mut self,
success: Vec<SequenceHash>,
failures: Vec<(SequenceHash, String)>,
blocks: Arc<Vec<ImmutableBlock<G2>>>,
) {
// Process successful loads - update state tracking
for hash in &success {
self.g4_state.pending_load.remove(hash);
// Remove from allocated_blocks since we now have registered ImmutableBlocks
self.g4_state.allocated_blocks.remove(hash);
}
// Unwrap the Arc to get the Vec (this is the only owner since the message was just received)
let blocks =
Arc::try_unwrap(blocks).expect("G4LoadComplete should be the sole owner of blocks");
// Add the registered G4 blocks to local_g2_blocks
// These blocks are now registered in the BlockRegistry and will be
// detected by the PresenceFilter during G1→G2 offloading
self.local_g2_blocks.extend(blocks);
// Process failures
for (hash, error) in failures {
self.g4_state.pending_load.remove(&hash);
self.g4_state.failed_hashes.insert(hash, error);
// Remove from allocated_blocks on failure (block was already dropped)
self.g4_state.allocated_blocks.remove(&hash);
// Also remove from won_hashes since it failed to load
self.g4_state.won_hashes.remove(&hash);
}
tracing::debug!(
session_id = %self.session_id,
won = self.g4_state.won_hashes.len(),
pending = self.g4_state.pending_load.len(),
failed = self.g4_state.failed_hashes.len(),
local_g2 = self.local_g2_blocks.count(),
"G4 load complete, blocks added to local_g2_blocks"
);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::ops::Range;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use crate::{BlockId, G2, InstanceId, SequenceHash};
use kvbm_logical::blocks::ImmutableBlock;
use kvbm_physical::manager::LayoutHandle;
use super::SessionId;
/// Messages exchanged between leaders during onboarding sessions.
///
/// Phase 2 protocol (G2-only):
/// 1. Initiator sends CreateSession to multiple responders
/// 2. Each responder searches local G2 and sends G2Results back
/// 3. Initiator applies first-responder-wins and sends HoldBlocks to each
/// 4. Responders send Acknowledged after releasing unwanted blocks
///
/// Phase 3 protocol (G3 staging):
/// 5. Responders search G3 and send G3Results
/// 6. Initiator sends StageBlocks with blocks to stage G3->G2
/// 7. Responders stage blocks and send BlocksReady when complete
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum OnboardMessage {
/// Initiator creates a new onboarding session.
CreateSession {
requester: InstanceId,
session_id: SessionId,
sequence_hashes: Vec<SequenceHash>,
},
/// Responder signals local search (G2 and G3) is complete.
SearchComplete {
responder: InstanceId,
session_id: SessionId,
},
/// Responder reports G2 search results.
/// - sequence_hashes: ordered list of matched sequence hashes
/// - block_ids: parallel list of block IDs (can be zipped with sequence_hashes)
G2Results {
responder: InstanceId,
session_id: SessionId,
sequence_hashes: Vec<SequenceHash>,
block_ids: Vec<BlockId>,
},
/// Responder reports G3 search results.
/// - sequence_hashes: ordered list of matched sequence hashes (no block IDs)
G3Results {
responder: InstanceId,
session_id: SessionId,
sequence_hashes: Vec<SequenceHash>,
},
/// Initiator tells responder which sequence hashes to hold/drop.
/// Works across G2 and G3 tiers.
HoldBlocks {
requester: InstanceId,
session_id: SessionId,
hold_hashes: Vec<SequenceHash>,
drop_hashes: Vec<SequenceHash>,
},
/// Initiator tells responder which G3 sequence hashes to stage to G2.
/// Any G3 blocks with these hashes should be staged to G2.
StageBlocks {
requester: InstanceId,
session_id: SessionId,
stage_hashes: Vec<SequenceHash>,
},
/// Responder reports newly staged blocks are ready in G2 (after G3->G2 staging).
/// Only reports blocks that were just staged, not all G2 blocks.
/// - sequence_hashes: newly staged blocks
/// - block_ids: parallel to sequence_hashes
BlocksReady {
responder: InstanceId,
session_id: SessionId,
sequence_hashes: Vec<SequenceHash>,
block_ids: Vec<BlockId>,
},
/// Responder acknowledges hold/drop request.
Acknowledged {
responder: InstanceId,
session_id: SessionId,
},
/// Initiator tells responder to release specific sequence hashes that weren't selected.
/// Works across G2 and G3 tiers.
ReleaseBlocks {
requester: InstanceId,
session_id: SessionId,
release_hashes: Vec<SequenceHash>,
},
/// Initiator tells responder session is complete, responder can cleanup.
CloseSession {
requester: InstanceId,
session_id: SessionId,
},
// =========================================================================
// G4/Object Storage Messages (Internal - not sent over network)
// =========================================================================
/// G4 search results from object storage `has_blocks`.
///
/// Internal message sent via mpsc channel from the G4 search task
/// to the initiator session. Contains hashes found in object storage
/// with their sizes.
G4Results {
session_id: SessionId,
/// Hashes found in G4 with their sizes in bytes
found_hashes: Vec<(SequenceHash, usize)>,
},
/// G4 load completion results from object storage `get_blocks`.
///
/// Internal message sent via mpsc channel from the G4 load task
/// to the initiator session. Contains per-block success/failure.
G4LoadComplete {
session_id: SessionId,
/// Successfully loaded hashes
success: Vec<SequenceHash>,
/// Failed hashes with error messages
failures: Vec<(SequenceHash, String)>,
/// Successfully loaded and registered G2 blocks.
/// These are ready to be added to local_g2_blocks.
/// Wrapped in Arc for Clone derivation (internal message only).
#[serde(skip)]
blocks: Arc<Vec<ImmutableBlock<G2>>>,
},
// TODO: Add heartbeat/TTL mechanism for handling unresponsive initiators
// Heartbeat {
// requester: InstanceId,
// session_id: SessionId,
// timestamp: u64,
// },
// TTL resets with each heartbeat. If TTL expires:
// - Responder releases all held blocks
// - Responder cleans up session state
// - Session task exits
}
/// Represents a block match found during search.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BlockMatch {
pub sequence_hash: SequenceHash,
pub block_id: BlockId,
}
impl OnboardMessage {
/// Extract the session ID from any message variant.
pub fn session_id(&self) -> SessionId {
match self {
OnboardMessage::CreateSession { session_id, .. }
| OnboardMessage::SearchComplete { session_id, .. }
| OnboardMessage::G2Results { session_id, .. }
| OnboardMessage::G3Results { session_id, .. }
| OnboardMessage::HoldBlocks { session_id, .. }
| OnboardMessage::StageBlocks { session_id, .. }
| OnboardMessage::BlocksReady { session_id, .. }
| OnboardMessage::Acknowledged { session_id, .. }
| OnboardMessage::ReleaseBlocks { session_id, .. }
| OnboardMessage::CloseSession { session_id, .. }
| OnboardMessage::G4Results { session_id, .. }
| OnboardMessage::G4LoadComplete { session_id, .. } => *session_id,
}
}
/// Extract the requester/responder instance ID from the message.
///
/// # Panics
/// Panics if called on G4 messages (internal only, no instance ID).
pub fn instance_id(&self) -> InstanceId {
match self {
OnboardMessage::CreateSession { requester, .. }
| OnboardMessage::HoldBlocks { requester, .. }
| OnboardMessage::StageBlocks { requester, .. }
| OnboardMessage::ReleaseBlocks { requester, .. }
| OnboardMessage::CloseSession { requester, .. } => *requester,
OnboardMessage::SearchComplete { responder, .. }
| OnboardMessage::G2Results { responder, .. }
| OnboardMessage::G3Results { responder, .. }
| OnboardMessage::BlocksReady { responder, .. }
| OnboardMessage::Acknowledged { responder, .. } => *responder,
OnboardMessage::G4Results { .. } | OnboardMessage::G4LoadComplete { .. } => {
panic!("G4 messages are internal and do not have an instance ID")
}
}
}
/// Get the variant name as a string for logging.
pub fn variant_name(&self) -> &'static str {
match self {
OnboardMessage::CreateSession { .. } => "CreateSession",
OnboardMessage::SearchComplete { .. } => "SearchComplete",
OnboardMessage::G2Results { .. } => "G2Results",
OnboardMessage::G3Results { .. } => "G3Results",
OnboardMessage::HoldBlocks { .. } => "HoldBlocks",
OnboardMessage::StageBlocks { .. } => "StageBlocks",
OnboardMessage::BlocksReady { .. } => "BlocksReady",
OnboardMessage::Acknowledged { .. } => "Acknowledged",
OnboardMessage::ReleaseBlocks { .. } => "ReleaseBlocks",
OnboardMessage::CloseSession { .. } => "CloseSession",
OnboardMessage::G4Results { .. } => "G4Results",
OnboardMessage::G4LoadComplete { .. } => "G4LoadComplete",
}
}
}
// =============================================================================
// Unified Session Protocol
// =============================================================================
//
// These types support the unified session model where sessions can dynamically
// transition between control roles.
use super::state::{ControlRole, SessionPhase};
/// Unified session message protocol.
///
/// Unified, bidirectional protocol that supports dynamic control transfer.
///
/// # Protocol Overview
///
/// 1. **Connection**: `Attach`/`Detach` for peer management
/// 2. **Control Transfer**: `YieldControl`/`AcquireControl` for bidirectional role changes
/// 3. **Block Operations**: Commands from controller to controllee
/// 4. **State Sync**: Responses from controllee to controller
/// 5. **Lifecycle**: `Close`/`Error` for termination
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SessionMessage {
// =========================================================================
// Connection Management
// =========================================================================
/// Attach to a session as a specific role.
///
/// Sent by a peer to establish the session relationship.
Attach {
/// The instance ID of the attaching peer.
peer: InstanceId,
/// The session to attach to.
session_id: SessionId,
/// The role this peer will assume (typically `Controllee` or `Controller`).
as_role: ControlRole,
},
/// Detach from a session.
///
/// Graceful disconnection. The session may continue if other peers are attached.
Detach {
/// The instance ID of the detaching peer.
peer: InstanceId,
/// The session to detach from.
session_id: SessionId,
},
// =========================================================================
// Control Transfer (Bidirectional)
// =========================================================================
/// Yield control to peer.
///
/// The sender transitions from `Controller` to `Neutral`.
/// The receiver (if in `Controllee`) can then `AcquireControl` or remain passive.
YieldControl {
/// The instance ID of the yielding peer.
peer: InstanceId,
/// The session.
session_id: SessionId,
},
/// Acquire control from peer.
///
/// The sender attempts to become `Controller`.
/// Valid when sender is `Neutral` or `Controllee` and peer is `Neutral`.
AcquireControl {
/// The instance ID of the peer acquiring control.
peer: InstanceId,
/// The session.
session_id: SessionId,
},
// =========================================================================
// Block Operations (Controller → Controllee)
// =========================================================================
/// Trigger staging of blocks (e.g., G3→G2).
TriggerStaging {
/// The session.
session_id: SessionId,
},
/// Request that specific blocks be held (kept alive).
HoldBlocks {
/// The session.
session_id: SessionId,
/// Sequence hashes of blocks to hold.
hold_hashes: Vec<SequenceHash>,
},
/// Release specific blocks (they can now be evicted).
ReleaseBlocks {
/// The session.
session_id: SessionId,
/// Sequence hashes of blocks to release.
release_hashes: Vec<SequenceHash>,
},
/// Notify that blocks have been pulled via RDMA.
///
/// The controllee can release these blocks from its hold.
BlocksPulled {
/// The session.
session_id: SessionId,
/// Sequence hashes of blocks that were pulled.
pulled_hashes: Vec<SequenceHash>,
},
// =========================================================================
// State Synchronization (Controllee → Controller)
// =========================================================================
/// Full state snapshot.
///
/// Sent after attachment and periodically on state changes.
StateResponse {
/// The session.
session_id: SessionId,
/// Complete state snapshot.
state: SessionStateSnapshot,
},
/// Notification that blocks have been staged.
///
/// This message supports layerwise transfer by optionally specifying
/// which layer range is ready. When `layer_range` is `None`, all layers
/// of the staged blocks are ready for transfer.
BlocksStaged {
/// The session.
session_id: SessionId,
/// Newly staged blocks (now in target tier).
staged_blocks: Vec<BlockInfo>,
/// Count of blocks remaining to stage.
remaining: usize,
/// Layer range that is ready for transfer.
///
/// - `None`: All layers are ready (default behavior)
/// - `Some(0..1)`: Only layer 0 is ready
/// - `Some(0..60)`: Layers 0-59 are ready
///
/// This enables layerwise streaming where the sender computes
/// layer-by-layer and notifies the receiver as each layer completes.
layer_range: Option<Range<usize>>,
},
// =========================================================================
// Lifecycle
// =========================================================================
/// Close the session gracefully.
Close {
/// The session.
session_id: SessionId,
},
/// Report an error.
Error {
/// The session.
session_id: SessionId,
/// Error description.
message: String,
},
}
impl SessionMessage {
/// Extract the session ID from any message variant.
pub fn session_id(&self) -> SessionId {
match self {
SessionMessage::Attach { session_id, .. }
| SessionMessage::Detach { session_id, .. }
| SessionMessage::YieldControl { session_id, .. }
| SessionMessage::AcquireControl { session_id, .. }
| SessionMessage::TriggerStaging { session_id, .. }
| SessionMessage::HoldBlocks { session_id, .. }
| SessionMessage::ReleaseBlocks { session_id, .. }
| SessionMessage::BlocksPulled { session_id, .. }
| SessionMessage::StateResponse { session_id, .. }
| SessionMessage::BlocksStaged { session_id, .. }
| SessionMessage::Close { session_id, .. }
| SessionMessage::Error { session_id, .. } => *session_id,
}
}
/// Extract the peer instance ID if present.
pub fn peer(&self) -> Option<InstanceId> {
match self {
SessionMessage::Attach { peer, .. }
| SessionMessage::Detach { peer, .. }
| SessionMessage::YieldControl { peer, .. }
| SessionMessage::AcquireControl { peer, .. } => Some(*peer),
_ => None,
}
}
/// Check if this is a control command (sent by controller).
pub fn is_control_command(&self) -> bool {
matches!(
self,
SessionMessage::TriggerStaging { .. }
| SessionMessage::HoldBlocks { .. }
| SessionMessage::ReleaseBlocks { .. }
| SessionMessage::BlocksPulled { .. }
)
}
/// Check if this is a state response (sent by controllee).
pub fn is_state_response(&self) -> bool {
matches!(
self,
SessionMessage::StateResponse { .. } | SessionMessage::BlocksStaged { .. }
)
}
/// Get the variant name as a string for logging.
pub fn variant_name(&self) -> &'static str {
match self {
SessionMessage::Attach { .. } => "Attach",
SessionMessage::Detach { .. } => "Detach",
SessionMessage::YieldControl { .. } => "YieldControl",
SessionMessage::AcquireControl { .. } => "AcquireControl",
SessionMessage::TriggerStaging { .. } => "TriggerStaging",
SessionMessage::HoldBlocks { .. } => "HoldBlocks",
SessionMessage::ReleaseBlocks { .. } => "ReleaseBlocks",
SessionMessage::BlocksPulled { .. } => "BlocksPulled",
SessionMessage::StateResponse { .. } => "StateResponse",
SessionMessage::BlocksStaged { .. } => "BlocksStaged",
SessionMessage::Close { .. } => "Close",
SessionMessage::Error { .. } => "Error",
}
}
}
/// Complete session state snapshot.
///
/// Sent in `SessionMessage::StateResponse` to provide the controller
/// with full visibility into the controllee's state.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionStateSnapshot {
/// Current session phase.
pub phase: SessionPhase,
/// Current control role of the sender.
pub control_role: ControlRole,
/// Blocks currently in G2 (ready for RDMA pull).
pub g2_blocks: Vec<BlockInfo>,
/// Count of blocks pending staging to G2.
pub g3_pending: usize,
/// Layer range that is ready for transfer.
///
/// - `None`: All layers are ready (or not applicable)
/// - `Some(0..1)`: Only layer 0 is ready
/// - `Some(0..60)`: Layers 0-59 are ready
///
/// This is updated when receiving `BlocksStaged` messages with `layer_range`.
/// The controller can use this to know which layers can be pulled.
#[serde(default)]
pub ready_layer_range: Option<Range<usize>>,
}
/// Block information for session messages.
///
/// Contains the metadata needed to identify and transfer a block.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BlockInfo {
/// Physical block ID in the layout.
pub block_id: BlockId,
/// Logical sequence hash.
pub sequence_hash: SequenceHash,
/// Layout handle for RDMA operations.
pub layout_handle: LayoutHandle,
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! # Session Module
//!
//! This module provides session management for distributed block transfers.
//!
//! ## Core Building Blocks
//!
//! Composable building blocks for session management:
//!
//! - `BlockHolder<T>`: RAII container for holding blocks during sessions
//! - `SessionEndpoint`: Point-to-point session primitive with state machine
//! - `SessionHandle`: Unified handle for controlling remote sessions
//! - `SessionMessage`: Unified message protocol with bidirectional control
//! - `SessionPhase`, `ControlRole`, `AttachmentState`: State machine types
//!
//! ## Session Implementations
//!
//! - `ServerSession`: Server-side session (merges former EndpointSession + ControllableSession)
//! - `InitiatorSession`: Multi-peer search orchestrator (OnboardMessage)
//! - `ResponderSession`: Responds to search requests (OnboardMessage)
// Core session building blocks
mod blocks;
mod endpoint;
mod handle;
mod server_session;
mod staging;
mod state;
// Session implementations
mod initiator;
mod messages;
mod responder;
pub mod transport;
// =============================================================================
// Core Building Blocks
// =============================================================================
/// RAII container for holding blocks during sessions.
pub use blocks::BlockHolder;
/// Point-to-point session endpoint with state machine.
pub use endpoint::{SessionEndpoint, SessionMessageTx, session_message_channel};
/// Server-side session (unified replacement for EndpointSession + ControllableSession).
pub use server_session::{
ServerSession, ServerSessionCommand, ServerSessionHandle, ServerSessionOptions,
create_server_session,
};
// Backwards-compatible aliases for the former EndpointSession types.
pub use server_session::ServerSessionCommand as EndpointSessionCommand;
pub use server_session::ServerSessionHandle as EndpointSessionHandle;
/// Unified handle for controlling remote sessions.
pub use handle::{SessionHandle, SessionHandleStateTx, session_handle_state_channel};
/// State machine types for the unified session model.
pub use state::{AttachmentState, ControlRole, SessionPhase};
/// Unified session message protocol.
pub use messages::{BlockInfo, SessionMessage, SessionStateSnapshot};
// =============================================================================
// Session Implementations
// =============================================================================
/// Session implementations for initiator and responder patterns.
pub use initiator::InitiatorSession;
pub use responder::ResponderSession;
/// Backwards-compatible re-exports (ControllableSessionResult is still used externally).
pub use server_session::ServerSessionOptions as ControllableSessionOptions;
/// Result of creating a controllable/server session.
#[derive(Debug, Clone)]
pub struct ControllableSessionResult {
/// The unique session ID.
pub session_id: super::SessionId,
/// Number of G2 blocks found.
pub local_g2_count: usize,
/// Number of G3 blocks found.
pub local_g3_count: usize,
}
/// Message types for session communication.
pub use messages::{BlockMatch, OnboardMessage};
/// Transport types.
pub use transport::{LocalTransport, MessageTransport, VeloTransport};
use anyhow::Result;
use dashmap::DashMap;
use tokio::sync::mpsc;
pub type SessionId = uuid::Uuid;
pub type OnboardSessionTx = mpsc::Sender<OnboardMessage>;
/// Route an [`OnboardMessage`] to its per-session task channel.
///
/// Looks up the session ID in the `DashMap` registry and forwards the message
/// through the session's mpsc sender. Each session processes messages serially
/// via its channel, so ordering is preserved per-session.
pub async fn dispatch_onboard_message(
sessions: &DashMap<SessionId, OnboardSessionTx>,
message: OnboardMessage,
) -> Result<()> {
let session_id = message.session_id();
let sender = sessions.get(&session_id).map(|entry| entry.value().clone());
if let Some(sender) = sender {
sender
.send(message)
.await
.map_err(|e| anyhow::anyhow!("failed to send to session {session_id}: {e}"))?;
return Ok(());
}
anyhow::bail!("no session task registered for session {session_id}");
}
/// Route a unified [`SessionMessage`] to its session task.
///
/// All message variants are routed through a single `DashMap<SessionId, SessionMessageTx>`
/// registry.
pub async fn dispatch_session_message(
sessions: &DashMap<SessionId, SessionMessageTx>,
message: SessionMessage,
) -> Result<()> {
let session_id = message.session_id();
let sender = sessions.get(&session_id).map(|entry| entry.value().clone());
if let Some(sender) = sender {
sender
.send(message)
.await
.map_err(|e| anyhow::anyhow!("failed to send to session {session_id}: {e}"))?;
return Ok(());
}
anyhow::bail!("no session registered for session {session_id}");
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_dispatch_onboard_message() {
let sessions: DashMap<SessionId, OnboardSessionTx> = DashMap::new();
let session_id = SessionId::new_v4();
let (tx, mut rx) = mpsc::channel(16);
sessions.insert(session_id, tx);
let msg = OnboardMessage::CloseSession {
requester: crate::InstanceId::new_v4(),
session_id,
};
dispatch_onboard_message(&sessions, msg).await.unwrap();
let received = rx.recv().await.unwrap();
assert_eq!(received.session_id(), session_id);
}
#[tokio::test]
async fn test_dispatch_session_message() {
let sessions: DashMap<SessionId, SessionMessageTx> = DashMap::new();
let session_id = SessionId::new_v4();
let (tx, mut rx) = mpsc::channel(16);
sessions.insert(session_id, tx);
let msg = SessionMessage::Close { session_id };
dispatch_session_message(&sessions, msg).await.unwrap();
let received = rx.recv().await.unwrap();
assert_eq!(received.session_id(), session_id);
}
#[tokio::test]
async fn test_dispatch_missing_onboard_session() {
let sessions: DashMap<SessionId, OnboardSessionTx> = DashMap::new();
let session_id = SessionId::new_v4();
let msg = OnboardMessage::CloseSession {
requester: crate::InstanceId::new_v4(),
session_id,
};
let result = dispatch_onboard_message(&sessions, msg).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_dispatch_missing_session_message() {
let sessions: DashMap<SessionId, SessionMessageTx> = DashMap::new();
let session_id = SessionId::new_v4();
let msg = SessionMessage::Close { session_id };
let result = dispatch_session_message(&sessions, msg).await;
assert!(result.is_err());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::Result;
use tokio::sync::mpsc;
use std::collections::HashSet;
use std::sync::Arc;
use crate::{BlockId, G2, G3, InstanceId, SequenceHash, worker::group::ParallelWorkers};
use kvbm_logical::manager::BlockManager;
use super::{BlockHolder, SessionId, messages::OnboardMessage, transport::MessageTransport};
/// Responder-side session for handling block onboarding requests.
///
/// Lifecycle:
/// 1. Spawned when receiving CreateSession
/// 2. Searches local G2 for matches
/// 3. Holds `ImmutableBlock<G2>` references (RAII)
/// 4. Sends G2Results immediately
/// 5. Searches local G3 for remaining matches (if G3 available)
/// 6. Sends G3Results
/// 7. Receives HoldBlocks and filters held G2 blocks
/// 8. Receives StageBlocks and executes G3->G2 transfers
/// 9. Sends BlocksReady when staging completes
/// 10. Sends Acknowledged
/// 11. Completes and drops (releases blocks)
pub struct ResponderSession {
session_id: SessionId,
instance_id: InstanceId,
requester: InstanceId,
g2_manager: Arc<BlockManager<G2>>,
g3_manager: Option<Arc<BlockManager<G3>>>,
parallel_worker: Option<Arc<dyn ParallelWorkers>>,
transport: Arc<MessageTransport>,
// Held blocks using BlockHolder for RAII semantics
// Blocks are automatically released when the session drops
held_g2_blocks: BlockHolder<G2>,
held_g3_blocks: BlockHolder<G3>,
}
impl ResponderSession {
/// Create a new responder session.
pub fn new(
session_id: SessionId,
instance_id: InstanceId,
requester: InstanceId,
g2_manager: Arc<BlockManager<G2>>,
g3_manager: Option<Arc<BlockManager<G3>>>,
parallel_worker: Option<Arc<dyn ParallelWorkers>>,
transport: Arc<MessageTransport>,
) -> Self {
Self {
session_id,
instance_id,
requester,
g2_manager,
g3_manager,
parallel_worker,
transport,
held_g2_blocks: BlockHolder::empty(),
held_g3_blocks: BlockHolder::empty(),
}
}
/// Run the responder session task.
///
/// This is the main session loop that processes messages from the channel.
pub async fn run(
mut self,
mut rx: mpsc::Receiver<OnboardMessage>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<()> {
// Phase 1: Immediate G2 search
// Use scan_matches instead of match_blocks to find all matching blocks
// without stopping on first miss (supports partial sequence matching)
let g2_matches_map = self.g2_manager.scan_matches(&sequence_hashes, true);
let mut g2_matches: Vec<_> = g2_matches_map.into_values().collect();
// Sort by position to ensure G2Results are in position order
// HashMap iteration order is arbitrary, so we must sort explicitly
g2_matches.sort_by_key(|block| block.sequence_hash().position());
// Hold the G2 blocks using BlockHolder (RAII semantics)
self.held_g2_blocks = BlockHolder::new(g2_matches);
// Send G2 results immediately (fire-and-forget) with parallel arrays
let g2_sequence_hashes: Vec<SequenceHash> = self.held_g2_blocks.sequence_hashes();
let g2_block_ids: Vec<BlockId> = self
.held_g2_blocks
.blocks()
.iter()
.map(|b| b.block_id())
.collect();
let g2_msg = OnboardMessage::G2Results {
responder: self.instance_id,
session_id: self.session_id,
sequence_hashes: g2_sequence_hashes,
block_ids: g2_block_ids,
};
self.transport.send(self.requester, g2_msg).await?;
// Phase 2: Search G3 for remaining hashes (if G3 available)
let g2_matched_hashes: HashSet<SequenceHash> =
self.held_g2_blocks.sequence_hashes().into_iter().collect();
let remaining_hashes: Vec<SequenceHash> = sequence_hashes
.iter()
.filter(|h| !g2_matched_hashes.contains(h))
.copied()
.collect();
if !remaining_hashes.is_empty()
&& let Some(ref g3_manager) = self.g3_manager
{
// Use scan_matches instead of match_blocks to find all matching blocks
// without stopping on first miss (supports partial sequence matching)
let g3_matches_map = g3_manager.scan_matches(&remaining_hashes, true);
let mut g3_matches: Vec<_> = g3_matches_map.into_values().collect();
// Sort by position to ensure G3Results are in position order
g3_matches.sort_by_key(|block| block.sequence_hash().position());
if !g3_matches.is_empty() {
// Hold the G3 blocks using BlockHolder
self.held_g3_blocks = BlockHolder::new(g3_matches);
// Send G3 results (sequence hashes only, keep order)
let g3_sequence_hashes: Vec<SequenceHash> = self.held_g3_blocks.sequence_hashes();
let g3_msg = OnboardMessage::G3Results {
responder: self.instance_id,
session_id: self.session_id,
sequence_hashes: g3_sequence_hashes,
};
self.transport.send(self.requester, g3_msg).await?;
}
}
// Send SearchComplete to signal we're done searching
let complete_msg = OnboardMessage::SearchComplete {
responder: self.instance_id,
session_id: self.session_id,
};
self.transport.send(self.requester, complete_msg).await?;
// Phase 3: Process incoming messages
while let Some(msg) = rx.recv().await {
match msg {
OnboardMessage::HoldBlocks {
hold_hashes,
drop_hashes: _,
..
} => {
// Filter by sequence hash - BlockHolder's retain keeps only matching hashes
self.held_g2_blocks.retain(&hold_hashes);
self.held_g3_blocks.retain(&hold_hashes);
// Send acknowledgment
let ack = OnboardMessage::Acknowledged {
responder: self.instance_id,
session_id: self.session_id,
};
self.transport.send(self.requester, ack).await?;
// Always wait for CloseSession, even if no G3 blocks
// This ensures proper session lifecycle and avoids race conditions
// where initiator sends CloseSession after we've already exited
}
OnboardMessage::StageBlocks { stage_hashes, .. } => {
// Filter G3 blocks to only keep blocks to be staged
// BlockHolder's retain keeps only matching hashes
self.held_g3_blocks.retain(&stage_hashes);
if !self.held_g3_blocks.is_empty() {
if self.parallel_worker.is_some() {
// Execute G3->G2 transfer
self.stage_g3_to_g2().await?;
} else {
tracing::warn!(
session_id = %self.session_id,
g3_blocks = self.held_g3_blocks.count(),
"G3 blocks cannot be staged: no parallel worker configured"
);
}
}
// Don't exit - wait for CloseSession in Hold/Prepare modes
}
OnboardMessage::ReleaseBlocks { release_hashes, .. } => {
// Release specific blocks by sequence hash
// BlockHolder's release removes blocks with given hashes
self.held_g2_blocks.release(&release_hashes);
self.held_g3_blocks.release(&release_hashes);
}
// todo: how does close session drop the session from the dashmap?
// todo: do we need to handle this in the handler rather than the session responder loop?
OnboardMessage::CloseSession { .. } => {
// Session complete - release all blocks and exit
// take_all() explicitly releases the blocks
let _ = self.held_g2_blocks.take_all();
let _ = self.held_g3_blocks.take_all();
break;
}
OnboardMessage::CreateSession { .. } => {
// Duplicate CreateSession - ignore
}
// todo: be explicit about what messages are expected and what messages are unexpected
// on the responder session - avoid using the wildcard match
_ => {
// Unexpected message - log and ignore
tracing::warn!(
session_id = %self.session_id,
msg = ?msg,
"ResponderSession: unexpected message"
);
}
}
// TODO: Add heartbeat/TTL timeout handling
// If no message received within TTL duration:
// - Release all held blocks
// - Exit session
// Implementation:
// tokio::select! {
// msg = rx.recv() => { /* process message */ }
// _ = tokio::time::sleep_until(ttl_deadline) => {
// eprintln!("Session {} TTL expired, releasing blocks", self.session_id);
// break;
// }
// }
}
Ok(())
}
/// Stage G3 blocks to G2.
async fn stage_g3_to_g2(&mut self) -> Result<()> {
let parallel_worker = self
.parallel_worker
.as_ref()
.ok_or_else(|| anyhow::anyhow!("ParallelWorker required for G3->G2 staging"))?;
let result = super::staging::stage_g3_to_g2(
&self.held_g3_blocks,
&self.g2_manager,
&**parallel_worker,
)
.await?;
// Extract sequence hashes and block IDs for newly staged blocks
let new_sequence_hashes: Vec<SequenceHash> = result
.new_g2_blocks
.iter()
.map(|b| b.sequence_hash())
.collect();
let new_block_ids: Vec<BlockId> =
result.new_g2_blocks.iter().map(|b| b.block_id()).collect();
// Release G3 blocks (take_all releases them) and hold new G2 blocks
let _ = self.held_g3_blocks.take_all();
self.held_g2_blocks.extend(result.new_g2_blocks);
// Send BlocksReady with only newly staged blocks
let ready_msg = OnboardMessage::BlocksReady {
responder: self.instance_id,
session_id: self.session_id,
sequence_hashes: new_sequence_hashes,
block_ids: new_block_ids,
};
self.transport.send(self.requester, ready_msg).await?;
Ok(())
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! ServerSession: Merged server-side session for both G2-only and G2+G3 staging modes.
//!
//! Unifies `EndpointSession` (G2-only, Direct layout handles) and
//! `ControllableSession` (G2+G3 staging, RoundRobin layout handles) into a
//! single type that uses `SessionEndpoint` for the state machine.
//!
//! # Modes
//!
//! - **G2-only**: Blocks are already in G2 with pre-assigned layout handles.
//! Created via `ServerSession::new_g2_only()` with `Direct` metadata.
//! `TriggerStaging` is a no-op.
//!
//! - **Staging**: G3 blocks need to be staged to G2. Layout handles are
//! assigned round-robin across workers. Created with `RoundRobin` metadata
//! and optional `auto_stage`.
//!
//! # Lifecycle
//!
//! 1. Created with G2 blocks (and optionally G3 blocks)
//! 2. If `auto_stage=true`, immediately stages G3→G2
//! 3. Waits for peer to `Attach`
//! 4. Sends `StateResponse` with block info
//! 5. Responds to `TriggerStaging`, `BlocksPulled`, `Detach`, etc.
//! 6. Completes when all blocks pulled or session closed
use std::collections::HashMap;
use std::ops::Range;
use std::sync::Arc;
use anyhow::Result;
use tokio::sync::mpsc;
use tracing::{debug, warn};
use kvbm_physical::manager::LayoutHandle;
use super::SessionId;
use super::blocks::BlockHolder;
use super::endpoint::SessionEndpoint;
use super::messages::{BlockInfo, SessionMessage, SessionStateSnapshot};
use super::staging;
use super::state::{ControlRole, SessionPhase};
use super::transport::MessageTransport;
use crate::{G2, G3, InstanceId, SequenceHash, worker::group::ParallelWorkers};
use kvbm_logical::manager::BlockManager;
/// Block metadata strategy for mapping blocks to layout handles.
///
/// Unifies the two approaches from the former EndpointSession (Direct)
/// and ControllableSession (RoundRobin).
pub enum BlockMetadataMap {
/// Pre-assigned layout handles keyed by sequence hash.
/// Used for G2-only mode where the caller knows exactly which handle
/// each block should use.
Direct(HashMap<SequenceHash, LayoutHandle>),
/// Worker layout handles for round-robin assignment.
/// Used for staging mode where blocks are distributed across workers.
RoundRobin(Vec<LayoutHandle>),
}
impl BlockMetadataMap {
/// Build `BlockInfo` list from the current G2 blocks.
fn build_block_infos(&self, g2_blocks: &BlockHolder<G2>) -> Vec<BlockInfo> {
match self {
BlockMetadataMap::Direct(map) => g2_blocks
.blocks()
.iter()
.filter_map(|block| {
let hash = block.sequence_hash();
map.get(&hash).map(|&layout_handle| BlockInfo {
block_id: block.block_id(),
sequence_hash: hash,
layout_handle,
})
})
.collect(),
BlockMetadataMap::RoundRobin(handles) => {
if handles.is_empty() {
return g2_blocks
.blocks()
.iter()
.map(|b| BlockInfo {
block_id: b.block_id(),
sequence_hash: b.sequence_hash(),
layout_handle: LayoutHandle::new(0, 0),
})
.collect();
}
g2_blocks
.blocks()
.iter()
.enumerate()
.map(|(i, b)| BlockInfo {
block_id: b.block_id(),
sequence_hash: b.sequence_hash(),
layout_handle: handles[i % handles.len()],
})
.collect()
}
}
}
/// Assign a layout handle for a newly staged block at the given index.
fn assign_handle(&self, index: usize) -> LayoutHandle {
match self {
BlockMetadataMap::Direct(_) => {
// Direct mode shouldn't be staging, but provide a fallback
LayoutHandle::new(0, 0)
}
BlockMetadataMap::RoundRobin(handles) => {
if handles.is_empty() {
LayoutHandle::new(0, 0)
} else {
handles[index % handles.len()]
}
}
}
}
/// Remove entries for the given sequence hashes (Direct mode only).
fn remove_all(&mut self, hashes: &[SequenceHash]) {
if let BlockMetadataMap::Direct(map) = self {
for hash in hashes {
map.remove(hash);
}
}
}
}
/// Options for server session creation.
#[derive(Debug, Clone)]
pub struct ServerSessionOptions {
/// If true (default), immediately start G3→G2 staging.
/// If false, wait for controller to call trigger_staging().
pub auto_stage: bool,
}
impl Default for ServerSessionOptions {
fn default() -> Self {
Self { auto_stage: true }
}
}
/// Server-side session that holds blocks and exposes them for remote RDMA pull.
///
/// Merges the functionality of the former `EndpointSession` and `ControllableSession`.
pub struct ServerSession {
/// State machine for the session protocol.
endpoint: SessionEndpoint,
/// G2 blocks held for RDMA pull (RAII - released on drop).
g2_blocks: BlockHolder<G2>,
/// Block metadata mapping (Direct or RoundRobin).
block_metadata: BlockMetadataMap,
/// G3 blocks pending staging (empty in G2-only mode).
g3_blocks: BlockHolder<G3>,
/// G2 manager for staging (only needed when G3 blocks present).
g2_manager: Option<Arc<BlockManager<G2>>>,
/// Parallel worker for G3→G2 staging.
parallel_worker: Option<Arc<dyn ParallelWorkers>>,
/// Channel for receiving local commands.
cmd_rx: mpsc::Receiver<ServerSessionCommand>,
/// Session options.
options: ServerSessionOptions,
/// Staging state tracking.
staging_started: bool,
staging_complete: bool,
}
/// Handle for local caller to control a ServerSession.
///
/// Used to send layer notifications or close the session.
/// When dropped, the session continues until peer detaches or channel closes.
#[derive(Clone)]
pub struct ServerSessionHandle {
session_id: SessionId,
local_instance: InstanceId,
cmd_tx: mpsc::Sender<ServerSessionCommand>,
}
/// Commands that can be sent to a ServerSession via its handle.
#[derive(Debug)]
pub enum ServerSessionCommand {
/// Notify that specific layers are ready for transfer.
NotifyLayersReady { layer_range: Range<usize> },
/// Close the session gracefully.
Close,
}
impl ServerSession {
/// Create a new ServerSession for G2-only mode.
///
/// Blocks are already in G2 with pre-assigned layout handles.
pub fn new_g2_only(
endpoint: SessionEndpoint,
g2_blocks: BlockHolder<G2>,
block_metadata: HashMap<SequenceHash, LayoutHandle>,
cmd_rx: mpsc::Receiver<ServerSessionCommand>,
) -> Self {
Self {
endpoint,
g2_blocks,
block_metadata: BlockMetadataMap::Direct(block_metadata),
g3_blocks: BlockHolder::empty(),
g2_manager: None,
parallel_worker: None,
cmd_rx,
options: ServerSessionOptions { auto_stage: false },
staging_started: false,
staging_complete: false,
}
}
/// Create a new ServerSession with G3→G2 staging capability.
#[allow(clippy::too_many_arguments)]
pub fn new_with_staging(
endpoint: SessionEndpoint,
g2_blocks: BlockHolder<G2>,
g3_blocks: BlockHolder<G3>,
worker_handles: Vec<LayoutHandle>,
g2_manager: Arc<BlockManager<G2>>,
parallel_worker: Option<Arc<dyn ParallelWorkers>>,
cmd_rx: mpsc::Receiver<ServerSessionCommand>,
options: ServerSessionOptions,
) -> Self {
Self {
endpoint,
g2_blocks,
block_metadata: BlockMetadataMap::RoundRobin(worker_handles),
g3_blocks,
g2_manager: Some(g2_manager),
parallel_worker,
cmd_rx,
options,
staging_started: false,
staging_complete: false,
}
}
/// Run the session message loop.
pub async fn run(mut self) -> Result<()> {
debug!(
session_id = %self.endpoint.session_id(),
g2 = self.g2_blocks.count(),
g3 = self.g3_blocks.count(),
"ServerSession starting"
);
// Set initial phase
if self.g2_blocks.count() > 0 || self.g3_blocks.count() > 0 {
self.endpoint.set_phase(SessionPhase::Holding);
}
// Auto-stage if enabled and we have G3 blocks
if self.options.auto_stage && !self.g3_blocks.is_empty() && self.parallel_worker.is_some() {
self.endpoint.set_phase(SessionPhase::Staging);
self.staging_started = true;
self.execute_staging().await?;
}
self.update_phase();
loop {
tokio::select! {
msg = self.endpoint.recv() => {
match msg {
Some(msg) => {
if !self.handle_message(msg).await? {
break;
}
}
None => {
debug!(
session_id = %self.endpoint.session_id(),
"Message channel closed"
);
break;
}
}
}
cmd = self.cmd_rx.recv() => {
match cmd {
Some(cmd) => {
if !self.handle_command(cmd).await? {
break;
}
}
None => {
debug!(
session_id = %self.endpoint.session_id(),
"Command channel closed"
);
}
}
}
}
}
debug!(
session_id = %self.endpoint.session_id(),
phase = ?self.endpoint.phase(),
"ServerSession completed"
);
Ok(())
}
/// Handle an incoming SessionMessage.
///
/// Returns `true` to continue, `false` to exit the loop.
async fn handle_message(&mut self, msg: SessionMessage) -> Result<bool> {
match msg {
SessionMessage::Attach { peer, as_role, .. } => {
debug!(
session_id = %self.endpoint.session_id(),
peer = %peer,
role = ?as_role,
"Peer attached"
);
self.endpoint.accept_attachment(peer, as_role.opposite());
// Update phase for attach
if self.endpoint.phase() == SessionPhase::Searching
|| self.endpoint.phase() == SessionPhase::Holding
{
self.update_phase();
}
// Send current state
self.send_state_response(None).await?;
}
SessionMessage::TriggerStaging { .. } => {
self.handle_trigger_staging().await?;
}
SessionMessage::BlocksPulled { pulled_hashes, .. } => {
debug!(
session_id = %self.endpoint.session_id(),
count = pulled_hashes.len(),
"Blocks pulled"
);
self.block_metadata.remove_all(&pulled_hashes);
self.g2_blocks.release(&pulled_hashes);
if self.g2_blocks.is_empty() && self.g3_blocks.is_empty() {
self.endpoint.set_phase(SessionPhase::Complete);
return Ok(false);
}
}
SessionMessage::YieldControl { peer, .. } => {
debug!(
session_id = %self.endpoint.session_id(),
peer = %peer,
"Peer yielded control"
);
self.endpoint.set_control_role(ControlRole::Neutral);
}
SessionMessage::AcquireControl { peer, .. } => {
debug!(
session_id = %self.endpoint.session_id(),
peer = %peer,
"Peer acquiring control"
);
self.endpoint.set_control_role(ControlRole::Controllee);
}
SessionMessage::Detach { peer, .. } => {
debug!(
session_id = %self.endpoint.session_id(),
peer = %peer,
"Peer detached"
);
self.endpoint.detach();
self.endpoint.set_phase(SessionPhase::Complete);
return Ok(false);
}
SessionMessage::Close { .. } => {
debug!(
session_id = %self.endpoint.session_id(),
"Session closed"
);
self.endpoint.set_phase(SessionPhase::Complete);
return Ok(false);
}
SessionMessage::Error { message, .. } => {
warn!(
session_id = %self.endpoint.session_id(),
error = %message,
"Received error"
);
self.endpoint.set_phase(SessionPhase::Failed);
return Ok(false);
}
// Ignore outbound-only messages
SessionMessage::StateResponse { .. }
| SessionMessage::BlocksStaged { .. }
| SessionMessage::HoldBlocks { .. }
| SessionMessage::ReleaseBlocks { .. } => {}
}
Ok(true)
}
/// Handle a local command.
///
/// Returns `true` to continue, `false` to exit the loop.
async fn handle_command(&mut self, cmd: ServerSessionCommand) -> Result<bool> {
match cmd {
ServerSessionCommand::NotifyLayersReady { layer_range } => {
debug!(
session_id = %self.endpoint.session_id(),
layer_range = ?layer_range,
"Notifying layers ready"
);
self.send_blocks_staged(Some(layer_range)).await?;
}
ServerSessionCommand::Close => {
debug!(
session_id = %self.endpoint.session_id(),
"Local close requested"
);
self.endpoint.set_phase(SessionPhase::Complete);
if self.endpoint.is_attached() {
let msg = SessionMessage::Close {
session_id: self.endpoint.session_id(),
};
self.endpoint.send(msg).await?;
}
return Ok(false);
}
}
Ok(true)
}
/// Handle trigger staging request (idempotent).
async fn handle_trigger_staging(&mut self) -> Result<()> {
if self.staging_started {
return Ok(());
}
if self.g3_blocks.is_empty() {
// No-op for G2-only mode
debug!(
session_id = %self.endpoint.session_id(),
"TriggerStaging ignored (no G3 blocks)"
);
return Ok(());
}
if self.parallel_worker.is_none() {
if self.endpoint.is_attached() {
let error_msg = SessionMessage::Error {
session_id: self.endpoint.session_id(),
message: "No parallel worker available for G3->G2 staging".to_string(),
};
self.endpoint.send(error_msg).await?;
}
return Ok(());
}
self.endpoint.set_phase(SessionPhase::Staging);
self.staging_started = true;
let staged_info = self.execute_staging().await?;
self.update_phase();
// Notify peer of newly staged blocks (if attached)
if self.endpoint.is_attached() {
let msg = SessionMessage::BlocksStaged {
session_id: self.endpoint.session_id(),
staged_blocks: staged_info,
remaining: self.g3_blocks.count(),
layer_range: None,
};
self.endpoint.send(msg).await?;
}
Ok(())
}
/// Execute G3→G2 staging.
///
/// Returns BlockInfo for newly staged blocks.
async fn execute_staging(&mut self) -> Result<Vec<BlockInfo>> {
let parallel_worker = self
.parallel_worker
.as_ref()
.ok_or_else(|| anyhow::anyhow!("ParallelWorkers required for G3→G2 staging"))?;
let g2_manager = self
.g2_manager
.as_ref()
.ok_or_else(|| anyhow::anyhow!("G2 manager required for staging"))?;
if self.g3_blocks.is_empty() {
self.staging_complete = true;
return Ok(Vec::new());
}
let result =
staging::stage_g3_to_g2(&self.g3_blocks, g2_manager, &**parallel_worker).await?;
// Build BlockInfo for newly staged blocks
let starting_index = self.g2_blocks.count();
let staged_info: Vec<BlockInfo> = result
.new_g2_blocks
.iter()
.enumerate()
.map(|(i, b)| BlockInfo {
block_id: b.block_id(),
sequence_hash: b.sequence_hash(),
layout_handle: self.block_metadata.assign_handle(starting_index + i),
})
.collect();
// Clear G3, extend G2
let _ = self.g3_blocks.take_all();
self.g2_blocks.extend(result.new_g2_blocks);
self.staging_complete = true;
Ok(staged_info)
}
/// Update phase based on current state.
fn update_phase(&mut self) {
if self.endpoint.phase() == SessionPhase::Complete
|| self.endpoint.phase() == SessionPhase::Failed
{
return;
}
if self.g3_blocks.is_empty() && (self.staging_complete || !self.staging_started) {
self.endpoint.set_phase(SessionPhase::Ready);
} else if self.staging_started && !self.staging_complete {
self.endpoint.set_phase(SessionPhase::Staging);
}
}
/// Send a StateResponse to the attached peer.
async fn send_state_response(&self, layer_range: Option<Range<usize>>) -> Result<()> {
let state = self.build_state_snapshot(layer_range);
let msg = SessionMessage::StateResponse {
session_id: self.endpoint.session_id(),
state,
};
self.endpoint.send(msg).await
}
/// Send a BlocksStaged message with optional layer range.
async fn send_blocks_staged(&self, layer_range: Option<Range<usize>>) -> Result<()> {
let blocks = self.block_metadata.build_block_infos(&self.g2_blocks);
let msg = SessionMessage::BlocksStaged {
session_id: self.endpoint.session_id(),
staged_blocks: blocks,
remaining: 0,
layer_range,
};
self.endpoint.send(msg).await
}
/// Build a state snapshot.
fn build_state_snapshot(&self, layer_range: Option<Range<usize>>) -> SessionStateSnapshot {
SessionStateSnapshot {
phase: self.endpoint.phase(),
control_role: self.endpoint.control_role(),
g2_blocks: self.block_metadata.build_block_infos(&self.g2_blocks),
g3_pending: self.g3_blocks.count(),
ready_layer_range: layer_range,
}
}
/// Get the session ID.
pub fn session_id(&self) -> SessionId {
self.endpoint.session_id()
}
}
impl ServerSessionHandle {
/// Create a new server session handle.
pub fn new(
session_id: SessionId,
local_instance: InstanceId,
cmd_tx: mpsc::Sender<ServerSessionCommand>,
) -> Self {
Self {
session_id,
local_instance,
cmd_tx,
}
}
/// Get the session ID.
pub fn session_id(&self) -> SessionId {
self.session_id
}
/// Get the local instance ID.
pub fn local_instance(&self) -> InstanceId {
self.local_instance
}
/// Notify attached controller that layers are ready.
pub async fn notify_layers_ready(&self, layer_range: Range<usize>) -> Result<()> {
self.cmd_tx
.send(ServerSessionCommand::NotifyLayersReady { layer_range })
.await
.map_err(|_| anyhow::anyhow!("Session command channel closed"))
}
/// Close the session gracefully.
pub async fn close(&self) -> Result<()> {
self.cmd_tx
.send(ServerSessionCommand::Close)
.await
.map_err(|_| anyhow::anyhow!("Session command channel closed"))
}
}
/// Create a ServerSession in G2-only mode with its handle.
///
/// This is the replacement for `create_endpoint_session`.
pub fn create_server_session(
session_id: SessionId,
instance_id: InstanceId,
blocks: BlockHolder<G2>,
layout_handles: Vec<LayoutHandle>,
sequence_hashes: Vec<SequenceHash>,
transport: Arc<MessageTransport>,
msg_rx: mpsc::Receiver<SessionMessage>,
) -> (ServerSession, ServerSessionHandle) {
let (cmd_tx, cmd_rx) = mpsc::channel(16);
let block_metadata: HashMap<SequenceHash, LayoutHandle> =
sequence_hashes.into_iter().zip(layout_handles).collect();
let endpoint = SessionEndpoint::new(session_id, instance_id, transport, msg_rx);
let session = ServerSession::new_g2_only(endpoint, blocks, block_metadata, cmd_rx);
let handle = ServerSessionHandle::new(session_id, instance_id, cmd_tx);
(session, handle)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::leader::session::SessionMessageTx;
use dashmap::DashMap;
use tokio::sync::mpsc;
fn create_test_transport() -> Arc<MessageTransport> {
Arc::new(MessageTransport::local(
Arc::new(DashMap::new()),
Arc::new(DashMap::new()),
))
}
#[tokio::test]
async fn test_handle_creation() {
let (cmd_tx, _cmd_rx) = mpsc::channel(16);
let session_id = SessionId::new_v4();
let instance_id = InstanceId::new_v4();
let handle = ServerSessionHandle::new(session_id, instance_id, cmd_tx);
assert_eq!(handle.session_id(), session_id);
assert_eq!(handle.local_instance(), instance_id);
}
#[tokio::test]
async fn test_notify_layers_ready() {
let (cmd_tx, mut cmd_rx) = mpsc::channel(16);
let session_id = SessionId::new_v4();
let instance_id = InstanceId::new_v4();
let handle = ServerSessionHandle::new(session_id, instance_id, cmd_tx);
handle.notify_layers_ready(0..1).await.unwrap();
let cmd = cmd_rx.recv().await.unwrap();
match cmd {
ServerSessionCommand::NotifyLayersReady { layer_range } => {
assert_eq!(layer_range, 0..1);
}
_ => panic!("Unexpected command"),
}
}
#[tokio::test]
async fn test_handle_close() {
let (cmd_tx, mut cmd_rx) = mpsc::channel(16);
let session_id = SessionId::new_v4();
let instance_id = InstanceId::new_v4();
let handle = ServerSessionHandle::new(session_id, instance_id, cmd_tx);
handle.close().await.unwrap();
let cmd = cmd_rx.recv().await.unwrap();
assert!(matches!(cmd, ServerSessionCommand::Close));
}
#[tokio::test]
async fn test_create_server_session() {
let session_id = SessionId::new_v4();
let instance_id = InstanceId::new_v4();
let transport = create_test_transport();
let (_msg_tx, msg_rx) = mpsc::channel(16);
let blocks = BlockHolder::empty();
let (_session, handle) = create_server_session(
session_id,
instance_id,
blocks,
vec![],
vec![],
transport,
msg_rx,
);
assert_eq!(handle.session_id(), session_id);
assert_eq!(handle.local_instance(), instance_id);
}
#[tokio::test]
async fn test_attach_sends_state_response() {
let session_id = SessionId::new_v4();
let instance_id = InstanceId::new_v4();
let peer_id = InstanceId::new_v4();
// Create transport with a session channel to capture responses
let session_sessions: Arc<DashMap<SessionId, SessionMessageTx>> = Arc::new(DashMap::new());
let transport = Arc::new(MessageTransport::local(
Arc::new(DashMap::new()),
session_sessions.clone(),
));
// Register a receiver for the peer's session (where StateResponse is sent)
let peer_session_id = SessionId::new_v4(); // peer's session
let (peer_tx, mut peer_rx) = mpsc::channel::<SessionMessage>(16);
session_sessions.insert(session_id, peer_tx);
let (msg_tx, msg_rx) = mpsc::channel(16);
let (_cmd_tx, cmd_rx) = mpsc::channel(16);
let endpoint = SessionEndpoint::new(session_id, instance_id, transport, msg_rx);
let session =
ServerSession::new_g2_only(endpoint, BlockHolder::empty(), HashMap::new(), cmd_rx);
// Spawn session
let session_task = tokio::spawn(session.run());
// Send attach message
msg_tx
.send(SessionMessage::Attach {
peer: peer_id,
session_id,
as_role: ControlRole::Controller,
})
.await
.unwrap();
// Read the StateResponse
let response = tokio::time::timeout(std::time::Duration::from_secs(1), peer_rx.recv())
.await
.expect("timeout")
.expect("channel closed");
match response {
SessionMessage::StateResponse { state, .. } => {
assert_eq!(state.phase, SessionPhase::Ready);
assert_eq!(state.control_role, ControlRole::Controllee);
}
other => panic!("Expected StateResponse, got {:?}", other),
}
// Close session
msg_tx
.send(SessionMessage::Close { session_id })
.await
.unwrap();
let _ = tokio::time::timeout(std::time::Duration::from_secs(1), session_task).await;
let _ = peer_session_id;
}
#[tokio::test]
async fn test_g2_only_ready_on_attach() {
let session_id = SessionId::new_v4();
let instance_id = InstanceId::new_v4();
let peer_id = InstanceId::new_v4();
let session_sessions: Arc<DashMap<SessionId, SessionMessageTx>> = Arc::new(DashMap::new());
let transport = Arc::new(MessageTransport::local(
Arc::new(DashMap::new()),
session_sessions.clone(),
));
let (peer_tx, mut peer_rx) = mpsc::channel::<SessionMessage>(16);
session_sessions.insert(session_id, peer_tx);
let (msg_tx, msg_rx) = mpsc::channel(16);
let (_cmd_tx, cmd_rx) = mpsc::channel(16);
let endpoint = SessionEndpoint::new(session_id, instance_id, transport, msg_rx);
// G2-only mode, no G3 blocks
let session =
ServerSession::new_g2_only(endpoint, BlockHolder::empty(), HashMap::new(), cmd_rx);
let session_task = tokio::spawn(session.run());
msg_tx
.send(SessionMessage::Attach {
peer: peer_id,
session_id,
as_role: ControlRole::Controller,
})
.await
.unwrap();
let response = tokio::time::timeout(std::time::Duration::from_secs(1), peer_rx.recv())
.await
.expect("timeout")
.expect("channel closed");
// G2-only with no blocks → Ready phase immediately
match response {
SessionMessage::StateResponse { state, .. } => {
assert_eq!(state.phase, SessionPhase::Ready);
assert_eq!(state.g3_pending, 0);
}
other => panic!("Expected StateResponse, got {:?}", other),
}
msg_tx
.send(SessionMessage::Close { session_id })
.await
.unwrap();
let _ = tokio::time::timeout(std::time::Duration::from_secs(1), session_task).await;
}
#[tokio::test]
async fn test_trigger_staging_no_g3_noop() {
let session_id = SessionId::new_v4();
let instance_id = InstanceId::new_v4();
let peer_id = InstanceId::new_v4();
let session_sessions: Arc<DashMap<SessionId, SessionMessageTx>> = Arc::new(DashMap::new());
let transport = Arc::new(MessageTransport::local(
Arc::new(DashMap::new()),
session_sessions.clone(),
));
let (peer_tx, mut peer_rx) = mpsc::channel::<SessionMessage>(16);
session_sessions.insert(session_id, peer_tx);
let (msg_tx, msg_rx) = mpsc::channel(16);
let (_cmd_tx, cmd_rx) = mpsc::channel(16);
let endpoint = SessionEndpoint::new(session_id, instance_id, transport, msg_rx);
let session =
ServerSession::new_g2_only(endpoint, BlockHolder::empty(), HashMap::new(), cmd_rx);
let session_task = tokio::spawn(session.run());
// Attach
msg_tx
.send(SessionMessage::Attach {
peer: peer_id,
session_id,
as_role: ControlRole::Controller,
})
.await
.unwrap();
// Consume StateResponse
let _ = tokio::time::timeout(std::time::Duration::from_secs(1), peer_rx.recv())
.await
.expect("timeout");
// Send TriggerStaging - should be no-op (no G3 blocks)
msg_tx
.send(SessionMessage::TriggerStaging { session_id })
.await
.unwrap();
// Close and check no extra messages were sent
msg_tx
.send(SessionMessage::Close { session_id })
.await
.unwrap();
let _ = tokio::time::timeout(std::time::Duration::from_secs(1), session_task).await;
}
#[tokio::test]
async fn test_detach_completes_session() {
let session_id = SessionId::new_v4();
let instance_id = InstanceId::new_v4();
let peer_id = InstanceId::new_v4();
let session_sessions: Arc<DashMap<SessionId, SessionMessageTx>> = Arc::new(DashMap::new());
let transport = Arc::new(MessageTransport::local(
Arc::new(DashMap::new()),
session_sessions.clone(),
));
let (peer_tx, mut _peer_rx) = mpsc::channel::<SessionMessage>(16);
session_sessions.insert(session_id, peer_tx);
let (msg_tx, msg_rx) = mpsc::channel(16);
let (_cmd_tx, cmd_rx) = mpsc::channel(16);
let endpoint = SessionEndpoint::new(session_id, instance_id, transport, msg_rx);
let session =
ServerSession::new_g2_only(endpoint, BlockHolder::empty(), HashMap::new(), cmd_rx);
let session_task = tokio::spawn(session.run());
// Attach then detach
msg_tx
.send(SessionMessage::Attach {
peer: peer_id,
session_id,
as_role: ControlRole::Controller,
})
.await
.unwrap();
msg_tx
.send(SessionMessage::Detach {
peer: peer_id,
session_id,
})
.await
.unwrap();
// Session should complete
let result = tokio::time::timeout(std::time::Duration::from_secs(1), session_task)
.await
.expect("timeout")
.expect("task panicked");
assert!(result.is_ok());
}
#[test]
fn test_block_metadata_direct_build_infos() {
let hash1 = SequenceHash::new(1, None, 100);
let hash2 = SequenceHash::new(2, None, 200);
let mut map = HashMap::new();
map.insert(hash1, LayoutHandle::new(0, 1));
map.insert(hash2, LayoutHandle::new(0, 2));
let metadata = BlockMetadataMap::Direct(map);
// Empty holder
let holder = BlockHolder::<G2>::empty();
let infos = metadata.build_block_infos(&holder);
assert!(infos.is_empty());
}
#[test]
fn test_block_metadata_round_robin_empty_handles() {
let metadata = BlockMetadataMap::RoundRobin(vec![]);
let holder = BlockHolder::<G2>::empty();
let infos = metadata.build_block_infos(&holder);
assert!(infos.is_empty());
}
#[test]
fn test_block_metadata_assign_handle() {
let h0 = LayoutHandle::new(0, 10);
let h1 = LayoutHandle::new(1, 20);
let metadata = BlockMetadataMap::RoundRobin(vec![h0, h1]);
assert_eq!(metadata.assign_handle(0), h0);
assert_eq!(metadata.assign_handle(1), h1);
assert_eq!(metadata.assign_handle(2), h0); // wraps around
}
#[test]
fn test_block_metadata_remove_all() {
let hash1 = SequenceHash::new(1, None, 100);
let hash2 = SequenceHash::new(2, None, 200);
let mut map = HashMap::new();
map.insert(hash1, LayoutHandle::new(0, 1));
map.insert(hash2, LayoutHandle::new(0, 2));
let mut metadata = BlockMetadataMap::Direct(map);
metadata.remove_all(&[hash1]);
// Verify hash1 was removed
if let BlockMetadataMap::Direct(ref inner) = metadata {
assert!(!inner.contains_key(&hash1));
assert!(inner.contains_key(&hash2));
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Shared G3→G2 staging logic.
//!
//! Extracts the common staging kernel used by InitiatorSession, ResponderSession,
//! and ServerSession. Each caller handles its own post-staging bookkeeping
//! (updating holders, sending messages, etc.).
use std::sync::Arc;
use anyhow::Result;
use crate::{BlockId, G2, G3, worker::group::ParallelWorkers};
use kvbm_common::LogicalLayoutHandle;
use kvbm_logical::{blocks::ImmutableBlock, manager::BlockManager};
use kvbm_physical::transfer::TransferOptions;
use super::blocks::BlockHolder;
/// Result of staging G3 blocks to G2.
pub struct StagingResult {
/// Newly created G2 blocks (registered with the G2 manager).
pub new_g2_blocks: Vec<ImmutableBlock<G2>>,
}
/// Stage G3 blocks to G2.
///
/// Core staging kernel: allocate G2 destinations → execute local transfer (G3→G2)
/// → register new G2 blocks with the source sequence hashes → return new blocks.
///
/// The caller is responsible for:
/// - Clearing the G3 holder (`take_all()`)
/// - Adding new blocks to the G2 holder (`extend()`)
/// - Sending any notifications to peers
pub async fn stage_g3_to_g2(
g3_blocks: &BlockHolder<G3>,
g2_manager: &BlockManager<G2>,
parallel_worker: &dyn ParallelWorkers,
) -> Result<StagingResult> {
if g3_blocks.is_empty() {
return Ok(StagingResult {
new_g2_blocks: Vec::new(),
});
}
let src_ids: Vec<BlockId> = g3_blocks.blocks().iter().map(|b| b.block_id()).collect();
// Allocate destination G2 blocks
let dst_blocks = g2_manager
.allocate_blocks(src_ids.len())
.ok_or_else(|| anyhow::anyhow!("Failed to allocate G2 blocks"))?;
let dst_ids: Vec<BlockId> = dst_blocks.iter().map(|b| b.block_id()).collect();
// Execute transfer
let notification = parallel_worker.execute_local_transfer(
LogicalLayoutHandle::G3,
LogicalLayoutHandle::G2,
Arc::from(src_ids),
Arc::from(dst_ids),
TransferOptions::default(),
)?;
// Wait for transfer to complete
notification.await?;
// Register new G2 blocks using the G3 blocks' metadata (sequence hashes)
let new_g2_blocks: Vec<ImmutableBlock<G2>> = dst_blocks
.into_iter()
.zip(g3_blocks.blocks().iter())
.map(|(dst, src)| {
let complete = dst
.stage(src.sequence_hash(), g2_manager.block_size())
.expect("block size mismatch");
g2_manager.register_block(complete)
})
.collect();
Ok(StagingResult { new_g2_blocks })
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Session state types for the unified session model.
//!
//! This module provides the core state machine types:
//! - [`ControlRole`]: Whether this session is controller, controllee, or neutral
//! - [`AttachmentState`]: Whether a peer is attached
//! - [`SessionPhase`]: The current operational phase of the session
use serde::{Deserialize, Serialize};
use crate::InstanceId;
/// Control role in a session relationship.
///
/// Sessions can dynamically transition between roles:
/// - Start as `Neutral` (independent, can initiate in either direction)
/// - Become `Controller` when issuing commands to a peer
/// - Become `Controllee` when executing commands from a peer
///
/// Control can be transferred bidirectionally via `YieldControl`/`AcquireControl`.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum ControlRole {
/// Independent - can initiate control in either direction.
/// This is the initial state and the state after yielding control.
#[default]
Neutral,
/// Currently controlling a peer session (issues commands).
Controller,
/// Currently being controlled by a peer session (executes commands).
Controllee,
}
impl ControlRole {
/// Check if this role can issue control commands.
pub fn can_command(&self) -> bool {
matches!(self, ControlRole::Controller)
}
/// Check if this role should respond to control commands.
pub fn responds_to_commands(&self) -> bool {
matches!(self, ControlRole::Controllee)
}
/// Check if this role is neutral (can transition either way).
pub fn is_neutral(&self) -> bool {
matches!(self, ControlRole::Neutral)
}
/// Get the opposite role.
///
/// - `Controller` ↔ `Controllee`
/// - `Neutral` → `Neutral` (no opposite)
pub fn opposite(&self) -> ControlRole {
match self {
ControlRole::Controller => ControlRole::Controllee,
ControlRole::Controllee => ControlRole::Controller,
ControlRole::Neutral => ControlRole::Neutral,
}
}
}
/// Attachment state - whether a peer is connected.
///
/// Valid state combinations:
/// - `Neutral + Unattached`: Initial state, waiting for connection
/// - `Neutral + Attached`: Post-yield state, peer still connected
/// - `Controllee + Unattached`: Waiting for controller to attach
/// - `Controllee + Attached`: Being actively controlled
/// - `Controller + Attached`: Actively controlling (Controller + Unattached is invalid)
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum AttachmentState {
/// No peer attached.
#[default]
Unattached,
/// Peer attached with the given instance ID.
Attached { peer: InstanceId },
}
impl AttachmentState {
/// Check if a peer is attached.
pub fn is_attached(&self) -> bool {
matches!(self, AttachmentState::Attached { .. })
}
/// Get the attached peer's instance ID if attached.
pub fn peer(&self) -> Option<InstanceId> {
match self {
AttachmentState::Attached { peer } => Some(*peer),
AttachmentState::Unattached => None,
}
}
}
/// Operational phase of a session.
///
/// Represents the lifecycle of block operations within a session:
/// 1. `Searching` - Initial discovery/search phase
/// 2. `Holding` - Blocks found and held, no staging yet
/// 3. `Staging` - Transfer in progress (G3→G2, G4→G2, etc.)
/// 4. `Ready` - All blocks in target tier, ready for transfer
/// 5. `Complete` - Session completed successfully
/// 6. `Failed` - Session failed or cancelled
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum SessionPhase {
/// Initial search/discovery phase.
#[default]
Searching,
/// Blocks found and held, no staging started.
Holding,
/// Transfer/staging in progress (any direction).
Staging,
/// All blocks in target tier, ready for RDMA pull.
Ready,
/// Session completed successfully.
Complete,
/// Session failed or was cancelled.
Failed,
}
impl SessionPhase {
/// Check if the session is in a terminal state.
pub fn is_terminal(&self) -> bool {
matches!(self, SessionPhase::Complete | SessionPhase::Failed)
}
/// Check if the session is active (not terminal).
pub fn is_active(&self) -> bool {
!self.is_terminal()
}
/// Check if blocks are ready for transfer.
pub fn is_ready(&self) -> bool {
matches!(self, SessionPhase::Ready)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_control_role_transitions() {
let role = ControlRole::Neutral;
assert!(role.is_neutral());
assert!(!role.can_command());
assert!(!role.responds_to_commands());
let role = ControlRole::Controller;
assert!(!role.is_neutral());
assert!(role.can_command());
assert!(!role.responds_to_commands());
let role = ControlRole::Controllee;
assert!(!role.is_neutral());
assert!(!role.can_command());
assert!(role.responds_to_commands());
}
#[test]
fn test_attachment_state() {
let state = AttachmentState::Unattached;
assert!(!state.is_attached());
assert!(state.peer().is_none());
let peer_id = InstanceId::new_v4();
let state = AttachmentState::Attached { peer: peer_id };
assert!(state.is_attached());
assert_eq!(state.peer(), Some(peer_id));
}
#[test]
fn test_session_phase() {
assert!(!SessionPhase::Searching.is_terminal());
assert!(!SessionPhase::Holding.is_terminal());
assert!(!SessionPhase::Staging.is_terminal());
assert!(!SessionPhase::Ready.is_terminal());
assert!(SessionPhase::Complete.is_terminal());
assert!(SessionPhase::Failed.is_terminal());
assert!(SessionPhase::Ready.is_ready());
assert!(!SessionPhase::Staging.is_ready());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use ::velo::Messenger;
use anyhow::Result;
use bytes::Bytes;
use dashmap::DashMap;
use std::sync::Arc;
use crate::InstanceId;
use kvbm_physical::manager::SerializedLayout;
use super::{
OnboardSessionTx, SessionId, SessionMessageTx, dispatch_onboard_message,
messages::{OnboardMessage, SessionMessage},
};
/// Transport abstraction for sending onboarding messages without boxing futures.
///
/// This enum allows sessions to work with different transport mechanisms:
/// - Velo (distributed): Uses Velo active messages
/// - Local (testing): Direct channel dispatch
pub enum MessageTransport {
Velo(VeloTransport),
Local(LocalTransport),
}
impl MessageTransport {
pub fn velo(messenger: Arc<Messenger>) -> Self {
Self::Velo(VeloTransport::new(messenger))
}
pub fn local(
sessions: Arc<DashMap<SessionId, OnboardSessionTx>>,
session_sessions: Arc<DashMap<SessionId, SessionMessageTx>>,
) -> Self {
Self::Local(LocalTransport::new(sessions, session_sessions))
}
/// Send an OnboardMessage to a target instance.
pub async fn send(&self, target: InstanceId, message: OnboardMessage) -> Result<()> {
match self {
MessageTransport::Velo(transport) => transport.send(target, message).await,
MessageTransport::Local(transport) => transport.send(target, message).await,
}
}
/// Request worker metadata from a remote leader for RDMA transfers.
///
/// This makes a synchronous RPC call to the remote leader's export_metadata
/// handler and returns the `Vec<SerializedLayout>` from all remote workers.
pub async fn request_metadata(&self, target: InstanceId) -> Result<Vec<SerializedLayout>> {
match self {
MessageTransport::Velo(transport) => transport.request_metadata(target).await,
MessageTransport::Local(_) => {
anyhow::bail!("request_metadata not supported for local transport")
}
}
}
/// Send a SessionMessage to a target instance.
///
/// This is the unified session message protocol used for all session communication.
pub async fn send_session(&self, target: InstanceId, message: SessionMessage) -> Result<()> {
match self {
MessageTransport::Velo(transport) => transport.send_session(target, message).await,
MessageTransport::Local(transport) => transport.send_session(target, message).await,
}
}
}
/// Velo-based transport using active messages (fire-and-forget).
pub struct VeloTransport {
messenger: Arc<Messenger>,
}
impl VeloTransport {
pub fn new(messenger: Arc<Messenger>) -> Self {
Self { messenger }
}
pub async fn send(&self, target: InstanceId, message: OnboardMessage) -> Result<()> {
tracing::debug!(
msg = message.variant_name(),
target = %target,
"Sending message"
);
let bytes = Bytes::from(serde_json::to_vec(&message)?);
self.messenger
.am_send("kvbm.leader.onboard")?
.raw_payload(bytes)
.instance(target)
.send()
.await?;
tracing::debug!(target = %target, "Successfully sent");
Ok(())
}
/// Request worker metadata from a remote leader for RDMA transfers.
///
/// Makes a unary RPC call to get `Vec<SerializedLayout>` from
/// the remote leader's workers.
pub async fn request_metadata(&self, target: InstanceId) -> Result<Vec<SerializedLayout>> {
tracing::debug!(target = %target, "Requesting metadata from instance");
let response: Bytes = self
.messenger
.unary("kvbm.leader.export_metadata")?
.instance(target)
.send()
.await?;
// Deserialize the response
let metadata: Vec<SerializedLayout> = serde_json::from_slice(&response)?;
tracing::debug!(
count = metadata.len(),
target = %target,
"Received metadata entries"
);
Ok(metadata)
}
/// Send a SessionMessage to a target instance.
///
/// Uses the unified "kvbm.leader.session" handler.
pub async fn send_session(&self, target: InstanceId, message: SessionMessage) -> Result<()> {
tracing::debug!(
msg = message.variant_name(),
target = %target,
"Sending Session"
);
let bytes = Bytes::from(serde_json::to_vec(&message)?);
self.messenger
.am_send("kvbm.leader.session")?
.raw_payload(bytes)
.instance(target)
.send()
.await?;
tracing::debug!(target = %target, "Successfully sent session msg");
Ok(())
}
}
/// Local transport for testing or same-instance communication.
///
/// Directly dispatches messages to session channels without network overhead.
pub struct LocalTransport {
sessions: Arc<DashMap<SessionId, OnboardSessionTx>>,
/// Unified session message receivers.
session_sessions: Arc<DashMap<SessionId, SessionMessageTx>>,
}
impl LocalTransport {
pub fn new(
sessions: Arc<DashMap<SessionId, OnboardSessionTx>>,
session_sessions: Arc<DashMap<SessionId, SessionMessageTx>>,
) -> Self {
Self {
sessions,
session_sessions,
}
}
pub async fn send(&self, _target: InstanceId, message: OnboardMessage) -> Result<()> {
dispatch_onboard_message(&self.sessions, message).await
}
/// Send a SessionMessage (unified protocol).
///
/// Routes to session_sessions by session ID.
pub async fn send_session(&self, _target: InstanceId, message: SessionMessage) -> Result<()> {
let session_id = message.session_id();
let sender = self
.session_sessions
.get(&session_id)
.map(|entry| entry.value().clone());
if let Some(sender) = sender {
sender
.send(message)
.await
.map_err(|e| anyhow::anyhow!("failed to send to session {session_id}: {e}"))?;
return Ok(());
}
anyhow::bail!("no session registered for session {session_id}");
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! LeaderState - Coordination layer for managing workers.
//!
//! This module provides the leader's coordination state, including:
//! - Worker registration and rank mapping
//! - Remote leader tracking for cross-leader transfers
//! - Routing strategies for asymmetric TP configurations
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use ::velo::Messenger;
use anyhow::Result;
use crate::InstanceId;
use crate::worker::{CoordinatedWorker, Worker};
use kvbm_physical::manager::SerializedLayout;
/// Info about a remote leader and its workers.
#[derive(Debug)]
pub struct RemoteLeaderInfo {
/// Instance ID of the remote leader process
pub instance_id: InstanceId,
/// Number of workers under the remote leader
pub worker_count: usize,
/// Cached metadata from remote workers (rank-ordered)
pub worker_metadata: Vec<SerializedLayout>,
}
/// Leader coordination state - owns workers and routing logic.
///
/// LeaderState manages:
/// - Registration of workers during handshake phase
/// - Coordination with remote leaders for cross-leader transfers
/// - Routing strategies for asymmetric TP configurations
pub struct LeaderState {
/// This leader's instance ID
instance_id: InstanceId,
/// Nova runtime for RPC
messenger: Arc<Messenger>,
/// Workers under this leader (rank-ordered)
workers: Vec<CoordinatedWorker>,
/// Known remote leaders (by their instance ID)
remote_leaders: RwLock<HashMap<InstanceId, RemoteLeaderInfo>>,
}
impl LeaderState {
/// Create a new LeaderState.
///
/// # Arguments
/// * `instance_id` - This leader's unique identifier
/// * `nova` - Nova runtime for RPC communication
pub fn new(instance_id: InstanceId, messenger: Arc<Messenger>) -> Self {
Self {
instance_id,
messenger,
workers: Vec::new(),
remote_leaders: RwLock::new(HashMap::new()),
}
}
/// Get this leader's instance ID.
pub fn instance_id(&self) -> InstanceId {
self.instance_id
}
/// Get the Nova runtime.
pub fn nova(&self) -> &Arc<Messenger> {
&self.messenger
}
/// Register a worker during the handshake phase.
///
/// Workers should be registered in rank order (0, 1, 2, ...).
///
/// # Arguments
/// * `rank` - The worker's rank (0-indexed)
/// * `host_instance` - Instance ID of the process hosting this worker
/// * `worker` - The Worker implementation (DirectWorker or VeloWorkerClient)
pub fn register_worker(
&mut self,
rank: usize,
host_instance: InstanceId,
worker: Box<dyn Worker>,
) {
let coordinated = CoordinatedWorker::new(worker, rank, host_instance);
// Ensure rank-ordered insertion
if rank == self.workers.len() {
// Sequential append (expected path)
self.workers.push(coordinated);
} else if rank < self.workers.len() {
// Re-registration or out-of-order within existing range
self.workers[rank] = coordinated;
} else {
panic!(
"Gap in worker ranks: rank {} but only {} workers registered",
rank,
self.workers.len()
);
}
}
/// Number of workers under this leader.
pub fn worker_count(&self) -> usize {
self.workers.len()
}
/// Get a worker by rank.
pub fn worker(&self, rank: usize) -> Option<&CoordinatedWorker> {
self.workers.get(rank)
}
/// Get a mutable worker by rank.
pub fn worker_mut(&mut self, rank: usize) -> Option<&mut CoordinatedWorker> {
self.workers.get_mut(rank)
}
/// Iterate over all workers.
pub fn workers(&self) -> impl Iterator<Item = &CoordinatedWorker> {
self.workers.iter()
}
/// Connect to a remote leader and distribute its worker metadata to our workers.
///
/// This implements the routing strategy for cross-leader transfers:
/// - 1:1 mapping when TP sizes match
/// - Many-to-one when local TP > remote TP
/// - One-to-many when local TP < remote TP
///
/// # Arguments
/// * `remote_leader_id` - Instance ID of the remote leader
/// * `remote_metadata` - Metadata from each remote worker (rank-ordered)
pub async fn import_remote_leader(
&self,
remote_leader_id: InstanceId,
remote_metadata: Vec<SerializedLayout>,
) -> Result<()> {
let remote_count = remote_metadata.len();
let local_count = self.workers.len();
tracing::info!(
local_count,
remote_count,
%remote_leader_id,
"Importing remote leader metadata"
);
// Store remote leader info
{
let mut leaders = self.remote_leaders.write().unwrap();
leaders.insert(
remote_leader_id,
RemoteLeaderInfo {
instance_id: remote_leader_id,
worker_count: remote_count,
worker_metadata: remote_metadata.clone(),
},
);
}
// Distribute metadata based on routing strategy
for (local_rank, worker) in self.workers.iter().enumerate() {
let target_remote_ranks = route_local_to_remote(local_rank, local_count, remote_count);
for remote_rank in target_remote_ranks {
tracing::debug!(
local_rank,
remote_rank,
%remote_leader_id,
"Importing remote metadata for local worker"
);
worker
.import_remote_metadata(
remote_leader_id,
remote_rank,
remote_metadata[remote_rank].clone(),
)
.await?;
}
}
Ok(())
}
/// Export this leader's workers' metadata for another leader to import.
///
/// Returns metadata from each worker in rank order.
pub async fn export_worker_metadata(&self) -> Result<Vec<SerializedLayout>> {
let mut metadata = Vec::with_capacity(self.workers.len());
for worker in &self.workers {
let response = worker.inner().export_metadata()?;
metadata.push(response.await?);
}
Ok(metadata)
}
/// Check if we have imported metadata from a remote leader.
pub fn has_remote_leader(&self, remote_leader_id: InstanceId) -> bool {
self.remote_leaders
.read()
.unwrap()
.contains_key(&remote_leader_id)
}
/// Get info about a remote leader if known.
pub fn remote_leader_info(&self, remote_leader_id: InstanceId) -> Option<RemoteLeaderInfo> {
self.remote_leaders
.read()
.unwrap()
.get(&remote_leader_id)
.map(|info| RemoteLeaderInfo {
instance_id: info.instance_id,
worker_count: info.worker_count,
worker_metadata: info.worker_metadata.clone(),
})
}
}
/// Routing strategy: which local ranks receive from which remote ranks.
///
/// This function determines how metadata/transfers are routed when
/// the local and remote TP sizes differ.
///
/// # Examples
/// - TP=4 local, TP=4 remote: 1:1 mapping (rank 0→0, 1→1, 2→2, 3→3)
/// - TP=4 local, TP=2 remote: 0→0, 1→0, 2→1, 3→1 (many-to-one)
/// - TP=2 local, TP=4 remote: 0→\[0,1\], 1→\[2,3\] (one-to-many)
pub fn route_local_to_remote(
local_rank: usize,
local_count: usize,
remote_count: usize,
) -> Vec<usize> {
if local_count == remote_count {
// 1:1 mapping
vec![local_rank]
} else if local_count > remote_count {
// Many local → few remote: multiple locals share a remote
vec![local_rank % remote_count]
} else {
// Few local → many remote: each local gets multiple remotes
let remotes_per_local = remote_count / local_count;
let start = local_rank * remotes_per_local;
// Last local rank absorbs any remainder from non-divisible ratios
let end = if local_rank == local_count - 1 {
remote_count
} else {
start + remotes_per_local
};
(start..end).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_route_1_to_1() {
// Same TP size
assert_eq!(route_local_to_remote(0, 4, 4), vec![0]);
assert_eq!(route_local_to_remote(1, 4, 4), vec![1]);
assert_eq!(route_local_to_remote(2, 4, 4), vec![2]);
assert_eq!(route_local_to_remote(3, 4, 4), vec![3]);
}
#[test]
fn test_route_many_to_one() {
// Local TP=4, Remote TP=2
assert_eq!(route_local_to_remote(0, 4, 2), vec![0]);
assert_eq!(route_local_to_remote(1, 4, 2), vec![1]);
assert_eq!(route_local_to_remote(2, 4, 2), vec![0]);
assert_eq!(route_local_to_remote(3, 4, 2), vec![1]);
}
#[test]
fn test_route_one_to_many() {
// Local TP=2, Remote TP=4
assert_eq!(route_local_to_remote(0, 2, 4), vec![0, 1]);
assert_eq!(route_local_to_remote(1, 2, 4), vec![2, 3]);
}
#[test]
fn test_route_4_to_8() {
// Local TP=4, Remote TP=8
assert_eq!(route_local_to_remote(0, 4, 8), vec![0, 1]);
assert_eq!(route_local_to_remote(1, 4, 8), vec![2, 3]);
assert_eq!(route_local_to_remote(2, 4, 8), vec![4, 5]);
assert_eq!(route_local_to_remote(3, 4, 8), vec![6, 7]);
}
#[test]
fn test_route_non_divisible_remainder() {
// Local TP=2, Remote TP=5: last local rank absorbs remainder
assert_eq!(route_local_to_remote(0, 2, 5), vec![0, 1]);
assert_eq!(route_local_to_remote(1, 2, 5), vec![2, 3, 4]);
// Local TP=3, Remote TP=7: last rank gets extras
assert_eq!(route_local_to_remote(0, 3, 7), vec![0, 1]);
assert_eq!(route_local_to_remote(1, 3, 7), vec![2, 3]);
assert_eq!(route_local_to_remote(2, 3, 7), vec![4, 5, 6]);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::Result;
use futures::future::{BoxFuture, Either, Ready, ready};
use serde::{Deserialize, Serialize};
use tokio::sync::{Mutex, watch};
use std::sync::Arc;
use crate::G2;
use kvbm_logical::blocks::ImmutableBlock;
use super::onboarding::{OnboardingStatus, SessionHandle};
use super::session::SessionId;
/// Staging mode for matched blocks.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum StagingMode {
/// Hold blocks in their current tiers (G2 and G3) without staging.
/// Session stays alive for future operations.
/// Blocks remain on their original instances (local or remote).
Hold,
/// Stage all G3→G2 on local and remote instances.
/// No RDMA pulls from remote instances.
/// Remote blocks stay in remote G2.
/// Session stays alive for future operations.
Prepare,
/// Full staging: G3→G2 everywhere, then RDMA pull remote G2→local G2.
/// Session completes after all blocks are in local G2.
#[default]
Full,
}
/// Options for find_matches operation.
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct FindMatchesOptions {
/// Whether to search remote instances in addition to local search.
/// Default: false (local only)
pub search_remote: bool,
/// Staging mode controlling how blocks are staged and session lifecycle.
/// Default: StagingMode::Full
pub staging_mode: StagingMode,
}
/// Result of a find_matches operation.
///
/// This enum has two variants:
/// - `Ready`: Immediate result when no async work is needed (local search with Hold mode)
/// - `AsyncSession`: When staging or remote search is required
#[derive(Debug)]
pub enum FindMatchesResult {
/// Immediate result - blocks are held in place without staging.
///
/// Returned when `search_remote == false` AND `staging_mode == Hold`.
/// Blocks remain in their original tiers (G2 or G3) on the local instance.
Ready(ReadyResult),
/// Async session for staging and/or remote search.
///
/// Returned when:
/// - `search_remote == true` (remote searching enabled)
/// - OR `staging_mode` is `Prepare` or `Full` (local/remote staging)
AsyncSession(AsyncSessionResult),
}
/// Immediate result containing matched blocks held directly.
///
/// No session is created - blocks are owned directly by this struct (RAII).
/// Dropping this struct will release the block references.
#[derive(Debug)]
pub struct ReadyResult {
/// G2 blocks held directly via RAII
blocks: Vec<ImmutableBlock<G2>>,
}
impl ReadyResult {
/// Create a new ready result with G2 blocks.
pub fn new(blocks: Vec<ImmutableBlock<G2>>) -> Self {
Self { blocks }
}
/// Number of G2 blocks held.
pub fn g2_count(&self) -> usize {
self.blocks.len()
}
/// Take ownership of the G2 blocks.
///
/// After calling this, the ReadyResult will be empty.
pub fn take_g2_blocks(&mut self) -> Vec<ImmutableBlock<G2>> {
std::mem::take(&mut self.blocks)
}
/// Get a reference to the G2 blocks.
pub fn blocks(&self) -> &[ImmutableBlock<G2>] {
&self.blocks
}
}
/// Async session result for staging and/or remote search operations.
#[derive(Debug)]
pub struct AsyncSessionResult {
session_id: SessionId,
status_rx: watch::Receiver<OnboardingStatus>,
blocks: Arc<Mutex<Option<Vec<ImmutableBlock<G2>>>>>,
session_handle: Option<SessionHandle>,
}
impl AsyncSessionResult {
/// Create a new async session result.
pub fn new(
session_id: SessionId,
status_rx: watch::Receiver<OnboardingStatus>,
blocks: Arc<Mutex<Option<Vec<ImmutableBlock<G2>>>>>,
session_handle: Option<SessionHandle>,
) -> Self {
Self {
session_id,
status_rx,
blocks,
session_handle,
}
}
/// Get the session ID for this onboarding operation.
pub fn session_id(&self) -> SessionId {
self.session_id
}
/// Get the current status of the onboarding operation.
pub fn status(&self) -> OnboardingStatus {
self.status_rx.borrow().clone()
}
/// Get session handle for deferred operations (Hold/Prepare modes only).
///
/// Returns None for StagingMode::Full.
pub fn session_handle(&self) -> Option<&SessionHandle> {
self.session_handle.as_ref()
}
/// Non-blocking check if blocks are available.
///
/// Returns Some(count) if blocks are available, None if still in progress.
/// Use wait_for_completion() to take ownership of blocks.
pub fn get_blocks_count(&self) -> Option<usize> {
self.blocks.try_lock().ok()?.as_ref().map(|v| v.len())
}
/// Wait for the operation to complete and return the matched blocks.
///
/// For StagingMode::Full, waits for Complete status.
/// For Hold/Prepare modes, waits for terminal state (Holding/Prepared/Complete).
///
/// This method returns a future that can be used with tokio::select!.
pub fn wait_for_completion(&self) -> BoxFuture<'static, Result<()>> {
let mut status_rx = self.status_rx.clone();
Box::pin(async move {
// Wait for terminal status
status_rx
.wait_for(|status| {
matches!(
status,
OnboardingStatus::Complete { .. }
| OnboardingStatus::Holding { .. }
| OnboardingStatus::Prepared { .. }
)
})
.await
.map_err(|e| anyhow::anyhow!("failed to wait for completion: {e}"))?;
Ok(())
})
}
}
impl FindMatchesResult {
/// Check if this is a ready (immediate) result.
pub fn is_ready(&self) -> bool {
matches!(self, FindMatchesResult::Ready(_))
}
/// Check if this is an async session result.
pub fn is_async(&self) -> bool {
matches!(self, FindMatchesResult::AsyncSession(_))
}
/// Get the ready result, if this is a Ready variant.
pub fn as_ready(&self) -> Option<&ReadyResult> {
match self {
FindMatchesResult::Ready(r) => Some(r),
FindMatchesResult::AsyncSession(_) => None,
}
}
/// Get the ready result mutably, if this is a Ready variant.
pub fn as_ready_mut(&mut self) -> Option<&mut ReadyResult> {
match self {
FindMatchesResult::Ready(r) => Some(r),
FindMatchesResult::AsyncSession(_) => None,
}
}
/// Get the async session result, if this is an AsyncSession variant.
pub fn as_async(&self) -> Option<&AsyncSessionResult> {
match self {
FindMatchesResult::Ready(_) => None,
FindMatchesResult::AsyncSession(a) => Some(a),
}
}
/// Get the async session result mutably, if this is an AsyncSession variant.
pub fn as_async_mut(&mut self) -> Option<&mut AsyncSessionResult> {
match self {
FindMatchesResult::Ready(_) => None,
FindMatchesResult::AsyncSession(a) => Some(a),
}
}
/// Get the number of G2 blocks available or matched.
///
/// For Ready: returns the count of blocks held.
/// For AsyncSession: returns the count if blocks are available, 0 otherwise.
pub fn g2_count(&self) -> usize {
match self {
FindMatchesResult::Ready(r) => r.g2_count(),
FindMatchesResult::AsyncSession(a) => a.get_blocks_count().unwrap_or(0),
}
}
/// Take ownership of G2 blocks if available.
///
/// For Ready: always succeeds, returns the blocks.
/// For AsyncSession: returns Some if blocks are available and lock succeeds.
pub fn take_g2_blocks(&mut self) -> Option<Vec<ImmutableBlock<G2>>> {
match self {
FindMatchesResult::Ready(r) => Some(r.take_g2_blocks()),
FindMatchesResult::AsyncSession(a) => a.blocks.try_lock().ok()?.take(),
}
}
pub fn session_id(&self) -> Option<SessionId> {
match self {
FindMatchesResult::Ready(_) => None,
FindMatchesResult::AsyncSession(a) => Some(a.session_id()),
}
}
/// Wait for the operation to complete.
///
/// For Ready variant: returns immediately with Ok(()).
/// For AsyncSession variant: waits for terminal status (Complete/Holding/Prepared).
///
/// Returns an Either future that can be used with tokio::select!.
pub fn wait_for_completion(&self) -> Either<Ready<Result<()>>, BoxFuture<'static, Result<()>>> {
match self {
FindMatchesResult::Ready(_) => Either::Left(ready(Ok(()))),
FindMatchesResult::AsyncSession(async_session) => {
Either::Right(async_session.wait_for_completion())
}
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//! Core event storage engine backed by a generational slot system. mod service;
pub(crate) mod system; pub use service::{ExportMetadataCallback, VeloLeaderService};
pub use system::EventSystemBase;
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use ::velo::{Handler, Messenger};
use anyhow::Result;
use bytes::Bytes;
use dashmap::DashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use crate::leader::session::{
OnboardMessage, OnboardSessionTx, SessionId, SessionMessage, SessionMessageTx,
dispatch_onboard_message, dispatch_session_message,
};
use kvbm_physical::manager::SerializedLayout;
/// Type alias for async export metadata callback.
/// Returns a boxed future that resolves to `Vec<SerializedLayout>`.
pub type ExportMetadataCallback = Arc<
dyn Fn() -> Pin<Box<dyn Future<Output = Result<Vec<SerializedLayout>>> + Send>> + Send + Sync,
>;
/// Velo leader service for handling distributed onboarding messages.
///
/// This service registers handlers for:
/// 1. OnboardMessage: Standard find_matches flow (initiator → responder)
/// 2. SessionMessage: Unified session protocol
/// 3. Export metadata RPC: Returns worker layout metadata for RDMA
pub struct VeloLeaderService {
messenger: Arc<Messenger>,
sessions: Arc<DashMap<SessionId, OnboardSessionTx>>,
/// Callback to spawn new responder sessions.
/// Takes the CreateSession message and creates a new responder task.
spawn_responder: Option<Arc<dyn Fn(OnboardMessage) -> Result<()> + Send + Sync>>,
// Unified session protocol
/// Map of unified session receivers.
session_sessions: Option<Arc<DashMap<SessionId, SessionMessageTx>>>,
// RDMA metadata export
/// Callback to export worker metadata for RDMA transfers.
export_metadata: Option<ExportMetadataCallback>,
}
impl VeloLeaderService {
pub fn new(
messenger: Arc<Messenger>,
sessions: Arc<DashMap<SessionId, OnboardSessionTx>>,
) -> Self {
Self {
messenger,
sessions,
spawn_responder: None,
session_sessions: None,
export_metadata: None,
}
}
/// Set the callback for spawning responder sessions.
pub fn with_spawn_responder<F>(mut self, f: F) -> Self
where
F: Fn(OnboardMessage) -> Result<()> + Send + Sync + 'static,
{
self.spawn_responder = Some(Arc::new(f));
self
}
/// Set the unified session sessions map.
pub fn with_session_sessions(
mut self,
sessions: Arc<DashMap<SessionId, SessionMessageTx>>,
) -> Self {
self.session_sessions = Some(sessions);
self
}
/// Set the callback for exporting worker metadata (RDMA).
///
/// This callback is invoked when a remote leader requests metadata
/// to enable RDMA transfers. The callback should return `Vec<SerializedLayout>`
/// containing metadata from all workers.
pub fn with_export_metadata(mut self, callback: ExportMetadataCallback) -> Self {
self.export_metadata = Some(callback);
self
}
/// Register all Velo handlers for leader-to-leader communication.
pub fn register_handlers(self) -> Result<()> {
self.register_onboard_handler()?;
// Register session handler if unified protocol is configured
if self.session_sessions.is_some() {
self.register_session_handler()?;
}
// Register export_metadata handler if callback is configured
if self.export_metadata.is_some() {
self.register_export_metadata_handler()?;
}
Ok(())
}
/// Register the "kvbm.leader.onboard" handler.
///
/// This handler is intentionally simple and fast:
/// - Deserializes the message
/// - If CreateSession and session doesn't exist, spawns responder
/// - Dispatches to session channel
/// - Returns immediately (< 1ms)
fn register_onboard_handler(&self) -> Result<()> {
let sessions = self.sessions.clone();
let spawn_responder = self.spawn_responder.clone();
let handler = Handler::am_handler_async("kvbm.leader.onboard", move |ctx| {
let sessions = sessions.clone();
let spawn_responder = spawn_responder.clone();
async move {
// Fast path: just deserialize and dispatch
let message: OnboardMessage = serde_json::from_slice(&ctx.payload)
.map_err(|e| anyhow::anyhow!("failed to deserialize OnboardMessage: {e}"))?;
let session_id = message.session_id();
tracing::debug!(
variant = message.variant_name(),
%session_id,
"Received onboard message"
);
// If this is a CreateSession and no session exists, spawn responder
if matches!(message, OnboardMessage::CreateSession { .. })
&& !sessions.contains_key(&session_id)
{
tracing::debug!(%session_id, "Spawning new ResponderSession");
if let Some(ref spawner) = spawn_responder {
spawner(message.clone()).ok(); // Best-effort spawn
}
}
// Dispatch to session channel (will create if needed by spawner above)
tracing::debug!(%session_id, "Dispatching message to session");
dispatch_onboard_message(&sessions, message).await?;
Ok(())
}
})
.build();
self.messenger.register_handler(handler)?;
Ok(())
}
/// Register the "kvbm.leader.session" handler.
///
/// This handler supports the unified session protocol.
/// Routes SessionMessages to the appropriate session endpoint.
fn register_session_handler(&self) -> Result<()> {
let session_sessions = self
.session_sessions
.clone()
.expect("session_sessions required for handler registration");
let handler = Handler::am_handler_async("kvbm.leader.session", move |ctx| {
let session_sessions = session_sessions.clone();
async move {
let message: SessionMessage = serde_json::from_slice(&ctx.payload)
.map_err(|e| anyhow::anyhow!("failed to deserialize SessionMessage: {e}"))?;
let session_id = message.session_id();
tracing::debug!(
variant = message.variant_name(),
%session_id,
"Received session message"
);
// Dispatch to session endpoint
dispatch_session_message(&session_sessions, message).await?;
Ok(())
}
})
.build();
self.messenger.register_handler(handler)?;
Ok(())
}
/// Register the "kvbm.leader.export_metadata" handler.
///
/// This handler returns `Vec<SerializedLayout>` containing metadata from all workers.
/// Used by remote leaders to enable RDMA transfers.
fn register_export_metadata_handler(&self) -> Result<()> {
let export_metadata = self
.export_metadata
.clone()
.expect("export_metadata callback required for handler registration");
let handler = Handler::unary_handler_async("kvbm.leader.export_metadata", move |_ctx| {
let export_metadata = export_metadata.clone();
async move {
tracing::debug!("Received export_metadata request");
// Call the async callback to get metadata from all workers
let metadata_vec = export_metadata().await?;
// Serialize the Vec<SerializedLayout> for transport
let serialized = serde_json::to_vec(&metadata_vec)?;
tracing::debug!(
count = metadata_vec.len(),
"Returning worker metadata entries"
);
Ok(Some(Bytes::from(serialized)))
}
})
.build();
self.messenger.register_handler(handler)?;
Ok(())
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
#![doc = include_str!("../docs/architecture.md")]
pub use kvbm_common::{BlockId, LogicalLayoutHandle, SequenceHash};
pub use velo::{InstanceId, PeerInfo, WorkerAddress};
/// GPU/device tier -- HBM KV cache. Fastest access, smallest capacity.
/// Blocks here are actively used by attention kernels.
#[derive(Clone, Copy, Debug)]
pub struct G1;
/// CPU/host tier -- pinned DRAM cache. Microsecond-latency staging area
/// for RDMA transfers and G3/G4 promotion.
#[derive(Clone, Copy, Debug)]
pub struct G2;
/// Disk tier -- NVMe/SSD cache. Millisecond-latency persistent storage
/// for warm blocks.
#[derive(Clone, Copy, Debug)]
pub struct G3;
/// Object store tier -- S3/MinIO. Highest latency but unlimited capacity
/// for cold/archival blocks.
#[derive(Clone, Copy, Debug)]
pub struct G4;
#[cfg(feature = "collectives")]
pub mod collectives;
#[doc = include_str!("../docs/leader.md")]
pub mod leader;
#[doc = include_str!("../docs/object.md")]
pub mod object;
#[doc = include_str!("../docs/offload.md")]
pub mod offload;
pub mod pubsub;
#[doc = include_str!("../docs/runtime.md")]
pub mod runtime;
#[doc = include_str!("../docs/worker.md")]
pub mod worker;
#[cfg(feature = "testing")]
pub mod testing;
pub use runtime::{KvbmRuntime, KvbmRuntimeBuilder, RuntimeHandle};
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Object storage module for distributed block management.
//!
//! This module provides traits and implementations for storing KV cache blocks
//! in object storage systems like S3/MinIO.
//!
//! # Architecture
//!
//! Traits are defined here; implementations are in feature-gated submodules:
//! - [`ObjectBlockOps`](crate::object::ObjectBlockOps) - High-level block operations (put, get, has)
//! - [`ObjectLockManager`](crate::object::ObjectLockManager) - Distributed locking for coordinated offloads
//!
//! Consumers should use factory functions to obtain trait objects without
//! depending on specific feature flags.
use std::sync::Arc;
use anyhow::Result;
use futures::future::BoxFuture;
use crate::{BlockId, SequenceHash};
use kvbm_common::LogicalLayoutHandle;
use kvbm_physical::layout::LayoutConfig;
use kvbm_physical::transfer::PhysicalLayout;
#[cfg(feature = "s3")]
pub mod s3;
// ============================================================================
// Key Formatting
// ============================================================================
/// Trait for converting SequenceHash to object storage keys.
///
/// Implementations can embed rank, namespace, or any other prefix/suffix
/// to ensure key uniqueness across SPMD workers or other contexts.
pub trait KeyFormatter: Send + Sync {
/// Convert a sequence hash to an object storage key string.
fn format_key(&self, hash: &SequenceHash) -> String;
}
/// Default key formatter - uses Display representation of PositionalLineageHash.
///
/// Produces keys like: `0:abc123` or `5:abc123:def456` (position:current\[:parent\])
/// using base58 encoding for hash fragments.
/// Suitable for single-worker scenarios or testing.
#[derive(Debug, Clone, Default)]
pub struct DefaultKeyFormatter;
impl KeyFormatter for DefaultKeyFormatter {
fn format_key(&self, hash: &SequenceHash) -> String {
hash.to_string()
}
}
/// Rank-prefixed key formatter for SPMD workers.
///
/// Formats keys as `{rank}/{display_hash}` to ensure uniqueness across workers
/// writing the same logical blocks. The hash uses the Display representation
/// (e.g., `0/5:abc123:def456`).
#[derive(Debug, Clone)]
pub struct RankPrefixedKeyFormatter {
rank: usize,
}
impl RankPrefixedKeyFormatter {
/// Create a new rank-prefixed formatter.
pub fn new(rank: usize) -> Self {
Self { rank }
}
/// Get the rank.
pub fn rank(&self) -> usize {
self.rank
}
}
impl KeyFormatter for RankPrefixedKeyFormatter {
fn format_key(&self, hash: &SequenceHash) -> String {
format!("{}/{}", self.rank, hash)
}
}
/// Create a key formatter appropriate for the given rank.
///
/// Returns a `RankPrefixedKeyFormatter` if rank is provided,
/// otherwise returns a `DefaultKeyFormatter`.
pub fn create_key_formatter(rank: Option<usize>) -> Arc<dyn KeyFormatter> {
match rank {
Some(r) => Arc::new(RankPrefixedKeyFormatter::new(r)),
None => Arc::new(DefaultKeyFormatter),
}
}
/// Extension methods for LayoutConfig to support object storage operations.
pub trait LayoutConfigExt {
/// Compute the size of a single block in bytes.
fn block_size_bytes(&self) -> usize;
/// Compute the size of a single memory region in bytes.
fn region_size(&self) -> usize;
}
impl LayoutConfigExt for LayoutConfig {
fn block_size_bytes(&self) -> usize {
self.num_layers
.saturating_mul(self.outer_dim)
.saturating_mul(self.page_size)
.saturating_mul(self.inner_dim)
.saturating_mul(self.dtype_width_bytes)
}
fn region_size(&self) -> usize {
self.page_size
.saturating_mul(self.inner_dim)
.saturating_mul(self.dtype_width_bytes)
}
}
/// Low-level object storage client trait.
pub trait ObjectClient: Send + Sync {
/// Check if an object exists.
fn has_object(&self, key: &[u8]) -> anyhow::Result<bool>;
/// Put an object.
fn put_object(&self, key: &[u8], data: &[&[u8]]) -> anyhow::Result<()>;
/// Get an object.
fn get_object(&self, key: &[u8], data: &mut [&mut [u8]]) -> anyhow::Result<()>;
}
/// Unified object block operations trait.
///
/// This trait provides high-level operations for storing and retrieving
/// KV cache blocks in object storage (e.g., S3, MinIO).
///
/// Uses `LogicalLayoutHandle` to identify source/destination layouts. In distributed
/// mode, workers resolve the logical handle to their own physical layouts. This allows
/// the leader (which doesn't have physical layouts) to use the same trait.
///
/// Uses `'static` BoxFuture for runtime flexibility - implementations clone/Arc
/// what they need from self. Takes owned Vecs for simplicity; keys are returned
/// in results so callers can correlate success/failure.
///
/// Implemented by:
/// - `S3ObjectBlockClient` - direct S3 operations (has_blocks only; put/get require physical layout)
/// - `DirectWorker` - resolves logical handle to physical layout, then delegates
/// - `CoordinatedWorker` - delegates to inner worker
/// - `LeaderObjectClient` - coordinates workers for distributed uploads
pub trait ObjectBlockOps: Send + Sync {
/// Check if blocks exist in object storage.
///
/// Returns a vector of (hash, size_option) pairs where:
/// - Some(size) indicates the block exists with the given size in bytes
/// - None indicates the block does not exist or an error occurred
fn has_blocks(
&self,
keys: Vec<SequenceHash>,
) -> BoxFuture<'static, Vec<(SequenceHash, Option<usize>)>>;
/// Put blocks to object storage.
///
/// # Arguments
/// * `keys` - Sequence hashes identifying each block
/// * `src_layout` - Logical layout handle identifying the source (workers resolve to physical)
/// * `block_ids` - Block IDs within the layout to upload
///
/// Returns a vector of results for each block:
/// - Ok(hash) indicates the block was successfully stored
/// - Err(hash) indicates the block failed to store
///
/// # Note
/// For `S3ObjectBlockClient`, this will error - use `put_blocks_with_layout` instead.
/// Workers should resolve the logical handle to their physical layout first.
fn put_blocks(
&self,
keys: Vec<SequenceHash>,
src_layout: LogicalLayoutHandle,
block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>>;
/// Get blocks from object storage.
///
/// # Arguments
/// * `keys` - Sequence hashes identifying each block
/// * `dst_layout` - Logical layout handle identifying the destination (workers resolve to physical)
/// * `block_ids` - Block IDs within the layout to download into
///
/// Returns a vector of results for each block:
/// - Ok(hash) indicates the block was successfully retrieved
/// - Err(hash) indicates the block failed to retrieve
///
/// # Note
/// For `S3ObjectBlockClient`, this will error - use `get_blocks_with_layout` instead.
/// Workers should resolve the logical handle to their physical layout first.
fn get_blocks(
&self,
keys: Vec<SequenceHash>,
dst_layout: LogicalLayoutHandle,
block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>>;
// =========================================================================
// Physical Layout Methods (for workers that can resolve handles)
// =========================================================================
/// Put blocks to object storage using a resolved physical layout.
///
/// This method is called by workers after resolving a logical handle to
/// their physical layout. The default implementation errors; storage backends
/// like `S3ObjectBlockClient` override this with actual upload logic.
///
/// # Arguments
/// * `keys` - Sequence hashes identifying each block
/// * `layout` - Physical layout containing the block data
/// * `block_ids` - Block IDs within the layout to upload
fn put_blocks_with_layout(
&self,
keys: Vec<SequenceHash>,
_layout: PhysicalLayout,
_block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
Box::pin(async move { keys.into_iter().map(Err).collect() })
}
/// Get blocks from object storage into a resolved physical layout.
///
/// This method is called by workers after resolving a logical handle to
/// their physical layout. The default implementation errors; storage backends
/// like `S3ObjectBlockClient` override this with actual download logic.
///
/// # Arguments
/// * `keys` - Sequence hashes identifying each block
/// * `layout` - Physical layout to write the block data into
/// * `block_ids` - Block IDs within the layout to download into
fn get_blocks_with_layout(
&self,
keys: Vec<SequenceHash>,
_layout: PhysicalLayout,
_block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
Box::pin(async move { keys.into_iter().map(Err).collect() })
}
}
// ============================================================================
// Object Lock Manager Trait
// ============================================================================
/// Lock file content structure for distributed locking.
///
/// The lock file is stored as JSON in object storage at `{sequence_hash}.lock`.
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct LockFileContent {
/// Unique identifier of the instance that holds the lock
pub instance_id: String,
/// When the lock was acquired (ISO 8601 timestamp)
pub acquired_at: String,
/// When the lock expires (ISO 8601 timestamp)
pub deadline: String,
}
/// Object lock manager trait for distributed locking in object storage.
///
/// This trait provides the locking semantics for the object offload pipeline:
/// 1. Check `.meta` file to see if block is already offloaded
/// 2. Try to acquire `.lock` file with conditional PUT
/// 3. Create `.meta` file after successful offload
/// 4. Release `.lock` file after completion
///
/// # Locking Flow
///
/// ```text
/// has_meta() -> false -> try_acquire_lock() -> true -> execute transfer -> create_meta() -> release_lock()
/// -> false -> skip (another instance owns it)
/// -> true -> skip (already offloaded)
/// ```
pub trait ObjectLockManager: Send + Sync {
/// Check if meta file exists (block already offloaded).
///
/// Returns `true` if `{hash}.meta` exists, meaning the block has been
/// successfully offloaded and should be skipped.
fn has_meta(&self, hash: SequenceHash) -> BoxFuture<'static, Result<bool>>;
/// Try to acquire a lock for the given block.
///
/// This method:
/// 1. Attempts conditional PUT of `{hash}.lock` with `If-None-Match: *`
/// 2. If lock exists, reads it to check deadline
/// 3. If deadline is breached (> timeout), overwrites the lock
///
/// Returns:
/// - `Ok(true)` if lock was acquired or overwritten
/// - `Ok(false)` if another instance owns a valid lock
/// - `Err(...)` for other errors
fn try_acquire_lock(&self, hash: SequenceHash) -> BoxFuture<'static, Result<bool>>;
/// Create the meta file after successful offload.
///
/// This marks the block as successfully offloaded by creating `{hash}.meta`.
fn create_meta(&self, hash: SequenceHash) -> BoxFuture<'static, Result<()>>;
/// Release the lock by deleting the lock file.
///
/// Deletes `{hash}.lock` after the transfer is complete.
fn release_lock(&self, hash: SequenceHash) -> BoxFuture<'static, Result<()>>;
}
// ============================================================================
// Factory Functions
// ============================================================================
/// Create an object client from configuration.
///
/// Returns a trait object so consumers don't need to depend on the `s3` feature.
/// The implementation is selected based on the configuration type.
///
/// # Arguments
/// * `config` - Object storage configuration
/// * `rank` - Optional worker rank for key prefixing (None for leader)
///
/// # Errors
/// Returns an error if the object client cannot be initialized or if the
/// required feature is not enabled.
#[cfg(feature = "s3")]
pub async fn create_object_client(
config: &kvbm_config::ObjectConfig,
rank: Option<usize>,
) -> Result<Arc<dyn ObjectBlockOps>> {
use kvbm_config::ObjectClientConfig;
use s3::{S3Config, S3ObjectBlockClient};
let key_formatter = create_key_formatter(rank);
match &config.client {
ObjectClientConfig::S3(s3_config) => {
let config = S3Config::from_object_config(s3_config);
let client = S3ObjectBlockClient::with_key_formatter(config, key_formatter).await?;
Ok(Arc::new(client))
}
ObjectClientConfig::Nixl(_nixl_config) => {
anyhow::bail!("Nixl object storage backend not yet implemented")
}
}
}
/// Fallback when S3 feature is disabled.
#[cfg(not(feature = "s3"))]
pub async fn create_object_client(
_config: &kvbm_config::ObjectConfig,
_rank: Option<usize>,
) -> Result<Arc<dyn ObjectBlockOps>> {
anyhow::bail!("Object storage requires the 's3' feature to be enabled")
}
/// Create a lock manager from configuration.
///
/// Returns a trait object so consumers don't need to depend on the `s3` feature.
///
/// # Arguments
/// * `config` - Object storage configuration
/// * `instance_id` - Unique identifier for this instance (used in lock files)
///
/// # Errors
/// Returns an error if the lock manager cannot be initialized or if the
/// required feature is not enabled.
#[cfg(feature = "s3")]
pub async fn create_lock_manager(
config: &kvbm_config::ObjectConfig,
instance_id: String,
) -> Result<Arc<dyn ObjectLockManager>> {
use kvbm_config::ObjectClientConfig;
use s3::{S3Config, S3LockManager, S3ObjectBlockClient};
match &config.client {
ObjectClientConfig::S3(s3_config) => {
let config = S3Config::from_object_config(s3_config);
// Lock manager uses default key formatter (no rank prefix for lock/meta files)
let client = Arc::new(S3ObjectBlockClient::new(config).await?);
let manager = S3LockManager::new(client, instance_id);
Ok(Arc::new(manager))
}
ObjectClientConfig::Nixl(_nixl_config) => {
anyhow::bail!("Nixl object storage backend not yet implemented")
}
}
}
/// Fallback when S3 feature is disabled.
#[cfg(not(feature = "s3"))]
pub async fn create_lock_manager(
_config: &kvbm_config::ObjectConfig,
_instance_id: String,
) -> Result<Arc<dyn ObjectLockManager>> {
anyhow::bail!("Object storage requires the 's3' feature to be enabled")
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! S3-compatible object storage client for block management.
//!
//! This module provides an implementation of [`ObjectBlockOps`] using the AWS S3 SDK.
//! It supports S3-compatible storage services including MinIO.
use anyhow::{Result, anyhow};
use aws_sdk_s3::Client;
use aws_sdk_s3::error::ProvideErrorMetadata;
use aws_sdk_s3::primitives::ByteStream;
use bytes::Bytes;
use futures::future::BoxFuture;
use futures::stream::StreamExt;
use crate::object::{DefaultKeyFormatter, KeyFormatter, LayoutConfigExt, ObjectBlockOps};
use crate::{BlockId, SequenceHash};
use kvbm_common::LogicalLayoutHandle;
use kvbm_physical::transfer::PhysicalLayout;
use std::sync::Arc;
/// Configuration for S3 object storage client.
#[derive(Debug, Clone)]
pub struct S3Config {
/// Custom endpoint URL for S3-compatible services (e.g., MinIO).
/// If None, uses the default AWS S3 endpoint.
pub endpoint_url: Option<String>,
/// S3 bucket name for storing blocks.
pub bucket: String,
/// AWS region.
pub region: String,
/// Use path-style URLs instead of virtual-hosted-style.
/// Required for MinIO and some S3-compatible services.
pub force_path_style: bool,
/// Maximum number of concurrent S3 requests.
pub max_concurrent_requests: usize,
}
impl Default for S3Config {
/// Returns a default configuration suitable for local MinIO testing.
fn default() -> Self {
Self {
endpoint_url: Some("http://localhost:9000".into()),
bucket: "kvbm-blocks".into(),
region: "us-east-1".into(),
force_path_style: true,
max_concurrent_requests: 16,
}
}
}
impl S3Config {
/// Create a new S3Config for AWS S3 (not MinIO).
pub fn aws(bucket: String, region: String) -> Self {
Self {
endpoint_url: None,
bucket,
region,
force_path_style: false,
max_concurrent_requests: 16,
}
}
/// Create a new S3Config for MinIO.
pub fn minio(endpoint_url: String, bucket: String) -> Self {
Self {
endpoint_url: Some(endpoint_url),
bucket,
region: "us-east-1".into(),
force_path_style: true,
max_concurrent_requests: 16,
}
}
/// Create from kvbm-config's S3ObjectConfig.
pub fn from_object_config(config: &kvbm_config::S3ObjectConfig) -> Self {
Self {
endpoint_url: config.endpoint_url.clone(),
bucket: config.bucket.clone(),
region: config.region.clone(),
force_path_style: config.force_path_style,
max_concurrent_requests: config.max_concurrent_requests,
}
}
/// Set the maximum number of concurrent requests.
pub fn with_max_concurrent_requests(mut self, max: usize) -> Self {
self.max_concurrent_requests = max;
self
}
}
/// S3-compatible object storage client for block operations.
///
/// This client implements [`ObjectBlockOps`] using the AWS S3 SDK.
/// It supports parallel block operations and uses rayon for CPU-bound memory copies.
///
/// # Key Formatting
///
/// Uses a [`KeyFormatter`] to convert `SequenceHash` to object keys. The formatter
/// can embed rank, namespace, or other prefixes for key uniqueness across workers.
pub struct S3ObjectBlockClient {
/// AWS S3 client
client: Client,
/// S3 configuration
config: S3Config,
/// Key formatter for converting SequenceHash to object keys.
key_formatter: Arc<dyn KeyFormatter>,
}
impl S3ObjectBlockClient {
/// Create a new S3ObjectBlockClient with default key formatting.
///
/// # Arguments
/// * `config` - S3 configuration
///
/// # Errors
/// Returns an error if the S3 client cannot be initialized.
pub async fn new(config: S3Config) -> Result<Self> {
let client = build_s3_client(&config).await?;
Ok(Self {
client,
config,
key_formatter: Arc::new(DefaultKeyFormatter),
})
}
/// Create a new S3ObjectBlockClient with a custom key formatter.
///
/// # Arguments
/// * `config` - S3 configuration
/// * `key_formatter` - Custom key formatter for SequenceHash → String conversion
///
/// # Errors
/// Returns an error if the S3 client cannot be initialized.
pub async fn with_key_formatter(
config: S3Config,
key_formatter: Arc<dyn KeyFormatter>,
) -> Result<Self> {
let client = build_s3_client(&config).await?;
Ok(Self {
client,
config,
key_formatter,
})
}
/// Create from an existing AWS S3 client with default key formatting.
pub fn from_client(client: Client, config: S3Config) -> Self {
Self {
client,
config,
key_formatter: Arc::new(DefaultKeyFormatter),
}
}
/// Create from an existing AWS S3 client with a custom key formatter.
pub fn from_client_with_formatter(
client: Client,
config: S3Config,
key_formatter: Arc<dyn KeyFormatter>,
) -> Self {
Self {
client,
config,
key_formatter,
}
}
/// Get a reference to the S3 client.
pub fn client(&self) -> &Client {
&self.client
}
/// Get a reference to the configuration.
pub fn config(&self) -> &S3Config {
&self.config
}
/// Get a reference to the key formatter.
pub fn key_formatter(&self) -> &Arc<dyn KeyFormatter> {
&self.key_formatter
}
/// Get a reference to the bucket name.
pub fn bucket(&self) -> &str {
&self.config.bucket
}
/// Ensure the bucket exists, creating it if necessary.
pub async fn ensure_bucket_exists(&self) -> Result<()> {
match self
.client
.head_bucket()
.bucket(&self.config.bucket)
.send()
.await
{
Ok(_) => Ok(()),
Err(_) => {
// Bucket doesn't exist, create it
self.client
.create_bucket()
.bucket(&self.config.bucket)
.send()
.await
.map_err(|e| {
anyhow!("failed to create bucket '{}': {}", self.config.bucket, e)
})?;
Ok(())
}
}
}
/// Put an object with a conditional check (If-None-Match: *).
///
/// This performs an atomic write that only succeeds if the object does not
/// already exist. Returns:
/// - `Ok(true)` if the object was created successfully
/// - `Ok(false)` if the object already exists (PreconditionFailed)
/// - `Err(...)` for other errors
///
/// # Arguments
/// * `key` - Object key
/// * `data` - Object data to write
pub async fn put_if_not_exists(&self, key: &str, data: Bytes) -> Result<bool> {
match self
.client
.put_object()
.bucket(&self.config.bucket)
.key(key)
.if_none_match("*")
.body(ByteStream::from(data))
.send()
.await
{
Ok(_) => Ok(true),
Err(e) => {
// Check if this is a precondition failed error (HTTP 412)
let service_error = e.into_service_error();
if service_error.code() == Some("PreconditionFailed") {
Ok(false)
} else {
Err(anyhow!(
"S3 conditional put failed for key '{}': {}",
key,
service_error
))
}
}
}
}
/// Get an object's raw bytes.
///
/// # Arguments
/// * `key` - Object key
///
/// # Returns
/// - `Ok(Some(bytes))` if the object exists
/// - `Ok(None)` if the object does not exist
/// - `Err(...)` for other errors
pub async fn get_object(&self, key: &str) -> Result<Option<Bytes>> {
match self
.client
.get_object()
.bucket(&self.config.bucket)
.key(key)
.send()
.await
{
Ok(resp) => {
let data = resp
.body
.collect()
.await
.map_err(|e| anyhow!("failed to collect S3 response body: {}", e))?
.into_bytes();
Ok(Some(data))
}
Err(e) => {
let service_error = e.into_service_error();
if service_error.code() == Some("NoSuchKey") {
Ok(None)
} else {
Err(anyhow!(
"S3 get_object failed for key '{}': {}",
key,
service_error
))
}
}
}
}
/// Delete an object.
///
/// # Arguments
/// * `key` - Object key
///
/// # Returns
/// - `Ok(true)` if the object was deleted
/// - `Ok(false)` if the object did not exist
/// - `Err(...)` for other errors
pub async fn delete_object(&self, key: &str) -> Result<bool> {
match self
.client
.delete_object()
.bucket(&self.config.bucket)
.key(key)
.send()
.await
{
Ok(_) => Ok(true),
Err(e) => {
let service_error = e.into_service_error();
if service_error.code() == Some("NoSuchKey") {
Ok(false)
} else {
Err(anyhow!(
"S3 delete_object failed for key '{}': {}",
key,
service_error
))
}
}
}
}
/// Check if an object exists (HEAD request).
///
/// # Arguments
/// * `key` - Object key
///
/// # Returns
/// - `Ok(true)` if the object exists
/// - `Ok(false)` if the object does not exist
/// - `Err(...)` for other errors
pub async fn has_object(&self, key: &str) -> Result<bool> {
match self
.client
.head_object()
.bucket(&self.config.bucket)
.key(key)
.send()
.await
{
Ok(_) => Ok(true),
Err(e) => {
let service_error = e.into_service_error();
// HeadObject returns "NotFound" when object doesn't exist
if service_error.code() == Some("NotFound") {
Ok(false)
} else {
Err(anyhow!(
"S3 head_object failed for key '{}': {}",
key,
service_error
))
}
}
}
}
/// Put an object unconditionally (overwrite if exists).
///
/// # Arguments
/// * `key` - Object key
/// * `data` - Object data to write
pub async fn put_object(&self, key: &str, data: Bytes) -> Result<()> {
self.client
.put_object()
.bucket(&self.config.bucket)
.key(key)
.body(ByteStream::from(data))
.send()
.await
.map_err(|e| anyhow!("S3 put_object failed for key '{}': {}", key, e))?;
Ok(())
}
/// Put blocks to object storage using a physical layout.
///
/// This is the internal implementation that workers call after resolving
/// the logical layout handle to a physical layout.
///
/// # Arguments
/// * `keys` - Sequence hashes identifying each block
/// * `layout` - Physical layout containing the block data
/// * `block_ids` - Block IDs within the layout to upload
///
/// Returns a vector of results for each block:
/// - Ok(hash) indicates the block was successfully stored
/// - Err(hash) indicates the block failed to store
pub fn put_blocks_with_layout(
&self,
keys: Vec<SequenceHash>,
layout: PhysicalLayout,
block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
let config = layout.layout().config();
let block_size = config.block_size_bytes();
let region_size = config.region_size();
let is_contiguous = layout.layout().is_fully_contiguous();
let max_concurrent = self.config.max_concurrent_requests;
let client = self.client.clone();
let bucket = self.config.bucket.clone();
let formatter = self.key_formatter.clone();
Box::pin(async move {
let work_items: Vec<_> = keys.into_iter().zip(block_ids).collect();
let tasks = work_items.into_iter().map(|(key, block_id)| {
let client = client.clone();
let bucket = bucket.clone();
let key_str = formatter.format_key(&key);
let layout = layout.clone();
async move {
let result: Result<(), anyhow::Error> = async {
// Copy block data to bytes on rayon thread pool
let data = tokio_rayon::spawn(move || {
copy_block_to_bytes(
&layout,
block_id,
block_size,
region_size,
is_contiguous,
)
})
.await?;
// Upload to S3
client
.put_object()
.bucket(&bucket)
.key(&key_str)
.body(ByteStream::from(data))
.send()
.await
.map_err(|e| anyhow!("S3 put_object failed: {}", e))?;
Ok(())
}
.await;
match result {
Ok(()) => Ok(key),
Err(e) => {
tracing::warn!(key = %key, error = %e, "put block transfer failed");
Err(key)
}
}
}
});
futures::stream::iter(tasks)
.buffer_unordered(max_concurrent)
.collect()
.await
})
}
/// Get blocks from object storage into a physical layout.
///
/// This is the internal implementation that workers call after resolving
/// the logical layout handle to a physical layout.
///
/// # Arguments
/// * `keys` - Sequence hashes identifying each block
/// * `layout` - Physical layout to write the block data into
/// * `block_ids` - Block IDs within the layout to download into
///
/// Returns a vector of results for each block:
/// - Ok(hash) indicates the block was successfully retrieved
/// - Err(hash) indicates the block failed to retrieve
pub fn get_blocks_with_layout(
&self,
keys: Vec<SequenceHash>,
layout: PhysicalLayout,
block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
let config = layout.layout().config();
let block_size = config.block_size_bytes();
let region_size = config.region_size();
let is_contiguous = layout.layout().is_fully_contiguous();
let max_concurrent = self.config.max_concurrent_requests;
let client = self.client.clone();
let bucket = self.config.bucket.clone();
let formatter = self.key_formatter.clone();
Box::pin(async move {
let work_items: Vec<_> = keys.into_iter().zip(block_ids).collect();
let tasks = work_items.into_iter().map(|(key, block_id)| {
let client = client.clone();
let bucket = bucket.clone();
let key_str = formatter.format_key(&key);
let layout = layout.clone();
async move {
let result: Result<(), anyhow::Error> = async {
// Download from S3
let resp = client
.get_object()
.bucket(&bucket)
.key(&key_str)
.send()
.await
.map_err(|e| anyhow!("S3 get_object failed: {}", e))?;
let data = resp
.body
.collect()
.await
.map_err(|e| anyhow!("failed to collect S3 response body: {}", e))?
.into_bytes();
// Copy bytes to block on rayon thread pool
tokio_rayon::spawn(move || {
copy_bytes_to_block(
&data,
&layout,
block_id,
block_size,
region_size,
is_contiguous,
)
})
.await?;
Ok(())
}
.await;
match result {
Ok(()) => Ok(key),
Err(e) => {
tracing::warn!(key = %key, error = %e, "get block transfer failed");
Err(key)
}
}
}
});
futures::stream::iter(tasks)
.buffer_unordered(max_concurrent)
.collect()
.await
})
}
/// Get an object's raw bytes along with its ETag.
///
/// Used for conditional updates (CAS-style operations) where the caller
/// needs the current ETag to perform a conditional PUT.
///
/// # Returns
/// - `Ok(Some((bytes, etag)))` if the object exists
/// - `Ok(None)` if the object does not exist
/// - `Err(...)` for other errors
pub async fn get_object_with_etag(&self, key: &str) -> Result<Option<(Bytes, Option<String>)>> {
match self
.client
.get_object()
.bucket(&self.config.bucket)
.key(key)
.send()
.await
{
Ok(resp) => {
let etag = resp.e_tag().map(|s| s.to_string());
let data = resp
.body
.collect()
.await
.map_err(|e| anyhow!("failed to collect S3 response body: {}", e))?
.into_bytes();
Ok(Some((data, etag)))
}
Err(e) => {
let service_error = e.into_service_error();
if service_error.code() == Some("NoSuchKey") {
Ok(None)
} else {
Err(anyhow!(
"S3 get_object failed for key '{}': {}",
key,
service_error
))
}
}
}
}
/// Put an object with an ETag precondition (If-Match).
///
/// This performs a conditional write that only succeeds if the object's current
/// ETag matches the provided value. Used for CAS-style atomic updates.
///
/// # Returns
/// - `Ok(true)` if the write succeeded (ETag matched)
/// - `Ok(false)` if the ETag did not match (412 PreconditionFailed — lost the race)
/// - `Err(...)` for other errors
pub async fn put_object_if_match(&self, key: &str, data: Bytes, etag: &str) -> Result<bool> {
match self
.client
.put_object()
.bucket(&self.config.bucket)
.key(key)
.if_match(etag)
.body(ByteStream::from(data))
.send()
.await
{
Ok(_) => Ok(true),
Err(e) => {
let service_error = e.into_service_error();
if service_error.code() == Some("PreconditionFailed") {
Ok(false)
} else {
Err(anyhow!(
"S3 conditional put (if-match) failed for key '{}': {}",
key,
service_error
))
}
}
}
}
}
/// Build an S3 client from configuration.
async fn build_s3_client(config: &S3Config) -> Result<Client> {
let sdk_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
.region(aws_sdk_s3::config::Region::new(config.region.clone()))
.load()
.await;
let mut s3_config_builder = aws_sdk_s3::config::Builder::from(&sdk_config);
if let Some(endpoint) = &config.endpoint_url {
s3_config_builder = s3_config_builder.endpoint_url(endpoint);
}
if config.force_path_style {
s3_config_builder = s3_config_builder.force_path_style(true);
}
let s3_config = s3_config_builder.build();
Ok(Client::from_conf(s3_config))
}
/// Copy block data from a layout to a Bytes buffer.
///
/// For fully contiguous layouts, this is a single memcpy.
/// For layer-separate layouts, this iterates over all regions.
fn copy_block_to_bytes(
layout: &PhysicalLayout,
block_id: BlockId,
block_size: usize,
region_size: usize,
is_contiguous: bool,
) -> Result<Bytes> {
if is_contiguous {
// Fast path: single contiguous region — the layout guarantees that
// block_size bytes are contiguous from region(block_id, 0, 0).addr().
let region = layout.memory_region(block_id, 0, 0)?;
let slice = unsafe { std::slice::from_raw_parts(region.addr() as *const u8, block_size) };
Ok(Bytes::copy_from_slice(slice))
} else {
// Slow path: iterate over all regions
let mut buf = Vec::with_capacity(block_size);
let inner_layout = layout.layout();
for layer_id in 0..inner_layout.num_layers() {
for outer_id in 0..inner_layout.outer_dim() {
let region = layout.memory_region(block_id, layer_id, outer_id)?;
if region.size() < region_size {
return Err(anyhow!(
"memory region too small: got {} bytes, need {}",
region.size(),
region_size
));
}
let slice =
unsafe { std::slice::from_raw_parts(region.addr() as *const u8, region_size) };
buf.extend_from_slice(slice);
}
}
Ok(Bytes::from(buf))
}
}
/// Copy data from a Bytes buffer to a layout.
///
/// For fully contiguous layouts, this is a single memcpy.
/// For layer-separate layouts, this iterates over all regions.
fn copy_bytes_to_block(
data: &[u8],
layout: &PhysicalLayout,
block_id: BlockId,
block_size: usize,
region_size: usize,
is_contiguous: bool,
) -> Result<()> {
if is_contiguous {
// Fast path: single contiguous region — the layout guarantees that
// block_size bytes are contiguous from region(block_id, 0, 0).addr().
if data.len() < block_size {
return Err(anyhow!(
"S3 data too short: got {} bytes, expected {}",
data.len(),
block_size
));
}
let region = layout.memory_region(block_id, 0, 0)?;
unsafe {
std::ptr::copy_nonoverlapping(data.as_ptr(), region.addr() as *mut u8, block_size);
}
} else {
// Slow path: iterate over all regions
let mut offset = 0;
let inner_layout = layout.layout();
for layer_id in 0..inner_layout.num_layers() {
for outer_id in 0..inner_layout.outer_dim() {
if offset + region_size > data.len() {
return Err(anyhow!(
"S3 data too short at offset {}: need {} more bytes, only {} remain",
offset,
region_size,
data.len().saturating_sub(offset)
));
}
let region = layout.memory_region(block_id, layer_id, outer_id)?;
if region.size() < region_size {
return Err(anyhow!(
"memory region too small: got {} bytes, need {}",
region.size(),
region_size
));
}
unsafe {
std::ptr::copy_nonoverlapping(
data[offset..].as_ptr(),
region.addr() as *mut u8,
region_size,
);
}
offset += region_size;
}
}
}
Ok(())
}
impl ObjectBlockOps for S3ObjectBlockClient {
fn has_blocks(
&self,
keys: Vec<SequenceHash>,
) -> BoxFuture<'static, Vec<(SequenceHash, Option<usize>)>> {
let max_concurrent = self.config.max_concurrent_requests;
let client = self.client.clone();
let bucket = self.config.bucket.clone();
let formatter = self.key_formatter.clone();
Box::pin(async move {
let tasks = keys.into_iter().map(|key| {
let client = client.clone();
let bucket = bucket.clone();
let key_str = formatter.format_key(&key);
async move {
match client
.head_object()
.bucket(&bucket)
.key(&key_str)
.send()
.await
{
Ok(resp) => (key, resp.content_length().map(|l| l as usize)),
Err(e) => {
tracing::warn!(key = %key, error = %e, "head_object failed, treating as missing");
(key, None)
}
}
}
});
futures::stream::iter(tasks)
.buffer_unordered(max_concurrent)
.collect()
.await
})
}
fn put_blocks(
&self,
keys: Vec<SequenceHash>,
_src_layout: LogicalLayoutHandle,
_block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
// S3ObjectBlockClient cannot resolve LogicalLayoutHandle to PhysicalLayout.
// Workers should use put_blocks_with_layout() instead after resolving the handle.
tracing::error!(
"S3ObjectBlockClient::put_blocks called with LogicalLayoutHandle - \
use put_blocks_with_layout() via DirectWorker instead"
);
Box::pin(async move { keys.into_iter().map(Err).collect() })
}
fn get_blocks(
&self,
keys: Vec<SequenceHash>,
_dst_layout: LogicalLayoutHandle,
_block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
// S3ObjectBlockClient cannot resolve LogicalLayoutHandle to PhysicalLayout.
// Workers should use get_blocks_with_layout() instead after resolving the handle.
tracing::error!(
"S3ObjectBlockClient::get_blocks called with LogicalLayoutHandle - \
use get_blocks_with_layout() via DirectWorker instead"
);
Box::pin(async move { keys.into_iter().map(Err).collect() })
}
fn put_blocks_with_layout(
&self,
keys: Vec<SequenceHash>,
layout: PhysicalLayout,
block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
// Delegate to the inherent method
S3ObjectBlockClient::put_blocks_with_layout(self, keys, layout, block_ids)
}
fn get_blocks_with_layout(
&self,
keys: Vec<SequenceHash>,
layout: PhysicalLayout,
block_ids: Vec<BlockId>,
) -> BoxFuture<'static, Vec<Result<SequenceHash, SequenceHash>>> {
// Delegate to the inherent method
S3ObjectBlockClient::get_blocks_with_layout(self, keys, layout, block_ids)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_s3_config_default() {
let config = S3Config::default();
assert_eq!(config.endpoint_url, Some("http://localhost:9000".into()));
assert_eq!(config.bucket, "kvbm-blocks");
assert_eq!(config.region, "us-east-1");
assert!(config.force_path_style);
assert_eq!(config.max_concurrent_requests, 16);
}
#[test]
fn test_s3_config_aws() {
let config = S3Config::aws("my-bucket".into(), "us-west-2".into());
assert_eq!(config.endpoint_url, None);
assert_eq!(config.bucket, "my-bucket");
assert_eq!(config.region, "us-west-2");
assert!(!config.force_path_style);
}
#[test]
fn test_s3_config_minio() {
let config = S3Config::minio("http://minio:9000".into(), "test-bucket".into());
assert_eq!(config.endpoint_url, Some("http://minio:9000".into()));
assert_eq!(config.bucket, "test-bucket");
assert!(config.force_path_style);
}
}
#[cfg(all(test, feature = "testing"))]
mod bounds_check_tests {
use super::*;
use crate::object::LayoutConfigExt;
use kvbm_physical::testing::{create_fc_layout, create_lw_layout, create_test_agent};
use kvbm_physical::transfer::StorageKind;
#[test]
fn test_copy_bytes_to_block_rejects_short_data_contiguous() {
let agent = create_test_agent("test_short_data_fc");
let layout = create_fc_layout(agent, StorageKind::System, 2);
let config = layout.layout().config();
let block_size = config.block_size_bytes();
let region_size = config.region_size();
// Data is one byte short
let short_data = vec![0u8; block_size - 1];
let err = copy_bytes_to_block(&short_data, &layout, 0, block_size, region_size, true)
.expect_err("should reject short data");
assert!(
err.to_string().contains("S3 data too short"),
"unexpected error: {}",
err
);
}
#[test]
fn test_copy_bytes_to_block_rejects_short_data_non_contiguous() {
let agent = create_test_agent("test_short_data_lw");
let layout = create_lw_layout(agent, StorageKind::System, 2);
let config = layout.layout().config();
let block_size = config.block_size_bytes();
let region_size = config.region_size();
// Data is one region short
let short_data = vec![0u8; block_size - region_size];
let err = copy_bytes_to_block(&short_data, &layout, 0, block_size, region_size, false)
.expect_err("should reject short data in non-contiguous path");
assert!(
err.to_string().contains("S3 data too short"),
"unexpected error: {}",
err
);
}
#[test]
fn test_copy_bytes_to_block_accepts_exact_size() {
let agent = create_test_agent("test_exact_fc");
let layout = create_fc_layout(agent, StorageKind::System, 2);
let config = layout.layout().config();
let block_size = config.block_size_bytes();
let region_size = config.region_size();
let data = vec![42u8; block_size];
copy_bytes_to_block(&data, &layout, 0, block_size, region_size, true)
.expect("exact-size data should succeed");
}
#[test]
fn test_copy_block_to_bytes_roundtrip_contiguous() {
let agent = create_test_agent("test_roundtrip_fc");
let layout = create_fc_layout(agent, StorageKind::System, 2);
let config = layout.layout().config();
let block_size = config.block_size_bytes();
let region_size = config.region_size();
// Write known data
let data = vec![0xAB_u8; block_size];
copy_bytes_to_block(&data, &layout, 0, block_size, region_size, true).unwrap();
// Read it back
let out = copy_block_to_bytes(&layout, 0, block_size, region_size, true).unwrap();
assert_eq!(out.as_ref(), &data[..]);
}
#[test]
fn test_copy_block_to_bytes_roundtrip_non_contiguous() {
let agent = create_test_agent("test_roundtrip_lw");
let layout = create_lw_layout(agent, StorageKind::System, 2);
let config = layout.layout().config();
let block_size = config.block_size_bytes();
let region_size = config.region_size();
let data = vec![0xCD_u8; block_size];
copy_bytes_to_block(&data, &layout, 0, block_size, region_size, false).unwrap();
let out = copy_block_to_bytes(&layout, 0, block_size, region_size, false).unwrap();
assert_eq!(out.as_ref(), &data[..]);
}
}
#[cfg(all(test, feature = "testing-s3"))]
pub mod s3_integration {
use super::*;
/// Create an S3ObjectBlockClient connected to the test MinIO instance.
///
/// Reads `S3_TEST_ENDPOINT` from the environment (set by `test-s3.sh`).
/// Falls back to `http://localhost:9876`.
pub async fn create_test_client(bucket: &str) -> S3ObjectBlockClient {
let endpoint =
std::env::var("S3_TEST_ENDPOINT").unwrap_or_else(|_| "http://localhost:9876".into());
let config = S3Config::minio(endpoint, bucket.to_string());
let client = S3ObjectBlockClient::new(config).await.unwrap();
client.ensure_bucket_exists().await.unwrap();
client
}
#[tokio::test]
async fn test_put_get_roundtrip() {
let client = create_test_client("test-roundtrip").await;
let key = format!("roundtrip-{}", uuid::Uuid::new_v4());
let payload = Bytes::from("hello world");
client.put_object(&key, payload.clone()).await.unwrap();
let result = client.get_object(&key).await.unwrap();
assert_eq!(result, Some(payload));
// Cleanup
client.delete_object(&key).await.unwrap();
}
#[tokio::test]
async fn test_put_object_if_match_rejects_stale_etag() {
let client = create_test_client("test-if-match").await;
let key = format!("if-match-{}", uuid::Uuid::new_v4());
// Write initial object
client
.put_object(&key, Bytes::from("version1"))
.await
.unwrap();
// Get with ETag
let (_, etag) = client
.get_object_with_etag(&key)
.await
.unwrap()
.expect("object should exist");
let etag = etag.expect("should have etag");
// Overwrite the object to change its ETag
client
.put_object(&key, Bytes::from("version2"))
.await
.unwrap();
// Conditional put with stale ETag should fail
let won = client
.put_object_if_match(&key, Bytes::from("version3"), &etag)
.await
.unwrap();
assert!(!won, "conditional put with stale ETag should return false");
// Verify the object still has version2
let data = client.get_object(&key).await.unwrap().unwrap();
assert_eq!(data, Bytes::from("version2"));
// Cleanup
client.delete_object(&key).await.unwrap();
}
#[tokio::test]
async fn test_put_object_if_match_accepts_current_etag() {
let client = create_test_client("test-if-match-ok").await;
let key = format!("if-match-ok-{}", uuid::Uuid::new_v4());
client
.put_object(&key, Bytes::from("version1"))
.await
.unwrap();
let (_, etag) = client
.get_object_with_etag(&key)
.await
.unwrap()
.expect("object should exist");
let etag = etag.expect("should have etag");
// Conditional put with current ETag should succeed
let won = client
.put_object_if_match(&key, Bytes::from("version2"), &etag)
.await
.unwrap();
assert!(won, "conditional put with current ETag should succeed");
let data = client.get_object(&key).await.unwrap().unwrap();
assert_eq!(data, Bytes::from("version2"));
// Cleanup
client.delete_object(&key).await.unwrap();
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment