/************************************************************************* * Copyright (c) 2015-2022, NVIDIA CORPORATION. All rights reserved. * Modifications Copyright (c) 2019-2022 Advanced Micro Devices, Inc. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ #ifndef NCCL_COMMON_KERNEL_H_ #define NCCL_COMMON_KERNEL_H_ #include "devcomm.h" #include "op128.h" #include "reduce_kernel.h" #include #include #include #define __syncwarp() #define SDMA_SPEC_DST 0x55 // Define min for ssize_t inline __device__ int min(int a, ssize_t b) { return (a < b) ? a : b; } inline __device__ int loadInt(int* ptr) { int v; v = atomicAdd((unsigned long long *)ptr, 0); return v; } template __device__ __forceinline__ void reduceCopyPacks( int nThreads, int &thread, uint64_t redArg, uint64_t *preOpArgs, bool postOp, int nSrcs, void **srcPtrs, int nDsts, void **dstPtrs, IntBytes &nBytesBehind, IntBytes &nBytesAhead ) { static_assert(std::is_signed::value, "IntBytes must be a signed integral type."); //if (BytePerPack == 0) __trap(); // A hunk is the amount of contiguous data a warp consumes per loop iteration // assuming all threads partake. constexpr int BytePerHunk = Unroll*WARP_SIZE*BytePerPack; int nWarps = nThreads/WARP_SIZE; int warp = thread/WARP_SIZE; int lane = thread%WARP_SIZE; // This thread's initial position. IntBytes threadBytesBehind = nBytesBehind + (warp*BytePerHunk + lane*BytePerPack); IntBytes threadBytesAhead = nBytesAhead - (warp*BytePerHunk + lane*BytePerPack); // Number of hunks to be consumed over all warps. IntBytes nHunksAhead = nBytesAhead/(BytePerHunk + !BytePerHunk); // Advance collective position. nBytesBehind += nHunksAhead*BytePerHunk; nBytesAhead -= nHunksAhead*BytePerHunk; if (Unroll==1 && BytePerPack <= nBytesAhead) { // Only Unroll=1 can do partial hunks (where not all threads partake). nHunksAhead += 1; nBytesBehind += nBytesAhead - (nBytesAhead%(BytePerPack + !BytePerPack)); nBytesAhead = nBytesAhead%(BytePerPack + !BytePerPack); } nHunksAhead -= warp; RedFn redFn(redArg); uintptr_t minSrcs[MinSrcs + !MinSrcs]; uintptr_t minDsts[MinDsts + !MinDsts]; #pragma unroll for (int s=0; s < MinSrcs; s++) minSrcs[s] = cvta_to_global(srcPtrs[s]) + threadBytesBehind; #pragma unroll for (int d=0; d < MinDsts; d++) minDsts[d] = cvta_to_global(dstPtrs[d]) + threadBytesBehind; // We dictate loop termination condition according to whether partial hunks // can be handled or not. while (Unroll==1 ? (BytePerPack <= threadBytesAhead) : (0 < nHunksAhead)) { BytePack acc[Unroll]; { RedFn preFn(0 < PreOpSrcs ? preOpArgs[0] : 0); #pragma unroll Unroll for (int u=0; u < Unroll; u++) { if (0 < MultimemSrcs) { // applyLoadMultimem uses relaxed semantics for same reason we use volatile below. acc[u] = applyLoadMultimem(preFn, minSrcs[0]); } else { // Use volatile loads in case credits are polled for with volatile (instead of acquire). acc[u] = ld_volatile_global(minSrcs[0]); } minSrcs[0] += WARP_SIZE*BytePerPack; if (0 < PreOpSrcs) acc[u] = applyPreOp(preFn, acc[u]); } } #pragma unroll Unroll for (int s=1; s < MinSrcs; s++) { BytePack tmp[Unroll]; RedFn preFn(s < PreOpSrcs ? preOpArgs[s] : 0); #pragma unroll Unroll for (int u=0; u < Unroll; u++) { if (s < MultimemSrcs) { // applyLoadMultimem uses relaxed semantics for same reason we use volatile below. acc[u] = applyLoadMultimem(preFn, minSrcs[s]); } else { // Use volatile loads in case credits are polled for with volatile (instead of acquire). tmp[u] = ld_volatile_global(minSrcs[s]); } minSrcs[s] += WARP_SIZE*BytePerPack; } #pragma unroll Unroll for (int u=0; u < Unroll; u++) { if (s < PreOpSrcs) tmp[u] = applyPreOp(preFn, tmp[u]); acc[u] = applyReduce(redFn, acc[u], tmp[u]); } } for (int s=MinSrcs; (MinSrcs < MaxSrcs) && (s < MaxSrcs) && (s < nSrcs); s++) { uintptr_t src = cvta_to_global(srcPtrs[s]) + threadBytesBehind; BytePack tmp[Unroll]; RedFn preFn(s < PreOpSrcs ? preOpArgs[s] : 0); #pragma unroll Unroll for (int u=0; u < Unroll; u++) { // Use volatile loads in case credits are polled for with volatile (instead of acquire). tmp[u] = ld_volatile_global(src); src += WARP_SIZE*BytePerPack; } #pragma unroll Unroll for (int u=0; u < Unroll; u++) { if (s < PreOpSrcs) tmp[u] = applyPreOp(preFn, tmp[u]); acc[u] = applyReduce(redFn, acc[u], tmp[u]); } } if (postOp) { #pragma unroll Unroll for (int u=0; u < Unroll; u++) acc[u] = applyPostOp(redFn, acc[u]); } #pragma unroll Unroll for (int d=0; d < MinDsts; d++) { #pragma unroll Unroll for (int u=0; u < Unroll; u++) { if (d < MultimemDsts) { multimem_st_global(minDsts[d], acc[u]); } else { st_global(minDsts[d], acc[u]); } minDsts[d] += WARP_SIZE*BytePerPack; } } for (int d=MinDsts; (MinDsts < MaxDsts) && (d < MaxDsts) && (d < nDsts); d++) { uintptr_t dst = cvta_to_global(dstPtrs[d]) + threadBytesBehind; #pragma unroll Unroll for (int u=0; u < Unroll; u++) { st_global(dst, acc[u]); dst += WARP_SIZE*BytePerPack; } } nWarps = nThreads/WARP_SIZE; #pragma unroll for (int s=0; s < MinSrcs; s++) minSrcs[s] += (nWarps-1)*BytePerHunk; #pragma unroll for (int d=0; d < MinDsts; d++) minDsts[d] += (nWarps-1)*BytePerHunk; threadBytesBehind += nWarps*BytePerHunk; threadBytesAhead -= nWarps*BytePerHunk; nHunksAhead -= nWarps; } nWarps = nThreads/WARP_SIZE; warp = thread/WARP_SIZE; lane = thread%WARP_SIZE; // The last loop iteration could have been partial, i.e. not taken by all // threads. The threads that weren't included need an extra subtraction to // make the value warp uniform. if (Unroll==1 && nHunksAhead > 0) nHunksAhead -= nWarps; // Rotate warps so the warp which got the least work here will be warp 0. // This effectively assigns: warp = (warp-nHunks+nWarps)%nWarps; warp = -nHunksAhead; thread = warp*WARP_SIZE + lane; } template __device__ __forceinline__ void reduceCopy( int thread, int nThreads, uint64_t redArg, uint64_t *preOpArgs, bool postOp, int nSrcs, void **srcPtrs, int nDsts, void **dstPtrs, IntBytes nElts ) { static_assert(MultimemSrcs <= MinSrcs && MultimemDsts <= MinDsts, "Multimem pointers cannot exceed respective Min values."); //int nWarps = nThreads/WARP_SIZE; //int warp = thread/WARP_SIZE; int lane = thread%WARP_SIZE; // If a multimem src is present then our biggest pack size is limited to what // is supported for this redfn/type. constexpr int BigPackSize = (MultimemSrcs == 0) ? 16 : LoadMultimem_BigPackSize::BigPackSize; IntBytes nBytesBehind = 0; IntBytes nBytesAhead = nElts*sizeof(T); #if __cpp_if_constexpr if constexpr (BigPackSize > sizeof(T)) { #else if (BigPackSize > sizeof(T)) { #endif // Check that all pointers are BigPackSize aligned. bool aligned = true; if (lane < nSrcs) aligned &= 0 == cvta_to_global(srcPtrs[lane]) % (BigPackSize + !BigPackSize); if (lane < nDsts) aligned &= 0 == cvta_to_global(dstPtrs[lane]) % (BigPackSize + !BigPackSize); aligned = !(__any(!aligned)); if (aligned) { #if defined(__gfx90a__) reduceCopyPacks 1) ? 2 : Unroll), BigPackSize, MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts == SDMA_SPEC_DST ? 1 : MinDsts, MaxDsts, PreOpSrcs> (nThreads, thread, redArg, preOpArgs, postOp, nSrcs, srcPtrs, nDsts, dstPtrs, nBytesBehind, nBytesAhead); #else reduceCopyPacks (nThreads, /*&*/thread, redArg, preOpArgs, postOp, nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); #endif if (nBytesAhead == 0) return; reduceCopyPacks (nThreads, /*&*/thread, redArg, preOpArgs, postOp, nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); if (nBytesAhead == 0) return; } } #if defined(__gfx90a__) if (MinSrcs > 1) { reduceCopyPacks (nThreads, thread, redArg, preOpArgs, postOp, nSrcs, srcPtrs, nDsts, dstPtrs, nBytesBehind, nBytesAhead); } else { reduceCopyPacks (nThreads, /*&*/thread, redArg, preOpArgs, postOp, nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); } #else reduceCopyPacks (nThreads, /*&*/thread, redArg, preOpArgs, postOp, nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); #endif if (nBytesAhead == 0) return; reduceCopyPacks (nThreads, /*&*/thread, redArg, preOpArgs, postOp, nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); } #endif // COMMON_KERNEL_H_