broadcast.h 8.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
/*************************************************************************
 * Copyright (c) 2015-2022, NVIDIA CORPORATION. All rights reserved.
 *
 * See LICENSE.txt for license information
 ************************************************************************/

#include "devcomm.h"
#include "collectives.h"
#include "primitives.h"

namespace {
  template<typename T, typename RedOp, typename Proto>
#if defined(USE_INDIRECT_FUNCTION_CALL) && !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
  __device__ void runRing(ncclWorkElem *args) {
#else
  __device__ __attribute__((noinline)) void runRing(ncclWorkElem *args) {
#endif
    const int tid = threadIdx.x;
    const int nthreads = args->nWarps*WARP_SIZE;
    const int bid = args->bid;
    const int nChannels = args->nChannels;
    ncclRing *ring = &ncclShmem.channel.ring;
    const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? BROADCAST_CHUNKSTEPS : 1));
    const ssize_t minChunkSizeLL128 = int(nthreads*(Proto::calcBytePerGrain()/sizeof(T)));
    const ssize_t loopSize = nChannels*chunkSize;
    const ssize_t size = args->count;
    const int rank = ring->userRanks[0];
    const int nextRank = ring->userRanks[1];
    const int root = args->root;

#if defined(ENABLE_NPKIT)
    int npKitCtxIdx = bid;
#endif

#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_CPU)
    if (tid == 0) {
      uint64_t* cpuTimestamp = ncclShmem.comm.cpuTimestamp;
      NpKit::CollectGpuEvent(NPKIT_EVENT_TIME_SYNC_CPU, 0, 0, *cpuTimestamp,
          ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
    }
#endif

#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_GPU)
    if (tid == 0) {
      NpKit::CollectGpuEvent(NPKIT_EVENT_TIME_SYNC_GPU, 0, 0, NPKIT_GET_GPU_TIMESTAMP(),
          ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
    }
#endif

#if defined (ENABLE_TIMELINE)
    int elems = 0, totalElems = 0;
    uint64_t clkStamp = 0ULL;
    struct ncclDevComm* comm = &ncclShmem.comm; 
    uint64_t entryStamp = __builtin_amdgcn_s_memrealtime();
    Timeline::CollectGpuPrimEvent(comm->gpuEventContext, TIMELINE_EVENT_BROADCAST_ENTRY, 0, entryStamp, comm->cpuTimestamp); 
#endif

    T *inputBuf = (T*)args->sendbuff;
    T *outputBuf = (T*)args->recvbuff;
    Primitives<T, RedOp, FanSymmetric<1>, 0, Proto, 0>
      prims(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, args->redOpArg, 0, args->connIndex, args->connIndex);

#ifdef HYGON_SDMA_FEATURE
    prims.ringIx = ring->index;
    INIT_PRIMS_SDMA(prims, args);
#endif
#if defined(ENABLE_NPKIT)
    if (tid == 0) {
      prims.npKitCtxIdx = npKitCtxIdx;
    }
#endif
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_BROADCAST_RING_ENTRY)
    if (tid == 0) {
      NpKit::CollectGpuEvent(NPKIT_EVENT_BROADCAST_RING_ENTRY, size*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
          ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
    }
#endif

    for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
      ssize_t realChunkSize;
      if (Proto::Id == NCCL_PROTO_SIMPLE) {
        realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels));
        realChunkSize = roundUp(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
      }
      else if (Proto::Id == NCCL_PROTO_LL)
        realChunkSize = size-gridOffset < loopSize ? args->lastChunkSize : chunkSize;
      else if (Proto::Id == NCCL_PROTO_LL128)
        realChunkSize = min(chunkSize, divUp(size-gridOffset, nChannels*minChunkSizeLL128)*minChunkSizeLL128);
      realChunkSize = int(realChunkSize);

      ssize_t offset = gridOffset + int(bid*realChunkSize);
      int nelem = min(realChunkSize, size-offset);

      if (rank == root) {
        if (inputBuf == outputBuf) {
#if defined (ENABLE_TIMELINE)
          elems = max(0, nelem);
          clkStamp = __builtin_amdgcn_s_memrealtime();
          Timeline::CollectGpuPrimEvent(comm->gpuEventContext, TIMELINE_EVENT_PRIM_SEND_ENTRY, elems*sizeof(T), clkStamp, comm->cpuTimestamp); 
#endif
          prims.send(offset, nelem);
#if defined (ENABLE_TIMELINE)
          totalElems += elems;
          Timeline::CollectGpuPrimEvent(comm->gpuEventContext, TIMELINE_EVENT_PRIM_SEND_EXIT, elems*sizeof(T), __builtin_amdgcn_s_memrealtime() - clkStamp, comm->cpuTimestamp); 
#endif
        } else {
#if defined (ENABLE_TIMELINE)
          elems = max(0, nelem);
          clkStamp = __builtin_amdgcn_s_memrealtime();
          Timeline::CollectGpuPrimEvent(comm->gpuEventContext, TIMELINE_EVENT_PRIM_COPY_SEND_ENTRY, elems*sizeof(T), clkStamp, comm->cpuTimestamp); 
#endif
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_BROADCAST_RING_COPY_SEND_ENTRY)
            if (threadIdx.x == 0) {
              NpKit::CollectGpuEvent(NPKIT_EVENT_BROADCAST_RING_COPY_SEND_ENTRY, max(0, nelem)*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
                ncclShmem.comm.npKitEventCollectContexts + blockIdx.x); 
              prims.npKitDataProcessTotalTime = 0;
            }
#endif
          prims.copySend(offset, offset, nelem);
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_BROADCAST_RING_COPY_SEND_EXIT)
            if (threadIdx.x == 0)
              NpKit::CollectGpuEvent(NPKIT_EVENT_BROADCAST_RING_COPY_SEND_EXIT, max(0, nelem)*sizeof(T), prims.npKitDataProcessTotalTime, NPKIT_GET_GPU_TIMESTAMP(),
                ncclShmem.comm.npKitEventCollectContexts + blockIdx.x); 
#endif
#if defined (ENABLE_TIMELINE)
          totalElems += elems;
          Timeline::CollectGpuPrimEvent(comm->gpuEventContext, TIMELINE_EVENT_PRIM_COPY_SEND_EXIT, elems*sizeof(T), __builtin_amdgcn_s_memrealtime() - clkStamp, comm->cpuTimestamp); 
#endif
        }
      } else if (nextRank == root) {
#if defined (ENABLE_TIMELINE)
        elems = max(0, nelem);
        clkStamp = __builtin_amdgcn_s_memrealtime();
        Timeline::CollectGpuPrimEvent(comm->gpuEventContext, TIMELINE_EVENT_PRIM_RECV_ENTRY, elems*sizeof(T), clkStamp, comm->cpuTimestamp); 
#endif
        prims.recv(offset, nelem);
#if defined (ENABLE_TIMELINE)
        totalElems += elems;
        Timeline::CollectGpuPrimEvent(comm->gpuEventContext, TIMELINE_EVENT_PRIM_RECV_EXIT, elems*sizeof(T), __builtin_amdgcn_s_memrealtime() - clkStamp, comm->cpuTimestamp); 
#endif
      } else {
#if defined (ENABLE_TIMELINE)
        elems = max(0, nelem);
        clkStamp = __builtin_amdgcn_s_memrealtime();
        Timeline::CollectGpuPrimEvent(comm->gpuEventContext, TIMELINE_EVENT_PRIM_RECV_COPY_SEND_ENTRY, elems*sizeof(T), clkStamp, comm->cpuTimestamp); 
#endif
        prims.recvCopySend(offset, nelem);
        #if defined (ENABLE_TIMELINE)
        totalElems += elems;
        Timeline::CollectGpuPrimEvent(comm->gpuEventContext, TIMELINE_EVENT_PRIM_RECV_COPY_SEND_EXIT, elems*sizeof(T), __builtin_amdgcn_s_memrealtime() - clkStamp, comm->cpuTimestamp); 
#endif
      }
    }
