Unverified Commit 3998fdcb authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: KVBM V2 Initial Migration (#3861)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
parent e64d2f09
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transfer context.
use std::sync::Arc;
use crate::block_manager::v2::kernels::OperationalCopyBackend;
use anyhow::Result;
use cudarc::driver::{CudaContext, CudaEvent, CudaStream};
use derive_builder::Builder;
use nixl_sys::XferRequest;
use tokio::sync::{mpsc, oneshot};
use uuid::Uuid;
use super::nixl_agent::{NixlAgent, NixlBackendConfig};
use crate::block_manager::v2::physical::manager::TransportManager;
// Notifications module is declared in ../mod.rs
// Re-export for convenience
use super::TransferCapabilities;
pub use super::notifications;
pub use super::notifications::TransferCompleteNotification;
#[derive(Debug, Clone, Builder)]
#[builder(pattern = "owned", build_fn(private, name = "build_internal"), public)]
#[allow(dead_code)] // Fields are used in build() but derive macros confuse dead code analysis
pub(crate) struct TransferConfig {
worker_id: u64,
/// Optional custom name for the NIXL agent. If not provided, defaults to "worker-{worker_id}"
#[builder(default = "None", setter(strip_option))]
nixl_agent_name: Option<String>,
/// Backend configuration for NIXL backends to enable
#[builder(default = "NixlBackendConfig::new()")]
nixl_backend_config: NixlBackendConfig,
#[builder(default = "0")]
cuda_device_id: usize,
#[builder(default = "get_tokio_runtime()")]
tokio_runtime: TokioRuntime,
#[builder(default = "TransferCapabilities::default()")]
capabilities: TransferCapabilities,
#[builder(default = "OperationalCopyBackend::Auto")]
operational_backend: OperationalCopyBackend,
}
impl TransferConfigBuilder {
/// Directly provide a pre-configured wrapped NIXL agent (mainly for testing).
///
/// This bypasses the agent creation and backend initialization logic,
/// using the provided agent directly. Useful for tests that need full
/// control over agent configuration.
pub fn nixl_agent(self, agent: NixlAgent) -> TransferConfigBuilderWithAgent {
TransferConfigBuilderWithAgent {
builder: self,
agent,
}
}
/// Add a NIXL backend to enable (uses default plugin parameters).
pub fn nixl_backend(mut self, backend: impl Into<String>) -> Self {
let config = self
.nixl_backend_config
.get_or_insert_with(NixlBackendConfig::new);
*config = config.clone().with_backend(backend);
self
}
/// Load NIXL backend configuration from environment variables.
///
/// This merges environment-based configuration with any backends already
/// configured via the builder.
pub fn with_env_backends(mut self) -> Result<Self> {
let env_config = NixlBackendConfig::from_env()?;
let config = self
.nixl_backend_config
.get_or_insert_with(NixlBackendConfig::new);
*config = config.clone().merge(env_config);
Ok(self)
}
pub fn build(self) -> Result<TransportManager> {
let mut config = self.build_internal()?;
// Merge environment backends if not explicitly configured
if config.nixl_backend_config.backends().is_empty() {
config.nixl_backend_config = NixlBackendConfig::from_env()?;
}
// Derive agent name from worker_id if not provided
let agent_name = config
.nixl_agent_name
.unwrap_or_else(|| format!("worker-{}", config.worker_id));
// Create wrapped NIXL agent with configured backends
let backend_names: Vec<&str> = config
.nixl_backend_config
.backends()
.iter()
.map(|s| s.as_str())
.collect();
let nixl_agent = if backend_names.is_empty() {
// No backends configured - create agent without backends
NixlAgent::new_with_backends(&agent_name, &[])?
} else {
// Create agent with requested backends
NixlAgent::new_with_backends(&agent_name, &backend_names)?
};
let cuda_context = CudaContext::new(config.cuda_device_id)?;
let context = TransferContext::new(
config.worker_id,
nixl_agent,
cuda_context,
config.tokio_runtime,
config.capabilities,
config.operational_backend,
)?;
Ok(TransportManager::from_context(context))
}
}
/// Builder that already has a pre-configured NIXL agent.
///
/// This is generally used for testing when you want to pass in an agent directly
/// rather than having it created by the builder.
pub struct TransferConfigBuilderWithAgent {
builder: TransferConfigBuilder,
agent: NixlAgent,
}
impl TransferConfigBuilderWithAgent {
/// Build the TransportManager using the pre-configured agent.
pub fn build(self) -> Result<TransportManager> {
let config = self.builder.build_internal()?;
let cuda_context = CudaContext::new(config.cuda_device_id)?;
let context = TransferContext::new(
config.worker_id,
self.agent,
cuda_context,
config.tokio_runtime,
config.capabilities,
config.operational_backend,
)?;
Ok(TransportManager::from_context(context))
}
// Proxy methods to allow configuring other builder fields
pub fn worker_id(mut self, worker_id: u64) -> Self {
self.builder = self.builder.worker_id(worker_id);
self
}
pub fn cuda_device_id(mut self, cuda_device_id: usize) -> Self {
self.builder = self.builder.cuda_device_id(cuda_device_id);
self
}
}
fn get_tokio_runtime() -> TokioRuntime {
match tokio::runtime::Handle::try_current() {
Ok(handle) => TokioRuntime::Handle(handle),
Err(_) => {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.max_blocking_threads(4)
.worker_threads(2)
.build()
.expect("failed to build tokio runtime");
TokioRuntime::Shared(Arc::new(rt))
}
}
}
#[derive(Debug, Clone)]
pub(crate) enum TokioRuntime {
Handle(tokio::runtime::Handle),
Shared(Arc<tokio::runtime::Runtime>),
}
impl TokioRuntime {
pub fn handle(&self) -> &tokio::runtime::Handle {
match self {
TokioRuntime::Handle(handle) => handle,
TokioRuntime::Shared(runtime) => runtime.handle(),
}
}
}
#[derive(Debug, Clone)]
pub struct TransferContext {
worker_id: u64,
nixl_agent: NixlAgent,
#[allow(dead_code)]
cuda_context: Arc<CudaContext>,
d2h_stream: Arc<CudaStream>,
h2d_stream: Arc<CudaStream>,
#[allow(dead_code)]
tokio_runtime: TokioRuntime,
capabilities: TransferCapabilities,
operational_backend: OperationalCopyBackend,
// Channels for background notification handlers
tx_nixl_status:
mpsc::Sender<notifications::RegisterPollingNotification<notifications::NixlStatusChecker>>,
tx_cuda_event:
mpsc::Sender<notifications::RegisterPollingNotification<notifications::CudaEventChecker>>,
#[allow(dead_code)]
tx_nixl_events: mpsc::Sender<notifications::RegisterNixlNotification>,
}
impl TransferContext {
pub fn builder() -> TransferConfigBuilder {
TransferConfigBuilder::default()
}
pub(crate) fn new(
worker_id: u64,
nixl_agent: NixlAgent,
cuda_context: Arc<CudaContext>,
tokio_runtime: TokioRuntime,
capabilities: TransferCapabilities,
operational_backend: OperationalCopyBackend,
) -> Result<Self> {
unsafe { cuda_context.disable_event_tracking() };
// Create channels for background notification handlers
let (tx_nixl_status, rx_nixl_status) = mpsc::channel(64);
let (tx_cuda_event, rx_cuda_event) = mpsc::channel(64);
let (tx_nixl_events, rx_nixl_events) = mpsc::channel(64);
// Spawn background handlers
let handle = tokio_runtime.handle();
// Spawn NIXL status polling handler
handle.spawn(notifications::process_polling_notifications(rx_nixl_status));
// Spawn CUDA event polling handler
handle.spawn(notifications::process_polling_notifications(rx_cuda_event));
// Spawn NIXL notification events handler
handle.spawn(notifications::process_nixl_notification_events(
nixl_agent.raw_agent().clone(),
rx_nixl_events,
));
Ok(Self {
worker_id,
nixl_agent,
cuda_context: cuda_context.clone(),
d2h_stream: cuda_context.new_stream()?,
h2d_stream: cuda_context.new_stream()?,
tokio_runtime,
capabilities,
operational_backend,
tx_nixl_status,
tx_cuda_event,
tx_nixl_events,
})
}
pub(crate) fn nixl_agent(&self) -> &NixlAgent {
&self.nixl_agent
}
#[allow(dead_code)]
pub(crate) fn cuda_context(&self) -> &Arc<CudaContext> {
&self.cuda_context
}
pub(crate) fn d2h_stream(&self) -> &Arc<CudaStream> {
&self.d2h_stream
}
pub(crate) fn h2d_stream(&self) -> &Arc<CudaStream> {
&self.h2d_stream
}
#[allow(dead_code)]
pub(crate) fn tokio(&self) -> &tokio::runtime::Handle {
self.tokio_runtime.handle()
}
pub(crate) fn capabilities(&self) -> &TransferCapabilities {
&self.capabilities
}
pub(crate) fn operational_backend(&self) -> OperationalCopyBackend {
self.operational_backend
}
/// Register a NIXL transfer request for status polling completion.
///
/// This method enqueues the transfer request to be polled for completion
/// using `agent.get_xfer_status()`. Returns a notification object that
/// can be awaited for completion.
pub(crate) fn register_nixl_status(
&self,
xfer_req: XferRequest,
) -> TransferCompleteNotification {
let (done_tx, done_rx) = oneshot::channel();
let notification = notifications::RegisterPollingNotification {
uuid: Uuid::new_v4(),
checker: notifications::NixlStatusChecker::new(
self.nixl_agent.raw_agent().clone(),
xfer_req,
),
done: done_tx,
};
// Send to background handler (ignore error if receiver dropped)
let _ = self.tx_nixl_status.try_send(notification);
TransferCompleteNotification { status: done_rx }
}
/// Register a CUDA event for polling completion.
///
/// This method enqueues the CUDA event to be polled for completion.
/// Returns a notification object that can be awaited for completion.
pub(crate) fn register_cuda_event(&self, event: CudaEvent) -> TransferCompleteNotification {
let (done_tx, done_rx) = oneshot::channel();
let notification = notifications::RegisterPollingNotification {
uuid: Uuid::new_v4(),
checker: notifications::CudaEventChecker::new(event),
done: done_tx,
};
// Send to background handler (ignore error if receiver dropped)
let _ = self.tx_cuda_event.try_send(notification);
TransferCompleteNotification { status: done_rx }
}
/// Register a NIXL transfer request for notification-based completion.
///
/// This method enqueues the transfer request to be completed via NIXL
/// notification events. Returns a notification object that can be awaited
/// for completion.
#[allow(dead_code)]
pub(crate) fn register_nixl_event(
&self,
xfer_req: XferRequest,
) -> TransferCompleteNotification {
let (done_tx, done_rx) = oneshot::channel();
let notification = notifications::RegisterNixlNotification {
uuid: Uuid::new_v4(),
xfer_req,
done: done_tx,
};
// Send to background handler (ignore error if receiver dropped)
let _ = self.tx_nixl_events.try_send(notification);
TransferCompleteNotification { status: done_rx }
}
/// Get the worker ID for this context.
pub(crate) fn worker_id(&self) -> u64 {
self.worker_id
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! CUDA executor for GPU memory transfers.
use super::TransferContext;
use super::{PhysicalLayout, TransferStrategy};
use crate::block_manager::v2::kernels::OperationalCopyBackend;
use crate::block_manager::v2::physical::transfer::context::TransferCompleteNotification;
use anyhow::{Result, anyhow};
use cudarc::driver::result as cuda_result;
use std::ops::Range;
// #[cfg(test)]
// mod cuda_kernel_tests;
/// Execute a CUDA transfer between host and device memory.
///
/// This executor handles transfers involving GPU memory using CUDA APIs.
/// Supports async and blocking transfers depending on the strategy.
///
/// # Arguments
/// * `src` - Source physical layout
/// * `dst` - Destination physical layout
/// * `src_block_ids` - Source block IDs to transfer
/// * `dst_block_ids` - Destination block IDs to transfer
/// * `layer_range` - Optional range of layers to transfer (None = all layers)
/// * `strategy` - CUDA transfer strategy (H2D, D2H, D2D, async or blocking)
/// * `ctx` - Transfer context with CUDA stream
pub fn execute_cuda_transfer(
src: &PhysicalLayout,
dst: &PhysicalLayout,
src_block_ids: &[usize],
dst_block_ids: &[usize],
layer_range: Option<Range<usize>>,
strategy: TransferStrategy,
ctx: &TransferContext,
) -> Result<TransferCompleteNotification> {
// Validate layouts
let src_layout = src.layout();
let dst_layout = dst.layout();
if src_layout.num_layers() != dst_layout.num_layers() {
return Err(anyhow!(
"Layouts have incompatible layer counts: src={}, dst={}",
src_layout.num_layers(),
dst_layout.num_layers()
));
}
if src_layout.outer_dim() != dst_layout.outer_dim() {
return Err(anyhow!(
"Layouts have incompatible outer dimensions: src={}, dst={}",
src_layout.outer_dim(),
dst_layout.outer_dim()
));
}
// Determine layer range
let layers = layer_range.unwrap_or(0..src_layout.num_layers());
// Get appropriate CUDA stream based on transfer direction
let stream = match strategy {
TransferStrategy::CudaAsyncD2H | TransferStrategy::CudaBlockingD2H => ctx.d2h_stream(),
_ => ctx.h2d_stream(), // H2D and D2D use h2d_stream
};
// Perform CUDA transfers based on strategy
match strategy {
TransferStrategy::CudaAsyncH2D => {
let backend = ctx.operational_backend();
if let Err(e) = try_execute_operational_kernel(
src,
dst,
src_block_ids,
dst_block_ids,
layers.clone(),
stream.as_ref(),
backend,
) {
// Fallback to memcpy-based path
tracing::debug!("Kernel-based H2D failed ({}), falling back to memcpy", e);
execute_h2d(
src,
dst,
src_block_ids,
dst_block_ids,
layers,
stream.as_ref(),
)?;
}
}
TransferStrategy::CudaAsyncD2H => {
let backend = ctx.operational_backend();
if let Err(e) = try_execute_operational_kernel(
src,
dst,
src_block_ids,
dst_block_ids,
layers.clone(),
stream.as_ref(),
backend,
) {
// Fallback to memcpy-based path
tracing::debug!("Kernel-based D2H failed ({}), falling back to memcpy", e);
execute_d2h(
src,
dst,
src_block_ids,
dst_block_ids,
layers,
stream.as_ref(),
)?;
}
}
TransferStrategy::CudaAsyncD2D => {
// Try kernel-based path first, fall back to memcpy on error
let backend = ctx.operational_backend();
if let Err(e) = try_execute_operational_kernel(
src,
dst,
src_block_ids,
dst_block_ids,
layers.clone(),
stream.as_ref(),
backend,
) {
// Fallback to memcpy-based path
tracing::debug!("Kernel-based D2D failed ({}), falling back to memcpy", e);
execute_d2d(
src,
dst,
src_block_ids,
dst_block_ids,
layers,
stream.as_ref(),
)?;
}
}
TransferStrategy::CudaBlockingH2D => {
execute_h2d(
src,
dst,
src_block_ids,
dst_block_ids,
layers,
stream.as_ref(),
)?;
// Synchronize immediately for blocking transfer
stream.synchronize()?;
}
TransferStrategy::CudaBlockingD2H => {
execute_d2h(
src,
dst,
src_block_ids,
dst_block_ids,
layers,
stream.as_ref(),
)?;
// Synchronize immediately for blocking transfer
stream.synchronize()?;
}
_ => {
return Err(anyhow!("Invalid CUDA transfer strategy: {:?}", strategy));
}
}
// For async transfers, record an event and register it for completion tracking
if matches!(
strategy,
TransferStrategy::CudaAsyncH2D
| TransferStrategy::CudaAsyncD2H
| TransferStrategy::CudaAsyncD2D
) {
let event = stream.record_event(None)?;
Ok(ctx.register_cuda_event(event))
} else {
// Blocking transfers are already synchronized
Ok(TransferCompleteNotification::completed())
}
}
/// Execute host-to-device transfer.
fn execute_h2d(
src: &PhysicalLayout,
dst: &PhysicalLayout,
src_block_ids: &[usize],
dst_block_ids: &[usize],
layers: Range<usize>,
stream: &cudarc::driver::CudaStream,
) -> Result<()> {
for (&src_block_id, &dst_block_id) in src_block_ids.iter().zip(dst_block_ids.iter()) {
for layer_id in layers.clone() {
for outer_id in 0..src.layout().outer_dim() {
let src_region = src.memory_region(src_block_id, layer_id, outer_id)?;
let dst_region = dst.memory_region(dst_block_id, layer_id, outer_id)?;
if src_region.size() != dst_region.size() {
return Err(anyhow!(
"Size mismatch at block=({},{}), layer={}, outer={}: src={}, dst={}",
src_block_id,
dst_block_id,
layer_id,
outer_id,
src_region.size(),
dst_region.size()
));
}
unsafe {
let src_ptr = src_region.addr() as *const u8;
let dst_ptr = dst_region.addr() as u64;
let src_slice = std::slice::from_raw_parts(src_ptr, src_region.size());
cuda_result::memcpy_htod_async(dst_ptr, src_slice, stream.cu_stream())?;
}
}
}
}
Ok(())
}
/// Execute device-to-host transfer.
fn execute_d2h(
src: &PhysicalLayout,
dst: &PhysicalLayout,
src_block_ids: &[usize],
dst_block_ids: &[usize],
layers: Range<usize>,
stream: &cudarc::driver::CudaStream,
) -> Result<()> {
for (&src_block_id, &dst_block_id) in src_block_ids.iter().zip(dst_block_ids.iter()) {
for layer_id in layers.clone() {
for outer_id in 0..src.layout().outer_dim() {
let src_region = src.memory_region(src_block_id, layer_id, outer_id)?;
let dst_region = dst.memory_region(dst_block_id, layer_id, outer_id)?;
if src_region.size() != dst_region.size() {
return Err(anyhow!(
"Size mismatch at block=({},{}), layer={}, outer={}: src={}, dst={}",
src_block_id,
dst_block_id,
layer_id,
outer_id,
src_region.size(),
dst_region.size()
));
}
unsafe {
let src_ptr = src_region.addr() as u64;
let dst_ptr = dst_region.addr() as *mut u8;
let dst_slice = std::slice::from_raw_parts_mut(dst_ptr, dst_region.size());
cuda_result::memcpy_dtoh_async(dst_slice, src_ptr, stream.cu_stream())?;
}
}
}
}
Ok(())
}
/// Execute device-to-device transfer.
fn execute_d2d(
src: &PhysicalLayout,
dst: &PhysicalLayout,
src_block_ids: &[usize],
dst_block_ids: &[usize],
layers: Range<usize>,
stream: &cudarc::driver::CudaStream,
) -> Result<()> {
for (&src_block_id, &dst_block_id) in src_block_ids.iter().zip(dst_block_ids.iter()) {
for layer_id in layers.clone() {
for outer_id in 0..src.layout().outer_dim() {
let src_region = src.memory_region(src_block_id, layer_id, outer_id)?;
let dst_region = dst.memory_region(dst_block_id, layer_id, outer_id)?;
if src_region.size() != dst_region.size() {
return Err(anyhow!(
"Size mismatch at block=({},{}), layer={}, outer={}: src={}, dst={}",
src_block_id,
dst_block_id,
layer_id,
outer_id,
src_region.size(),
dst_region.size()
));
}
unsafe {
let src_ptr = src_region.addr() as u64;
let dst_ptr = dst_region.addr() as u64;
cuda_result::memcpy_dtod_async(
dst_ptr,
src_ptr,
src_region.size(),
stream.cu_stream(),
)?;
}
}
}
}
Ok(())
}
/// TODO: For now, we've stubbed this out just so we can merge.
/// For now, we'll always just fall back to memcpy.
#[cfg_attr(test, allow(dead_code))]
pub(crate) fn try_execute_operational_kernel(
_src: &PhysicalLayout,
_dst: &PhysicalLayout,
_src_block_ids: &[usize],
_dst_block_ids: &[usize],
_layers: Range<usize>,
_stream: &cudarc::driver::CudaStream,
_backend: OperationalCopyBackend,
) -> Result<()> {
anyhow::bail!("Not implemented.");
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Memcpy executor for host-to-host transfers.
use crate::block_manager::v2::physical::transfer::PhysicalLayout;
use crate::block_manager::v2::physical::transfer::context::TransferCompleteNotification;
use anyhow::Result;
use std::ops::Range;
/// Execute a memcpy transfer between host memory locations.
///
/// This executor handles transfers between System and Pinned memory using
/// standard CPU memcpy operations. The transfer is synchronous and blocking.
///
/// # Arguments
/// * `src` - Source physical layout
/// * `dst` - Destination physical layout
/// * `block_pairs` - Pairs of (src_block_id, dst_block_id) to transfer
/// * `layer_range` - Optional range of layers to transfer (None = all layers)
pub fn execute_memcpy_transfer(
src: &PhysicalLayout,
dst: &PhysicalLayout,
src_block_ids: &[usize],
dst_block_ids: &[usize],
layer_range: Option<Range<usize>>,
) -> Result<TransferCompleteNotification> {
// Validate layouts have compatible structure
let src_layout = src.layout();
let dst_layout = dst.layout();
if src_layout.num_layers() != dst_layout.num_layers() {
return Err(anyhow::anyhow!(
"Layouts have incompatible layer counts: src={}, dst={}",
src_layout.num_layers(),
dst_layout.num_layers()
));
}
if src_layout.outer_dim() != dst_layout.outer_dim() {
return Err(anyhow::anyhow!(
"Layouts have incompatible outer dimensions: src={}, dst={}",
src_layout.outer_dim(),
dst_layout.outer_dim()
));
}
// Determine layer range
let layers = layer_range.unwrap_or(0..src_layout.num_layers());
// Perform synchronous copies
for (&src_block_id, &dst_block_id) in src_block_ids.iter().zip(dst_block_ids.iter()) {
for layer_id in layers.clone() {
for outer_id in 0..src_layout.outer_dim() {
// Get source and destination memory regions
let src_region = src.memory_region(src_block_id, layer_id, outer_id)?;
let dst_region = dst.memory_region(dst_block_id, layer_id, outer_id)?;
// Validate sizes match
if src_region.size() != dst_region.size() {
return Err(anyhow::anyhow!(
"Memory region size mismatch at block=({},{}), layer={}, outer={}: src={}, dst={}",
src_block_id,
dst_block_id,
layer_id,
outer_id,
src_region.size(),
dst_region.size()
));
}
// Perform memcpy
unsafe {
let src_ptr = src_region.addr() as *const u8;
let dst_ptr = dst_region.addr() as *mut u8;
std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, src_region.size());
}
}
}
}
// Memcpy is synchronous, so return already-completed notification
Ok(TransferCompleteNotification::completed())
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transfer executors for different copy strategies.
pub(super) mod cuda;
mod memcpy;
mod nixl;
use super::strategy::select_strategy;
use super::validation::validate_block_transfer;
use super::{PhysicalLayout, TransferContext, TransferOptions, TransferPlan, TransferStrategy};
use crate::block_manager::v2::physical::transfer::{
StorageKind, context::TransferCompleteNotification,
};
use anyhow::Result;
use std::ops::Range;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
// Re-export the NIXL transfer builder for public use
pub use nixl::NixlTransferBuilder;
/// Execute a transfer between two physical layouts.
///
/// This is an internal entry point for all transfer operations called by TransportManager.
/// It selects the appropriate strategy and dispatches to the corresponding executor.
///
/// # Arguments
/// * `src` - Source physical layout
/// * `dst` - Destination physical layout
/// * `src_block_ids` - Source block IDs to transfer
/// * `dst_block_ids` - Destination block IDs to transfer
/// * `layer_range` - Optional range of layers to transfer (None = all layers)
/// * `ctx` - Transfer context with CUDA stream and NIXL agent
pub fn execute_transfer(
src: &PhysicalLayout,
dst: &PhysicalLayout,
src_block_ids: &[usize],
dst_block_ids: &[usize],
options: TransferOptions,
ctx: &TransferContext,
) -> Result<TransferCompleteNotification> {
// Validate block IDs
validate_block_transfer(src_block_ids, dst_block_ids, None, src, dst, None)?;
// Select transfer plan based on locations and capabilities
let plan = select_strategy(src, dst, ctx)?;
// Dispatch based on plan type
match plan {
TransferPlan::Direct(strategy) => execute_direct_transfer(
src,
dst,
src_block_ids,
dst_block_ids,
options.layer_range,
strategy,
ctx,
),
TransferPlan::TwoHop {
first,
bounce_location,
second,
} => execute_two_hop_transfer(TwoHopTransferParams {
src,
dst,
src_block_ids,
dst_block_ids,
first_strategy: first,
bounce_location,
second_strategy: second,
options,
ctx,
}),
}
}
/// Execute a direct single-hop transfer.
fn execute_direct_transfer(
src: &PhysicalLayout,
dst: &PhysicalLayout,
src_block_ids: &[usize],
dst_block_ids: &[usize],
layer_range: Option<Range<usize>>,
strategy: TransferStrategy,
ctx: &TransferContext,
) -> Result<TransferCompleteNotification> {
match strategy {
TransferStrategy::Memcpy => {
memcpy::execute_memcpy_transfer(src, dst, src_block_ids, dst_block_ids, layer_range)
}
TransferStrategy::CudaAsyncH2D
| TransferStrategy::CudaAsyncD2H
| TransferStrategy::CudaAsyncD2D
| TransferStrategy::CudaBlockingH2D
| TransferStrategy::CudaBlockingD2H => Ok(cuda::execute_cuda_transfer(
src,
dst,
src_block_ids,
dst_block_ids,
layer_range,
strategy,
ctx,
)?),
TransferStrategy::NixlRead
| TransferStrategy::NixlWrite
| TransferStrategy::NixlReadFlipped
| TransferStrategy::NixlWriteFlipped => {
let mut builder = NixlTransferBuilder::new()
.src(src)
.dst(dst)
.src_blocks(src_block_ids)
.dst_blocks(dst_block_ids)
.strategy(strategy);
if let Some(range) = layer_range {
builder = builder.layer_range(range);
}
builder.execute(ctx)
}
TransferStrategy::Invalid => Err(anyhow::anyhow!(
"Invalid transfer strategy for src={:?}, dst={:?}",
src.location(),
dst.location()
)),
}
}
#[allow(clippy::too_many_arguments)]
async fn execute_two_hop_transfer_chunk(
src: &PhysicalLayout,
bounce_layout: &PhysicalLayout,
dst: &PhysicalLayout,
src_block_ids: &[usize],
bounce_block_ids: &[usize],
dst_block_ids: &[usize],
first_strategy: TransferStrategy,
second_strategy: TransferStrategy,
layer_range: &Option<Range<usize>>,
ctx: &TransferContext,
) -> Result<()> {
let bounce_ids_to_use = &bounce_block_ids[..src_block_ids.len()];
execute_direct_transfer(
src,
bounce_layout,
src_block_ids,
bounce_ids_to_use,
layer_range.clone(),
first_strategy,
ctx,
)?
.await?;
execute_direct_transfer(
bounce_layout,
dst,
bounce_ids_to_use,
dst_block_ids,
layer_range.clone(),
second_strategy,
ctx,
)?
.await?;
Ok(())
}
/// Parameters for two-hop transfer execution
struct TwoHopTransferParams<'a> {
src: &'a PhysicalLayout,
dst: &'a PhysicalLayout,
src_block_ids: &'a [usize],
dst_block_ids: &'a [usize],
first_strategy: TransferStrategy,
bounce_location: StorageKind,
second_strategy: TransferStrategy,
options: TransferOptions,
ctx: &'a TransferContext,
}
fn execute_two_hop_transfer(params: TwoHopTransferParams) -> Result<TransferCompleteNotification> {
let TwoHopTransferParams {
src,
dst,
src_block_ids,
dst_block_ids,
first_strategy,
bounce_location,
second_strategy,
options,
ctx,
} = params;
let (tx, rx) = tokio::sync::oneshot::channel();
// TODO: Cloning all this stuff is not ideal.
let src_clone = src.clone();
let dst_clone = dst.clone();
let src_block_ids = src_block_ids.to_vec();
let dst_block_ids = dst_block_ids.to_vec();
let options_clone = options.clone();
let handle = ctx.tokio();
let ctx_clone = ctx.clone();
handle.spawn(async move {
let Some(ref bounce_buffer_spec) = options_clone.bounce_buffer else {
tx.send(Err(anyhow::anyhow!(
"Two-hop transfers require a bounce buffer."
)))
.unwrap();
return;
};
if bounce_buffer_spec.layout().location() != bounce_location {
tx.send(Err(anyhow::anyhow!(
"Bounce buffer layout does not match bounce location."
)))
.unwrap();
return;
}
let num_bounce_blocks = bounce_buffer_spec.block_ids().len();
if num_bounce_blocks < src_block_ids.len() {
for (src_block_ids, dst_block_ids) in src_block_ids
.chunks(num_bounce_blocks)
.zip(dst_block_ids.chunks(num_bounce_blocks))
{
let bounce_block_ids_to_use =
&bounce_buffer_spec.block_ids()[..src_block_ids.len()];
if let Err(e) = execute_two_hop_transfer_chunk(
&src_clone,
bounce_buffer_spec.layout(),
&dst_clone,
src_block_ids,
bounce_block_ids_to_use,
dst_block_ids,
first_strategy,
second_strategy,
&options_clone.layer_range,
&ctx_clone,
)
.await
{
tx.send(Err(e)).unwrap();
return;
}
}
tx.send(Ok(())).unwrap();
} else {
let bounce_block_ids_to_use = &bounce_buffer_spec.block_ids()[..src_block_ids.len()];
let result = execute_two_hop_transfer_chunk(
&src_clone,
bounce_buffer_spec.layout(),
&dst_clone,
src_block_ids.as_slice(),
bounce_block_ids_to_use,
dst_block_ids.as_slice(),
first_strategy,
second_strategy,
&options_clone.layer_range,
&ctx_clone,
)
.await;
tx.send(result).unwrap();
}
});
Ok(TransferCompleteNotification { status: rx })
}
pub struct TransferNotification {
status: Arc<AtomicBool>,
}
impl Default for TransferNotification {
fn default() -> Self {
Self::new()
}
}
impl TransferNotification {
pub fn new() -> Self {
Self {
status: Arc::new(AtomicBool::new(false)),
}
}
pub fn done() -> Self {
Self {
status: Arc::new(AtomicBool::new(true)),
}
}
pub fn is_complete(&self) -> bool {
self.status.load(Ordering::Relaxed)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Typestate builder for NIXL transfers.
//!
//! This module provides a compile-time safe builder for NIXL transfers that ensures
//! all required parameters are set before execution.
use super::{PhysicalLayout, TransferContext, TransferStrategy};
use crate::block_manager::v2::physical::transfer::context::TransferCompleteNotification;
use anyhow::{Result, anyhow};
use nixl_sys::{XferDescList, XferOp};
use std::marker::PhantomData;
use std::ops::Range;
/// Marker type for unset builder fields.
pub struct Unset;
/// Marker type for set builder fields.
pub struct Set;
/// Typestate builder for NIXL transfers.
///
/// This builder uses the typestate pattern to ensure all required parameters are set
/// at compile time. The type parameters track which fields have been set:
/// - `TSrc`: Source layout state
/// - `TDst`: Destination layout state
/// - `TSrcBlocks`: Source block IDs state
/// - `TDstBlocks`: Destination block IDs state
/// - `TStrategy`: Transfer strategy state
pub struct NixlTransferBuilder<'a, TSrc, TDst, TSrcBlocks, TDstBlocks, TStrategy> {
src: Option<&'a PhysicalLayout>,
dst: Option<&'a PhysicalLayout>,
src_block_ids: Option<&'a [usize]>,
dst_block_ids: Option<&'a [usize]>,
strategy: Option<TransferStrategy>,
layer_range: Option<Range<usize>>,
write_notif: Option<uuid::Uuid>,
_phantom: PhantomData<(TSrc, TDst, TSrcBlocks, TDstBlocks, TStrategy)>,
}
impl<'a> NixlTransferBuilder<'a, Unset, Unset, Unset, Unset, Unset> {
/// Creates a new NIXL transfer builder with all fields unset.
pub fn new() -> Self {
Self {
src: None,
dst: None,
src_block_ids: None,
dst_block_ids: None,
strategy: None,
layer_range: None,
write_notif: None,
_phantom: PhantomData,
}
}
}
impl<'a> Default for NixlTransferBuilder<'a, Unset, Unset, Unset, Unset, Unset> {
fn default() -> Self {
Self::new()
}
}
// Required field setters - these consume self and return a new builder with the field marked as Set
impl<'a, TDst, TSrcBlocks, TDstBlocks, TStrategy>
NixlTransferBuilder<'a, Unset, TDst, TSrcBlocks, TDstBlocks, TStrategy>
{
/// Sets the source physical layout.
pub fn src(
self,
src: &'a PhysicalLayout,
) -> NixlTransferBuilder<'a, Set, TDst, TSrcBlocks, TDstBlocks, TStrategy> {
NixlTransferBuilder {
src: Some(src),
dst: self.dst,
src_block_ids: self.src_block_ids,
dst_block_ids: self.dst_block_ids,
strategy: self.strategy,
layer_range: self.layer_range,
write_notif: self.write_notif,
_phantom: PhantomData,
}
}
}
impl<'a, TSrc, TSrcBlocks, TDstBlocks, TStrategy>
NixlTransferBuilder<'a, TSrc, Unset, TSrcBlocks, TDstBlocks, TStrategy>
{
/// Sets the destination physical layout.
pub fn dst(
self,
dst: &'a PhysicalLayout,
) -> NixlTransferBuilder<'a, TSrc, Set, TSrcBlocks, TDstBlocks, TStrategy> {
NixlTransferBuilder {
src: self.src,
dst: Some(dst),
src_block_ids: self.src_block_ids,
dst_block_ids: self.dst_block_ids,
strategy: self.strategy,
layer_range: self.layer_range,
write_notif: self.write_notif,
_phantom: PhantomData,
}
}
}
impl<'a, TSrc, TDst, TDstBlocks, TStrategy>
NixlTransferBuilder<'a, TSrc, TDst, Unset, TDstBlocks, TStrategy>
{
/// Sets the source block IDs to transfer.
pub fn src_blocks(
self,
src_block_ids: &'a [usize],
) -> NixlTransferBuilder<'a, TSrc, TDst, Set, TDstBlocks, TStrategy> {
NixlTransferBuilder {
src: self.src,
dst: self.dst,
src_block_ids: Some(src_block_ids),
dst_block_ids: self.dst_block_ids,
strategy: self.strategy,
layer_range: self.layer_range,
write_notif: self.write_notif,
_phantom: PhantomData,
}
}
}
impl<'a, TSrc, TDst, TSrcBlocks, TStrategy>
NixlTransferBuilder<'a, TSrc, TDst, TSrcBlocks, Unset, TStrategy>
{
/// Sets the destination block IDs to transfer.
pub fn dst_blocks(
self,
dst_block_ids: &'a [usize],
) -> NixlTransferBuilder<'a, TSrc, TDst, TSrcBlocks, Set, TStrategy> {
NixlTransferBuilder {
src: self.src,
dst: self.dst,
src_block_ids: self.src_block_ids,
dst_block_ids: Some(dst_block_ids),
strategy: self.strategy,
layer_range: self.layer_range,
write_notif: self.write_notif,
_phantom: PhantomData,
}
}
}
impl<'a, TSrc, TDst, TSrcBlocks, TDstBlocks>
NixlTransferBuilder<'a, TSrc, TDst, TSrcBlocks, TDstBlocks, Unset>
{
/// Sets the NIXL transfer strategy (Read or Write).
pub fn strategy(
self,
strategy: TransferStrategy,
) -> NixlTransferBuilder<'a, TSrc, TDst, TSrcBlocks, TDstBlocks, Set> {
NixlTransferBuilder {
src: self.src,
dst: self.dst,
src_block_ids: self.src_block_ids,
dst_block_ids: self.dst_block_ids,
strategy: Some(strategy),
layer_range: self.layer_range,
write_notif: self.write_notif,
_phantom: PhantomData,
}
}
}
// Optional field setters - these can be called at any point in the builder chain
impl<'a, TSrc, TDst, TSrcBlocks, TDstBlocks, TStrategy>
NixlTransferBuilder<'a, TSrc, TDst, TSrcBlocks, TDstBlocks, TStrategy>
{
/// Sets an optional range of layers to transfer.
/// If not called, all layers will be transferred.
pub fn layer_range(mut self, layer_range: Range<usize>) -> Self {
self.layer_range = Some(layer_range);
self
}
/// Sets an optional write notification UUID.
pub fn write_notif(mut self, write_notif: uuid::Uuid) -> Self {
self.write_notif = Some(write_notif);
self
}
}
// Execute method - only available when all required fields are Set
impl<'a> NixlTransferBuilder<'a, Set, Set, Set, Set, Set> {
/// Executes the NIXL transfer with the configured parameters.
///
/// This method is only available when all required fields have been set,
/// enforced at compile time by the typestate pattern.
pub(crate) fn execute(self, ctx: &TransferContext) -> Result<TransferCompleteNotification> {
// Unwrap all required fields (safe because typestate guarantees they're set)
let src = self.src.unwrap();
let dst = self.dst.unwrap();
let src_block_ids = self.src_block_ids.unwrap();
let dst_block_ids = self.dst_block_ids.unwrap();
let strategy = self.strategy.unwrap();
let layer_range = self.layer_range;
let _write_notif = self.write_notif;
// Validate layouts
let src_layout = src.layout();
let dst_layout = dst.layout();
if src_layout.num_layers() != dst_layout.num_layers() {
return Err(anyhow!(
"Layouts have incompatible layer counts: src={}, dst={}",
src_layout.num_layers(),
dst_layout.num_layers()
));
}
if src_layout.outer_dim() != dst_layout.outer_dim() {
return Err(anyhow!(
"Layouts have incompatible outer dimensions: src={}, dst={}",
src_layout.outer_dim(),
dst_layout.outer_dim()
));
}
// Get NIXL agent
let nixl_agent = ctx.nixl_agent();
// Determine layer range
let layers = layer_range.unwrap_or(0..src_layout.num_layers());
// Determine NIXL operation type
let xfer_op = match strategy {
TransferStrategy::NixlRead | TransferStrategy::NixlReadFlipped => XferOp::Read,
TransferStrategy::NixlWrite | TransferStrategy::NixlWriteFlipped => XferOp::Write,
_ => {
return Err(anyhow!("Invalid NIXL transfer strategy: {:?}", strategy));
}
};
assert!(
nixl_agent.name() == src.nixl_metadata().agent_name(),
"the source must be local"
);
// Capture NIXL metadata for both layouts
let src_metadata = src.nixl_metadata();
let dst_metadata = dst.nixl_metadata();
let src_mem_type = src_metadata.mem_type();
let dst_mem_type = dst_metadata.mem_type();
let src_device_id = src_metadata.device_id();
let dst_device_id = dst_metadata.device_id();
// Build XferDescLists for source and destination
let mut src_dl = XferDescList::new(src_mem_type)?;
let mut dst_dl = XferDescList::new(dst_mem_type)?;
// Add memory regions to descriptor lists
for (&src_block_id, &dst_block_id) in src_block_ids.iter().zip(dst_block_ids.iter()) {
for layer_id in layers.clone() {
for outer_id in 0..src_layout.outer_dim() {
let src_region = src.memory_region(src_block_id, layer_id, outer_id)?;
let dst_region = dst.memory_region(dst_block_id, layer_id, outer_id)?;
if src_region.size() != dst_region.size() {
return Err(anyhow!(
"Size mismatch at block=({},{}), layer={}, outer={}: src={}, dst={}",
src_block_id,
dst_block_id,
layer_id,
outer_id,
src_region.size(),
dst_region.size()
));
}
// Add to source descriptor list
src_dl.add_desc(src_region.addr(), src_region.size(), src_device_id)?;
// Add to destination descriptor list
dst_dl.add_desc(dst_region.addr(), dst_region.size(), dst_device_id)?;
}
}
}
// Note: Overlap detection was removed from nixl-sys 0.6.1
// The NIXL library now handles overlap detection internally
if matches!(
strategy,
TransferStrategy::NixlReadFlipped | TransferStrategy::NixlWriteFlipped
) {
std::mem::swap(&mut src_dl, &mut dst_dl);
}
// Create transfer request
let xfer_req = nixl_agent.create_xfer_req(
xfer_op,
&src_dl,
&dst_dl,
dst_metadata.agent_name(),
None, // opt_args
)?;
// Post transfer request
// Note: Notification handling via OptArgs can be added later if needed
let still_pending = nixl_agent.post_xfer_req(&xfer_req, None)?;
if still_pending {
// Register for async completion via status polling
Ok(ctx.register_nixl_status(xfer_req))
} else {
// Transfer completed synchronously
Ok(TransferCompleteNotification::completed())
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Block filling operations for testing.
//!
//! This module provides utilities to populate blocks with specific patterns
//! for verification in round-trip tests.
use super::PhysicalLayout;
use crate::block_manager::v2::memory::StorageKind;
use aligned_vec::{AVec, avec};
use anyhow::{Result, anyhow};
use cudarc::runtime::sys::{cudaMemcpy, cudaMemcpyKind};
use std::{
fs::File,
io::{Seek, Write},
mem::ManuallyDrop,
ops::Range,
os::fd::FromRawFd,
};
/// Fill strategy for block memory.
#[derive(Debug, Clone, Copy)]
pub enum FillPattern {
/// Fill with a constant byte value
Constant(u8),
/// Fill with a sequential pattern: block_id + layer_id + offset % 256
Sequential,
}
/// Fill blocks in a physical layout with a specific pattern.
///
/// This operation directly writes to memory and should only be used on
/// local layouts. Remote layouts cannot be filled directly.
///
/// # Arguments
/// * `layout` - The physical layout containing the blocks
/// * `block_ids` - List of block IDs to fill
/// * `pattern` - Fill pattern to use
///
/// # Errors
/// Returns an error if:
/// - Layout is remote (cannot fill remote memory directly)
/// - Block IDs are out of range
/// - Memory access fails
pub fn fill_blocks(
layout: &PhysicalLayout,
block_ids: &[usize],
pattern: FillPattern,
) -> Result<()> {
// Can only fill local layouts
let config = layout.layout().config();
let num_layers = config.num_layers;
let outer_dim = config.outer_dim;
for &block_id in block_ids {
if block_id >= config.num_blocks {
return Err(anyhow!("Block ID {} out of range", block_id));
}
// Fill all layers and outer dimensions for this block
for layer_id in 0..num_layers {
for outer_id in 0..outer_dim {
let region = layout.memory_region(block_id, layer_id, outer_id)?;
match layout.location() {
StorageKind::System | StorageKind::Pinned => {
fill_memory_region(
region.addr(),
region.size(),
block_id,
layer_id,
pattern,
)?;
}
StorageKind::Device(_) => {
let system_region: Vec<u8> = vec![0; region.size()];
fill_memory_region(
system_region.as_ptr() as usize,
system_region.len(),
block_id,
layer_id,
pattern,
)?;
unsafe {
cudaMemcpy(
region.addr() as *mut std::ffi::c_void,
system_region.as_ptr() as *const std::ffi::c_void,
region.size(),
cudaMemcpyKind::cudaMemcpyHostToDevice,
);
}
}
StorageKind::Disk(fd) => {
let system_region: AVec<u8, _> = avec![[4096]| 0; region.size()];
fill_memory_region(
system_region.as_ptr() as usize,
system_region.len(),
block_id,
layer_id,
pattern,
)?;
let mut file = ManuallyDrop::new(unsafe { File::from_raw_fd(fd as i32) });
file.seek(std::io::SeekFrom::Start(region.addr() as u64))?;
file.write_all(&system_region)?;
file.sync_all()?;
file.flush()?;
}
}
}
}
}
Ok(())
}
/// Fill a subset of layers in blocks with a specific pattern.
///
/// # Arguments
/// * `layout` - The physical layout containing the blocks
/// * `block_ids` - List of block IDs to fill
/// * `layer_range` - Range of layers to fill
/// * `pattern` - Fill pattern to use
pub fn fill_layers(
layout: &PhysicalLayout,
block_ids: &[usize],
layer_range: Range<usize>,
pattern: FillPattern,
) -> Result<()> {
let config = layout.layout().config();
let num_layers = config.num_layers;
let outer_dim = config.outer_dim;
if layer_range.end > num_layers {
return Err(anyhow!(
"Layer range {:?} exceeds num_layers {}",
layer_range,
num_layers
));
}
for &block_id in block_ids {
if block_id >= config.num_blocks {
return Err(anyhow!("Block ID {} out of range", block_id));
}
// Fill specified layers and all outer dimensions
for layer_id in layer_range.clone() {
for outer_id in 0..outer_dim {
let region = layout.memory_region(block_id, layer_id, outer_id)?;
fill_memory_region(region.addr(), region.size(), block_id, layer_id, pattern)?;
}
}
}
Ok(())
}
/// Fill a memory region with the specified pattern.
///
/// # Safety
/// This function performs unsafe memory writes. The caller must ensure:
/// - The memory region is valid and accessible
/// - No other references exist to this memory
fn fill_memory_region(
addr: usize,
size: usize,
block_id: usize,
layer_id: usize,
pattern: FillPattern,
) -> Result<()> {
unsafe {
let ptr = addr as *mut u8;
match pattern {
FillPattern::Constant(value) => {
std::ptr::write_bytes(ptr, value, size);
}
FillPattern::Sequential => {
for offset in 0..size {
let value = ((block_id + layer_id + offset) % 256) as u8;
ptr.add(offset).write(value);
}
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::super::tests::*;
use super::*;
use crate::block_manager::v2::memory::actions::Slice;
#[test]
fn test_fill_blocks_constant() {
let physical = builder(2)
.fully_contiguous()
.allocate_system()
.build()
.unwrap();
fill_blocks(&physical, &[0, 1], FillPattern::Constant(42)).unwrap();
// Verify all bytes are set to 42
assert!(
physical
.memory_region(0, 0, 0)
.unwrap()
.as_slice()
.unwrap()
.iter()
.all(|&b| b == 42)
);
}
#[test]
fn test_fill_blocks_sequential() {
let physical = builder(2)
.fully_contiguous()
.allocate_system()
.build()
.unwrap();
fill_blocks(&physical, &[0, 1], FillPattern::Sequential).unwrap();
let mr = physical.memory_region(0, 0, 0).unwrap();
let mr_slice = mr.as_slice().unwrap();
// Verify pattern is applied (spot check a few bytes)
let first_byte = mr_slice[0];
let second_byte = mr_slice[1];
assert_eq!(first_byte, 0);
assert_eq!(second_byte, first_byte.wrapping_add(1));
let mr = physical.memory_region(1, 1, 0).unwrap();
let mr_slice = mr.as_slice().unwrap();
let first_byte = mr_slice[0];
let second_byte = mr_slice[1];
assert_eq!(first_byte, 2);
assert_eq!(second_byte, first_byte.wrapping_add(1));
}
#[test]
fn test_fill_layers() {
let physical = builder(2)
.fully_contiguous()
.allocate_system()
.build()
.unwrap();
// Fill only layer 0
fill_layers(&physical, &[0], 0..1, FillPattern::Constant(0)).unwrap();
fill_layers(&physical, &[0], 1..2, FillPattern::Constant(1)).unwrap();
fill_layers(&physical, &[1], 0..1, FillPattern::Constant(100)).unwrap();
fill_layers(&physical, &[1], 1..2, FillPattern::Constant(101)).unwrap();
let mr_00 = physical.memory_region(0, 0, 0).unwrap().as_slice().unwrap()[0];
let mr_01 = physical.memory_region(0, 1, 0).unwrap().as_slice().unwrap()[0];
let mr_10 = physical.memory_region(1, 0, 0).unwrap().as_slice().unwrap()[0];
let mr_11 = physical.memory_region(1, 1, 0).unwrap().as_slice().unwrap()[0];
assert_eq!(mr_00, 0);
assert_eq!(mr_01, 1);
assert_eq!(mr_10, 100);
assert_eq!(mr_11, 101);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transfer module for copying blocks between layouts with different storage locations.
//!
//! This module provides functionality for transferring KV cache blocks between layouts
//! that may be backed by different storage types (GPU memory, pinned host memory, disk, etc.)
//! and potentially across NIXL-connected remote nodes.
//!
//! # Core Concepts
//!
//! - [`PhysicalLayout`]: Wraps a layout with its physical storage location and NIXL metadata
//! - [`LayoutDescriptor`]: Serializable representation for cross-node communication
//! - Transfer strategies: memcpy, CUDA, NIXL based on source/destination locations
//! - Block-wise and layer-wise transfer operations
//!
//! # Usage
//!
//! ```rust,ignore
//! use dynamo_kvbm::v2::transfer::{PhysicalLayout, transfer_blocks};
//!
//! // Create local physical layout with NIXL registration
//! let src = PhysicalLayout::new_local(src_layout, StorageKind::Device(0))
//! .with_nixl_registration("local_agent".to_string())?;
//!
//! // Create remote physical layout
//! let dst = PhysicalLayout::new_remote(
//! dst_layout,
//! StorageKind::Pinned,
//! "remote_agent".to_string()
//! );
//!
//! // Transfer blocks from local to remote
//! let src_block_ids = [0, 1, 2];
//! let dst_block_ids = [0, 1, 2];
//! let future = transfer_blocks(&src, &dst, &src_block_ids, &dst_block_ids, &ctx)?;
//! future.await?;
//! ```
pub mod capabilities;
pub mod checksum;
pub mod context;
pub mod executor;
pub mod fill;
pub mod nixl_agent;
pub mod notifications;
pub mod options;
pub mod preferences;
pub mod strategy;
pub mod validation;
#[cfg(test)]
mod tests;
// Re-export StorageKind
pub use crate::block_manager::v2::memory::StorageKind;
pub use capabilities::TransferCapabilities;
pub use checksum::{BlockChecksum, compute_block_checksums, compute_layer_checksums};
pub use fill::{FillPattern, fill_blocks, fill_layers};
pub use nixl_agent::{NixlAgent, NixlBackendConfig};
pub use options::{TransferOptions, TransferOptionsBuilder};
pub use preferences::{NativeVsNixlPolicy, TransferPreferences};
pub use strategy::{TransferPlan, TransferStrategy};
pub use validation::BlockValidationError;
// Internal - TransferContext is now managed by TransportManager
pub(crate) use context::TransferContext;
pub use super::layout::PhysicalLayout;
// Re-export manager types - TransportManager is the primary public API
pub use super::manager::{LayoutHandle, SerializedLayout, TransportManager, WorkerAddress};
// #[cfg(test)]
// pub use testing::{RoundTripTest, RoundTripTestResult};
use anyhow::Result;
/// Future representing an in-progress transfer operation.
///
/// The transfer completes when this future resolves.
pub type TransferFuture = std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>;
/// Specification for bounce buffer in multi-hop transfers.
///
/// This structure provides the layout and block IDs to use as an intermediate
/// staging area when direct transfers are not allowed.
pub trait BounceBufferSpec: Send + Sync {
fn layout(&self) -> &PhysicalLayout;
fn block_ids(&self) -> &[usize];
}
// #[cfg(all(test, feature = "testing-cuda"))]
// mod cuda_integration_tests {
// use super::*;
// use crate::block_manager::v2::layout::{
// FullyContiguousLayout, Layout, LayoutConfig, MemoryRegion, OwnedMemoryRegion,
// };
// use cudarc::driver::CudaContext;
// use std::sync::Arc;
// // TODO: Add CUDA-specific integration tests
// // These would test:
// // - H2D transfers
// // - D2H transfers
// // - D2D transfers
// // - Async completion via event synchronization
// }
// #[cfg(all(test, feature = "testing-nixl"))]
// mod nixl_integration_tests {
// use super::*;
// // TODO: Add NIXL-specific integration tests
// // These would test:
// // - Remote memory access via NIXL Read
// // - Disk-backed transfers via NIXL Write
// // - Cross-node serialization with LayoutDescriptor
// }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! NIXL backend configuration with Figment support.
//!
//! This module provides configuration extraction for NIXL backends from
//! environment variables with the pattern: `DYN_KVBM_NIXL_BACKEND_<backend>_<key>=<value>`
use anyhow::{Result, bail};
use dynamo_runtime::config::parse_bool;
use std::collections::HashSet;
/// Configuration for NIXL backends.
///
/// Supports extracting backend configurations from environment variables:
/// - `DYN_KVBM_NIXL_BACKEND_UCX=true` - Enable UCX backend with default params
/// - `DYN_KVBM_NIXL_BACKEND_GDS=false` - Explicitly disable GDS backend
/// - Valid values: true/false, 1/0, on/off, yes/no (case-insensitive)
/// - Invalid values (e.g., "maybe", "random") will cause an error
/// - Custom params (e.g., `DYN_KVBM_NIXL_BACKEND_UCX_PARAM1=value`) will cause an error
///
/// # Examples
///
/// ```rust,ignore
/// // Extract from environment
/// let config = NixlBackendConfig::from_env()?;
///
/// // Or combine with builder overrides
/// let config = NixlBackendConfig::from_env()?
/// .with_backend("ucx")
/// .with_backend("gds");
/// ```
#[derive(Debug, Clone, Default)]
pub struct NixlBackendConfig {
/// Set of enabled backends (just backend names, no custom params yet)
backends: HashSet<String>,
}
impl NixlBackendConfig {
/// Create a new empty configuration.
pub fn new() -> Self {
Self::default()
}
/// Create configuration from environment variables.
///
/// Extracts backends from `DYN_KVBM_NIXL_BACKEND_<backend>=<value>` variables.
///
/// # Errors
/// Returns an error if:
/// - Custom parameters are detected (not yet supported)
/// - Invalid boolean values are provided (must be truthy or falsey)
pub fn from_env() -> Result<Self> {
let mut backends = HashSet::new();
// Extract all environment variables that match our pattern
for (key, value) in std::env::vars() {
if let Some(remainder) = key.strip_prefix("DYN_KVBM_NIXL_BACKEND_") {
// Check if there's an underscore (indicating custom params)
if remainder.contains('_') {
bail!(
"Custom NIXL backend parameters are not yet supported. \
Found: {}. Please use only DYN_KVBM_NIXL_BACKEND_<backend>=true \
to enable backends with default parameters.",
key
);
}
// Simple backend enablement (e.g., DYN_KVBM_NIXL_BACKEND_UCX=true)
let backend_name = remainder.to_uppercase();
match parse_bool(&value) {
Ok(true) => {
backends.insert(backend_name);
}
Ok(false) => {
// Explicitly disabled, don't add to backends
continue;
}
Err(e) => bail!("Invalid value for {}: {}", key, e),
}
}
}
// Default to UCX if no backends specified
if backends.is_empty() {
backends.insert("UCX".to_string());
}
Ok(Self { backends })
}
/// Add a backend to the configuration.
///
/// Backend names will be converted to uppercase for consistency.
pub fn with_backend(mut self, backend: impl Into<String>) -> Self {
self.backends.insert(backend.into().to_uppercase());
self
}
/// Get the set of enabled backends.
pub fn backends(&self) -> &HashSet<String> {
&self.backends
}
/// Check if a specific backend is enabled.
pub fn has_backend(&self, backend: &str) -> bool {
self.backends.contains(&backend.to_uppercase())
}
/// Merge another configuration into this one.
///
/// Backends from the other configuration will be added to this one.
pub fn merge(mut self, other: NixlBackendConfig) -> Self {
self.backends.extend(other.backends);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_config_is_empty() {
let config = NixlBackendConfig::new();
assert!(config.backends().is_empty());
}
#[test]
fn test_with_backend() {
let config = NixlBackendConfig::new()
.with_backend("ucx")
.with_backend("gds_mt");
assert!(config.has_backend("ucx"));
assert!(config.has_backend("UCX"));
assert!(config.has_backend("gds_mt"));
assert!(config.has_backend("GDS_MT"));
assert!(!config.has_backend("other"));
}
#[test]
fn test_merge_configs() {
let config1 = NixlBackendConfig::new().with_backend("ucx");
let config2 = NixlBackendConfig::new().with_backend("gds");
let merged = config1.merge(config2);
assert!(merged.has_backend("ucx"));
assert!(merged.has_backend("gds"));
}
#[test]
fn test_backend_name_case_insensitive() {
let config = NixlBackendConfig::new()
.with_backend("ucx")
.with_backend("Gds_mt")
.with_backend("OTHER");
assert!(config.has_backend("UCX"));
assert!(config.has_backend("ucx"));
assert!(config.has_backend("GDS_MT"));
assert!(config.has_backend("gds_mt"));
assert!(config.has_backend("OTHER"));
assert!(config.has_backend("other"));
}
// Note: Testing from_env() would require setting environment variables,
// which is challenging in unit tests. This is better tested with integration tests.
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! NIXL agent wrapper and configuration.
//!
//! This module provides:
//! - `NixlAgent`: Wrapper around nixl_sys::Agent that tracks initialized backends
//! - `NixlBackendConfig`: Configuration for NIXL backends from environment variables
mod config;
pub use config::NixlBackendConfig;
use anyhow::Result;
use nixl_sys::Agent as RawNixlAgent;
use std::collections::HashSet;
/// A NIXL agent wrapper that tracks which backends were successfully initialized.
///
/// This wrapper provides:
/// - Runtime validation of backend availability
/// - Clear error messages when operations need unavailable backends
/// - Single source of truth for backend state in tests and production
///
/// # Backend Tracking
///
/// Since `nixl_sys::Agent` doesn't provide a method to query active backends,
/// we track them during initialization. The `available_backends` set is populated
/// based on successful `create_backend()` calls.
#[derive(Clone, Debug)]
pub struct NixlAgent {
agent: RawNixlAgent,
available_backends: HashSet<String>,
}
impl NixlAgent {
/// Create a new NIXL agent with the specified backends.
///
/// Attempts to initialize all requested backends. If a backend fails, it logs
/// a warning but continues with remaining backends. At least one backend must
/// succeed or this returns an error.
///
/// # Arguments
/// * `name` - Agent name
/// * `backends` - List of backend names to try (e.g., `&["UCX", "GDS_MT, "POSIX"]`)
///
/// # Returns
/// A `NixlAgent` that tracks which backends were successfully initialized.
///
/// # Errors
/// Returns an error if:
/// - Agent creation fails
/// - All backend initialization attempts fail
pub fn new_with_backends(name: &str, backends: &[&str]) -> Result<Self> {
let agent = RawNixlAgent::new(name)?;
let mut available_backends = HashSet::new();
for backend in backends {
let backend_upper = backend.to_uppercase();
match agent.get_plugin_params(&backend_upper) {
Ok((_, params)) => match agent.create_backend(&backend_upper, &params) {
Ok(_) => {
available_backends.insert(backend_upper);
}
Err(e) => {
eprintln!(
"✗ Failed to create {} backend: {}. Operations requiring this backend will fail.",
backend_upper, e
);
}
},
Err(_) => {
eprintln!(
"✗ No {} plugin found. Operations requiring this backend will fail.",
backend_upper
);
}
}
}
if available_backends.is_empty() {
anyhow::bail!("Failed to initialize any NIXL backends from {:?}", backends);
}
Ok(Self {
agent,
available_backends,
})
}
/// Create a NIXL agent requiring ALL specified backends to be available.
///
/// Unlike `new_with_backends()` which continues if some backends fail, this method
/// will return an error if ANY backend fails to initialize. Use this in production
/// when specific backends are mandatory.
///
/// # Arguments
/// * `name` - Agent name
/// * `backends` - List of backend names that MUST be available
///
/// # Returns
/// A `NixlAgent` with all requested backends initialized.
///
/// # Errors
/// Returns an error if:
/// - Agent creation fails
/// - Any backend fails to initialize
///
/// # Example
/// ```ignore
/// // In production: require both UCX and GDS, fail if either is missing
/// let agent = NixlAgent::require_backends("worker-0", &["UCX", "GDS_MT])?;
/// ```
pub fn require_backends(name: &str, backends: &[&str]) -> Result<Self> {
let agent = RawNixlAgent::new(name)?;
let mut available_backends = HashSet::new();
let mut failed_backends = Vec::new();
for backend in backends {
let backend_upper = backend.to_uppercase();
match agent.get_plugin_params(&backend_upper) {
Ok((_, params)) => match agent.create_backend(&backend_upper, &params) {
Ok(_) => {
available_backends.insert(backend_upper);
}
Err(e) => {
eprintln!("✗ Failed to create {} backend: {}", backend_upper, e);
failed_backends
.push((backend_upper.clone(), format!("create failed: {}", e)));
}
},
Err(e) => {
eprintln!("✗ No {} plugin found", backend_upper);
failed_backends
.push((backend_upper.clone(), format!("plugin not found: {}", e)));
}
}
}
if !failed_backends.is_empty() {
let error_details: Vec<String> = failed_backends
.iter()
.map(|(name, reason)| format!("{}: {}", name, reason))
.collect();
anyhow::bail!(
"Failed to initialize required backends: [{}]",
error_details.join(", ")
);
}
Ok(Self {
agent,
available_backends,
})
}
/// Create a NIXL agent with default backends for testing/development.
///
/// Attempts to initialize UCX, GDS, and POSIX backends. If some are unavailable,
/// continues with whatever succeeds. This ensures code works in various environments.
pub fn new_default(name: &str) -> Result<Self> {
Self::new_with_backends(name, &["UCX", "GDS_MT", "POSIX"])
}
/// Get a reference to the underlying raw NIXL agent.
pub fn raw_agent(&self) -> &RawNixlAgent {
&self.agent
}
/// Consume and return the underlying raw NIXL agent.
///
/// **Warning**: Once consumed, backend tracking is lost. Use this only when
/// interfacing with code that requires `nixl_sys::Agent` directly.
pub fn into_raw_agent(self) -> RawNixlAgent {
self.agent
}
/// Check if a specific backend is available.
pub fn has_backend(&self, backend: &str) -> bool {
self.available_backends.contains(&backend.to_uppercase())
}
/// Get all available backends.
pub fn backends(&self) -> &HashSet<String> {
&self.available_backends
}
/// Require a specific backend, returning an error if unavailable.
///
/// Use this at the start of operations that need specific backends.
///
/// # Example
/// ```ignore
/// agent.require_backend("GDS_MT)?;
/// // Proceed with GDS-specific operations
/// ```
pub fn require_backend(&self, backend: &str) -> Result<()> {
let backend_upper = backend.to_uppercase();
if self.has_backend(&backend_upper) {
Ok(())
} else {
anyhow::bail!(
"Operation requires {} backend, but it was not initialized. Available backends: {:?}",
backend_upper,
self.available_backends
)
}
}
}
// Delegate common methods to the underlying agent
impl std::ops::Deref for NixlAgent {
type Target = RawNixlAgent;
fn deref(&self) -> &Self::Target {
&self.agent
}
}
#[cfg(all(test, feature = "testing-nixl"))]
mod tests {
use super::*;
#[test]
fn test_agent_backend_tracking() {
// Try to create agent with UCX
let agent = NixlAgent::new_with_backends("test", &["UCX"]);
// Should succeed if UCX is available
if let Ok(agent) = agent {
assert!(agent.has_backend("UCX"));
assert!(agent.has_backend("ucx")); // Case insensitive
}
}
#[test]
fn test_require_backend() {
let agent = NixlAgent::new_with_backends("test", &["UCX"]).expect("Need UCX for test");
// Should succeed for available backend
assert!(agent.require_backend("UCX").is_ok());
// Should fail for unavailable backend
assert!(agent.require_backend("GDS_MT").is_err());
}
#[test]
fn test_require_backends_strict() {
// Should succeed if UCX is available
let agent = NixlAgent::require_backends("test_strict", &["UCX"])
.expect("Failed to require backends");
assert!(agent.has_backend("UCX"));
// Should fail if any backend is missing (GDS likely not available)
let result = NixlAgent::require_backends("test_strict_fail", &["UCX", "DUDE"]);
assert!(result.is_err());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! CUDA event polling-based completion checker.
use anyhow::Result;
use cudarc::driver::{CudaEvent, DriverError, result as cuda_result, sys::CUresult};
use super::CompletionChecker;
/// Completion checker that polls CUDA event status.
pub struct CudaEventChecker {
event: CudaEvent,
}
impl CudaEventChecker {
pub fn new(event: CudaEvent) -> Self {
Self { event }
}
}
impl CompletionChecker for CudaEventChecker {
fn is_complete(&self) -> Result<bool> {
// Query the CUDA event to check if it's complete
// cudaEventQuery returns cudaSuccess if complete, cudaErrorNotReady if still pending
unsafe {
match cuda_result::event::query(self.event.cu_event()) {
Ok(()) => Ok(true), // Event is complete
Err(DriverError(CUresult::CUDA_ERROR_NOT_READY)) => Ok(false),
Err(e) => Err(anyhow::anyhow!("CUDA event query failed: {:?}", e)),
}
}
}
}
#[cfg(all(test, feature = "testing-cuda"))]
mod tests {
use crate::block_manager::v2::physical::manager::TransportManager;
use crate::block_manager::v2::physical::transfer::nixl_agent::NixlAgent;
use crate::block_manager::v2::physical::transfer::tests::cuda::CudaSleep;
use std::time::{Duration, Instant};
#[tokio::test]
async fn test_cuda_event_delayed_notification() {
let agent = NixlAgent::require_backends("test_agent", &[]).unwrap();
let manager = TransportManager::builder()
.worker_id(0)
.cuda_device_id(0)
.nixl_agent(agent)
.build()
.unwrap();
let stream = manager.h2d_stream();
let cuda_ctx = manager.cuda_context();
// Get or create the CudaSleep utility (compiles kernel and calibrates on first use)
let cuda_sleep = CudaSleep::for_context(cuda_ctx).unwrap();
// Test 1: Launch sleep and wait via async notification
let t0_queue_start = Instant::now();
cuda_sleep
.launch(Duration::from_millis(600), stream)
.unwrap();
let queue_time = t0_queue_start.elapsed();
let event = stream.record_event(None).unwrap();
let notification = manager.register_cuda_event(event);
notification.await.unwrap();
let wait_time = t0_queue_start.elapsed() - queue_time;
println!(
"GPU sleep test: queue {:?}, wait {:?}",
queue_time, wait_time
);
assert!(
queue_time < Duration::from_millis(10),
"launching the sleep kernel should be fast: {:?}",
queue_time
);
assert!(
wait_time >= Duration::from_millis(500),
"wait time should reflect >=500ms of GPU work: {:?}",
wait_time
);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transfer completion notification system.
//!
//! This module provides abstractions for waiting on transfer completions using different
//! mechanisms: polling-based (NIXL status, CUDA events) and event-based (NIXL notifications).
use std::collections::HashMap;
use std::time::{Duration, Instant};
use anyhow::Result;
use tokio::sync::{mpsc, oneshot};
use tokio::time::interval;
use tracing::warn;
use uuid::Uuid;
pub mod cuda_event;
pub mod nixl_events;
pub mod nixl_status;
pub mod notification;
pub use cuda_event::CudaEventChecker;
pub use nixl_events::{RegisterNixlNotification, process_nixl_notification_events};
pub use nixl_status::NixlStatusChecker;
pub use notification::TransferCompleteNotification;
/// Trait for checking if a transfer operation has completed.
/// Supports polling-based completion checks (NIXL status, CUDA events).
pub trait CompletionChecker: Send {
/// Returns true if the transfer is complete, false if still pending.
fn is_complete(&self) -> Result<bool>;
}
/// Registration message for polling-based transfer completion.
pub struct RegisterPollingNotification<C: CompletionChecker> {
pub uuid: Uuid,
pub checker: C,
pub done: oneshot::Sender<Result<()>>,
}
/// Tracking struct for outstanding polling-based transfers.
struct OutstandingPollingTransfer<C: CompletionChecker> {
checker: C,
done: oneshot::Sender<Result<()>>,
arrived_at: Instant,
last_warned_at: Option<Instant>,
}
/// Helper function to check if a transfer should be warned about and log the warning.
/// Returns the new last_warned_at time if a warning was issued.
fn check_and_warn_slow_transfer(
uuid: &Uuid,
arrived_at: Instant,
last_warned_at: Option<Instant>,
) -> Option<Instant> {
let elapsed = arrived_at.elapsed();
if elapsed > Duration::from_secs(60) {
let should_warn = last_warned_at
.map(|last| last.elapsed() > Duration::from_secs(30))
.unwrap_or(true);
if should_warn {
warn!(
uuid = %uuid,
elapsed_secs = elapsed.as_secs(),
"Transfer has been pending for over 1 minute"
);
return Some(Instant::now());
}
}
last_warned_at
}
/// Generic polling-based transfer completion handler.
/// Works with any CompletionChecker implementation (NIXL status, CUDA events, etc.)
pub async fn process_polling_notifications<C: CompletionChecker>(
mut rx: mpsc::Receiver<RegisterPollingNotification<C>>,
) {
let mut outstanding: HashMap<Uuid, OutstandingPollingTransfer<C>> = HashMap::new();
let mut check_interval = interval(Duration::from_millis(1));
loop {
tokio::select! {
// Handle new transfer requests
notification = rx.recv() => {
match notification {
Some(notif) => {
outstanding.insert(notif.uuid, OutstandingPollingTransfer {
checker: notif.checker,
done: notif.done,
arrived_at: Instant::now(),
last_warned_at: None,
});
}
None => {
// Channel closed, finish processing outstanding transfers then exit
break;
}
}
}
// Periodically check status of outstanding transfers
_ = check_interval.tick(), if !outstanding.is_empty() => {
let mut completed = Vec::new();
for (uuid, transfer) in outstanding.iter_mut() {
// Check transfer status
match transfer.checker.is_complete() {
Ok(true) => {
// Transfer complete - mark for removal
completed.push((*uuid, Ok(())));
}
Ok(false) => {
// Transfer still in progress - check if we should warn
transfer.last_warned_at = check_and_warn_slow_transfer(
uuid,
transfer.arrived_at,
transfer.last_warned_at,
);
}
Err(e) => {
warn!(
uuid = %uuid,
error = %e,
"Transfer status check failed"
);
completed.push((*uuid, Err(e)));
}
}
}
// Remove completed transfers and signal completion
for (uuid, result) in completed {
if let Some(transfer) = outstanding.remove(&uuid) {
// Signal completion (ignore if receiver dropped)
let _ = transfer.done.send(result);
}
}
}
}
}
// Channel closed, but we may still have outstanding transfers
// Continue processing them until all are complete
while !outstanding.is_empty() {
check_interval.tick().await;
let mut completed = Vec::new();
for (uuid, transfer) in outstanding.iter() {
match transfer.checker.is_complete() {
Ok(true) => {
completed.push((*uuid, Ok(())));
}
Ok(false) => {
// Still pending
}
Err(e) => {
warn!(
uuid = %uuid,
error = %e,
"Transfer status check failed during shutdown"
);
completed.push((*uuid, Err(e)));
}
}
}
for (uuid, result) in completed {
if let Some(transfer) = outstanding.remove(&uuid) {
let _ = transfer.done.send(result);
}
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! NIXL notification-based completion handler.
use std::collections::HashMap;
use std::time::{Duration, Instant};
use anyhow::Result;
use nixl_sys::{Agent as NixlAgent, NotificationMap, XferRequest};
use tokio::sync::{mpsc, oneshot};
use tokio::time::interval;
use tracing::warn;
use uuid::Uuid;
/// Registration message for NIXL notification-based transfer completion.
pub struct RegisterNixlNotification {
pub uuid: Uuid,
pub xfer_req: XferRequest,
pub done: oneshot::Sender<Result<()>>,
}
/// Tracking struct for outstanding NIXL notification transfers.
struct OutstandingTransfer {
#[allow(dead_code)] // Kept for potential future cleanup or debugging
xfer_req: XferRequest,
done: oneshot::Sender<Result<()>>,
arrived_at: Instant,
last_warned_at: Option<Instant>,
}
/// Helper function to check if a transfer should be warned about and log the warning.
/// Returns the new last_warned_at time if a warning was issued.
fn check_and_warn_slow_transfer(
uuid: &Uuid,
arrived_at: Instant,
last_warned_at: Option<Instant>,
) -> Option<Instant> {
let elapsed = arrived_at.elapsed();
if elapsed > Duration::from_secs(60) {
let should_warn = last_warned_at
.map(|last| last.elapsed() > Duration::from_secs(30))
.unwrap_or(true);
if should_warn {
warn!(
uuid = %uuid,
elapsed_secs = elapsed.as_secs(),
"Transfer has been pending for over 1 minute"
);
return Some(Instant::now());
}
}
last_warned_at
}
/// NIXL notification-based transfer completion handler.
/// Fetches notifications in batches and matches them against outstanding transfers.
pub async fn process_nixl_notification_events(
agent: NixlAgent,
mut rx: mpsc::Receiver<RegisterNixlNotification>,
) {
let mut outstanding: HashMap<Uuid, OutstandingTransfer> = HashMap::new();
let mut check_interval = interval(Duration::from_millis(1));
loop {
tokio::select! {
// Handle new transfer requests
notification = rx.recv() => {
match notification {
Some(notif) => {
outstanding.insert(notif.uuid, OutstandingTransfer {
xfer_req: notif.xfer_req,
done: notif.done,
arrived_at: Instant::now(),
last_warned_at: None,
});
}
None => {
// Channel closed, finish processing outstanding transfers then exit
break;
}
}
}
// Periodically fetch and process notifications
_ = check_interval.tick(), if !outstanding.is_empty() => {
// Create notification map inside this branch to avoid Send issues
let mut notif_map = match NotificationMap::new() {
Ok(map) => map,
Err(e) => {
warn!(error = %e, "Failed to create notification map");
continue;
}
};
// Fetch all pending notifications
if let Err(e) = agent.get_notifications(&mut notif_map, None) {
warn!(error = %e, "Failed to fetch NIXL notifications");
continue;
}
// Process notifications and match against outstanding transfers
let notifications = match notif_map.take_notifs() {
Ok(notifs) => notifs,
Err(e) => {
warn!(error = %e, "Failed to extract notifications from map");
continue;
}
};
let mut completed = Vec::new();
// Iterate through all notifications
for (_agent_name, notif_strings) in notifications {
for notif_str in notif_strings {
// Try to parse notification as UUID
// NOTE: This assumes notifications contain UUIDs.
// The actual format may be different and may need adjustment.
if let Ok(notif_uuid) = Uuid::parse_str(&notif_str) {
if outstanding.contains_key(&notif_uuid) {
completed.push(notif_uuid);
} else {
// Notification arrived before we started waiting for it
// This is the race condition we need to handle
warn!(
uuid = %notif_uuid,
"Received notification for transfer not in outstanding map (early arrival)"
);
}
}
}
}
// Check for slow transfers and update warnings
for (uuid, transfer) in outstanding.iter_mut() {
if !completed.contains(uuid) {
transfer.last_warned_at = check_and_warn_slow_transfer(
uuid,
transfer.arrived_at,
transfer.last_warned_at,
);
}
}
// Remove completed transfers and signal completion
for uuid in completed {
if let Some(transfer) = outstanding.remove(&uuid) {
let _ = transfer.done.send(Ok(()));
}
}
}
}
}
// Channel closed, but we may still have outstanding transfers
// Continue processing them until all are complete
while !outstanding.is_empty() {
check_interval.tick().await;
let mut notif_map = match NotificationMap::new() {
Ok(map) => map,
Err(_) => continue,
};
if let Ok(()) = agent.get_notifications(&mut notif_map, None)
&& let Ok(notifications) = notif_map.take_notifs()
{
let mut completed = Vec::new();
for (_agent_name, notif_strings) in notifications {
for notif_str in notif_strings {
if let Ok(notif_uuid) = Uuid::parse_str(&notif_str)
&& outstanding.contains_key(&notif_uuid)
{
completed.push(notif_uuid);
}
}
}
for uuid in completed {
if let Some(transfer) = outstanding.remove(&uuid) {
let _ = transfer.done.send(Ok(()));
}
}
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! NIXL status polling-based completion checker.
use anyhow::{Result, anyhow};
use nixl_sys::{Agent as NixlAgent, XferRequest};
use super::CompletionChecker;
/// Completion checker that polls NIXL transfer status.
pub struct NixlStatusChecker {
agent: NixlAgent,
xfer_req: XferRequest,
}
impl NixlStatusChecker {
pub fn new(agent: NixlAgent, xfer_req: XferRequest) -> Self {
Self { agent, xfer_req }
}
}
impl CompletionChecker for NixlStatusChecker {
fn is_complete(&self) -> Result<bool> {
// get_xfer_status returns XferStatus enum:
// - XferStatus::Success means transfer is complete
// - XferStatus::InProgress means still pending
match self.agent.get_xfer_status(&self.xfer_req) {
Ok(status) => Ok(status.is_success()),
Err(e) => Err(anyhow!("NIXL transfer status check failed: {}", e)),
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transfer completion notification handle.
use anyhow::Result;
use tokio::sync::oneshot;
/// Notification handle for an in-progress transfer.
///
/// This object can be awaited to block until the transfer completes.
/// The transfer is tracked by a background handler that polls for completion
/// or processes notification events.
pub struct TransferCompleteNotification {
pub(crate) status: oneshot::Receiver<Result<()>>,
}
impl TransferCompleteNotification {
/// Create a notification that is already completed (for synchronous transfers).
///
/// This is useful for transfers that complete immediately without needing
/// background polling, such as memcpy operations.
pub fn completed() -> Self {
let (tx, rx) = oneshot::channel();
// Signal completion immediately
let _ = tx.send(Ok(()));
Self { status: rx }
}
/// Wait for the transfer to complete (blocking).
///
/// This method blocks the current thread until the transfer completes.
/// Use `.await` for async contexts.
///
/// Returns `Ok(())` when the transfer successfully completes, or an error
/// if the background handler was dropped before completion or if the transfer failed.
pub fn wait(self) -> Result<()> {
self.status
.blocking_recv()
.map_err(|_| anyhow::anyhow!("Transfer handler dropped before completion"))?
}
}
impl std::future::Future for TransferCompleteNotification {
type Output = Result<()>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
use std::pin::Pin;
Pin::new(&mut self.status).poll(cx).map(|result| {
result
.map_err(|_| anyhow::anyhow!("Transfer handler dropped before completion"))
.and_then(|r| r)
})
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transfer options for configuring block and layer transfers.
use super::BounceBufferSpec;
use derive_builder::Builder;
use std::{ops::Range, sync::Arc};
/// Options for configuring transfer operations.
///
/// This structure provides configuration for block and layer transfers,
/// including layer ranges, NIXL write notifications, and bounce buffers.
///
/// # Examples
///
/// ```rust,ignore
/// let options = TransferOptions::builder()
/// .nixl_write_notification(42)
/// .layer_range(0..10)
/// .build();
/// ```
#[derive(Clone, Default, Builder)]
#[builder(pattern = "owned", default)]
pub struct TransferOptions {
/// Range of layers to transfer (None = all layers).
///
/// When specified, only the layers in this range will be transferred.
/// This is useful for partial block transfers or layer-specific operations.
#[builder(default, setter(strip_option))]
pub layer_range: Option<Range<usize>>,
/// NIXL write notification value delivered after RDMA write completes.
///
/// When specified, NIXL will deliver this notification value to the remote
/// node after the RDMA write operation completes. This enables efficient
/// notification of transfer completion without requiring polling.
#[builder(default, setter(strip_option))]
pub nixl_write_notification: Option<u64>,
/// Bounce buffer specification for multi-hop transfers.
///
/// When direct transfers are not allowed or efficient, this specifies
/// an intermediate staging area. The transfer will be split into two hops:
/// source → bounce buffer → destination.
#[builder(default, setter(strip_option, into))]
pub bounce_buffer: Option<Arc<dyn BounceBufferSpec>>,
}
impl TransferOptions {
/// Create a new builder for transfer options.
pub fn builder() -> TransferOptionsBuilder {
TransferOptionsBuilder::default()
}
/// Create transfer options from an optional layer range.
pub fn from_layer_range(layer_range: Option<Range<usize>>) -> Self {
Self {
layer_range,
..Self::default()
}
}
/// Create default transfer options.
///
/// This transfers all layers with no special configuration.
pub fn new() -> Self {
Self::default()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_options() {
let options = TransferOptions::default();
assert!(options.layer_range.is_none());
assert!(options.nixl_write_notification.is_none());
assert!(options.bounce_buffer.is_none());
}
#[test]
fn test_builder_with_notification() {
let options = TransferOptions::builder()
.nixl_write_notification(42)
.build()
.unwrap();
assert_eq!(options.nixl_write_notification, Some(42));
assert!(options.layer_range.is_none());
}
#[test]
fn test_builder_with_layer_range() {
let options = TransferOptions::builder()
.layer_range(0..10)
.build()
.unwrap();
assert_eq!(options.layer_range, Some(0..10));
assert!(options.nixl_write_notification.is_none());
}
#[test]
fn test_builder_with_all_options() {
let options = TransferOptions::builder()
.nixl_write_notification(100)
.layer_range(5..15)
.build()
.unwrap();
assert_eq!(options.nixl_write_notification, Some(100));
assert_eq!(options.layer_range, Some(5..15));
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transfer preferences for resolving redundant strategy choices.
//!
//! Some source/destination combinations can use multiple transfer strategies.
//! For example:
//! - System ↔ Pinned: memcpy or NIXL
//! - Pinned ↔ Device: CUDA or NIXL
//!
//! This module provides preferences to control which strategy to prefer.
use serde::{Deserialize, Serialize};
/// Policy for choosing between native transports (memcpy/CUDA) and NIXL.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum NativeVsNixlPolicy {
/// Always prefer native transports (memcpy/CUDA) when available
PreferNative,
/// Always prefer NIXL when available
PreferNixl,
/// Use native for local-to-local, NIXL for remote/disk
#[default]
Automatic,
}
/// Transfer preferences for strategy selection.
///
/// These preferences allow fine-grained control over transfer strategy selection
/// when multiple valid strategies exist for a source/destination pair.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransferPreferences {
/// Policy for native vs NIXL transport selection
pub native_vs_nixl: NativeVsNixlPolicy,
/// Whether to prefer async CUDA operations over blocking ones
pub prefer_async_cuda: bool,
}
impl Default for TransferPreferences {
fn default() -> Self {
Self {
native_vs_nixl: NativeVsNixlPolicy::default(),
prefer_async_cuda: true,
}
}
}
impl TransferPreferences {
/// Create preferences with all defaults.
pub fn new() -> Self {
Self::default()
}
/// Create preferences that always prefer native transports.
pub fn prefer_native() -> Self {
Self {
native_vs_nixl: NativeVsNixlPolicy::PreferNative,
prefer_async_cuda: true,
}
}
/// Create preferences that always prefer NIXL.
pub fn prefer_nixl() -> Self {
Self {
native_vs_nixl: NativeVsNixlPolicy::PreferNixl,
prefer_async_cuda: true,
}
}
/// Set the native vs NIXL policy.
pub fn with_native_vs_nixl(mut self, policy: NativeVsNixlPolicy) -> Self {
self.native_vs_nixl = policy;
self
}
/// Set whether to prefer async CUDA operations.
pub fn with_async_cuda(mut self, prefer_async: bool) -> Self {
self.prefer_async_cuda = prefer_async;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_preferences() {
let prefs = TransferPreferences::default();
assert_eq!(prefs.native_vs_nixl, NativeVsNixlPolicy::Automatic);
assert!(prefs.prefer_async_cuda);
}
#[test]
fn test_prefer_native() {
let prefs = TransferPreferences::prefer_native();
assert_eq!(prefs.native_vs_nixl, NativeVsNixlPolicy::PreferNative);
assert!(prefs.prefer_async_cuda);
}
#[test]
fn test_prefer_nixl() {
let prefs = TransferPreferences::prefer_nixl();
assert_eq!(prefs.native_vs_nixl, NativeVsNixlPolicy::PreferNixl);
assert!(prefs.prefer_async_cuda);
}
#[test]
fn test_builder_pattern() {
let prefs = TransferPreferences::new()
.with_native_vs_nixl(NativeVsNixlPolicy::PreferNixl)
.with_async_cuda(false);
assert_eq!(prefs.native_vs_nixl, NativeVsNixlPolicy::PreferNixl);
assert!(!prefs.prefer_async_cuda);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transfer strategy selection based on source and destination storage locations.
use crate::block_manager::v2::memory::StorageKind;
use super::TransferCapabilities;
use crate::block_manager::v2::physical::{layout::PhysicalLayout, transfer::TransferContext};
/// Transfer strategy to use for copying memory between locations.
///
/// The strategy is determined by the source and destination storage locations.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransferStrategy {
/// CPU memcpy (for host-to-host transfers)
Memcpy,
/// CUDA async host-to-device transfer
CudaAsyncH2D,
/// CUDA async device-to-host transfer
CudaAsyncD2H,
/// CUDA async device-to-device transfer
CudaAsyncD2D,
/// CUDA blocking host-to-device transfer
CudaBlockingH2D,
/// CUDA blocking device-to-host transfer
CudaBlockingD2H,
/// NIXL read operation (pull from remote)
NixlRead,
/// NIXL write operation (push to remote)
NixlWrite,
/// NIXL write (flipped local and remote order)
/// This is needed for some NIXL backends.
/// For example, the POSIX backend requires that host memory
/// always be the "local" descriptor list, regardless of whether
/// it's a read or write.
NixlWriteFlipped,
/// NIXL read (flipped local and remote order)
NixlReadFlipped,
/// Invalid/unsupported transfer
Invalid,
}
/// Plan for executing a transfer, either direct or via bounce buffer.
///
/// Some transfers require staging through host memory when direct paths
/// are not enabled via capabilities.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TransferPlan {
/// Direct single-hop transfer using the specified strategy.
Direct(TransferStrategy),
/// Two-hop transfer requiring a bounce buffer in host memory.
///
/// This is used when:
/// - Device → Remote (without GPU RDMA)
/// - Disk → Remote
/// - Device ↔ Disk (without GDS)
TwoHop {
/// First hop strategy (src → bounce)
first: TransferStrategy,
/// Bounce buffer location (always Pinned for best performance)
bounce_location: StorageKind,
/// Second hop strategy (bounce → dst)
second: TransferStrategy,
},
}
pub(crate) fn select_strategy(
src: &PhysicalLayout,
dst: &PhysicalLayout,
ctx: &TransferContext,
) -> anyhow::Result<TransferPlan> {
let is_src_local = src.nixl_metadata().agent_name() == ctx.nixl_agent().name();
let is_dst_local = dst.nixl_metadata().agent_name() == ctx.nixl_agent().name();
if !is_src_local && !is_dst_local {
return Err(anyhow::anyhow!(
"Both src and dst are remote - this is not supported."
));
}
if is_src_local && is_dst_local {
return Ok(select_direct_strategy(
src.location(),
dst.location(),
false,
ctx.capabilities(),
));
}
select_remote_strategy_v2(
src.location(),
is_src_local,
dst.location(),
is_dst_local,
ctx.capabilities(),
)
}
/// Select the appropriate transfer plan based on source and destination locations.
///
/// # Arguments
/// * `src` - Source storage location (always local)
/// * `dst` - Destination storage location (can be local or remote)
/// * `dst_is_remote` - Whether destination is on a remote node
/// * `capabilities` - Transfer capability flags
///
/// # Returns
/// A transfer plan (direct or two-hop)
///
/// # Conservative Default Policy
///
/// With default capabilities (all disabled):
/// - Device can only transfer to/from Host
/// - Disk can only transfer to/from Host
/// - Host can transfer to Device, Disk, or Remote
/// - Device ↔ Device is allowed (native CUDA)
///
/// Transfers that would violate this policy are staged through host:
/// - Device → Remote: Device → Host → Remote (2 hops)
/// - Disk → Remote: Disk → Host → Remote (2 hops)
/// - Device ↔ Disk: Device → Host → Disk (2 hops)
///
/// # Optional Direct Paths
///
/// - `allow_gds`: Enables Disk ↔ Device direct transfers
/// - `allow_gpu_rdma`: Enables Device → Remote direct transfers
fn select_direct_strategy(
src: StorageKind,
dst: StorageKind,
dst_is_remote: bool,
capabilities: &TransferCapabilities,
) -> TransferPlan {
use StorageKind::*;
use TransferStrategy::*;
// Handle remote destination
if dst_is_remote {
return select_remote_strategy(src, capabilities);
}
// Local-to-local transfers
match (src, dst) {
// Host ↔ Host - direct memcpy
(System, System) | (System, Pinned) | (Pinned, System) | (Pinned, Pinned) => {
TransferPlan::Direct(Memcpy)
}
// Host → Device - direct CUDA
(System, Device(_)) => TransferPlan::Direct(CudaBlockingH2D),
(Pinned, Device(_)) => TransferPlan::Direct(CudaAsyncH2D),
// Device → Host - direct CUDA
(Device(_), System) => TransferPlan::Direct(CudaBlockingD2H),
(Device(_), Pinned) => TransferPlan::Direct(CudaAsyncD2H),
// Device ↔ Device - direct CUDA
(Device(_), Device(_)) => TransferPlan::Direct(CudaAsyncD2D),
// Host ↔ Disk - direct NIXL
(System, Disk(_)) | (Pinned, Disk(_)) => TransferPlan::Direct(NixlWrite),
(Disk(_), System) | (Disk(_), Pinned) => TransferPlan::Direct(NixlReadFlipped),
// Disk ↔ Disk - NIXL doesn't seem to support direct transfers here.
// Leaving this as two-hop for now.
(Disk(_), Disk(_)) => TransferPlan::TwoHop {
first: NixlReadFlipped,
bounce_location: Pinned,
second: NixlWrite,
},
// Device ↔ Disk - check GDS capability
(Device(_), Disk(_)) => {
if capabilities.allows_device_disk_direct() {
// Direct GDS transfer
TransferPlan::Direct(NixlWrite)
} else {
// Stage through host: Device → Pinned → Disk
TransferPlan::TwoHop {
first: CudaAsyncD2H,
bounce_location: Pinned,
second: NixlWrite,
}
}
}
(Disk(_), Device(_)) => {
if capabilities.allows_device_disk_direct() {
// Direct GDS transfer
TransferPlan::Direct(NixlRead)
} else {
// Stage through host: Disk → Pinned → Device
TransferPlan::TwoHop {
first: NixlReadFlipped,
bounce_location: Pinned,
second: CudaAsyncH2D,
}
}
}
}
}
/// Select transfer strategy for remote destination.
fn select_remote_strategy(src: StorageKind, capabilities: &TransferCapabilities) -> TransferPlan {
use StorageKind::*;
use TransferStrategy::*;
match src {
// Host → Remote - direct NIXL
System | Pinned => TransferPlan::Direct(NixlWrite),
// Device → Remote - check GPU RDMA capability
Device(_) => {
if capabilities.allows_device_remote_direct() {
// Direct GPU RDMA transfer
TransferPlan::Direct(NixlWrite)
} else {
// Stage through host: Device → Pinned → Remote
TransferPlan::TwoHop {
first: CudaAsyncD2H,
bounce_location: Pinned,
second: NixlWrite,
}
}
}
// Disk → Remote - always stage through host
Disk(_) => TransferPlan::TwoHop {
first: NixlWrite,
bounce_location: Pinned,
second: NixlWrite,
},
}
}
fn select_remote_strategy_v2(
src: StorageKind,
is_src_local: bool,
dst: StorageKind,
is_dst_local: bool,
capabilities: &TransferCapabilities,
) -> anyhow::Result<TransferPlan> {
// We only support System, Pinned and Device for remote transfers.
// Later we might support staged/bounce buffer transfers.
if matches!(src, StorageKind::Disk(_)) | matches!(dst, StorageKind::Disk(_)) {
return Err(anyhow::anyhow!(
"Neither local nor remote disk transfers are supported over NIXL at this time."
));
}
if !capabilities.allow_gpu_rdma
&& (matches!(src, StorageKind::Device(_)) || matches!(dst, StorageKind::Device(_)))
{
return Err(anyhow::anyhow!(
"GPU RDMA is disabled - this transfer requires GPU RDMA."
));
}
if is_src_local && !is_dst_local {
return Ok(TransferPlan::Direct(TransferStrategy::NixlWrite));
}
if is_dst_local && !is_src_local {
return Ok(TransferPlan::Direct(TransferStrategy::NixlReadFlipped));
}
unreachable!("Both src and dst are remote - this is not supported.");
}
#[cfg(test)]
mod tests {
use super::*;
fn default_caps() -> TransferCapabilities {
TransferCapabilities::default()
}
#[test]
fn test_host_to_host_transfers() {
let caps = default_caps();
assert_eq!(
select_direct_strategy(StorageKind::System, StorageKind::System, false, &caps),
TransferPlan::Direct(TransferStrategy::Memcpy)
);
assert_eq!(
select_direct_strategy(StorageKind::System, StorageKind::Pinned, false, &caps),
TransferPlan::Direct(TransferStrategy::Memcpy)
);
assert_eq!(
select_direct_strategy(StorageKind::Pinned, StorageKind::System, false, &caps),
TransferPlan::Direct(TransferStrategy::Memcpy)
);
assert_eq!(
select_direct_strategy(StorageKind::Pinned, StorageKind::Pinned, false, &caps),
TransferPlan::Direct(TransferStrategy::Memcpy)
);
}
#[test]
fn test_host_to_device_transfers() {
let caps = default_caps();
// System (unpinned) to device should be blocking
assert_eq!(
select_direct_strategy(StorageKind::System, StorageKind::Device(0), false, &caps),
TransferPlan::Direct(TransferStrategy::CudaBlockingH2D)
);
// Pinned to device should be async
assert_eq!(
select_direct_strategy(StorageKind::Pinned, StorageKind::Device(0), false, &caps),
TransferPlan::Direct(TransferStrategy::CudaAsyncH2D)
);
}
#[test]
fn test_device_to_host_transfers() {
let caps = default_caps();
// Device to system should be blocking
assert_eq!(
select_direct_strategy(StorageKind::Device(0), StorageKind::System, false, &caps),
TransferPlan::Direct(TransferStrategy::CudaBlockingD2H)
);
// Device to pinned should be async
assert_eq!(
select_direct_strategy(StorageKind::Device(0), StorageKind::Pinned, false, &caps),
TransferPlan::Direct(TransferStrategy::CudaAsyncD2H)
);
}
#[test]
fn test_device_to_device_transfers() {
let caps = default_caps();
assert_eq!(
select_direct_strategy(StorageKind::Device(0), StorageKind::Device(1), false, &caps),
TransferPlan::Direct(TransferStrategy::CudaAsyncD2D)
);
assert_eq!(
select_direct_strategy(StorageKind::Device(3), StorageKind::Device(3), false, &caps),
TransferPlan::Direct(TransferStrategy::CudaAsyncD2D)
);
}
#[test]
fn test_disk_to_host_transfers() {
let caps = default_caps();
// Disk to host - direct NIXL
assert_eq!(
select_direct_strategy(StorageKind::Disk(42), StorageKind::System, false, &caps),
TransferPlan::Direct(TransferStrategy::NixlReadFlipped)
);
assert_eq!(
select_direct_strategy(StorageKind::Disk(42), StorageKind::Pinned, false, &caps),
TransferPlan::Direct(TransferStrategy::NixlReadFlipped)
);
}
#[test]
fn test_host_to_disk_transfers() {
let caps = default_caps();
// Host to disk - direct NIXL
assert_eq!(
select_direct_strategy(StorageKind::System, StorageKind::Disk(42), false, &caps),
TransferPlan::Direct(TransferStrategy::NixlWrite)
);
assert_eq!(
select_direct_strategy(StorageKind::Pinned, StorageKind::Disk(42), false, &caps),
TransferPlan::Direct(TransferStrategy::NixlWrite)
);
}
#[test]
fn test_device_to_disk_without_gds() {
let caps = default_caps(); // GDS disabled
// Device → Disk should use bounce buffer
let plan =
select_direct_strategy(StorageKind::Device(0), StorageKind::Disk(42), false, &caps);
match plan {
TransferPlan::TwoHop {
first,
bounce_location,
second,
} => {
assert_eq!(first, TransferStrategy::CudaAsyncD2H);
assert_eq!(bounce_location, StorageKind::Pinned);
assert_eq!(second, TransferStrategy::NixlWrite);
}
_ => panic!("Expected TwoHop plan"),
}
}
#[test]
fn test_disk_to_device_without_gds() {
let caps = default_caps(); // GDS disabled
// Disk → Device should use bounce buffer
let plan =
select_direct_strategy(StorageKind::Disk(42), StorageKind::Device(0), false, &caps);
match plan {
TransferPlan::TwoHop {
first,
bounce_location,
second,
} => {
assert_eq!(first, TransferStrategy::NixlReadFlipped);
assert_eq!(bounce_location, StorageKind::Pinned);
assert_eq!(second, TransferStrategy::CudaAsyncH2D);
}
_ => panic!("Expected TwoHop plan"),
}
}
#[test]
fn test_device_to_disk_with_gds() {
let caps = TransferCapabilities::default().with_gds(true);
// Device → Disk should be direct with GDS
assert_eq!(
select_direct_strategy(StorageKind::Device(0), StorageKind::Disk(42), false, &caps),
TransferPlan::Direct(TransferStrategy::NixlWrite)
);
}
#[test]
fn test_disk_to_device_with_gds() {
let caps = TransferCapabilities::default().with_gds(true);
// Disk → Device should be direct with GDS
assert_eq!(
select_direct_strategy(StorageKind::Disk(42), StorageKind::Device(0), false, &caps),
TransferPlan::Direct(TransferStrategy::NixlRead)
);
}
#[test]
fn test_host_to_remote() {
let caps = default_caps();
// Host → Remote - always direct
assert_eq!(
select_direct_strategy(StorageKind::System, StorageKind::System, true, &caps),
TransferPlan::Direct(TransferStrategy::NixlWrite)
);
assert_eq!(
select_direct_strategy(StorageKind::Pinned, StorageKind::Pinned, true, &caps),
TransferPlan::Direct(TransferStrategy::NixlWrite)
);
}
#[test]
fn test_device_to_remote_without_rdma() {
let caps = default_caps(); // GPU RDMA disabled
// Device → Remote should use bounce buffer
let plan = select_direct_strategy(StorageKind::Device(0), StorageKind::System, true, &caps);
match plan {
TransferPlan::TwoHop {
first,
bounce_location,
second,
} => {
assert_eq!(first, TransferStrategy::CudaAsyncD2H);
assert_eq!(bounce_location, StorageKind::Pinned);
assert_eq!(second, TransferStrategy::NixlWrite);
}
_ => panic!("Expected TwoHop plan"),
}
}
#[test]
fn test_device_to_remote_with_rdma() {
let caps = TransferCapabilities::default().with_gpu_rdma(true);
// Device → Remote should be direct with GPU RDMA
assert_eq!(
select_direct_strategy(StorageKind::Device(0), StorageKind::Device(0), true, &caps),
TransferPlan::Direct(TransferStrategy::NixlWrite)
);
}
#[test]
fn test_disk_to_remote() {
let caps = default_caps();
// Disk → Remote always uses bounce buffer
let plan = select_direct_strategy(StorageKind::Disk(42), StorageKind::System, true, &caps);
match plan {
TransferPlan::TwoHop {
first,
bounce_location,
second,
} => {
assert_eq!(first, TransferStrategy::NixlWrite);
assert_eq!(bounce_location, StorageKind::Pinned);
assert_eq!(second, TransferStrategy::NixlWrite);
}
_ => panic!("Expected TwoHop plan"),
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Round-trip testing infrastructure for transfer verification.
//!
//! This module provides utilities for testing data integrity across transfers
//! by comparing checksums after round-trip operations:
//! 1. Source blocks (host) → Intermediate (device/disk/remote)
//! 2. Intermediate → Destination blocks (host, different IDs)
//! 3. Verify checksums match between source and destination
use super::{
BlockChecksum, FillPattern, PhysicalLayout, StorageKind, compute_block_checksums,
fill_blocks, transfer_blocks,
};
use super::context::TransferContext;
use anyhow::{Result, anyhow};
use std::collections::HashMap;
/// Result of a round-trip test.
#[derive(Debug)]
pub struct RoundTripTestResult {
/// Source block checksums (keyed by source block ID)
pub source_checksums: HashMap<usize, BlockChecksum>,
/// Destination block checksums (keyed by destination block ID)
pub dest_checksums: HashMap<usize, BlockChecksum>,
/// Block ID mapping used (src_id, dst_id)
pub block_mapping: Vec<(usize, usize)>,
/// Whether all checksums matched
pub success: bool,
/// Mismatched blocks (if any)
pub mismatches: Vec<(usize, usize)>, // (src_id, dst_id) pairs that didn't match
}
impl RoundTripTestResult {
/// Check if the round-trip test passed.
pub fn is_success(&self) -> bool {
self.success
}
/// Get the number of blocks tested.
pub fn num_blocks(&self) -> usize {
self.block_mapping.len()
}
/// Get a detailed report of the test results.
pub fn report(&self) -> String {
if self.success {
format!(
"Round-trip test PASSED: {}/{} blocks verified successfully",
self.num_blocks(),
self.num_blocks()
)
} else {
format!(
"Round-trip test FAILED: {}/{} blocks mismatched\nMismatches: {:?}",
self.mismatches.len(),
self.num_blocks(),
self.mismatches
)
}
}
}
/// Builder for round-trip tests.
///
/// This allows configuring a test that transfers data from source blocks
/// to intermediate storage and back to different destination blocks,
/// verifying data integrity via checksums.
pub struct RoundTripTest {
/// Source physical layout (must be local)
source: PhysicalLayout,
/// Intermediate physical layout (can be remote/device/disk)
intermediate: PhysicalLayout,
/// Destination physical layout (must be local)
destination: PhysicalLayout,
/// Block mapping: (src_id, intermediate_id, dst_id)
block_mapping: Vec<(usize, usize, usize)>,
/// Fill pattern for source blocks
fill_pattern: FillPattern,
}
impl RoundTripTest {
/// Create a new round-trip test.
///
/// # Arguments
/// * `source` - Source physical layout (must be local)
/// * `intermediate` - Intermediate physical layout
/// * `destination` - Destination physical layout (must be local)
pub fn new(
source: PhysicalLayout,
intermediate: PhysicalLayout,
destination: PhysicalLayout,
) -> Result<Self> {
if source.is_remote() {
return Err(anyhow!("Source layout must be local"));
}
if destination.is_remote() {
return Err(anyhow!("Destination layout must be local"));
}
Ok(Self {
source,
intermediate,
destination,
block_mapping: Vec::new(),
fill_pattern: FillPattern::Sequential,
})
}
/// Set the fill pattern for source blocks.
pub fn with_fill_pattern(mut self, pattern: FillPattern) -> Self {
self.fill_pattern = pattern;
self
}
/// Add a block mapping for the round-trip test.
///
/// # Arguments
/// * `src_id` - Source block ID
/// * `intermediate_id` - Intermediate block ID
/// * `dst_id` - Destination block ID
pub fn add_block_mapping(
mut self,
src_id: usize,
intermediate_id: usize,
dst_id: usize,
) -> Self {
self.block_mapping.push((src_id, intermediate_id, dst_id));
self
}
/// Add multiple block mappings at once.
///
/// This is a convenience method for adding several mappings.
pub fn with_block_mappings(mut self, mappings: &[(usize, usize, usize)]) -> Self {
self.block_mapping.extend_from_slice(mappings);
self
}
/// Run the round-trip test.
///
/// # Workflow
/// 1. Fill source blocks with the specified pattern
/// 2. Compute source checksums
/// 3. Transfer source → intermediate
/// 4. Transfer intermediate → destination
/// 5. Compute destination checksums
/// 6. Compare checksums
///
/// # Arguments
/// * `ctx` - Transfer context with CUDA stream and NIXL agent
pub async fn run(self, ctx: &TransferContext) -> Result<RoundTripTestResult> {
if self.block_mapping.is_empty() {
return Err(anyhow!("No block mappings specified"));
}
// Step 1: Fill source blocks
let src_ids: Vec<usize> = self.block_mapping.iter().map(|(src, _, _)| *src).collect();
fill_blocks(&self.source, &src_ids, self.fill_pattern)?;
// Step 2: Compute source checksums
let source_checksums = compute_block_checksums(&self.source, &src_ids)?;
// Step 3: Transfer source → intermediate
let src_ids_intermediate: Vec<usize> =
self.block_mapping.iter().map(|(src, _, _)| *src).collect();
let inter_ids_from_src: Vec<usize> = self
.block_mapping
.iter()
.map(|(_, inter, _)| *inter)
.collect();
let notification = transfer_blocks(
&self.source,
&self.intermediate,
&src_ids_intermediate,
&inter_ids_from_src,
ctx,
)?;
notification.await?;
// Step 4: Transfer intermediate → destination
let inter_ids_to_dst: Vec<usize> = self
.block_mapping
.iter()
.map(|(_, inter, _)| *inter)
.collect();
let dst_ids_from_inter: Vec<usize> =
self.block_mapping.iter().map(|(_, _, dst)| *dst).collect();
let notification = transfer_blocks(
&self.intermediate,
&self.destination,
&inter_ids_to_dst,
&dst_ids_from_inter,
ctx,
)?;
notification.await?;
// Step 5: Compute destination checksums
let dst_ids: Vec<usize> = self.block_mapping.iter().map(|(_, _, dst)| *dst).collect();
let dest_checksums = compute_block_checksums(&self.destination, &dst_ids)?;
// Step 6: Compare checksums
let mut mismatches = Vec::new();
for (src_id, _, dst_id) in &self.block_mapping {
let src_checksum = &source_checksums[src_id];
let dst_checksum = &dest_checksums[dst_id];
if src_checksum != dst_checksum {
mismatches.push((*src_id, *dst_id));
}
}
let success = mismatches.is_empty();
let block_mapping: Vec<(usize, usize)> = self
.block_mapping
.iter()
.map(|(src, _, dst)| (*src, *dst))
.collect();
Ok(RoundTripTestResult {
source_checksums,
dest_checksums,
block_mapping,
success,
mismatches,
})
}
}
#[cfg(test, features = "testing-cuda")]
mod tests {
use super::*;
use crate::block_manager::v2::layout::{
FullyContiguousLayout, Layout, LayoutConfig, MemoryRegion, OwnedMemoryRegion,
};
use std::sync::Arc;
// Helper to create a minimal transfer context for testing
// In real tests with CUDA/NIXL, this would be properly constructed
fn create_test_context() -> TransferContext {
// For now, we'll skip these tests if CUDA is not available
// In the future, we can mock TransferContext or use conditional compilation
todo!("Create test context - requires CUDA/NIXL setup")
}
#[tokio::test]
async fn test_round_trip_host_to_host() {
// Create three layouts: source, intermediate, destination
let (src_layout, _src_mem) = create_test_layout(4);
let (inter_layout, _inter_mem) = create_test_layout(4);
let (dst_layout, _dst_mem) = create_test_layout(4);
let source = PhysicalLayout::new_local(src_layout, StorageKind::System);
let intermediate = PhysicalLayout::new_local(inter_layout, StorageKind::Pinned);
let destination = PhysicalLayout::new_local(dst_layout, StorageKind::System);
// Build round-trip test with different block IDs
// Source: blocks [0, 1, 2, 3]
// Intermediate: blocks [0, 1, 2, 3]
// Destination: blocks [0, 1, 2, 3] (different memory than source)
let test = RoundTripTest::new(source, intermediate, destination)
.unwrap()
.with_fill_pattern(FillPattern::Sequential)
.add_block_mapping(0, 0, 0)
.add_block_mapping(1, 1, 1)
.add_block_mapping(2, 2, 2)
.add_block_mapping(3, 3, 3);
// Create a transfer context (requires actual CUDA/NIXL setup)
let ctx = create_test_context();
// Run the test
let result = test.run(&ctx).await.unwrap();
assert!(result.is_success(), "{}", result.report());
assert_eq!(result.num_blocks(), 4);
}
#[tokio::test]
async fn test_round_trip_different_block_ids() {
// Create layouts with enough blocks
let (src_layout, _src_mem) = create_test_layout(8);
let (inter_layout, _inter_mem) = create_test_layout(8);
let (dst_layout, _dst_mem) = create_test_layout(8);
let source = PhysicalLayout::new_local(src_layout, StorageKind::System);
let intermediate = PhysicalLayout::new_local(inter_layout, StorageKind::Pinned);
let destination = PhysicalLayout::new_local(dst_layout, StorageKind::System);
// Test with non-overlapping block IDs
// Source: blocks [0, 1, 2, 3]
// Intermediate: blocks [2, 3, 4, 5]
// Destination: blocks [4, 5, 6, 7]
let test = RoundTripTest::new(source, intermediate, destination)
.unwrap()
.with_fill_pattern(FillPattern::BlockBased)
.with_block_mappings(&[(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)]);
let ctx = create_test_context();
let result = test.run(&ctx).await.unwrap();
assert!(result.is_success(), "{}", result.report());
assert_eq!(result.num_blocks(), 4);
}
#[test]
fn test_round_trip_builder() {
let (src_layout, _) = create_test_layout(4);
let (inter_layout, _) = create_test_layout(4);
let (dst_layout, _) = create_test_layout(4);
let source = PhysicalLayout::new_local(src_layout, StorageKind::System);
let intermediate = PhysicalLayout::new_local(inter_layout, StorageKind::Pinned);
let destination = PhysicalLayout::new_local(dst_layout, StorageKind::System);
let test = RoundTripTest::new(source, intermediate, destination)
.unwrap()
.with_fill_pattern(FillPattern::Constant(42))
.add_block_mapping(0, 0, 1)
.add_block_mapping(1, 1, 2);
assert_eq!(test.block_mapping.len(), 2);
}
#[test]
fn test_round_trip_requires_local_source() {
let (src_layout, _) = create_test_layout(1);
let (inter_layout, _) = create_test_layout(1);
let (dst_layout, _) = create_test_layout(1);
let source =
PhysicalLayout::new_remote(src_layout, StorageKind::System, "remote".to_string());
let intermediate = PhysicalLayout::new_local(inter_layout, StorageKind::Pinned);
let destination = PhysicalLayout::new_local(dst_layout, StorageKind::System);
let result = RoundTripTest::new(source, intermediate, destination);
assert!(result.is_err());
}
#[test]
fn test_round_trip_requires_local_destination() {
let (src_layout, _) = create_test_layout(1);
let (inter_layout, _) = create_test_layout(1);
let (dst_layout, _) = create_test_layout(1);
let source = PhysicalLayout::new_local(src_layout, StorageKind::System);
let intermediate = PhysicalLayout::new_local(inter_layout, StorageKind::Pinned);
let destination =
PhysicalLayout::new_remote(dst_layout, StorageKind::System, "remote".to_string());
let result = RoundTripTest::new(source, intermediate, destination);
assert!(result.is_err());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Local transfer tests where source and destination use the same NIXL agent.
//!
//! These tests verify data integrity across:
//! - Different storage types (System, Pinned, Device)
//! - Different layout types (Fully Contiguous, Layer-wise)
//! - Different transfer strategies (Memcpy, CUDA H2D/D2H)
use super::*;
use crate::block_manager::v2::physical::layout::BlockDimension;
use crate::block_manager::v2::physical::transfer::executor::execute_transfer;
use crate::block_manager::v2::physical::transfer::{
BlockChecksum, BounceBufferSpec, FillPattern, StorageKind, TransferCapabilities,
TransferOptions, compute_block_checksums, compute_layer_checksums, fill_blocks, fill_layers,
};
use anyhow::Result;
use rstest::rstest;
use std::collections::HashMap;
use std::ops::Range;
use std::sync::Arc;
// ============================================================================
// System <=> System Tests (Memcpy)
// ============================================================================
#[derive(Clone)]
enum LayoutType {
FC,
LW,
}
fn build_layout(
agent: NixlAgent,
layout_type: LayoutType,
storage_kind: StorageKind,
num_blocks: usize,
) -> PhysicalLayout {
match layout_type {
LayoutType::FC => create_fc_layout(agent, storage_kind, num_blocks),
LayoutType::LW => create_lw_layout(agent, storage_kind, num_blocks),
}
}
/// Layout kind for parameterized testing.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LayoutKind {
/// Fully contiguous layout
FC,
/// Layer-wise (layer-separate) layout
LW,
}
/// Storage and layout specification for creating test layouts.
#[derive(Debug, Clone, Copy)]
pub struct LayoutSpec {
pub kind: LayoutKind,
pub storage: StorageKind,
}
impl LayoutSpec {
pub fn new(kind: LayoutKind, storage: StorageKind) -> Self {
Self { kind, storage }
}
}
/// Transfer mode for parameterized testing.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransferMode {
/// Transfer entire blocks (all layers)
FullBlocks,
/// Transfer only the first layer
FirstLayerOnly,
/// Transfer only the second layer
SecondLayerOnly,
}
impl TransferMode {
/// Convert to optional layer range for execute_transfer.
pub fn layer_range(&self) -> Option<Range<usize>> {
match self {
TransferMode::FullBlocks => None,
TransferMode::FirstLayerOnly => Some(0..1),
TransferMode::SecondLayerOnly => Some(1..2),
}
}
/// Get a descriptive suffix for test names.
pub fn suffix(&self) -> &'static str {
match self {
TransferMode::FullBlocks => "full",
TransferMode::FirstLayerOnly => "layer0",
TransferMode::SecondLayerOnly => "layer1",
}
}
}
/// Create a fully contiguous physical layout with the specified storage type.
pub fn create_fc_layout(
agent: NixlAgent,
storage_kind: StorageKind,
num_blocks: usize,
) -> PhysicalLayout {
let config = standard_config(num_blocks);
let builder = PhysicalLayout::builder(agent)
.with_config(config)
.fully_contiguous();
match storage_kind {
StorageKind::System => builder.allocate_system().build().unwrap(),
StorageKind::Pinned => builder.allocate_pinned(false).build().unwrap(),
StorageKind::Device(device_id) => builder.allocate_device(device_id).build().unwrap(),
StorageKind::Disk(_) => builder.allocate_disk(None).build().unwrap(),
}
}
/// Create a layer-separate physical layout with the specified storage type.
pub fn create_lw_layout(
agent: NixlAgent,
storage_kind: StorageKind,
num_blocks: usize,
) -> PhysicalLayout {
let config = standard_config(num_blocks);
let builder = PhysicalLayout::builder(agent)
.with_config(config)
.layer_separate(BlockDimension::BlockIsFirstDim);
match storage_kind {
StorageKind::System => builder.allocate_system().build().unwrap(),
StorageKind::Pinned => builder.allocate_pinned(false).build().unwrap(),
StorageKind::Device(device_id) => builder.allocate_device(device_id).build().unwrap(),
StorageKind::Disk(_) => builder.allocate_disk(None).build().unwrap(),
}
}
/// Create a physical layout based on the specification.
///
/// This is a DRY helper that dispatches to create_fc_layout or create_lw_layout
/// based on the layout kind in the spec.
pub fn create_layout(agent: NixlAgent, spec: LayoutSpec, num_blocks: usize) -> PhysicalLayout {
match spec.kind {
LayoutKind::FC => create_fc_layout(agent, spec.storage, num_blocks),
LayoutKind::LW => create_lw_layout(agent, spec.storage, num_blocks),
}
}
/// Fill blocks or layers based on transfer mode and compute checksums.
///
/// This is a mode-aware version of fill_and_checksum that handles both
/// full block transfers and layer-wise transfers.
pub fn fill_and_checksum_with_mode(
layout: &PhysicalLayout,
block_ids: &[usize],
pattern: FillPattern,
mode: TransferMode,
) -> Result<HashMap<usize, BlockChecksum>> {
match mode {
TransferMode::FullBlocks => {
fill_blocks(layout, block_ids, pattern)?;
compute_block_checksums(layout, block_ids)
}
TransferMode::FirstLayerOnly => {
fill_layers(layout, block_ids, 0..1, pattern)?;
compute_layer_checksums(layout, block_ids, 0..1)
}
TransferMode::SecondLayerOnly => {
fill_layers(layout, block_ids, 1..2, pattern)?;
compute_layer_checksums(layout, block_ids, 1..2)
}
}
}
/// Verify checksums with transfer mode awareness.
///
/// This is a mode-aware version that handles both full block and layer-wise verification.
pub fn verify_checksums_by_position_with_mode(
src_checksums: &HashMap<usize, BlockChecksum>,
src_block_ids: &[usize],
dst_layout: &PhysicalLayout,
dst_block_ids: &[usize],
mode: TransferMode,
) -> Result<()> {
assert_eq!(
src_block_ids.len(),
dst_block_ids.len(),
"Source and destination block arrays must have same length"
);
let dst_checksums = match mode {
TransferMode::FullBlocks => compute_block_checksums(dst_layout, dst_block_ids)?,
TransferMode::FirstLayerOnly => compute_layer_checksums(dst_layout, dst_block_ids, 0..1)?,
TransferMode::SecondLayerOnly => compute_layer_checksums(dst_layout, dst_block_ids, 1..2)?,
};
for (src_id, dst_id) in src_block_ids.iter().zip(dst_block_ids.iter()) {
let src_checksum = src_checksums
.get(src_id)
.unwrap_or_else(|| panic!("Missing source checksum for block {}", src_id));
let dst_checksum = dst_checksums
.get(dst_id)
.unwrap_or_else(|| panic!("Missing destination checksum for block {}", dst_id));
assert_eq!(
src_checksum, dst_checksum,
"Checksum mismatch (mode={:?}): src[{}] != dst[{}]: {} != {}",
mode, src_id, dst_id, src_checksum, dst_checksum
);
}
Ok(())
}
/// Create a test agent with specific backends.
pub fn create_test_agent_with_backends(name: &str, backends: &[&str]) -> Result<NixlAgent> {
NixlAgent::new_with_backends(name, backends)
}
/// Create a transport manager for testing with the specified agent.
///
/// Note: The agent should already have backends configured. Use `create_test_agent`
/// or `build_agent_with_backends` to create properly configured agents.
pub fn create_transfer_context(
agent: NixlAgent,
capabilities: Option<TransferCapabilities>,
) -> Result<crate::block_manager::v2::physical::manager::TransportManager> {
crate::block_manager::v2::physical::manager::TransportManager::builder()
.capabilities(capabilities.unwrap_or_default())
.worker_id(0) // Default worker ID for local tests
.nixl_agent(agent)
.cuda_device_id(0)
.build()
}
/// Fill blocks and compute checksums.
///
/// This can only be called on System or Pinned layouts.
pub fn fill_and_checksum(
layout: &PhysicalLayout,
block_ids: &[usize],
pattern: FillPattern,
) -> Result<HashMap<usize, BlockChecksum>> {
fill_blocks(layout, block_ids, pattern)?;
compute_block_checksums(layout, block_ids)
}
/// Verify that destination block checksums match the expected source checksums.
///
/// This function compares checksums in order, assuming the source and destination
/// block arrays have a 1:1 correspondence (src[i] was transferred to dst[i]).
pub fn verify_checksums_by_position(
src_checksums: &HashMap<usize, BlockChecksum>,
src_block_ids: &[usize],
dst_layout: &PhysicalLayout,
dst_block_ids: &[usize],
) -> Result<()> {
assert_eq!(
src_block_ids.len(),
dst_block_ids.len(),
"Source and destination block arrays must have same length"
);
let dst_checksums = compute_block_checksums(dst_layout, dst_block_ids)?;
for (src_id, dst_id) in src_block_ids.iter().zip(dst_block_ids.iter()) {
let src_checksum = src_checksums
.get(src_id)
.unwrap_or_else(|| panic!("Missing source checksum for block {}", src_id));
let dst_checksum = dst_checksums
.get(dst_id)
.unwrap_or_else(|| panic!("Missing destination checksum for block {}", dst_id));
assert_eq!(
src_checksum, dst_checksum,
"Checksum mismatch: src[{}] != dst[{}]: {} != {}",
src_id, dst_id, src_checksum, dst_checksum
);
}
Ok(())
}
/// Fill guard blocks and return their checksums for later verification.
///
/// Guard blocks are blocks adjacent to transfer destinations that should
/// remain unchanged during transfers. This function fills them with a
/// distinctive pattern and returns their checksums for later validation.
///
/// # Arguments
/// * `layout` - The physical layout containing the guard blocks
/// * `guard_block_ids` - Block IDs to use as guards
/// * `pattern` - Fill pattern for guard blocks (typically a constant like 0xFF)
///
/// # Returns
/// A map of block ID to checksum for all guard blocks
pub fn create_guard_blocks(
layout: &PhysicalLayout,
guard_block_ids: &[usize],
pattern: FillPattern,
) -> Result<HashMap<usize, BlockChecksum>> {
fill_blocks(layout, guard_block_ids, pattern)?;
compute_block_checksums(layout, guard_block_ids)
}
/// Verify that guard blocks remain unchanged after transfers.
///
/// This function compares the current checksums of guard blocks against
/// their expected values. Any mismatch indicates memory corruption or
/// unintended overwrites during transfer operations.
///
/// # Arguments
/// * `layout` - The physical layout containing the guard blocks
/// * `guard_block_ids` - Block IDs to verify
/// * `expected_checksums` - Expected checksums from create_guard_blocks
///
/// # Errors
/// Returns an error if any guard block checksum has changed
pub fn verify_guard_blocks_unchanged(
layout: &PhysicalLayout,
guard_block_ids: &[usize],
expected_checksums: &HashMap<usize, BlockChecksum>,
) -> Result<()> {
let current_checksums = compute_block_checksums(layout, guard_block_ids)?;
for &block_id in guard_block_ids {
let expected = expected_checksums
.get(&block_id)
.unwrap_or_else(|| panic!("Missing expected checksum for guard block {}", block_id));
let current = current_checksums
.get(&block_id)
.unwrap_or_else(|| panic!("Missing current checksum for guard block {}", block_id));
if expected != current {
return Err(anyhow::anyhow!(
"Guard block {} was modified during transfer! Expected: {}, Got: {}",
block_id,
expected,
current
));
}
}
Ok(())
}
struct DummyBounceBufferSpec {
pub layout: PhysicalLayout,
pub block_ids: Vec<usize>,
}
impl BounceBufferSpec for DummyBounceBufferSpec {
fn layout(&self) -> &PhysicalLayout {
&self.layout
}
fn block_ids(&self) -> &[usize] {
&self.block_ids
}
}
fn build_agent_for_kinds(src_kind: StorageKind, dst_kind: StorageKind) -> Result<NixlAgent> {
use std::collections::HashSet;
let mut backends = HashSet::new();
// Determine required backends for both source and destination
for kind in [src_kind, dst_kind] {
match kind {
StorageKind::System | StorageKind::Pinned => {
backends.insert("POSIX"); // Lightweight for DRAM
}
StorageKind::Device(_) => {
backends.insert("UCX"); // Required for VRAM (expensive)
}
StorageKind::Disk(_) => {
backends.insert("POSIX"); // Required for disk I/O
}
}
}
// Optional: Add GDS for Device <-> Disk optimization
match (src_kind, dst_kind) {
(StorageKind::Device(_), StorageKind::Disk(_))
| (StorageKind::Disk(_), StorageKind::Device(_)) => {
backends.insert("GDS_MT");
}
_ => {}
}
let backend_vec: Vec<&str> = backends.into_iter().collect();
create_test_agent_with_backends("agent", &backend_vec)
}
#[rstest]
#[tokio::test]
async fn test_p2p(
#[values(LayoutType::FC, LayoutType::LW)] src_layout: LayoutType,
#[values(
StorageKind::System,
StorageKind::Pinned,
StorageKind::Device(0),
StorageKind::Disk(0)
)]
src_kind: StorageKind,
#[values(LayoutType::FC, LayoutType::LW)] dst_layout: LayoutType,
#[values(
StorageKind::System,
StorageKind::Pinned,
StorageKind::Device(0),
StorageKind::Disk(0)
)]
dst_kind: StorageKind,
) -> Result<()> {
use crate::block_manager::v2::physical::transfer::TransferOptions;
let agent = build_agent_for_kinds(src_kind, dst_kind)?;
let src = build_layout(agent.clone(), src_layout, src_kind, 4);
let dst = build_layout(agent.clone(), dst_layout, dst_kind, 4);
let bounce_layout = build_layout(agent.clone(), LayoutType::FC, StorageKind::Pinned, 4);
let bounce_buffer_spec: Arc<dyn BounceBufferSpec> = Arc::new(DummyBounceBufferSpec {
layout: bounce_layout,
block_ids: vec![0, 1],
});
let src_blocks = vec![0, 1];
let dst_blocks = vec![2, 3];
let checksums = fill_and_checksum(&src, &src_blocks, FillPattern::Sequential)?;
let ctx = create_transfer_context(agent, None).unwrap();
let options = TransferOptions::builder()
.bounce_buffer(bounce_buffer_spec)
.build()?;
let notification =
execute_transfer(&src, &dst, &src_blocks, &dst_blocks, options, ctx.context())?;
notification.await?;
verify_checksums_by_position(&checksums, &src_blocks, &dst, &dst_blocks)?;
Ok(())
}
#[rstest]
#[tokio::test]
async fn test_roundtrip(
#[values(LayoutType::FC, LayoutType::LW)] src_layout: LayoutType,
#[values(StorageKind::System, StorageKind::Pinned, StorageKind::Device(0))]
src_kind: StorageKind,
#[values(LayoutType::FC, LayoutType::LW)] inter_layout: LayoutType,
#[values(StorageKind::System, StorageKind::Pinned, StorageKind::Device(0))]
inter_kind: StorageKind,
#[values(LayoutType::FC, LayoutType::LW)] dst_layout: LayoutType,
#[values(StorageKind::System, StorageKind::Pinned, StorageKind::Device(0))]
dst_kind: StorageKind,
) -> Result<()> {
let agent = build_agent_for_kinds(src_kind, dst_kind)?;
// Create layouts: source pinned, device intermediate, destination pinned
let src = build_layout(agent.clone(), src_layout, src_kind, 4);
let device = build_layout(agent.clone(), inter_layout, inter_kind, 4);
let dst = build_layout(agent.clone(), dst_layout, dst_kind, 4);
let src_blocks = vec![0, 1];
let device_blocks = vec![0, 1];
let dst_blocks = vec![2, 3];
// Fill source and compute checksums
let checksums = fill_and_checksum(&src, &src_blocks, FillPattern::Sequential)?;
let ctx = create_transfer_context(agent, None).unwrap();
// Transfer: Pinned[0,1] -> Device[0,1]
let notification = execute_transfer(
&src,
&device,
&src_blocks,
&device_blocks,
TransferOptions::default(),
ctx.context(),
)?;
notification.await?;
// Transfer: Device[0,1] -> Pinned[2,3]
let notification = execute_transfer(
&device,
&dst,
&device_blocks,
&dst_blocks,
TransferOptions::default(),
ctx.context(),
)?;
notification.await?;
// Verify checksums match
verify_checksums_by_position(&checksums, &src_blocks, &dst, &dst_blocks)?;
Ok(())
}
#[rstest]
#[case(StorageKind::Device(0), StorageKind::Disk(0))]
#[case(StorageKind::Disk(0), StorageKind::Device(0))]
#[tokio::test]
async fn test_gds(
#[case] src_kind: StorageKind,
#[values(LayoutType::FC, LayoutType::LW)] src_layout: LayoutType,
#[case] dst_kind: StorageKind,
#[values(LayoutType::FC, LayoutType::LW)] dst_layout: LayoutType,
) -> Result<()> {
let capabilities = TransferCapabilities::default().with_gds_if_supported();
if !capabilities.allow_gds {
println!("System does not support GDS. Skipping test.");
return Ok(());
}
let agent = build_agent_for_kinds(src_kind, dst_kind)?;
let src = build_layout(agent.clone(), src_layout, src_kind, 4);
let dst = build_layout(agent.clone(), dst_layout, dst_kind, 4);
let src_blocks = vec![0, 1];
let dst_blocks = vec![2, 3];
let checksums = fill_and_checksum(&src, &src_blocks, FillPattern::Sequential)?;
let ctx = create_transfer_context(agent, Some(capabilities)).unwrap();
let notification = execute_transfer(
&src,
&dst,
&src_blocks,
&dst_blocks,
TransferOptions::default(),
ctx.context(),
)?;
notification.await?;
verify_checksums_by_position(&checksums, &src_blocks, &dst, &dst_blocks)?;
Ok(())
}
#[rstest]
#[case(StorageKind::Device(0), StorageKind::Disk(0))]
#[case(StorageKind::Disk(0), StorageKind::Device(0))]
#[tokio::test]
async fn test_buffered_transfer(
#[case] src_kind: StorageKind,
#[values(LayoutType::FC, LayoutType::LW)] src_layout: LayoutType,
#[case] dst_kind: StorageKind,
#[values(LayoutType::FC, LayoutType::LW)] dst_layout: LayoutType,
) -> Result<()> {
let agent = build_agent_for_kinds(src_kind, dst_kind)?;
let src = build_layout(agent.clone(), src_layout, src_kind, 5);
let dst = build_layout(agent.clone(), dst_layout, dst_kind, 5);
let src_blocks = vec![0, 1, 2, 3, 4];
let dst_blocks = vec![4, 3, 2, 1, 0];
let bounce_layout = build_layout(agent.clone(), LayoutType::FC, StorageKind::Pinned, 3);
let bounce_buffer_spec: Arc<dyn BounceBufferSpec> = Arc::new(DummyBounceBufferSpec {
layout: bounce_layout,
block_ids: vec![0, 1, 2],
});
let checksums = fill_and_checksum(&src, &src_blocks, FillPattern::Sequential)?;
let ctx = create_transfer_context(agent, None).unwrap();
let notification = execute_transfer(
&src,
&dst,
&src_blocks,
&dst_blocks,
TransferOptions::builder()
.bounce_buffer(bounce_buffer_spec)
.build()?,
ctx.context(),
)?;
notification.await?;
verify_checksums_by_position(&checksums, &src_blocks, &dst, &dst_blocks)?;
Ok(())
}
#[rstest]
#[case(1024)]
#[case(2048)]
#[case(4096)]
#[case(8192)]
#[case(16384)]
#[tokio::test]
async fn test_large_block_counts(#[case] block_count: usize) {
let agent = create_test_agent(&format!("test_large_block_counts_{}", block_count));
let src = create_fc_layout(agent.clone(), StorageKind::Pinned, block_count);
let device = create_fc_layout(agent.clone(), StorageKind::Device(0), block_count);
let src_blocks = (0..block_count).collect::<Vec<_>>();
let device_blocks = (0..block_count).collect::<Vec<_>>();
let ctx = create_transfer_context(agent, None).unwrap();
let notification = execute_transfer(
&src,
&device,
&src_blocks,
&device_blocks,
TransferOptions::default(),
ctx.context(),
)
.unwrap();
notification.await.unwrap();
}
// ============================================================================
// Parameterized Bounce Tests with Guard Block Validation
// ============================================================================
/// Test bounce transfers with guard block validation.
///
/// This test validates that:
/// 1. Data can be transferred: host[src_blocks] → bounce[src_blocks] → host[dst_blocks]
/// 2. Guard blocks adjacent to dst_blocks remain unchanged (no memory corruption)
/// 3. Works correctly with different storage types, layouts, and transfer modes
///
/// Test pattern (6 blocks total):
/// - Source blocks: [0, 1]
/// - Destination blocks: [3, 4]
/// - Guard blocks: [2, 5] (adjacent to destination, should remain unchanged)
#[rstest]
// Storage combinations (host, bounce)
#[case(StorageKind::System, StorageKind::Pinned, "sys_pin")]
#[case(StorageKind::Pinned, StorageKind::System, "pin_sys")]
#[case(StorageKind::Pinned, StorageKind::Device(0), "pin_dev")]
#[tokio::test]
async fn test_bounce_with_guards_fc_fc_full(
#[case] host_storage: StorageKind,
#[case] bounce_storage: StorageKind,
#[case] name_suffix: &str,
) {
test_bounce_with_guards_impl(
host_storage,
bounce_storage,
LayoutKind::FC,
LayoutKind::FC,
TransferMode::FullBlocks,
name_suffix,
)
.await
.unwrap();
}
#[rstest]
#[case(StorageKind::System, StorageKind::Pinned, "sys_pin")]
#[case(StorageKind::Pinned, StorageKind::System, "pin_sys")]
#[case(StorageKind::Pinned, StorageKind::Device(0), "pin_dev")]
#[tokio::test]
async fn test_bounce_with_guards_fc_lw_full(
#[case] host_storage: StorageKind,
#[case] bounce_storage: StorageKind,
#[case] name_suffix: &str,
) {
test_bounce_with_guards_impl(
host_storage,
bounce_storage,
LayoutKind::FC,
LayoutKind::LW,
TransferMode::FullBlocks,
name_suffix,
)
.await
.unwrap();
}
#[rstest]
#[case(StorageKind::System, StorageKind::Pinned, "sys_pin")]
#[case(StorageKind::Pinned, StorageKind::System, "pin_sys")]
#[case(StorageKind::Pinned, StorageKind::Device(0), "pin_dev")]
#[tokio::test]
async fn test_bounce_with_guards_lw_fc_full(
#[case] host_storage: StorageKind,
#[case] bounce_storage: StorageKind,
#[case] name_suffix: &str,
) {
test_bounce_with_guards_impl(
host_storage,
bounce_storage,
LayoutKind::LW,
LayoutKind::FC,
TransferMode::FullBlocks,
name_suffix,
)
.await
.unwrap();
}
#[rstest]
#[case(StorageKind::System, StorageKind::Pinned, "sys_pin")]
#[case(StorageKind::Pinned, StorageKind::System, "pin_sys")]
#[case(StorageKind::Pinned, StorageKind::Device(0), "pin_dev")]
#[tokio::test]
async fn test_bounce_with_guards_lw_lw_full(
#[case] host_storage: StorageKind,
#[case] bounce_storage: StorageKind,
#[case] name_suffix: &str,
) {
test_bounce_with_guards_impl(
host_storage,
bounce_storage,
LayoutKind::LW,
LayoutKind::LW,
TransferMode::FullBlocks,
name_suffix,
)
.await
.unwrap();
}
#[rstest]
#[case(StorageKind::Pinned, StorageKind::Device(0), "pin_dev")]
#[tokio::test]
async fn test_bounce_with_guards_fc_fc_layer0(
#[case] host_storage: StorageKind,
#[case] bounce_storage: StorageKind,
#[case] name_suffix: &str,
) {
test_bounce_with_guards_impl(
host_storage,
bounce_storage,
LayoutKind::FC,
LayoutKind::FC,
TransferMode::FirstLayerOnly,
name_suffix,
)
.await
.unwrap();
}
#[rstest]
#[case(StorageKind::Pinned, StorageKind::Device(0), "pin_dev")]
#[tokio::test]
async fn test_bounce_with_guards_lw_lw_layer0(
#[case] host_storage: StorageKind,
#[case] bounce_storage: StorageKind,
#[case] name_suffix: &str,
) {
test_bounce_with_guards_impl(
host_storage,
bounce_storage,
LayoutKind::LW,
LayoutKind::LW,
TransferMode::FirstLayerOnly,
name_suffix,
)
.await
.unwrap();
}
/// Implementation helper for bounce tests with guard blocks.
async fn test_bounce_with_guards_impl(
host_storage: StorageKind,
bounce_storage: StorageKind,
host_layout: LayoutKind,
bounce_layout: LayoutKind,
mode: TransferMode,
name_suffix: &str,
) -> Result<()> {
let num_blocks = 6;
let test_name = format!(
"bounce_{}_{:?}_{:?}_{}_{}",
name_suffix,
host_layout,
bounce_layout,
mode.suffix(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis()
);
let agent = create_test_agent(&test_name);
// Create layouts
let host = create_layout(
agent.clone(),
LayoutSpec::new(host_layout, host_storage),
num_blocks,
);
let bounce = create_layout(
agent.clone(),
LayoutSpec::new(bounce_layout, bounce_storage),
num_blocks,
);
// Block assignments:
// - Transfer: host[0,1] → bounce[0,1] → host[3,4]
// - Guards: host[2,5] (should remain unchanged)
let src_blocks = vec![0, 1];
let dst_blocks = vec![3, 4];
let guard_blocks = vec![2, 5];
// Setup: Fill source blocks and guard blocks
let src_checksums =
fill_and_checksum_with_mode(&host, &src_blocks, FillPattern::Sequential, mode)?;
let guard_checksums = create_guard_blocks(&host, &guard_blocks, FillPattern::Constant(0xFF))?;
let ctx = create_transfer_context(agent, None)?;
// Execute bounce: host[0,1] → bounce[0,1]
let notification = execute_transfer(
&host,
&bounce,
&src_blocks,
&src_blocks,
TransferOptions::from_layer_range(mode.layer_range()),
ctx.context(),
)?;
notification.await?;
// Execute bounce: bounce[0,1] → host[3,4]
let notification = execute_transfer(
&bounce,
&host,
&src_blocks,
&dst_blocks,
TransferOptions::from_layer_range(mode.layer_range()),
ctx.context(),
)?;
notification.await?;
// Verify: Data integrity + guards unchanged
verify_checksums_by_position_with_mode(&src_checksums, &src_blocks, &host, &dst_blocks, mode)?;
verify_guard_blocks_unchanged(&host, &guard_blocks, &guard_checksums)?;
Ok(())
}
// ============================================================================
// Parameterized Direct Transfer Tests
// ============================================================================
/// Test direct transfers with parameterization over storage, layout, and transfer mode.
///
/// This demonstrates the DRY parameterized approach that can replace the 18 individual
/// tests above (System<=>System, Pinned<=>Pinned, cross-type, etc).
///
/// Note: Only tests System<=>System, Pinned<=>Pinned, and System<=>Pinned since we can only
/// fill/checksum System and Pinned storage. For Device tests, use bounce tests instead.
#[rstest]
// Storage combinations (only fillable storage types)
#[case(StorageKind::System, StorageKind::System, "sys_sys")]
#[case(StorageKind::Pinned, StorageKind::Pinned, "pin_pin")]
#[case(StorageKind::System, StorageKind::Pinned, "sys_pin")]
#[case(StorageKind::Pinned, StorageKind::System, "pin_sys")]
#[tokio::test]
async fn test_direct_transfer_fc_fc_full(
#[case] src_storage: StorageKind,
#[case] dst_storage: StorageKind,
#[case] name_suffix: &str,
) {
test_direct_transfer_impl(
src_storage,
dst_storage,
LayoutKind::FC,
LayoutKind::FC,
TransferMode::FullBlocks,
name_suffix,
)
.await
.unwrap();
}
#[rstest]
#[case(StorageKind::System, StorageKind::Pinned, "sys_pin")]
#[case(StorageKind::Pinned, StorageKind::System, "pin_sys")]
#[tokio::test]
async fn test_direct_transfer_fc_lw_layer0(
#[case] src_storage: StorageKind,
#[case] dst_storage: StorageKind,
#[case] name_suffix: &str,
) {
test_direct_transfer_impl(
src_storage,
dst_storage,
LayoutKind::FC,
LayoutKind::LW,
TransferMode::FirstLayerOnly,
name_suffix,
)
.await
.unwrap();
}
#[rstest]
#[case(StorageKind::Pinned, StorageKind::Pinned, "pin_pin")]
#[tokio::test]
async fn test_direct_transfer_lw_lw_layer1(
#[case] src_storage: StorageKind,
#[case] dst_storage: StorageKind,
#[case] name_suffix: &str,
) {
test_direct_transfer_impl(
src_storage,
dst_storage,
LayoutKind::LW,
LayoutKind::LW,
TransferMode::SecondLayerOnly,
name_suffix,
)
.await
.unwrap();
}
/// Implementation helper for direct transfer tests.
async fn test_direct_transfer_impl(
src_storage: StorageKind,
dst_storage: StorageKind,
src_layout: LayoutKind,
dst_layout: LayoutKind,
mode: TransferMode,
name_suffix: &str,
) -> Result<()> {
let num_blocks = 4;
let test_name = format!(
"direct_{}_{:?}_{:?}_{}_{}",
name_suffix,
src_layout,
dst_layout,
mode.suffix(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis()
);
let agent = create_test_agent(&test_name);
// Create layouts
let src = create_layout(
agent.clone(),
LayoutSpec::new(src_layout, src_storage),
num_blocks,
);
let dst = create_layout(
agent.clone(),
LayoutSpec::new(dst_layout, dst_storage),
num_blocks,
);
// Transfer src[0,1] -> dst[2,3]
let src_blocks = vec![0, 1];
let dst_blocks = vec![2, 3];
// Fill source and compute checksums
let src_checksums =
fill_and_checksum_with_mode(&src, &src_blocks, FillPattern::Sequential, mode)?;
let ctx = create_transfer_context(agent, None)?;
// Execute transfer
let notification = execute_transfer(
&src,
&dst,
&src_blocks,
&dst_blocks,
TransferOptions::from_layer_range(mode.layer_range()),
ctx.context(),
)?;
notification.await?;
// Verify data integrity
verify_checksums_by_position_with_mode(&src_checksums, &src_blocks, &dst, &dst_blocks, mode)?;
Ok(())
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Comprehensive transfer tests for verifying data integrity across storage types and layout configurations.
#[cfg(all(feature = "testing-cuda", feature = "testing-nixl"))]
mod local_transfers;
use super::{NixlAgent, PhysicalLayout};
use crate::block_manager::v2::physical::layout::{
LayoutConfig,
builder::{HasConfig, NoLayout, NoMemory, PhysicalLayoutBuilder},
};
/// Standard layout configuration for all tests.
pub fn standard_config(num_blocks: usize) -> LayoutConfig {
LayoutConfig::builder()
.num_blocks(num_blocks)
.num_layers(2)
.outer_dim(2)
.page_size(16)
.inner_dim(128)
.dtype_width_bytes(2)
.build()
.unwrap()
}
/// Helper function for creating a PhysicalLayout builder with standard config.
///
/// This is used by other test modules (fill, checksum, validation) for backwards compatibility.
pub fn builder(num_blocks: usize) -> PhysicalLayoutBuilder<HasConfig, NoLayout, NoMemory> {
let agent = create_test_agent("test_agent");
let config = standard_config(num_blocks);
PhysicalLayout::builder(agent).with_config(config)
}
/// Create a test agent with optimal backends for testing.
///
/// Attempts to initialize UCX, GDS, and POSIX backends. Falls back gracefully
/// if some backends are unavailable (e.g., GDS on non-DGX machines).
pub fn create_test_agent(name: &str) -> NixlAgent {
NixlAgent::require_backends(name, &[]).expect("Failed to require backends")
}
#[cfg(feature = "testing-cuda")]
pub(crate) mod cuda {
use anyhow::Result;
use cudarc::driver::sys::CUdevice_attribute_enum;
use cudarc::driver::{CudaContext, CudaStream, LaunchConfig, PushKernelArg};
use cudarc::nvrtc::{CompileOptions, compile_ptx_with_opts};
use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
use std::time::{Duration, Instant};
/// CUDA sleep kernel source code.
pub const SLEEP_KERNEL_SRC: &str = r#"
extern "C" __global__ void sleep_kernel(unsigned long long min_cycles) {
const unsigned long long start = clock64();
while ((clock64() - start) < min_cycles) {
asm volatile("");
}
}
"#;
/// A reusable CUDA sleep utility for tests.
///
/// This struct provides a simple interface to execute GPU sleep operations
/// with calibrated timing. It compiles the sleep kernel once per CUDA context
/// and caches the calibration for reuse.
///
/// The calibration is conservative (prefers longer sleep durations over shorter)
/// to ensure minimum sleep times are met.
pub struct CudaSleep {
function: cudarc::driver::CudaFunction,
cycles_per_ms: f64,
}
impl CudaSleep {
/// Get or create a CudaSleep instance for the given CUDA context.
///
/// This function uses lazy initialization and caches instances per device ID.
/// The first call for each device will compile the kernel and run calibration.
///
/// # Arguments
/// * `cuda_ctx` - The CUDA context to use
///
/// # Returns
/// A shared reference to the CudaSleep instance for this context's device.
pub fn for_context(cuda_ctx: &Arc<CudaContext>) -> Result<Arc<Self>> {
static INSTANCES: OnceLock<parking_lot::Mutex<HashMap<usize, Arc<CudaSleep>>>> =
OnceLock::new();
let instances = INSTANCES.get_or_init(|| parking_lot::Mutex::new(HashMap::new()));
let device_ordinal = cuda_ctx.ordinal();
// Fast path: check if instance already exists
{
let instances_guard = instances.lock();
if let Some(instance) = instances_guard.get(&device_ordinal) {
return Ok(Arc::clone(instance));
}
}
// Slow path: create new instance with calibration
let instance = Arc::new(Self::new(cuda_ctx)?);
// Store in cache
let mut instances_guard = instances.lock();
instances_guard
.entry(device_ordinal)
.or_insert_with(|| Arc::clone(&instance));
Ok(instance)
}
/// Create a new CudaSleep instance with calibration.
///
/// This compiles the sleep kernel and runs a calibration loop to determine
/// the relationship between clock cycles and wall-clock time.
fn new(cuda_ctx: &Arc<CudaContext>) -> Result<Self> {
// Get device compute capability
let major = cuda_ctx
.attribute(CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR)?;
let minor = cuda_ctx
.attribute(CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR)?;
// Compile PTX for this device
let mut compile_opts = CompileOptions {
name: Some("sleep_kernel.cu".into()),
..Default::default()
};
compile_opts
.options
.push(format!("--gpu-architecture=compute_{}{}", major, minor));
let ptx = compile_ptx_with_opts(SLEEP_KERNEL_SRC, compile_opts)?;
let module = cuda_ctx.load_module(ptx)?;
let function = module.load_function("sleep_kernel")?;
// Get device clock rate
let clock_rate_khz =
cuda_ctx.attribute(CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_CLOCK_RATE)? as u64;
// Create a temporary stream for calibration
let stream = cuda_ctx.new_stream()?;
// Warm up to absorb JIT overhead
let warm_cycles = clock_rate_khz.saturating_mul(10).max(1);
Self::launch_kernel(&function, &stream, warm_cycles)?;
stream.synchronize()?;
// Run calibration loop
let desired_delay = Duration::from_millis(600);
let mut target_cycles = clock_rate_khz.saturating_mul(50).max(1); // ~50ms starting point
let mut actual_duration = Duration::ZERO;
for _ in 0..8 {
let start = Instant::now();
Self::launch_kernel(&function, &stream, target_cycles)?;
stream.synchronize()?;
actual_duration = start.elapsed();
if actual_duration >= desired_delay {
break;
}
target_cycles = target_cycles.saturating_mul(2);
}
// Calculate cycles per millisecond with conservative 20% margin
// (prefer longer sleeps over shorter)
let cycles_per_ms = if actual_duration.as_millis() > 0 {
(target_cycles as f64 / actual_duration.as_millis() as f64) * 1.2
} else {
clock_rate_khz as f64 // Fallback to clock rate
};
Ok(Self {
function,
cycles_per_ms,
})
}
/// Launch the sleep kernel with the specified number of cycles.
fn launch_kernel(
function: &cudarc::driver::CudaFunction,
stream: &Arc<CudaStream>,
cycles: u64,
) -> Result<()> {
let launch_cfg = LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
};
let mut launch = stream.launch_builder(function);
unsafe {
launch.arg(&cycles);
launch.launch(launch_cfg)?;
}
Ok(())
}
/// Launch a sleep operation on the given stream.
///
/// This queues a GPU kernel that will sleep for approximately the specified
/// duration. The sleep is conservative and may take longer than requested.
///
/// # Arguments
/// * `duration` - The minimum duration to sleep
/// * `stream` - The CUDA stream to launch the kernel on
///
/// # Returns
/// Ok(()) if the kernel was successfully queued
pub fn launch(&self, duration: Duration, stream: &Arc<CudaStream>) -> Result<()> {
let target_cycles = (duration.as_millis() as f64 * self.cycles_per_ms) as u64;
Self::launch_kernel(&self.function, stream, target_cycles)
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment