// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. pub mod context; mod cuda; mod memcpy; mod nixl; mod strategy; use super::*; use crate::block_manager::storage::{ DeviceStorage, DiskStorage, PinnedStorage, SystemStorage, nixl::{NixlRegisterableStorage, NixlStorage}, }; use cudarc::driver::CudaStream; use nixl_sys::NixlDescriptor; use nixl_sys::XferOp::{Read, Write}; use std::ops::Range; use tokio::sync::oneshot; pub use crate::block_manager::storage::{CudaAccessible, Local, Remote}; pub use async_trait::async_trait; pub use context::TransferContext; /// A block that can be the target of a write pub trait Writable {} /// A block that can be the source of a read pub trait Readable {} pub trait Mutable: Readable + Writable {} pub trait Immutable: Readable {} #[derive(Debug)] pub enum BlockTarget { Source, Destination, } #[derive(Debug, thiserror::Error)] pub enum TransferError { #[error("Builder configuration error: {0}")] BuilderError(String), #[error("Transfer execution failed: {0}")] ExecutionError(String), #[error("Incompatible block types provided: {0}")] IncompatibleTypes(String), #[error("Mismatched source/destination counts: {0} sources, {1} destinations")] CountMismatch(usize, usize), #[error("Block operation failed: {0}")] BlockError(#[from] BlockError), // TODO: Add NIXL specific errors #[error("No blocks provided")] NoBlocksProvided, #[error("Mismatched {0:?} block set index: {1} != {2}")] MismatchedBlockSetIndex(BlockTarget, usize, usize), #[error("Mismatched {0:?} worker ID: {1} != {2}")] MismatchedWorkerID(BlockTarget, usize, usize), #[error(transparent)] Other(#[from] anyhow::Error), } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum NixlTransfer { Read, Write, } impl NixlTransfer { pub fn as_xfer_op(&self) -> nixl_sys::XferOp { match self { NixlTransfer::Read => Read, NixlTransfer::Write => Write, } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum TransferStrategy { Memcpy, CudaAsyncH2D, CudaAsyncD2H, CudaAsyncD2D, CudaBlockingH2D, CudaBlockingD2H, Nixl(NixlTransfer), Invalid, } /// Trait for determining the transfer strategy for writing from a local /// source to a target destination which could be local or remote pub trait WriteToStrategy { fn write_to_strategy() -> TransferStrategy { TransferStrategy::Invalid } } /// Trait for determining the transfer strategy for reading from a /// `Source` which could be local or remote into `Self` which must /// be both local and writable. pub trait ReadFromStrategy { fn read_from_strategy() -> TransferStrategy { TransferStrategy::Invalid } } impl WriteToStrategy for RB where ::StorageType: Local + WriteToStrategy<::StorageType>, { #[inline(always)] fn write_to_strategy() -> TransferStrategy { <::StorageType as WriteToStrategy< ::StorageType, >>::write_to_strategy() } } impl ReadFromStrategy for WB where ::StorageType: Remote, ::StorageType: NixlRegisterableStorage, { #[inline(always)] fn read_from_strategy() -> TransferStrategy { TransferStrategy::Nixl(NixlTransfer::Read) } } pub fn handle_local_transfer( sources: &[RB], targets: &mut [WB], ctx: Arc, ) -> Result, TransferError> where RB: ReadableBlock + WriteToStrategy + Local, WB: WritableBlock, ::StorageType: NixlDescriptor, ::StorageType: NixlDescriptor, { let (tx, rx) = oneshot::channel(); match RB::write_to_strategy() { TransferStrategy::Memcpy => { for (src, dst) in sources.iter().zip(targets.iter_mut()) { // TODO: Unlike all other transfer strategies, this is fully blocking. // We probably want some sort of thread pool to handle these. memcpy::copy_block(src, dst)?; } tx.send(()).unwrap(); Ok(rx) } TransferStrategy::CudaAsyncH2D | TransferStrategy::CudaAsyncD2H | TransferStrategy::CudaAsyncD2D => { for (src, dst) in sources.iter().zip(targets.iter_mut()) { cuda::copy_block(src, dst, ctx.stream().as_ref(), RB::write_to_strategy())?; } ctx.cuda_event(tx)?; Ok(rx) } TransferStrategy::Nixl(transfer_type) => { let transfer_fut = nixl::write_blocks_to(sources, targets, &ctx, transfer_type)?; ctx.async_rt_handle().spawn(async move { transfer_fut.await; tx.send(()).unwrap(); }); Ok(rx) } _ => Err(TransferError::IncompatibleTypes(format!( "Unsupported copy strategy: {:?}", RB::write_to_strategy() ))), } } pub trait WriteTo { fn write_to( &self, dst: &mut Vec, ctx: Arc, ) -> Result, TransferError>; } impl WriteTo for Vec where RB: ReadableBlock + WriteToStrategy + Local, ::StorageType: NixlDescriptor, ::StorageType: NixlDescriptor, RB: BlockDataProvider, WB: WritableBlock + BlockDataProviderMut, { fn write_to( &self, dst: &mut Vec, ctx: Arc, ) -> Result, TransferError> { L::handle_transfer(self, dst, ctx) } } #[cfg(test)] mod tests { use super::*; #[test] fn write_to_strategy() { // System to ... assert_eq!( >::write_to_strategy(), TransferStrategy::Memcpy ); assert_eq!( >::write_to_strategy(), TransferStrategy::Memcpy ); assert_eq!( >::write_to_strategy(), TransferStrategy::CudaBlockingH2D ); assert_eq!( >::write_to_strategy(), TransferStrategy::Nixl(NixlTransfer::Write) ); // Pinned to ... assert_eq!( >::write_to_strategy(), TransferStrategy::Memcpy ); assert_eq!( >::write_to_strategy(), TransferStrategy::Memcpy ); assert_eq!( >::write_to_strategy(), TransferStrategy::CudaAsyncH2D ); assert_eq!( >::write_to_strategy(), TransferStrategy::Nixl(NixlTransfer::Write) ); // Device to ... assert_eq!( >::write_to_strategy(), TransferStrategy::CudaBlockingD2H ); assert_eq!( >::write_to_strategy(), TransferStrategy::CudaAsyncD2H ); assert_eq!( >::write_to_strategy(), TransferStrategy::CudaAsyncD2D ); assert_eq!( >::write_to_strategy(), TransferStrategy::Nixl(NixlTransfer::Write) ); // Nixl to ... should fail to compile // assert_eq!( // >::write_to_strategy(), // TransferStrategy::Invalid // ); // assert_eq!( // >::write_to_strategy(), // TransferStrategy::Invalid // ); // assert_eq!( // >::write_to_strategy(), // TransferStrategy::Invalid // ); // assert_eq!( // >::write_to_strategy(), // TransferStrategy::Invalid // ); } }