/************************************************************************* * Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved. * Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ #include "devcomm.h" #include "collectives.h" #include "common.h" __shared__ ncclShmemData ncclShmem; #if __CUDA_ARCH__ < 700 __shared__ ulong2 ncclShmemPerWarp[ncclShmemScratchWarpSize()*(NCCL_MAX_NTHREADS/WARP_SIZE)/sizeof(ulong2)]; #endif #if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) #else #define NCCL_FUNC5(func, algo, devredop, type, nullify) \ MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \ MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL128, devredop, type)), \ MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, SIMPLE, devredop, type)) #define NCCL_FUNC4(func, devredop, type, nullify) \ NCCL_FUNC5(func, TREE, devredop, type, nullify), \ NCCL_FUNC5(func, RING, devredop, type, nullify), \ NCCL_FUNC5(func, COLLNET_DIRECT, devredop, type, nullify), \ NCCL_FUNC5(func, COLLNET_CHAIN, devredop, type, nullify), \ NCCL_FUNC5(func, NVLS, devredop, type, nullify), \ NCCL_FUNC5(func, NVLS_TREE, devredop, type, nullify) #if defined(__CUDA_BF16_TYPES_EXIST__) // Must be consistent with ncclDataType_t #define NCCL_FUNCS3A(func, devredop, nullForFloat) \ NCCL_FUNC4(func, devredop, int8_t, 0), \ NCCL_FUNC4(func, devredop, uint8_t, 0), \ NCCL_FUNC4(func, devredop, int32_t, 0), \ NCCL_FUNC4(func, devredop, uint32_t, 0), \ NCCL_FUNC4(func, devredop, int64_t, 0), \ NCCL_FUNC4(func, devredop, uint64_t, 0), \ NCCL_FUNC4(func, devredop, half, nullForFloat), \ NCCL_FUNC4(func, devredop, float, nullForFloat), \ NCCL_FUNC4(func, devredop, double, nullForFloat), \ NCCL_FUNC4(func, devredop, __nv_bfloat16, nullForFloat) #define NCCL_FUNCS3B(func, devredop) \ NCCL_FUNC4(func, devredop, int8_t, 0), \ NCCL_FUNC4(func, devredop, int8_t, 0), \ NCCL_FUNC4(func, devredop, int8_t, 0), \ NCCL_FUNC4(func, devredop, int8_t, 0), \ NCCL_FUNC4(func, devredop, int8_t, 0), \ NCCL_FUNC4(func, devredop, int8_t, 0), \ NCCL_FUNC4(func, devredop, int8_t, 0), \ NCCL_FUNC4(func, devredop, int8_t, 0), \ NCCL_FUNC4(func, devredop, int8_t, 0), \ NCCL_FUNC4(func, devredop, int8_t, 0) #else // Must be consistent with ncclDataType_t #define NCCL_FUNCS3A(func, devredop, nullForFloat) \ NCCL_FUNC4(func, devredop, int8_t, 0), \ NCCL_FUNC4(func, devredop, uint8_t, 0), \ NCCL_FUNC4(func, devredop, int32_t, 0), \ NCCL_FUNC4(func, devredop, uint32_t, 0), \ NCCL_FUNC4(func, devredop, int64_t, 0), \ NCCL_FUNC4(func, devredop, uint64_t, 0), \ NCCL_FUNC4(func, devredop, half, nullForFloat), \ NCCL_FUNC4(func, devredop, float, nullForFloat), \ NCCL_FUNC4(func, devredop, double, nullForFloat) #define NCCL_FUNCS3B(func, devredop) \ NCCL_FUNC4(func, devredop, int8_t, 0), \ NCCL_FUNC4(func, devredop, int8_t, 0), \ NCCL_FUNC4(func, devredop, int8_t, 0), \ NCCL_FUNC4(func, devredop, int8_t, 0), \ NCCL_FUNC4(func, devredop, int8_t, 0), \ NCCL_FUNC4(func, devredop, int8_t, 0), \ NCCL_FUNC4(func, devredop, int8_t, 0), \ NCCL_FUNC4(func, devredop, int8_t, 0), \ NCCL_FUNC4(func, devredop, int8_t, 0) #endif // Must be consistent with ncclRedOp_t #define NCCL_FUNCS2A(func) \ NCCL_FUNCS3A(func, Sum, /*nullForFloat=*/0), \ NCCL_FUNCS3A(func, Prod, /*nullForFloat=*/0), \ NCCL_FUNCS3A(func, Max, /*nullForFloat=*/0), \ NCCL_FUNCS3A(func, Min, /*nullForFloat=*/0), \ NCCL_FUNCS3A(func, PreMulSum, /*nullForFloat=*/0), \ NCCL_FUNCS3A(func, SumPostDiv, /*nullForFloat=*/1) #define NCCL_FUNCS2B(func) \ NCCL_FUNCS3B(func, Sum), \ NCCL_FUNCS3B(func, Sum), \ NCCL_FUNCS3B(func, Sum), \ NCCL_FUNCS3B(func, Sum), \ NCCL_FUNCS3B(func, Sum), \ NCCL_FUNCS3B(func, Sum) // Must be consistent with the ncclFuncSet enum __device__ ncclKern_t ncclFuncs[1+ncclNumTypes+NCCL_NUM_FUNCTIONS*ncclNumDevRedOps*ncclNumTypes*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS] = { // Don't try to initialize the host shadow copy of this device-side global // variable. There is no host pointer to a device-side function, which // confuses clang. This will be fixed in the next clang release. #if __CUDA_ARCH__ NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t), NCCL_ONERANK_REDUCE_NAME(PreMulSum, int8_t), NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint8_t), NCCL_ONERANK_REDUCE_NAME(PreMulSum, int32_t), NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint32_t), NCCL_ONERANK_REDUCE_NAME(PreMulSum, int64_t), NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint64_t), NCCL_ONERANK_REDUCE_NAME(PreMulSum, half), NCCL_ONERANK_REDUCE_NAME(PreMulSum, float), NCCL_ONERANK_REDUCE_NAME(PreMulSum, double), #if defined(__CUDA_BF16_TYPES_EXIST__) NCCL_ONERANK_REDUCE_NAME(PreMulSum, __nv_bfloat16), #endif NCCL_FUNCS2B(Broadcast), NCCL_FUNCS2A(Reduce), NCCL_FUNCS2B(AllGather), NCCL_FUNCS2A(ReduceScatter), NCCL_FUNCS2A(AllReduce) #endif }; #endif // Workaround for https://reviews.llvm.org/D55580 __device__ void ncclWorkaroundClangD55580() {}