msccl_struct.h 7.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
/*************************************************************************
 * Copyright (c) Microsoft Corporation.
 * Licensed under the MIT License.
 ************************************************************************/

#ifndef MSCCL_STRUCT_H_
#define MSCCL_STRUCT_H_

#include <cstdint>
#include <map>
#include <set>
#include <vector>
#include "devcomm.h"
#include "msccl/msccl_scheduler.h"

#define MSCCL_MAX_NUM_STEPS 64
#define MSCCL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL 32
#define MSCCL_MAX_NUM_THREAD_BLOCKS 64
#define MSCCL_MAX_COUNT 72 // max concurrent number of msccl chunk transmission
#define MSCCL_MAX_REDUCE_FUSION 16
#define MSCCL_MAX_NUM_ALGOS 1024

#define MSCCL_SLICESTEPS (NCCL_STEPS/4)
#define MSCCL_CHUNKSTEPS (NCCL_STEPS/2)

#define MSCCL_INPUT_BUFFER 0
#define MSCCL_OUTPUT_BUFFER 1
#define MSCCL_SCRATCH_BUFFER 2

#define MSCCL_SEND 0
#define MSCCL_RECV 1
#define MSCCL_RECV_COPY_SEND 2
#define MSCCL_RECV_REDUCE_SEND 3
#define MSCCL_RECV_REDUCE_COPY 4
#define MSCCL_RECV_REDUCE_COPY_SEND 5
#define MSCCL_LOCAL_COPY 6
#define MSCCL_REDUCE 7

struct mscclTransmission {
  int16_t dependencePointer; // index to the first dependence
  int16_t numDependencies; // dependencePointer+numDependencies indicate the last dependence
  int16_t reductionPointer; // where the reduction starts
  int16_t numReductions; // number of reductions with the same dst
  int16_t srcOffset;
  int16_t dstOffset;
  uint8_t srcBuffer : 4; // input/output/scratch
  uint8_t dstBuffer : 4; // input/output/scratch
  int8_t hasDependence;
  uint8_t type;
  uint8_t count;
}; // 16 bytes

static_assert((1ULL << (8*sizeof(mscclTransmission::count))) - 1 > MSCCL_MAX_COUNT, "MSCCL_MAX_COUNT must representable by datatype of count");

struct mscclThreadBlock {
  // step is used to index into these arrays
  struct mscclTransmission transmissions[MSCCL_MAX_NUM_STEPS]; // 4KB
  int8_t dependentBid[MSCCL_MAX_NUM_STEPS]; // -1 if not dependent on any thread block, 256 bytes
  int16_t dependentStep[MSCCL_MAX_NUM_STEPS]; // 512 bytes
  int16_t reductionSrcOffsets[MSCCL_MAX_NUM_STEPS]; // 512 bytes
  int16_t sendPeer;
  int16_t recvPeer;
  uint16_t nSteps;
  int16_t channelId; // associated channel. -1 indicates a thread block with only local copies
}; // 5384 bytes

static_assert(sizeof(struct mscclThreadBlock) % sizeof(uint64_t) == 0, "Sanity check: sizeof(struct mscclThreadBlock) % sizeof(uint64_t) != 0");

struct mscclFlag {
  uint64_t flag;
  uint64_t align[3]; // to avoid false sharing
};

struct mscclChannelPeerInfo {
  int peer;
  // nTransmissionsOfCount[i]: number of transmissions with count i (in terms of msccl chunks)
  int nTransmissionsOfCount[MSCCL_MAX_COUNT + 1];
  int existingCounts[MSCCL_MAX_COUNT + 1];
  int nExistingCounts;
};

struct mscclChannelInfo {
  struct mscclChannelPeerInfo sendPeerInfo[MSCCL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL];
  int nSendPeers;
  struct mscclChannelPeerInfo recvPeerInfo[MSCCL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL];
  int nRecvPeers;
};

struct mscclAlgoMeta {
  // Path to algorithm file
  std::string filePath;
  // number of chunks of input/output in each MSCCL algorithm loop
  int nChunksPerLoop;
  // number of ranks required by this algorithm
  int nRanks;
  // need to times nRanks for all-gather, reduce-scatter and all-to-all
  int sizeMultiplier;
  // MSCCL function type
  mscclFunc_t func;
  // Min message size allowed for this algorithm.
  int64_t minBytes;
  // Max message size allowed for this algorithm, 0 for no limit.
  int64_t maxBytes;
  // Whether this algorithm is suitable for in-place.
  bool inPlace;
  // Whether this algorithm is suitable for out-of-place.
  bool outOfPlace;
};

