// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 pub mod context; mod cuda; mod memcpy; mod nixl; mod strategy; use super::*; use crate::block_manager::storage::{ DeviceStorage, DiskStorage, PinnedStorage, SystemStorage, nixl::{NixlRegisterableStorage, NixlStorage}, }; use nixl_sys::NixlDescriptor; use nixl_sys::XferOp::{Read, Write}; use std::ops::Range; use tokio::sync::oneshot; pub use crate::block_manager::storage::{CudaAccessible, Local, Remote}; pub use async_trait::async_trait; pub use context::{PoolConfig, TransferContext}; /// A block that can be the target of a write pub trait Writable {} /// A block that can be the source of a read pub trait Readable {} pub trait Mutable: Readable + Writable {} pub trait Immutable: Readable {} #[derive(Debug)] pub enum BlockTarget { Source, Destination, } #[derive(Debug, thiserror::Error)] pub enum TransferError { #[error("Builder configuration error: {0}")] BuilderError(String), #[error("Transfer execution failed: {0}")] ExecutionError(String), #[error("Incompatible block types provided: {0}")] IncompatibleTypes(String), #[error("Mismatched source/destination counts: {0} sources, {1} destinations")] CountMismatch(usize, usize), #[error("Block operation failed: {0}")] BlockError(#[from] BlockError), // TODO: Add NIXL specific errors #[error("No blocks provided")] NoBlocksProvided, #[error("Mismatched {0:?} block set index: {1} != {2}")] MismatchedBlockSetIndex(BlockTarget, usize, usize), #[error("Mismatched {0:?} worker ID: {1} != {2}")] MismatchedWorkerID(BlockTarget, usize, usize), #[error(transparent)] Other(#[from] anyhow::Error), } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum NixlTransfer { Read, Write, } impl NixlTransfer { pub fn as_xfer_op(&self) -> nixl_sys::XferOp { match self { NixlTransfer::Read => Read, NixlTransfer::Write => Write, } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum CudaTransferMode { /// Use the custom CUDA kernel for G1 <-> G2 transfers Custom, /// Use the default CUDA async memcpy for G1 <-> G2 transfers Default, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum TransferStrategy { Memcpy, CudaAsyncH2D, CudaAsyncD2H, CudaAsyncD2D, CudaBlockingH2D, CudaBlockingD2H, Nixl(NixlTransfer), Invalid, } /// Trait for determining the transfer strategy for writing from a local /// source to a target destination which could be local or remote pub trait WriteToStrategy { fn write_to_strategy() -> TransferStrategy { TransferStrategy::Invalid } } /// Trait for determining the transfer strategy for reading from a /// `Source` which could be local or remote into `Self` which must /// be both local and writable. pub trait ReadFromStrategy { fn read_from_strategy() -> TransferStrategy { TransferStrategy::Invalid } } impl WriteToStrategy for RB where ::StorageType: Local + WriteToStrategy<::StorageType>, { #[inline(always)] fn write_to_strategy() -> TransferStrategy { <::StorageType as WriteToStrategy< ::StorageType, >>::write_to_strategy() } } impl ReadFromStrategy for WB where ::StorageType: Remote, ::StorageType: NixlRegisterableStorage, { #[inline(always)] fn read_from_strategy() -> TransferStrategy { TransferStrategy::Nixl(NixlTransfer::Read) } } #[inline] fn resolve_cuda_transfer_mode( base_strategy: TransferStrategy, is_contiguous: bool, ) -> CudaTransferMode { match base_strategy { TransferStrategy::CudaAsyncH2D => { if is_contiguous { CudaTransferMode::Default } else { CudaTransferMode::Custom } } TransferStrategy::CudaAsyncD2H => { if is_contiguous { CudaTransferMode::Default } else { CudaTransferMode::Custom } } other => panic!( "resolve_cuda_strategy called with non-CUDA strategy: {:?}", other ), } } pub fn handle_local_transfer( sources: &[RB], targets: &mut [WB], ctx: Arc, ) -> Result, TransferError> where RB: ReadableBlock + WriteToStrategy + Local, WB: WritableBlock, ::StorageType: NixlDescriptor, ::StorageType: NixlDescriptor, { // Check for empty slices and length mismatch early if sources.is_empty() && targets.is_empty() { tracing::warn!( "handle_local_transfer called with both sources and targets empty, skipping transfer" ); let (tx, rx) = oneshot::channel(); tx.send(()).unwrap(); return Ok(rx); } if sources.len() != targets.len() { return Err(TransferError::CountMismatch(sources.len(), targets.len())); } let (tx, rx) = oneshot::channel(); match RB::write_to_strategy() { TransferStrategy::Memcpy => { for (src, dst) in sources.iter().zip(targets.iter_mut()) { // TODO: Unlike all other transfer strategies, this is fully blocking. // We probably want some sort of thread pool to handle these. memcpy::copy_block(src, dst)?; } tx.send(()).unwrap(); Ok(rx) } TransferStrategy::CudaAsyncH2D | TransferStrategy::CudaAsyncD2H | TransferStrategy::CudaAsyncD2D => { tracing::debug!( "Transfer: Using CUDA strategy: {:?}", RB::write_to_strategy() ); if RB::write_to_strategy() == TransferStrategy::CudaAsyncH2D || RB::write_to_strategy() == TransferStrategy::CudaAsyncD2H { let is_contiguous = sources[0].block_data().is_fully_contiguous() && targets[0].block_data().is_fully_contiguous(); let transfer_mode = resolve_cuda_transfer_mode(RB::write_to_strategy(), is_contiguous); match transfer_mode { CudaTransferMode::Custom => { let selected_stream = ctx.stream(); cuda::copy_blocks_with_customized_kernel( sources, targets, selected_stream.as_ref(), &ctx, )?; } CudaTransferMode::Default => { for (src, dst) in sources.iter().zip(targets.iter_mut()) { cuda::copy_block( src, dst, ctx.stream().as_ref(), RB::write_to_strategy(), )?; } } }; ctx.cuda_event(tx)?; Ok(rx) } else { // Fall back to individual copy for D2Dblocks for (src, dst) in sources.iter().zip(targets.iter_mut()) { cuda::copy_block(src, dst, ctx.stream().as_ref(), RB::write_to_strategy())?; } ctx.cuda_event(tx)?; Ok(rx) } } TransferStrategy::Nixl(transfer_type) => { let transfer_fut = nixl::write_blocks_to(sources, targets, &ctx, transfer_type)?; ctx.async_rt_handle().spawn(async move { transfer_fut.await; tx.send(()).unwrap(); }); Ok(rx) } _ => Err(TransferError::IncompatibleTypes(format!( "Unsupported copy strategy: {:?}", RB::write_to_strategy() ))), } } pub trait WriteTo { fn write_to( &self, dst: &mut Vec, ctx: Arc, ) -> Result, TransferError>; } impl WriteTo for Vec where RB: ReadableBlock + WriteToStrategy + Local, ::StorageType: NixlDescriptor, ::StorageType: NixlDescriptor, RB: BlockDataProvider, WB: WritableBlock + BlockDataProviderMut, { fn write_to( &self, dst: &mut Vec, ctx: Arc, ) -> Result, TransferError> { L::handle_transfer(self, dst, ctx) } } #[cfg(test)] mod tests { use super::*; #[test] fn write_to_strategy() { // System to ... assert_eq!( >::write_to_strategy(), TransferStrategy::Memcpy ); assert_eq!( >::write_to_strategy(), TransferStrategy::Memcpy ); assert_eq!( >::write_to_strategy(), TransferStrategy::CudaBlockingH2D ); assert_eq!( >::write_to_strategy(), TransferStrategy::Nixl(NixlTransfer::Write) ); // Pinned to ... assert_eq!( >::write_to_strategy(), TransferStrategy::Memcpy ); assert_eq!( >::write_to_strategy(), TransferStrategy::Memcpy ); assert_eq!( >::write_to_strategy(), TransferStrategy::CudaAsyncH2D ); assert_eq!( >::write_to_strategy(), TransferStrategy::Nixl(NixlTransfer::Write) ); // Device to ... assert_eq!( >::write_to_strategy(), TransferStrategy::CudaBlockingD2H ); assert_eq!( >::write_to_strategy(), TransferStrategy::CudaAsyncD2H ); assert_eq!( >::write_to_strategy(), TransferStrategy::CudaAsyncD2D ); assert_eq!( >::write_to_strategy(), TransferStrategy::Nixl(NixlTransfer::Write) ); // Nixl to ... should fail to compile // assert_eq!( // >::write_to_strategy(), // TransferStrategy::Invalid // ); // assert_eq!( // >::write_to_strategy(), // TransferStrategy::Invalid // ); // assert_eq!( // >::write_to_strategy(), // TransferStrategy::Invalid // ); // assert_eq!( // >::write_to_strategy(), // TransferStrategy::Invalid // ); } }