/************************************************************************* * Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved. * Modifications Copyright (c) 2019-2022 Advanced Micro Devices, Inc. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ #ifndef NCCL_PRIMITIVES_H_ #define NCCL_PRIMITIVES_H_ #include #include "reduce_kernel.h" // for reduction funcs #include "common_kernel.h" #include "common.h" #define NCCL_SPINS_BEFORE_CHECK_ABORT 1000000 #define barrier_by_group() do { \ if (nthreads == NCCL_MAX_NTHREADS) { \ __asm__ __volatile__("s_waitcnt vmcnt(0) lgkmcnt(0)\ns_barrier\ns_waitcnt lgkmcnt(0)"); \ } else { \ const int w = threadIdx.x/WARP_SIZE; \ const int wid = threadIdx.x%WARP_SIZE; \ __threadfence(); \ if (wid == 0) { \ barrier_next[w] += nthreads/WARP_SIZE; \ atomicAdd((unsigned long long *)barriers, 1); \ while (atomicAdd((unsigned long long *)barriers, 0) < barrier_next[w]) __builtin_amdgcn_s_sleep(1); \ __asm__ __volatile__("s_wakeup"); \ } \ } \ } while (0) /* Protocol classes: ProtoSimple, ProtoLL, ProtoLL128 * We use these as template args to the Primtiives class instead of integral * enums (e.g. NCCL_PROTO_LL) because for SIMPLE we need to carry a few extra * numbers. Also these types hold methods which let us compute numbers important * to how that protocol operates with a consistent interface so that our * algorithm code can operate protocol parametrically. */ template struct ProtoSimple { static constexpr int Id = NCCL_PROTO_SIMPLE; static constexpr int SlicePerChunk = SlicePerChunk_1; static constexpr int StepPerSlice = StepPerSlice_1; static constexpr int Unroll = Unroll_1; static constexpr int MultimemSrcs = MultimemSrcs_1; static constexpr int MultimemDsts = MultimemDsts_1; // Data bytes (no flags etc) in one step of the fifo queue. __device__ static int calcBytePerStep() { return ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS; } // Granularity of data bytes transferred per thread. __device__ static int calcBytePerGrain() { return sizeof(uint64_t); // Bogus value? Nobody queries this metric for simple. } // Group width is how many consecutive group values a subchannel occupies. static constexpr int MaxGroupWidth = 1; }; struct ProtoLL { static constexpr int Id = NCCL_PROTO_LL; // Data bytes (no flags etc) in one step of the fifo queue. __device__ static int calcBytePerStep() { return ncclShmem.comm.buffSizes[NCCL_PROTO_LL]/NCCL_STEPS/2; // Half is data } // Granularity of data bytes transferred per thread. __device__ static int calcBytePerGrain() { return sizeof(uint64_t); // One 16-byte line has 8-bytes of data } // Group width is how many consecutive group values a subchannel occupies. static constexpr int MaxGroupWidth = 1; }; struct ProtoLL128 { static constexpr int Id = NCCL_PROTO_LL128; // Data bytes (no flags etc) in one step of the fifo queue. __device__ static int calcBytePerStep() { return (ncclShmem.comm.buffSizes[NCCL_PROTO_LL128]/NCCL_STEPS)*NCCL_LL128_DATAELEMS/NCCL_LL128_LINEELEMS; } // Granularity of data bytes transferred per thread. __device__ static int calcBytePerGrain() { return NCCL_LL128_SHMEM_ELEMS_PER_THREAD*NCCL_LL128_DATAELEMS*sizeof(uint64_t)/NCCL_LL128_LINEELEMS; } // Group width is how many consecutive group values a subchannel occupies. static constexpr int MaxGroupWidth = 1; }; /* Fan (as in fan-in & fan-out) classes hold recv and send counts. The template * arguments are static bounds on the maximum values. Asymmetric counts are * independent. Symmetric is a static guarantee that nrecv==nsend, so it only * stores one value at runtime. This optimization save 32-bit register, but more * importantly uses fewer predicate registers when unrolling loops. */ template struct FanAsymmetric { static constexpr int MaxRecv = MaxRecv_, MaxSend = MaxSend_; int nr, ns; FanAsymmetric() = default; __device__ FanAsymmetric(int nrecv, int nsend): nr(nrecv), ns(nsend) { // assert(nrecv <= MaxRecv && nsend <= MaxSend); } __device__ int nrecv() const { return MaxRecv ? nr : 0; } __device__ int nsend() const { return MaxSend ? ns : 0; } }; template struct FanSymmetric { static constexpr int MaxRecv = MaxArity, MaxSend = MaxArity; int n; FanSymmetric() = default; __device__ FanSymmetric(int nrecv, int nsend): n(nrecv) { // assert(nrecv == nsend && nrecv <= MaxArity); } __device__ int nrecv() const { return n; } __device__ int nsend() const { return n; } }; // The primitives class. Specialized per protocol in the other headers. template class Primitives; // Used by LL & LL128 to implement direct members in the naive way. template struct PrimitivesWithoutDirect { __device__ void directSend(intptr_t inpIx, intptr_t outIx, int eltN) { static_cast(this)->send(inpIx, eltN); } __device__ void directSendFromOutput(intptr_t outIx, int eltN) { static_cast(this)->sendFromOutput(outIx, eltN); } __device__ void directRecv(intptr_t outIx, int eltN) { static_cast(this)->recv(outIx, eltN, /*postOp=*/false); } __device__ void directCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { static_cast(this)->copySend(inpIx, outIx, eltN, postOp); } __device__ void directRecvCopySend(intptr_t outIx, int eltN) { static_cast(this)->recvCopySend(outIx, eltN, /*postOp=*/false); } __device__ void directRecvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { // Direct is only for the send part static_cast(this)->recvReduceCopySend(inpIx, outIx, eltN, postOp); } }; #include "prims_simple.h" #include "prims_ll.h" #include "prims_ll128.h" #endif