struct mscclAlgo {
  // number of chunks of input/output in each MSCCL algorithm loop
  int nChunksPerLoop;
  // the protocol that the algorithm needs to use
  int protocol;
  // number of channels needed by MSCCL algorithm
  int nChannels;
  // number of ranks required by this algorithm
  int nRanks;
  // number of necessary thread blocks
  int nBlocks;
  // number of scratch chunks that MSCCL will use
  int nScratchChunks;
  // need to times nRanks for all-gather, reduce-scatter and all-to-all
  int sizeMultiplier;
  // number of steps per chunk for this algorithm
  int chunkSteps;
  // number of steps per slice for this algorithm
  int sliceSteps;
  // bid is used as an index into this array
  struct mscclThreadBlock mscclTBs[MSCCL_MAX_NUM_THREAD_BLOCKS];
  // used to calculate proxy info
  struct mscclChannelInfo mscclChannels[MAXCHANNELS];
  // Whether the algorithm requires reduce operation
  bool hasReduce;
  // MSCCL function type
  mscclFunc_t func;
  // Min message size allowed for this algorithm.
  int64_t minBytes;
  // Max message size allowed for this algorithm, 0 for no limit.
  int64_t maxBytes;
  // Whether this algorithm is suitable for in-place.
  bool inPlace;
  // Whether this algorithm is suitable for out-of-place.
  bool outOfPlace;
  // Keep a bit mask of used types (max 8 at present)
  uint8_t typeMask;
};

enum mscclGroupStatus {
  mscclNoGroup,
  mscclGroupSupportedOp,
  mscclGroupUnsupportedOp
};

struct mscclSavedSchedulerParam {
  struct mscclSchedulerParam p;
  std::vector<size_t> savedSendCounts;
  std::vector<size_t> savedSDisPls;
  std::vector<size_t> savedRecvCounts;
  std::vector<size_t> savedRDisPls;
  ncclComm_t comm;
  hipStream_t stream;
};

enum mscclCaptureStatus {
  mscclNoCapture,
  mscclNewCapture,
  mscclExistingCapture
};

struct mscclProxyArg {
  struct mscclAlgo* hostAlgo;
  ncclComm_t comm;
  mscclProxyArg(struct mscclAlgo* hostAlgo, ncclComm_t comm) 
    : hostAlgo(hostAlgo), comm(comm) {}
};

typedef std::map<unsigned long long, std::vector<struct mscclProxyArg>> mscclSavedProxyArgs;

struct mscclThreadLocalStatus {
  bool mscclIsCallerFlag;
  mscclGroupStatus groupStatus;
  int groupDepth;
  std::vector<struct mscclSavedSchedulerParam> savedSchedulerParams;
  unsigned long long captureId;
  mscclCaptureStatus captureStatus;
  hipGraph_t graph;
};

struct mscclWorkFifoStatus {
  uint64_t workFifoDepth;
  struct mscclWork* workFifo;
  uint32_t* workFifoDone;
  uint32_t workFifoSent;
  uint32_t workFifoSentPerThreadBlock[MSCCL_MAX_NUM_THREAD_BLOCKS];
  uint32_t workFifoAckdMin;
};

typedef std::map<unsigned long long, mscclWorkFifoStatus> mscclSavedGraphWorkFifoStatus;

struct mscclStatus {
  std::vector<mscclAlgoHandle_t> freeAlgoHandles;
  std::map<mscclAlgoHandle_t, mscclAlgo *> hostAlgos;
  std::map<mscclAlgoHandle_t, mscclAlgo *> devAlgos;
  struct mscclFlag* syncFlags;
  void *scratchBuffer;
  uint64_t scratchBufferSize;
  size_t nBytes;
  int stepSize;
  int chunkSteps;
  int sliceSteps;
  int chunkSize;
  int chunkEffectiveSize;
  uint32_t workIndex;
  uint32_t maxAllowedCount;
  ncclDataType_t dataType;
  std::map<ncclComm_t, std::set<mscclAlgoHandle_t>> connectedAlgos;
  hipStream_t lastStream;
  void* mscclSchedulerLib;
  mscclSchedulerInterface* mscclSchedulerPtr;
  std::vector<mscclAlgoMeta> algoMetas;
  std::vector<std::map<int, mscclAlgoHandle_t>> rankToAlgoHandles;
  bool graphEnabled;
  bool graphFirstKernel;
  bool needsProxy;
  mscclWorkFifoStatus defaultWorkFifoStatus;
  mscclSavedGraphWorkFifoStatus graphWorkFifoStatus;
};

#pragma pack(push)
#pragma pack(8)

struct mscclWork {
  volatile struct mscclFlag *syncFlags;
  void *scratchBuffer;
  const void *sendBuff;
  void *recvBuff;
  uint32_t* workFifoDone;
  size_t sizePerMscclChunk;
  uint64_t redOpArg;
  uint32_t workIndex;
  uint32_t maxAllowedCount;
  uint32_t workFifoDoneAck;
  int nChunksPerLoop;
  bool hasReduce;
  bool redOpArgIsPtr;
  uint32_t fnIndex;
};
static_assert(sizeof(struct mscclWork) % 16 == 0, "mscclWork needs to be 16B aligned");

#pragma pack(pop)

struct mscclShmemData {
  struct mscclThreadBlock mscclTB;
  alignas(16) struct mscclWork work;
};
static_assert(offsetof(struct mscclShmemData, work) % 16 == 0, "mscclShmemData.work needs to be 16B aligned");

#endif