/************************************************************************* * Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved. * Modifications Copyright (c) 2019-2023 Advanced Micro Devices, Inc. All rights reserved. * Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License. * * See LICENSE.txt for license information ************************************************************************/ #if defined(ENABLE_NPKIT) #include "npkit/npkit.h" #endif #include "msccl/msccl_struct.h" #ifdef HYGON_SDMA_FEATURE #include "op128.h" //#define GC_COPY_DATA typedef enum { SDMA_TRANS_MODE_NON = 0, SDMA_TRANS_MODE_REDUCE_SEND = 1, SDMA_TRANS_MODE_SRC_SEND = 2, SDMA_TRANS_MODE_RECV_SEND = 3 } sdma_trans_mode_t; template __device__ __forceinline__ void memsetData(int ringIx, uint64_t srcAddr, uint32_t dataSize, int val) { BytePack setVal; if (srcAddr == 0 || dataSize == 0) { PRINT_ERR("memsetData error input, ringIx:%d bid:%d copy data srcAddr:0x%lx len %d\n", ringIx, (int)blockIdx.x, srcAddr, dataSize); return; } setVal.native = val; for (int i = 0; i < dataSize; i++) { st_global(srcAddr, setVal); srcAddr += BytePerPack; } return; } template __device__ __forceinline__ int compareData(int ringIx, uint64_t srcAddr, uint64_t dstAddr, uint32_t dataSize) { BytePack srcVal; BytePack dstVal; int miscompare = 0; if (srcAddr == 0 || dstAddr == 0 || dataSize == 0) { PRINT_ERR("compareData error input, ringIx:%d bid:%d copy data srcAddr:0x%lx dstAddr:0x%lx len %d\n", ringIx, (int)blockIdx.x, srcAddr, dstAddr, dataSize); return 0; } for (int i = 0; i < dataSize; i++) { srcVal = ld_volatile_global(srcAddr); dstVal = ld_volatile_global(dstAddr); if (srcVal.native != dstVal.native) { PRINT_INFO("compareData, ringIx:%d bid:%d miscompare index:%d srcVal[0x%lx]:%d dstVal[0x%lx]:%d\n", ringIx, (int)blockIdx.x, i, srcAddr, srcVal.native, dstAddr, dstVal.native); miscompare++; } srcAddr += BytePerPack; dstAddr += BytePerPack; } if (miscompare) { PRINT_INFO("compareData end error, ringIx:%d bid:%d miscompare count:%d dataSize:%d \n", ringIx, (int)blockIdx.x, miscompare, dataSize); } else { PRINT_INFO("compareData end ok, ringIx:%d bid:%d same data, last:%d\n", ringIx, (int)blockIdx.x, srcVal.native); } return miscompare; } template __device__ __forceinline__ int startSdmaTask(struct sdmaQueueInfo *sdmaQueue, int ringIx, uint64_t srcAddr, uint64_t dstAddr, uint32_t dataLen) { if (srcAddr == 0 || dstAddr == 0) { PRINT_ERR("startSdma error input, ringIx:%d bid:%d srcAddr:%p dstAddr:%p len:%d\n", ringIx, (int)blockIdx.x, srcAddr, dstAddr, dataLen); return -1; } #ifdef GC_COPY_DATA BytePack val; for (int i = 0; i < dataLen / BytePerPack; i++) { val = ld_volatile_global(srcAddr); st_global(dstAddr, val); srcAddr += BytePerPack; dstAddr += BytePerPack; } return 0; #endif uint32_t sdmaIndex = atomicAdd(sdmaQueue->pkgIndex, 1) % sdmaQueue->sdmaDepth; volatile hsa_sdma_info_t *sdmaInfo = &sdmaQueue->sdmaInfo[sdmaIndex]; if (*sdmaInfo->wptr == *sdmaInfo->rptr) { PRINT_ERR("ringIx:%d bid:%d sdma pkg is empty\n", ringIx, (int)blockIdx.x); } sdmaInfo->completion_signal = 1; sdmaInfo->src_addr = srcAddr; sdmaInfo->dst_addr = dstAddr; sdmaInfo->data_size = dataLen; //sdmaInfo->flag = NPKIT_GET_GPU_TIMESTAMP(); sdmaInfo->dep_signal = 1; return sdmaIndex; } inline __device__ uint64_t waitSdmaTaskComplete(struct sdmaQueueInfo *sdmaQueue, uint32_t sdmaIndex) { if (sdmaIndex < 0) { PRINT_ERR("waitSdmaTaskComplete bid:%d sdmaIndex:%d invalid sdma index \n", (int)blockIdx.x, sdmaIndex); return 0; } #ifdef GC_COPY_DATA return 0; #endif volatile hsa_sdma_info_t *sdmaInfo = &sdmaQueue->sdmaInfo[sdmaIndex]; while (sdmaInfo->completion_signal) { __builtin_amdgcn_s_sleep(1); } __asm__ __volatile__("s_wakeup"); return sdmaInfo->end_ts - sdmaInfo->start_ts; } #endif template class Primitives< T, RedOp, Fan, Direct, ProtoSimple, P2p > { static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend; static constexpr int Input=0, Output=1; static constexpr int RoleInput = 0x01, RoleOutput = 0x02, RoleWaitRecv = 0x04, RoleWaitSend = 0x08, RolePostSend = 0x10, RolePostRecv = 0x20, Aborted = 0x40, OffsFifoEnabled = 0x80, SizesFifoEnabled = 0x100, DirectWrite = 0x200, DirectRead = 0x400, ThreadsSynced = 0x800, NvlsMinPolling = 0x1000; const int tid, tidInBlock; const int nthreads; int nworkers; const int stepSize; Fan fan; int index; // Peer index I'm responsible for int flags; int group; uint64_t step; int *connOffsFifoPtr; // (flags & OffsFifoEnabled) union { T *userBuff; // (flags & (RoleInput|RoleOutput)) T *connEltsFifo; // !(flags & (RoleInput|RoleOutput)) }; union { int volatile *connSizesFifoPtr; // (flags & SizesFifoEnabled) T *directBuff; // !(flags & SizesFifoEnabled) }; uint64_t *connStepPtr; uint64_t connStepCache; // Cache last seen value of (*connStepPtr) uint64_t* barriers; uint64_t* barrier_next; uint32_t* next_hdp_reg; #ifdef HYGON_SDMA_FEATURE public: uint32_t ringIx; uint32_t useSdmaCopy; uint32_t sdmaMinCopySize; uint32_t sdmaCountEnabe; uint32_t sdmaCopyCount; uint32_t allCopyCount; private: #endif #if defined(ENABLE_NPKIT) public: int npKitCtxIdx = 0; uint64_t npKitDataProcessEntryTime = 0; uint64_t npKitDataProcessExitTime = 0; uint64_t npKitDataProcessTotalTime = 0; private: #endif // Don't use barrier 0 as it's used by the final sync inline __device__ void barrier() { flags |= ThreadsSynced; if (nthreads == WARP_SIZE) __syncwarp(); else barrier_by_group(); } inline __device__ void subBarrier() { barrier(); } inline __device__ bool checkAbort(int &spins) { spins++; if (!(flags & Aborted) && spins == NCCL_SPINS_BEFORE_CHECK_ABORT) { if (atomicAdd_system((unsigned int *)ncclShmem.comm.abortFlag, 0)) { flags |= Aborted; ncclShmem.aborted = 1; } spins = 0; } return flags & Aborted; } inline __device__ uint64_t loadStepValue(uint64_t* ptr) { #if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010 if (flags & NvlsMinPolling) { uint64_t ans; asm("multimem.ld_reduce.acquire.sys.global.min.u64 %0, [%1];" : "=l"(ans) : "l"(cvta_to_global(ptr))); return ans; } #endif // volatile is faster than acquire but not as correct. Make sure reduceCopy // loads data using volatile so it doesn't see stale data in L1. #ifdef __GFX9__ return atomicAdd((unsigned long long *)ptr, 0); #else return __atomic_load_n(ptr, __ATOMIC_SEQ_CST); #endif } template __device__ __forceinline__ void waitPeer(intptr_t srcIx, intptr_t dstIx, int offset, int nelts) { const bool isSendNotRecv = (Send && Recv) ? (flags & RoleWaitSend) : Send; const bool noRecvWait = DirectRecv && Src && (flags & DirectRead); // no wait when directly reading from remote input const bool noSendWait = DirectSend && (flags & (DirectRead|DirectWrite)); // no wait in empty send (e.g. directScatter) or direct remote write if (((flags & (Recv*RoleWaitRecv)) && !noRecvWait) || ((flags & (Send*RoleWaitSend)) && !noSendWait)) { int spins = 0; while (connStepCache + (isSendNotRecv ? NCCL_STEPS : 0) < step + StepPerSlice) { __builtin_amdgcn_s_sleep(1); connStepCache = loadStepValue(connStepPtr); if (checkAbort(spins)) break; //if (spins == 0) printf("r=%d b=%d t=%d SPUN OUT got=%d want=%d\n", ncclShmem.comm.rank, blockIdx.x, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice)); if (spins == 0) traceData(__LINE__, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice)); } __asm__ __volatile__("s_wakeup"); } if (flags & (Recv*RoleWaitRecv | Send*RoleWaitSend)) { if (isSendNotRecv && (flags & SizesFifoEnabled)) __atomic_store_n(connSizesFifoPtr+step%NCCL_STEPS, nelts*sizeof(T), __ATOMIC_SEQ_CST); void **ptrs = isSendNotRecv ? (ncclShmem.groups[group].dsts + Dst) : (ncclShmem.groups[group].srcs + Src); if (flags & OffsFifoEnabled) ptrs[index] = connEltsFifo + loadInt(connOffsFifoPtr + (step%NCCL_STEPS))/sizeof(T); else if (isSendNotRecv && DirectSend) { if (flags & DirectWrite) { ptrs[index] = directBuff + dstIx + offset; } else if (flags & DirectRead) { // empty send ptrs[index] = nullptr; } else { ptrs[index] = connEltsFifo + (step%NCCL_STEPS)*stepSize; } } else if (!isSendNotRecv && DirectRecv) { if (flags & DirectRead) { ptrs[index] = directBuff + srcIx + offset; } else if (flags & DirectWrite) { ptrs[index] = directBuff + dstIx + offset; // send to next from my output buffer } else { ptrs[index] = connEltsFifo + (step%NCCL_STEPS)*stepSize; } } else { ptrs[index] = connEltsFifo + (step%NCCL_STEPS)*stepSize; } step += StepPerSlice; } } template inline __device__ void postPeer(bool dataStored) { if (Send && (flags & RolePostSend) && dataStored) #ifdef __GFX9__ __builtin_amdgcn_buffer_wbinvl1(); #else __threadfence_system(); #endif if ((flags & Send*RolePostSend) && next_hdp_reg) STORE((unsigned int *)next_hdp_reg, 0x1); if (flags & (Recv*RolePostRecv | Send*RolePostSend)) { step += StepPerSlice; STORE(connStepPtr, step); } } template __device__ __forceinline__ void genericOp( intptr_t srcIx, intptr_t dstIx, int nelem, bool postOp ) { constexpr int DirectRecv = 1 && Direct && DirectRecv1; constexpr int DirectSend = 1 && Direct && DirectSend1; constexpr int Src = SrcBuf != -1; constexpr int Dst = DstBuf != -1; nelem = nelem < 0 ? 0 : nelem; int sliceSize = stepSize*StepPerSlice; sliceSize = max(divUp(nelem, 16*SlicePerChunk)*16, sliceSize/32); int slice = 0; int offset = 0; #ifdef HYGON_SDMA_FEATURE uint64_t srcSdmaAddr = 0; uint64_t dstSdmaAddr = 0; int sendToNextRankMode = SDMA_TRANS_MODE_NON; int needSdmaCopy = 0; if (useSdmaCopy) { if (Send == 1 && fan.nsend() == 1) { if (Src && Recv) { sendToNextRankMode = SDMA_TRANS_MODE_REDUCE_SEND; // 1 } else if (Src && !Recv) { sendToNextRankMode = SDMA_TRANS_MODE_SRC_SEND; // 2 } else if (!Src && Recv) { sendToNextRankMode = SDMA_TRANS_MODE_RECV_SEND; // 3 } if (sendToNextRankMode) needSdmaCopy = 1; } } #endif PRINT_DEBUG("genericOp-1- ringIx:%d bid:%d sliceSize:%d nelem:%d SlicePerChunk:%d stepSize:%d StepPerSlice:%d slicesize0:%d max(val1:%d val2:%d) mode:%d send:%d %d sizeofT:%d\n", ringIx, (int)blockIdx.x, sliceSize*sizeof(T), nelem*sizeof(T), SlicePerChunk, stepSize*sizeof(T), StepPerSlice, stepSize*StepPerSlice*sizeof(T), divUp(nelem, 16*SlicePerChunk)*16*sizeof(T), sizeof(T)*stepSize*StepPerSlice/32, sendToNextRankMode, Send, MaxSend, sizeof(T)); if (tid < nworkers && offset < nelem) { // Worker-only loop for non-empty slices. Non-workers and empty slices are // processed in the loop following this if block. The benefit of splitting // the loop like this is we pull two branches out of the critical path. // Using "number of branch insns (taken or not) encountered dynamically" // as the performance metric, then: // perf_orig = 2*numslices // perf_new = 2+numslices // So the new code and old code behave the same for numslices=2, and for // numslices>2 the new code is superior. And note that in the case // numslices=1, the loop is trivially unrollable (single iteration) so we // don't incur that that tail branch and we still have perf_new=2. // // ORIGINAL CODE: // unrolled for(slices) { // if(worker) { // This branch removed // wait(); // subBarrier(); // if(slice not empty) // This branch removed // ReduceCopyMulti(); // } // barrier(); // post(); // } // Since we no longer unroll, new branch added here #pragma unroll 1 do { sliceSize = sliceSize < nelem-offset ? sliceSize : nelem-offset; if (Src && (flags & (SrcBuf==Input ? RoleInput : RoleOutput))) ncclShmem.groups[group].srcs[0] = userBuff + srcIx + offset; if (Dst && (flags & (DstBuf==Input ? RoleInput : RoleOutput))) ncclShmem.groups[group].dsts[0] = userBuff + dstIx + offset; waitPeer(srcIx, dstIx, offset, sliceSize); subBarrier(); /* if user abort the kernel, we don't need to actually perform copy/reduce; just set size * to 0 to avoid unnecessary workload. */ int workSize = ncclShmem.aborted ? 0 : sliceSize; #ifdef HYGON_SDMA_FEATURE if (tid == 0 && sdmaCountEnabe) allCopyCount++; if (needSdmaCopy && workSize*sizeof(T) < sdmaMinCopySize) { sendToNextRankMode = 0; needSdmaCopy = 0; PRINT_DEBUG("genericOp-sdma- ringIx:%d bid:%d workSize:%d minCopySize:%d\n", ringIx, (int)blockIdx.x, workSize*sizeof(T), sdmaMinCopySize); } if (tid == 0 && sendToNextRankMode && useSdmaCopy) { // SDMA拷贝源地址是Src或Recv地址,只有Src时Src占用srcs[0],只有Recv时Recv占用srcs[0],同时有时,Src占用srcs[0],Recv占用srcs[1] srcSdmaAddr = (uint64_t)ncclShmem.groups[group].srcs[Src*Recv]; // SDMA拷贝目的地址是Send地址,只有Dst时Dst占用dsts[0],只有Send时Send占用dsts[0],同时有时,Dst占用dsts[0],Send占用dsts[1] dstSdmaAddr = (uint64_t)ncclShmem.groups[group].dsts[Dst]; if (sendToNextRankMode == SDMA_TRANS_MODE_REDUCE_SEND) { // 同时有Src和Recv时,将Reduce计算后数据保存目的地址设置为srcs[1],srcs[1]是作为SDMA拷贝的源地址 ncclShmem.groups[group].dsts[Dst] = ncclShmem.groups[group].srcs[Src*Recv]; } } if (useSdmaCopy) subBarrier(); #endif if (DirectRecv && ncclShmem.groups[group].srcs[0] == ncclShmem.groups[group].dsts[0]) { // We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy if (Send) { #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_ENTRY) if (tid == 0) { NpKit::CollectGpuEvent(NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_ENTRY, sliceSize*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); } #endif #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME) if (tid == 0) { npKitDataProcessEntryTime = NPKIT_GET_GPU_TIMESTAMP(); } #endif reduceCopy (tid, nworkers, /*redArg*/0, /*preOpArgs*/nullptr, /*postOp*/false, 1, ncclShmem.groups[group].srcs, fan.nsend(), ncclShmem.groups[group].dsts+1, workSize); #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME) if (tid == 0) { npKitDataProcessExitTime = NPKIT_GET_GPU_TIMESTAMP(); npKitDataProcessTotalTime += npKitDataProcessExitTime - npKitDataProcessEntryTime; } #endif #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_EXIT) if (tid == 0) { NpKit::CollectGpuEvent(NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_EXIT, sliceSize*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); } #endif } } else if (DirectSend && !DirectRecv && SrcBuf != Input && ncclShmem.groups[group].dsts[Dst] == nullptr) { // For broadcast in CollNet to do empty send #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_ENTRY) if (tid == 0) { NpKit::CollectGpuEvent(NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_ENTRY, sliceSize*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); } #endif #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME) if (tid == 0) { npKitDataProcessEntryTime = NPKIT_GET_GPU_TIMESTAMP(); } #endif reduceCopy (tid, nworkers, ncclShmem.redOpArgs[0], nullptr, postOp, Recv, ncclShmem.groups[group].srcs, Dst, ncclShmem.groups[group].dsts, workSize); #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME) if (tid == 0) { npKitDataProcessExitTime = NPKIT_GET_GPU_TIMESTAMP(); npKitDataProcessTotalTime += npKitDataProcessExitTime - npKitDataProcessEntryTime; } #endif #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_EXIT) if (tid == 0) { NpKit::CollectGpuEvent(NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_EXIT, sliceSize*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); } #endif } else { #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_ENTRY) if (tid == 0) { NpKit::CollectGpuEvent(NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_ENTRY, sliceSize*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); } #endif #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME) if (tid == 0) { npKitDataProcessEntryTime = NPKIT_GET_GPU_TIMESTAMP(); } #endif constexpr int PreOpSrcs = SrcBuf != Input ? 0 : DirectRecv*MaxRecv == NCCL_MAX_DIRECT_ARITY ? (1+NCCL_MAX_DIRECT_ARITY) : 1; PRINT_DEBUG("genericOp-4-ringIx:%d bid:%d MaxRecv:%d MaxSend:%d PreOpSrcs:%d nworkers:%d Recv:%d fan.nrecv:%d Src:%d Send:%d fan.nsend:%d Dst:%d " "workSize:%d group:%d Unroll:%d mode:%d sdma:%d sq:%p src:0x%lx dst:0x%lx mins:%d\n", ringIx, (int)blockIdx.x, MaxRecv, MaxSend, PreOpSrcs, nworkers, Recv, fan.nrecv(), Src, Send, fan.nsend(), Dst, workSize*sizeof(T), group, Unroll, sendToNextRankMode, useSdmaCopy, ncclShmem.channel.sdmaQueue.sdmaInfo, srcSdmaAddr, dstSdmaAddr, sdmaMinCopySize); #ifdef HYGON_SDMA_FEATURE if (sendToNextRankMode <= SDMA_TRANS_MODE_REDUCE_SEND) { reduceCopy (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, Recv*fan.nrecv()+Src, ncclShmem.groups[group].srcs, Send*fan.nsend()+Dst, ncclShmem.groups[group].dsts, workSize); } else if (Dst) { uint32_t sdmaIndex; uint64_t delta_ts; if (tid == 0) { NPKIT_SET_GPU_EVENT(NPKIT_EVENT_PRIM_SIMPLE_SDMA_COPY_PAL_ENTRY, workSize*sizeof(T), 0); sdmaIndex = startSdmaTask(&ncclShmem.channel.sdmaQueue, ringIx, srcSdmaAddr, dstSdmaAddr, workSize * sizeof(T)); } reduceCopy (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, Recv*fan.nrecv()+Src, ncclShmem.groups[group].srcs, Dst, ncclShmem.groups[group].dsts, workSize); needSdmaCopy = 0; if (tid == 0) { delta_ts = waitSdmaTaskComplete(&ncclShmem.channel.sdmaQueue, sdmaIndex); #if defined(ENABLE_NPKIT_EVENT_PRIM_SIMPLE_SDMA_COST) NPKIT_SET_GPU_EVENT_TM(NPKIT_EVENT_PRIM_SIMPLE_SDMA_COST_ENTRY, workSize * sizeof(T), 0, ncclShmem.channel.sdmaQueue.sdmaInfo[sdmaIndex].start_ts); NPKIT_SET_GPU_EVENT_TM(NPKIT_EVENT_PRIM_SIMPLE_SDMA_COST_EXIT, workSize * sizeof(T), 0, ncclShmem.channel.sdmaQueue.sdmaInfo[sdmaIndex].end_ts); #endif NPKIT_SET_GPU_EVENT(NPKIT_EVENT_PRIM_SIMPLE_SDMA_COPY_PAL_EXIT, workSize*sizeof(T), delta_ts); if (sdmaCountEnabe) sdmaCopyCount++; } } #else reduceCopy (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, Recv*fan.nrecv()+Src, ncclShmem.groups[group].srcs, Send*fan.nsend()+Dst, ncclShmem.groups[group].dsts, workSize); #endif #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME) if (tid == 0) { npKitDataProcessExitTime = NPKIT_GET_GPU_TIMESTAMP(); npKitDataProcessTotalTime += npKitDataProcessExitTime - npKitDataProcessEntryTime; } #endif #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_EXIT) if (tid == 0) { NpKit::CollectGpuEvent(NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_EXIT, sliceSize*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); } #endif } barrier(); // This barrier has a counterpart in following loop #ifdef HYGON_SDMA_FEATURE if (tid == 0 && sendToNextRankMode && useSdmaCopy && needSdmaCopy) { PRINT_DEBUG("genericOp-5-ringIx:%d bid:%d MaxRecv:%d MaxSend:%d Recv:%d fan.nrecv:%d Src:%d Send:%d fan.nsend:%d Dst:%d " "workSize:%d mode:%d sdma:%d need:%d sq:%p src:0x%lx dst:0x%lx\n", ringIx, (int)blockIdx.x, MaxRecv, MaxSend, Recv, fan.nrecv(), Src, Send, fan.nsend(), Dst, workSize*sizeof(T), sendToNextRankMode, useSdmaCopy, needSdmaCopy, ncclShmem.channel.sdmaQueue.sdmaInfo, srcSdmaAddr, dstSdmaAddr); #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME) npKitDataProcessEntryTime = NPKIT_GET_GPU_TIMESTAMP(); #endif NPKIT_SET_GPU_EVENT(NPKIT_EVENT_PRIM_SIMPLE_SDMA_COPY_ENTRY, workSize*sizeof(T), 0); uint32_t sdmaIndex = startSdmaTask(&ncclShmem.channel.sdmaQueue, ringIx, srcSdmaAddr, dstSdmaAddr, workSize * sizeof(T)); uint64_t delta_ts = waitSdmaTaskComplete(&ncclShmem.channel.sdmaQueue, sdmaIndex); #if defined(ENABLE_NPKIT_EVENT_PRIM_SIMPLE_SDMA_COST) NPKIT_SET_GPU_EVENT_TM(NPKIT_EVENT_PRIM_SIMPLE_SDMA_COST_ENTRY, workSize * sizeof(T), 0, ncclShmem.channel.sdmaQueue.sdmaInfo[sdmaIndex].start_ts); NPKIT_SET_GPU_EVENT_TM(NPKIT_EVENT_PRIM_SIMPLE_SDMA_COST_EXIT, workSize * sizeof(T), 0, ncclShmem.channel.sdmaQueue.sdmaInfo[sdmaIndex].end_ts); #endif NPKIT_SET_GPU_EVENT(NPKIT_EVENT_PRIM_SIMPLE_SDMA_COPY_EXIT, workSize*sizeof(T), delta_ts); if (sdmaCountEnabe) sdmaCopyCount++; #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME) npKitDataProcessExitTime = NPKIT_GET_GPU_TIMESTAMP(); npKitDataProcessTotalTime += npKitDataProcessExitTime - npKitDataProcessEntryTime; #endif } if (useSdmaCopy) barrier(); #endif postPeer(0 < sliceSize); offset += sliceSize; slice += 1; } while (slice < SlicePerChunk && offset < nelem); } // Non-workers come straight here. Workers too but only once the remaining // slices are all empty. Since empty slices are the uncommon case, and // worker perf is the limiter, perf-wise this loop is effectively unentered, // hence just a single branch insn. #pragma unroll 1 while (slice < SlicePerChunk) { sliceSize = sliceSize < nelem-offset ? sliceSize : nelem-offset; { // Only workers could have Wait roles so we know the slice must be empty // since we've exited the loop above. waitPeer(0, 0, 0, 0); } barrier(); // Has couterpart in preceding worker-only loop. postPeer(0 < sliceSize); offset += sliceSize; slice += 1; } } template __device__ __forceinline__ void mscclGenericOp(T** srcs, int nsrcs, T** dsts, int ndsts, int nelem) { #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_GENERIC_OP_ENTRY) if (tid == 0) { NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_GENERIC_OP_ENTRY, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); } #endif nelem = nelem < 0 ? 0 : nelem; if (tid < nworkers) { if (REDUCE){ srcs[nsrcs] = dsts[0]; nsrcs++; if (MULTISRCS){ reduceCopy (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, nsrcs, (void **)srcs, 1, (void **)dsts, nelem); } else { reduceCopy (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 2, (void **)srcs, 1, (void **)dsts, nelem); } } if (COPY){ reduceCopy (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, (void **)srcs, 1, (void **)dsts, nelem); if (MULTISRCS) { for (int i = 1; i < nsrcs; i++){ reduceCopy (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, (void **)&srcs[i], 1, (void **)&dsts[i], nelem); } } } } #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_GENERIC_OP_EXIT) if (tid == 0) { NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_GENERIC_OP_EXIT, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx); } #endif barrier(); } // Scatter/Gather generic op // skip: my own rank order in the buffer chunks // shift: peer offset to avoid all ranks sending to or receiving from same peer template __device__ __forceinline__ void ScatterGatherOp(intptr_t inpIx, intptr_t outIx, int totalElem, int peerElem, int peerOffset, int skip, int shift, bool postOp) { constexpr int DirectRecv = 1 && Direct && DirectRecv1; constexpr int DirectSend = 1 && Direct && DirectSend1; int offset = 0; // slice offset int sliceSize = stepSize*StepPerSlice; int dataSize = max(DIVUP(peerElem, 16*SlicePerChunk)*16, sliceSize/32); // per-peer slice size #pragma unroll 1 for (int slice=0; slice(0, inpIx, offset, realSize); subBarrier(); #pragma unroll 1 // Loop over peers for (int j=0; j= 0 && i >= skip) pOffset += peerElem; void* src0 = (T*)ncclShmem.groups[group].srcs[0] + pOffset; int realPeerSize = min(realSize, totalElem-pOffset); if (realPeerSize > 0 && ncclShmem.groups[group].dsts[i] != nullptr) { reduceCopy(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, &src0, 1, ncclShmem.groups[group].dsts+i, realPeerSize); // Mark for threadfence at the end fenceNeeded |= true; } } } else if (Recv) { if (flags & RoleOutput) ncclShmem.groups[group].dsts[0] = userBuff + outIx + offset; int pOffset = index*peerOffset; if (skip >= 0 && index >= skip) pOffset += peerElem; // Adjust remote index with peer offset in case we are directly pulling from peer's output buffer waitPeer(outIx, outIx+pOffset, offset, realSize); subBarrier(); #pragma unroll 1 for (int j=0; j= 0 && i >= skip) pOffset += peerElem; void* dst0 = (T*)ncclShmem.groups[group].dsts[0] + pOffset; int realPeerSize = min(realSize, totalElem-pOffset); if (DirectRecv && ncclShmem.groups[group].srcs[i] == dst0) realPeerSize = 0; if (realPeerSize > 0) reduceCopy(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, 1, ncclShmem.groups[group].srcs+i, 1, &dst0, realPeerSize); } } } fenceNeeded = __any(fenceNeeded); postPeer(fenceNeeded); offset += realSize; } } __device__ __forceinline__ void loadRecvConn(ncclDevChannelPeer *peer, int connIndex, struct ncclWorkElem* e) { if (flags & (RoleWaitRecv|RolePostRecv)) { auto *conn = &peer->recv[connIndex]; step = conn->step; step = roundUp(step, SlicePerChunk*StepPerSlice); if (flags & RolePostRecv) { connStepPtr = conn->head; STORE(connStepPtr, step); // Return credits in case we rounded up. } if (flags & RoleWaitRecv) { ncclShmem.groups[group].recvConns[index] = conn; // WaitRecv role saves since that's who needs it in setDataPtrs() flags |= (conn->flags & NCCL_NVLS_MIN_POLL) ? NvlsMinPolling : 0; connStepPtr = conn->tail; connStepCache = loadStepValue(connStepPtr); flags |= (conn->offsFifo != nullptr) ? OffsFifoEnabled : 0; if (Direct) { // User buffers have been registered if ((conn->flags & (NCCL_IPC_READ|NCCL_IPC_WRITE)) && e != nullptr && e->regUsed) { if (connIndex == 1 && P2p == 0) { flags |= DirectRead; // scatter-reduce use direct pull } else { flags |= (e->direct & NCCL_DIRECT_WRITE) ? DirectWrite : (e->direct & NCCL_DIRECT_READ) ? DirectRead : 0; } } else if (conn->flags & (NCCL_DIRECT_WRITE|NCCL_DIRECT_READ)) { if (connIndex == 1 && P2p == 0) { flags |= DirectRead; // scatter-reduce use direct pull } else { // direct read not allowed in non-register case // otherwise, in one-to-multi send, we could mix empty send and intermediate send flags |= (conn->flags & NCCL_DIRECT_WRITE) ? DirectWrite : 0; } } } if (flags & OffsFifoEnabled) connOffsFifoPtr = conn->offsFifo; connEltsFifo = (T*)conn->buffs[NCCL_PROTO_SIMPLE]; } } } __device__ __forceinline__ void loadSendConn(ncclDevChannelPeer *peer, int connIndex, struct ncclWorkElem* e) { if (flags & (RoleWaitSend|RolePostSend)) { auto *conn = &peer->send[connIndex]; step = conn->step; step = roundUp(step, SlicePerChunk*StepPerSlice); if (flags & RolePostSend) { connStepPtr = conn->tail; next_hdp_reg = conn->next_hdp_reg; } if (flags & RoleWaitSend) { ncclShmem.groups[group].sendConns[index] = conn; // WaitSend role saves since that's who needs it in setDataPtrs() flags |= (conn->flags & NCCL_NVLS_MIN_POLL) ? NvlsMinPolling : 0; connStepPtr = conn->head; connStepCache = loadStepValue(connStepPtr); flags |= (conn->offsFifo != nullptr) ? OffsFifoEnabled : 0; if (flags & OffsFifoEnabled) connOffsFifoPtr = conn->offsFifo; connEltsFifo = (T*)conn->buffs[NCCL_PROTO_SIMPLE]; if (conn->sizesFifo != nullptr) { flags |= SizesFifoEnabled; connSizesFifoPtr = conn->sizesFifo; } else if (Direct) { // User buffers have been registered if ((conn->flags & (NCCL_IPC_READ|NCCL_IPC_WRITE)) && e != nullptr && e->regUsed) { if (connIndex == 1 && P2p == 0) { flags |= DirectRead; // scatter-reduce use direct pull } else { flags |= (e->direct & NCCL_DIRECT_WRITE) ? DirectWrite : (e->direct & NCCL_DIRECT_READ) ? DirectRead : 0; } } else if (conn->flags & (NCCL_DIRECT_WRITE|NCCL_DIRECT_READ)) { if (connIndex == 1 && P2p == 0) { flags |= DirectRead; // scatter-reduce use direct pull } else { // direct read not allowed in non-register case // otherwise, in one-to-multi send, we could mix empty send and intermediate send flags |= (conn->flags & NCCL_DIRECT_WRITE) ? DirectWrite : 0; } } } } } } public: __forceinline__ __device__ Primitives( int tid, int nthreads, int const *recvPeers, int const *sendPeers, void const *inputBuf, void *outputBuf, uint64_t redOpArg, uint8_t group=0, uint8_t connIndexRecv = 0, uint8_t connIndexSend = 0, struct ncclWorkElem* e = nullptr ): tid(tid), nthreads(nthreads), tidInBlock(threadIdx.x), group(group), stepSize(ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T)) { // For send operations, we need an extra warp to overlap the threadfence and the copy barriers = &ncclShmem.groups[group].barrier; barrier_next = ncclShmem.groups[group].barrier_next; this->nworkers = nthreads; int nrecv=0, nsend=0; while (nrecv < MaxRecv && recvPeers[nrecv] != -1) nrecv++; while (nsend < MaxSend && sendPeers[nsend] != -1) nsend++; this->fan = Fan(nrecv, nsend); constexpr int ThreadPerSync = 8; static_assert(MaxSend <= ThreadPerSync && MaxRecv <= ThreadPerSync, "Not enough threads to cover all peers"); int g = tid / ThreadPerSync; int ng = nthreads / ThreadPerSync; index = tid % ThreadPerSync; flags = 0; if (g == 0) { if (index < nrecv) flags |= RoleWaitRecv; if (index == nrecv) flags |= RoleInput; } else if (g == 1) { if (index < nsend) flags |= RoleWaitSend; if (index == nsend) flags |= RoleOutput; } else if (g == ng - 2) { if (index < nrecv) flags |= RolePostRecv; } else if (g == ng - 1) { if (index < nsend) flags |= RolePostSend; } int peer = 0; if (flags & (RoleWaitRecv|RolePostRecv)) peer = recvPeers[index]; if (flags & (RoleWaitSend|RolePostSend)) peer = sendPeers[index]; loadRecvConn(ncclShmem.channel.peers[peer], connIndexRecv, e); loadSendConn(ncclShmem.channel.peers[peer], connIndexSend, e); setDataPtrs(inputBuf, outputBuf, redOpArg, (struct ncclWorkElemReg*)e); } __forceinline__ __device__ ~Primitives() { // Ensure ncclShmem.groups[].send/recvConns are available if (!(flags & ThreadsSynced)) barrier(); // Save steps for the next operation if (flags & (RolePostSend|RolePostRecv)) { auto *conns = (flags & RolePostSend) ? ncclShmem.groups[group].sendConns : ncclShmem.groups[group].recvConns; conns[index]->step = step; } // Make sure all threads are done writing back conn->step and done using // ncclShmem.groups[group] barrier(); } __device__ void setDataPtrs(void const *inputBuf, void *outputBuf, uint64_t redOpArg, struct ncclWorkElemReg* e) { if (flags & RoleInput) { userBuff = (T*)inputBuf; ncclShmem.redOpArgs[0] = redOpArg; // scaler for local input } if (flags & RoleOutput) userBuff = (T*)outputBuf; bool recvProvider = flags == (flags|RoleWaitRecv|DirectWrite); bool sendAcceptor = flags == (flags|RoleWaitSend|DirectWrite); bool sendProvider = flags == (flags|RoleWaitSend|DirectRead); // sender provides direct buffer (to be fetched) bool recvAcceptor = flags == (flags|RoleWaitRecv|DirectRead); // receiver accepts direct buffer int regUsed = e != nullptr ? e->elem.regUsed : 0; if (Direct && recvProvider) { int spins = 0; void *volatile *slot = ncclShmem.groups[group].recvConns[index]->ptrExchange; // Wait for consumer to consume previous value before trampling it. while ((void *)atomicAdd((unsigned long long *) slot,0) != nullptr && !checkAbort(spins)); directBuff = (T*)outputBuf; // Encode pointer by XOR'ing against some address they definitely wouldn't send // since we want to allow them sending us nullptr while not colliding with // the empty slot value. *slot = reinterpret_cast(reinterpret_cast(directBuff) ^ reinterpret_cast(slot)); } if (Direct && sendAcceptor) { int spins = 0; void *volatile *slot = ncclShmem.groups[group].sendConns[index]->ptrExchange; void *ptr; while (true) { ptr = (void *)atomicAdd((unsigned long long *) slot,0); if (ptr != nullptr || checkAbort(spins)) break; } directBuff = regUsed ? (T*)(e->dnOutputs[index]) : reinterpret_cast(reinterpret_cast(ptr) ^ reinterpret_cast(slot)); *slot = nullptr; } if (Direct && sendProvider) { int spins = 0; void *volatile *slot = ncclShmem.groups[group].sendConns[index]->ptrExchange; volatile uint64_t* argSlot0 = ncclShmem.groups[group].sendConns[index]->redOpArgExchange; volatile uint64_t* argSlot1 = ncclShmem.groups[group].sendConns[index]->redOpArgExchange+1; // Wait for consumer to consume previous value before trampling it. while (((void *)atomicAdd((unsigned long long *) slot,0) != nullptr || *argSlot0 != 0 || *argSlot1 !=0) && !checkAbort(spins)); // If there is no recv, then we are directly pulling from input buffer (e.g. directScatter) // Otherwise, we are pulling from output buffer (e.g. recvCopyDirectSend) directBuff = MaxRecv == 0 ? (T*)inputBuf : (T*)outputBuf; // Exchange pre-scalers for use in direct pull *argSlot0 = (uint64_t(1)<<32) | (uint32_t)redOpArg; *argSlot1 = (uint64_t(1)<<32) | (uint32_t)(redOpArg>>32); // Encode pointer by XOR'ing against some address they definitely wouldn't send // since we want to allow them sending us nullptr while not colliding with // the empty slot value. *slot = reinterpret_cast(reinterpret_cast(directBuff) ^ reinterpret_cast(slot)); } if (Direct && recvAcceptor) { int spins = 0; void *volatile *slot = ncclShmem.groups[group].recvConns[index]->ptrExchange; volatile uint64_t* argSlot0 = ncclShmem.groups[group].recvConns[index]->redOpArgExchange; volatile uint64_t* argSlot1 = ncclShmem.groups[group].recvConns[index]->redOpArgExchange+1; void *ptr; while (true) { ptr = (void *)atomicAdd((unsigned long long *) slot,0); if (ptr != nullptr || checkAbort(spins)) break; } directBuff = regUsed ? (T*)(MaxSend == 0 ? e->upOutputs[index] : e->dnInputs[index]) : reinterpret_cast(reinterpret_cast(ptr) ^ reinterpret_cast(slot)); if (MaxSend != 0) { // reduce group rather than gather group // Store scalers for remote inputs uint64_t arg0, arg1; while (true) { arg0 = *argSlot0; arg1 = *argSlot1; if ((arg0 != 0 && arg1 != 0) || checkAbort(spins)) break; } ncclShmem.redOpArgs[1+index] = ((arg1 & 0xffffffff)<<32) | (arg0 & 0xffffffff); } *argSlot0 = 0; *argSlot1 = 0; *slot = nullptr; } } __device__ void moveDataPtrs(intptr_t delta) { if (flags & (RoleInput|RoleOutput)) userBuff += delta; } // Set MSCCL data pointers __device__ __forceinline__ void setDataPtrs(void const *inputBuf, void *outputBuf) { if (flags & RoleInput) userBuff = (T*)inputBuf; if (flags & RoleOutput) userBuff = (T*)outputBuf; } __device__ __forceinline__ void send(intptr_t inpIx, int eltN) { genericOp<0, 0, 0, 1, Input, -1>(inpIx, -1, eltN, false); } __device__ __forceinline__ void sendFromOutput(intptr_t outIx, int eltN) { genericOp<0, 0, 0, 1, Output, -1>(outIx, -1, eltN, false); } __device__ __forceinline__ void directSend(intptr_t inpIx, intptr_t outIx, int eltN) { genericOp<0, 1, 0, 1, Input, -1>(inpIx, outIx, eltN, false); } __device__ __forceinline__ void directSendFromOutput(intptr_t outIx, int eltN) { genericOp<0, 1, 0, 1, Output, -1>(outIx, outIx, eltN, false); } __device__ __forceinline__ void recv(intptr_t outIx, int eltN, bool postOp=false) { genericOp<0, 0, 1, 0, -1, Output>(-1, outIx, eltN, postOp); } __device__ __forceinline__ void directRecv(intptr_t outIx, int eltN) { genericOp<1, 0, 1, 0, -1, Output>(-1, outIx, eltN, /*postOp=*/false); } __device__ __forceinline__ void copySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { genericOp<0, 0, 0, 1, Input, Output>(inpIx, outIx, eltN, postOp); } __device__ __forceinline__ void directCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { genericOp<0, 1, 0, 1, Input, Output>(inpIx, outIx, eltN, postOp); } __device__ __forceinline__ void recvSend(int eltN, bool postOp=false) { genericOp<0, 0, 1, 1, -1, -1>(-1, -1, eltN, postOp); } __device__ __forceinline__ void recvCopySend(intptr_t outIx, int eltN, bool postOp=false) { genericOp<0, 0, 1, 1, -1, Output>(-1, outIx, eltN, postOp); } __device__ __forceinline__ void directRecvCopySend(intptr_t outIx, int eltN) { genericOp<1, 1, 1, 1, -1, Output>(-1, outIx, eltN, false); } __device__ __forceinline__ void recvCopyDirectSend(intptr_t outIx, int eltN, bool postOp=false) { genericOp<0, 1, 1, 1, -1, Output>(-1, outIx, eltN, postOp); } __device__ __forceinline__ void recvReduceCopy(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { genericOp<0, 0, 1, 0, Input, Output>(inpIx, outIx, eltN, postOp); } __device__ __forceinline__ void recvReduceSend(intptr_t inpIx, int eltN, bool postOp=false) { genericOp<0, 0, 1, 1, Input, -1>(inpIx, -1, eltN, postOp); } __device__ __forceinline__ void directRecvReduceSend(intptr_t inpIx, int eltN, bool postOp=false) { genericOp<1, 0, 1, 1, Input, -1>(inpIx, -1, eltN, postOp); } __device__ __forceinline__ void recvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { genericOp<0, 0, 1, 1, Input, Output>(inpIx, outIx, eltN, postOp); } __device__ __forceinline__ void directRecvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { // Direct is only for the send part genericOp<0, 1, 1, 1, Input, Output>(inpIx, outIx, eltN, postOp); } __device__ __forceinline__ void scatter(intptr_t inpIx, int totalElem, int peerElem, int peerOffset, int skip, int shift) { ScatterGatherOp<0, 0, 0, 1>(inpIx, -1, totalElem, peerElem, peerOffset, skip, shift, /*postOp=*/false); } __device__ __forceinline__ void directScatter(intptr_t inpIx, int totalElem, int peerElem, int peerOffset, int skip, int shift) { ScatterGatherOp<0, 1, 0, 1>(inpIx, -1, totalElem, peerElem, peerOffset, skip, shift, /*postOp=*/false); } __device__ __forceinline__ void gather(intptr_t outIx, int totalElem, int peerElem, int peerOffset, int skip, int shift, bool postOp=false) { ScatterGatherOp<0, 0, 1, 0>(-1, outIx, totalElem, peerElem, peerOffset, skip, shift, postOp); } __device__ __forceinline__ void directGather(intptr_t outIx, int totalElem, int peerElem, int peerOffset, int skip, int shift) { ScatterGatherOp<1, 0, 1, 0>(-1, outIx, totalElem, peerElem, peerOffset, skip, shift, /*postOp=*/false); } // MSCCL primitives __device__ __forceinline__ void sendWithBarrier(intptr_t inpIx, int eltN) { send(inpIx, eltN); } __device__ __forceinline__ void localCopy(T* srcs, T* dsts, int eltN) { return mscclGenericOp<0,1,0,0>(&srcs, 1, &dsts, 1, eltN); } __device__ __forceinline__ void reduce(T** srcs, int nsrcs, T** dsts, int ndsts, int eltN) { if (nsrcs == 1) { return mscclGenericOp<1,0,0,0>(srcs, 1, dsts, 1, eltN); } else { return mscclGenericOp<1,0,1,0>(srcs, nsrcs, dsts, 1, eltN); } } };