#if defined (ENABLE_TIMELINE)
    Timeline::CollectGpuPrimEvent(comm->gpuEventContext, TIMELINE_EVENT_BROADCAST_EXIT, totalElems*sizeof(T), __builtin_amdgcn_s_memrealtime() - entryStamp, comm->cpuTimestamp); 
#endif
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_BROADCAST_RING_EXIT)
    if (tid == 0) {
      NpKit::CollectGpuEvent(NPKIT_EVENT_BROADCAST_RING_EXIT, size*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
          ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
    }
#endif
  }
}

template<typename T, typename RedOp>
struct RunWorkElement<ncclFuncBroadcast, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
  __device__ __forceinline__ void run(ncclWorkElem *args) {
    using Proto = ProtoSimple<BROADCAST_CHUNKSTEPS/BROADCAST_SLICESTEPS, BROADCAST_SLICESTEPS>;
    runRing<T, RedOp, Proto>(args);
  }
};

template<typename T, typename RedOp>
struct RunWorkElement<ncclFuncBroadcast, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL> {
  __device__ __forceinline__ void run(ncclWorkElem *args) {
    runRing<T, RedOp, ProtoLL>(args);
  }
};

template<typename T, typename RedOp>
struct RunWorkElement<ncclFuncBroadcast, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL128> {
  __device__ __forceinline__ void run(ncclWorkElem *args) {
    runRing<T, RedOp, ProtoLL128>(args);
  }
};