// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. //! This module implements the `WriteToStrategy` and `ReadFromStrategy` traits //! for the common storage types. use super::*; impl WriteToStrategy for SystemStorage { #[inline(always)] fn write_to_strategy() -> TransferStrategy { TransferStrategy::Memcpy } } impl WriteToStrategy for SystemStorage { #[inline(always)] fn write_to_strategy() -> TransferStrategy { TransferStrategy::Memcpy } } impl WriteToStrategy for SystemStorage { #[inline(always)] fn write_to_strategy() -> TransferStrategy { TransferStrategy::CudaBlockingH2D } } impl WriteToStrategy for PinnedStorage { #[inline(always)] fn write_to_strategy() -> TransferStrategy { TransferStrategy::Memcpy } } impl WriteToStrategy for PinnedStorage { #[inline(always)] fn write_to_strategy() -> TransferStrategy { TransferStrategy::Memcpy } } impl WriteToStrategy for PinnedStorage { #[inline(always)] fn write_to_strategy() -> TransferStrategy { TransferStrategy::CudaAsyncH2D } } impl WriteToStrategy for DeviceStorage { #[inline(always)] fn write_to_strategy() -> TransferStrategy { TransferStrategy::CudaBlockingD2H } } impl WriteToStrategy for DeviceStorage { #[inline(always)] fn write_to_strategy() -> TransferStrategy { TransferStrategy::CudaAsyncD2H } } impl WriteToStrategy for DeviceStorage { #[inline(always)] fn write_to_strategy() -> TransferStrategy { TransferStrategy::CudaAsyncD2D } } impl WriteToStrategy for S { #[inline(always)] fn write_to_strategy() -> TransferStrategy { TransferStrategy::NixlWrite } } impl ReadFromStrategy for SystemStorage where S: WriteToStrategy + Storage + Local, { #[inline(always)] fn read_from_strategy() -> TransferStrategy { S::write_to_strategy() } } impl ReadFromStrategy for PinnedStorage where S: WriteToStrategy + Storage + Local, { #[inline(always)] fn read_from_strategy() -> TransferStrategy { S::write_to_strategy() } } impl ReadFromStrategy for DeviceStorage where S: WriteToStrategy + Storage + Local, { #[inline(always)] fn read_from_strategy() -> TransferStrategy { S::write_to_strategy() } } impl ReadFromStrategy for S { #[inline(always)] fn read_from_strategy() -> TransferStrategy { TransferStrategy::NixlRead } } #[cfg(test)] mod tests { use super::*; #[test] fn write_to_strategy() { // System to ... assert_eq!( >::write_to_strategy(), TransferStrategy::Memcpy ); assert_eq!( >::write_to_strategy(), TransferStrategy::Memcpy ); assert_eq!( >::write_to_strategy(), TransferStrategy::CudaBlockingH2D ); assert_eq!( >::write_to_strategy(), TransferStrategy::NixlWrite ); // Pinned to ... assert_eq!( >::write_to_strategy(), TransferStrategy::Memcpy ); assert_eq!( >::write_to_strategy(), TransferStrategy::Memcpy ); assert_eq!( >::write_to_strategy(), TransferStrategy::CudaAsyncH2D ); assert_eq!( >::write_to_strategy(), TransferStrategy::NixlWrite ); // Device to ... assert_eq!( >::write_to_strategy(), TransferStrategy::CudaBlockingD2H ); assert_eq!( >::write_to_strategy(), TransferStrategy::CudaAsyncD2H ); assert_eq!( >::write_to_strategy(), TransferStrategy::CudaAsyncD2D ); assert_eq!( >::write_to_strategy(), TransferStrategy::NixlWrite ); // Nixl to ... should fail to compile // assert_eq!( // >::write_to_strategy(), // TransferStrategy::Invalid // ); // assert_eq!( // >::write_to_strategy(), // TransferStrategy::Invalid // ); // assert_eq!( // >::write_to_strategy(), // TransferStrategy::Invalid // ); // assert_eq!( // >::write_to_strategy(), // TransferStrategy::Invalid // ); } #[test] fn read_from_strategy() { // System to ... assert_eq!( >::read_from_strategy(), TransferStrategy::Memcpy ); assert_eq!( >::read_from_strategy(), TransferStrategy::Memcpy ); assert_eq!( >::read_from_strategy(), TransferStrategy::CudaBlockingD2H ); assert_eq!( >::read_from_strategy(), TransferStrategy::NixlRead ); // Pinned to ... assert_eq!( >::read_from_strategy(), TransferStrategy::Memcpy ); assert_eq!( >::read_from_strategy(), TransferStrategy::Memcpy ); assert_eq!( >::read_from_strategy(), TransferStrategy::CudaAsyncD2H ); assert_eq!( >::read_from_strategy(), TransferStrategy::NixlRead ); // Device to ... assert_eq!( >::read_from_strategy(), TransferStrategy::CudaBlockingH2D ); assert_eq!( >::read_from_strategy(), TransferStrategy::CudaAsyncH2D ); assert_eq!( >::read_from_strategy(), TransferStrategy::CudaAsyncD2D ); assert_eq!( >::read_from_strategy(), TransferStrategy::NixlRead ); // Nixl to ... should fail to compile // assert_eq!( // >::read_from_strategy(), // TransferStrategy::Invalid // ); // // assert_eq!( // >::read_from_strategy(), // TransferStrategy::Invalid // ); // // assert_eq!( // >::read_from_strategy(), // TransferStrategy::Invalid // ); // // assert_eq!( // >::read_from_strategy(), // TransferStrategy::Invalid // ); } }