custom_all_reduce_test.cu 12.9 KB
Newer Older
1
2
3
/**
 * This is a standalone test for custom allreduce.
 * To compile, make sure you have MPI and NCCL installed in your system.
xiabo's avatar
xiabo committed
4
 * export MPI_HOME=XXX
5
 * nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
xiabo's avatar
xiabo committed
6
 * custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi
7
8
9
10
11
 *
 * Warning: this C++ test is not designed to be very readable and was used
 * during the rapid prototyping process.
 *
 * To run:
12
 * mpirun --allow-run-as-root -np 8 ./custom_all_reduce_test
13
 */
xiabo's avatar
xiabo committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
 #include <cuda.h>
 #include <curand_kernel.h>
 #include <stdio.h>
 #include <stdlib.h>
 
 #include <limits>
 #include <vector>
 
 #include "cuda_profiler_api.h"
 #include "mpi.h"
 #ifdef USE_ROCM
   #include <hip/hip_bf16.h>
 typedef __hip_bfloat16 nv_bfloat16;
   #include "rccl/rccl.h"
   #include "custom_all_reduce_hip.cuh"
 #else
   #include "nccl.h"
   #include "custom_all_reduce.cuh"
 #endif
 
 #define MPICHECK(cmd)                                                  \
   do {                                                                 \
     int e = cmd;                                                       \
     if (e != MPI_SUCCESS) {                                            \
       printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \
       exit(EXIT_FAILURE);                                              \
     }                                                                  \
   } while (0)
 
 #define NCCLCHECK(cmd)                                              \
   do {                                                              \
     ncclResult_t r = cmd;                                           \
     if (r != ncclSuccess) {                                         \
       printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
              ncclGetErrorString(r));                                \
       exit(EXIT_FAILURE);                                           \
     }                                                               \
   } while (0)
 
 __global__ void dummy_kernel() {
 #ifdef USE_ROCM
   for (int i = 0; i < 100; i++) {
     uint64_t start = wall_clock64();
     uint64_t cycles_elapsed;
     do {
       cycles_elapsed = wall_clock64() - start;
     } while (cycles_elapsed < 100);
   }
 #else
   for (int i = 0; i < 100; i++) __nanosleep(1000000);  // 100ms
 #endif
 }
 
 template <typename T>
 __global__ void set_data(T* data, int size, int myRank) {
   for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
        idx += gridDim.x * blockDim.x) {
     data[idx] = myRank * 0.11f;
   }
 }
 
 template <typename T>
 __global__ void convert_data(const T* data1, const T* data2, double* fdata1,
                              double* fdata2, int size) {
   for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
        idx += gridDim.x * blockDim.x) {
     fdata1[idx] = data1[idx];
     fdata2[idx] = data2[idx];
   }
 }
 
 __global__ void init_rand(curandState_t* state, int size, int nRanks) {
   for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
        idx += gridDim.x * blockDim.x) {
     for (int i = 0; i < nRanks; i++) {
       curand_init(i + 1, idx, 0, &state[idx * nRanks + i]);
     }
   }
 }
 
 template <typename T>
 __global__ void gen_data(curandState_t* state, T* data, double* ground_truth,
                          int myRank, int nRanks, int size) {
   for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
        idx += gridDim.x * blockDim.x) {
     double sum = 0.0;
     for (int i = 0; i < nRanks; i++) {
       double val = curand_uniform_double(&state[idx * nRanks + i]) * 4;
       T hval = val;  // downcast first
       sum += static_cast<double>(hval);
       if (i == myRank) data[idx] = hval;
     }
     ground_truth[idx] = sum;
   }
 }
 
 template <typename T>
 void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
          int data_size, bool performance_test) {
   T* result;
   cudaStream_t stream;
   CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
   CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
   CUDACHECK(cudaMemset(result, 0, data_size * sizeof(T)));
 
   cudaIpcMemHandle_t self_data_handle;
   cudaIpcMemHandle_t data_handles[8];
   vllm::Signal* buffer;
   T* self_data_copy;
   /**
    * Allocate IPC buffer
    *
    * The first section is a temporary buffer for storing intermediate allreduce
    * results, if a particular algorithm requires it. The second section is for
    * the input to the allreduce. The actual API takes the input pointer as an
    * argument (that is, they can and usually should be allocated separately).
    * But since the input pointers and the temporary buffer all require IPC
    * registration, they are allocated and registered together in the test for
    * convenience.
    */
 #ifdef USE_ROCM
   CUDACHECK(hipExtMallocWithFlags(
       (void**)&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal),
       hipDeviceMallocUncached));
 #else
   CUDACHECK(
       cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
 #endif
   CUDACHECK(
       cudaMemset(buffer, 0, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
   CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T)));
   CUDACHECK(cudaIpcGetMemHandle(&self_data_handle, buffer));
 
   MPICHECK(MPI_Allgather(&self_data_handle, sizeof(cudaIpcMemHandle_t),
                          MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t),
                          MPI_BYTE, MPI_COMM_WORLD));
 
   void* rank_data;
   size_t rank_data_sz = 16 * 1024 * 1024;
   CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
   std::vector<int64_t> offsets(nRanks, 0);
   vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
                            offsets, myRank);
   auto* self_data =
       reinterpret_cast<T*>(reinterpret_cast<char*>(buffer) +
                            sizeof(vllm::Signal) + data_size * sizeof(T));
   // hack buffer registration
   {
     std::vector<std::string> handles;
     handles.reserve(nRanks);
     for (int i = 0; i < nRanks; i++) {
       char* begin = (char*)&data_handles[i];
       char* end = (char*)&data_handles[i + 1];
       handles.emplace_back(begin, end);
     }
     std::vector<int64_t> offsets(nRanks,
                                  sizeof(vllm::Signal) + data_size * sizeof(T));
     fa.register_buffer(handles, offsets, self_data);
   }
 
   double* ground_truth;
   CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double)));
   curandState_t* states;
   CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
   init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
   gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
                                         nRanks, data_size);
   CUDACHECK(cudaMemcpyAsync(self_data_copy, self_data, data_size * sizeof(T),
                             cudaMemcpyDeviceToDevice, stream));
   cudaEvent_t start, stop;
   CUDACHECK(cudaEventCreate(&start));
   CUDACHECK(cudaEventCreate(&stop));
 
   ncclDataType_t ncclDtype;
   if (std::is_same<T, half>::value) {
     ncclDtype = ncclFloat16;
   } else if (std::is_same<T, nv_bfloat16>::value) {
     ncclDtype = ncclBfloat16;
   } else {
     ncclDtype = ncclFloat;
   }
   double *nccl_result, *my_result;
   CUDACHECK(cudaMallocHost(&nccl_result, data_size * sizeof(double)));
   CUDACHECK(cudaMallocHost(&my_result, data_size * sizeof(double)));
   if (performance_test) {
     dummy_kernel<<<1, 1, 0, stream>>>();
     constexpr int warmup_iters = 5;
     constexpr int num_iters = 100;
     // warmup
     for (int i = 0; i < warmup_iters; i++) {
       NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum,
                               comm, stream));
     }
     CUDACHECK(cudaEventRecord(start, stream));
     for (int i = 0; i < num_iters; i++) {
       NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum,
                               comm, stream));
     }
     CUDACHECK(cudaEventRecord(stop, stream));
     CUDACHECK(cudaStreamSynchronize(stream));
     float allreduce_ms = 0;
     cudaEventElapsedTime(&allreduce_ms, start, stop);
 
     dummy_kernel<<<1, 1, 0, stream>>>();
     // warm up
     for (int i = 0; i < warmup_iters; i++) {
       fa.allreduce<T>(stream, self_data, result, data_size, threads,
                       block_limit);
     }
     CUDACHECK(cudaEventRecord(start, stream));
     for (int i = 0; i < num_iters; i++) {
       fa.allreduce<T>(stream, self_data, result, data_size, threads,
                       block_limit);
     }
     CUDACHECK(cudaEventRecord(stop, stream));
     CUDACHECK(cudaStreamSynchronize(stream));
 
     float duration_ms = 0;
     cudaEventElapsedTime(&duration_ms, start, stop);
     if (myRank == 0)
       printf(
           "Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl "
           "time:%.2fus\n",
           myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit,
           duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters);
 
     // And wait for all the queued up work to complete
     CUDACHECK(cudaStreamSynchronize(stream));
 
     NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype,
                             ncclSum, comm, stream));
 
     convert_data<T><<<108, 1024, 0, stream>>>(self_data, result, nccl_result,
                                               my_result, data_size);
     CUDACHECK(cudaStreamSynchronize(stream));
 
     for (unsigned long j = 0; j < data_size; j++) {
       auto diff = abs(nccl_result[j] - my_result[j]);
       if (diff >= 4e-2) {
         printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
                myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
         break;
       }
     }
     long double nccl_diffs = 0.0;
     long double my_diffs = 0.0;
     for (int j = 0; j < data_size; j++) {
       nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
       my_diffs += abs(my_result[j] - ground_truth[j]);
     }
     if (myRank == 0)
       std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
                 << " me: " << my_diffs / data_size << std::endl;
   } else {
     for (int i = 0; i < 100; i++) {
       fa.allreduce<T>(stream, self_data, result, data_size, threads,
                       block_limit);
       CUDACHECK(cudaStreamSynchronize(stream));
       NCCLCHECK(ncclAllReduce(self_data, self_data_copy, data_size, ncclDtype,
                               ncclSum, comm, stream));
       convert_data<T><<<108, 1024, 0, stream>>>(
           self_data_copy, result, nccl_result, my_result, data_size);
       CUDACHECK(cudaStreamSynchronize(stream));
 
       for (unsigned long j = 0; j < data_size; j++) {
         auto diff = abs(nccl_result[j] - my_result[j]);
         if (diff >= 4e-2) {
           printf(
               "Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
283
               myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
xiabo's avatar
xiabo committed
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
           break;
         }
       }
     }
     if (myRank == 0)
       printf("Test passed: nGPUs:%d, sz (kb): %d, %d, %d\n", nRanks,
              data_size * sizeof(T) / 1024, threads, block_limit);
     // long double nccl_diffs = 0.0;
     // long double my_diffs = 0.0;
     // for (int j = 0; j < data_size; j++) {
     //   nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
     //   my_diffs += abs(my_result[j] - ground_truth[j]);
     // }
     // if (myRank == 0)
     //   std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
     //             << " me: " << my_diffs / data_size << std::endl;
   }
 
   CUDACHECK(cudaFree(result));
   CUDACHECK(cudaFree(self_data_copy));
   CUDACHECK(cudaFree(rank_data));
   CUDACHECK(cudaFree(buffer));
   CUDACHECK(cudaFree(states));
   CUDACHECK(cudaFreeHost(ground_truth));
   CUDACHECK(cudaFreeHost(nccl_result));
   CUDACHECK(cudaFreeHost(my_result));
   CUDACHECK(cudaStreamDestroy(stream));
 }
 
 int main(int argc, char** argv) {
   int nRanks, myRank;
   MPICHECK(MPI_Init(&argc, &argv));
   MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
   MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks));
   CUDACHECK(cudaSetDevice(myRank));
   ncclUniqueId id;
   ncclComm_t comm;
   if (myRank == 0) ncclGetUniqueId(&id);
   MPICHECK(MPI_Bcast(static_cast<void*>(&id), sizeof(id), MPI_BYTE, 0,
                      MPI_COMM_WORLD));
   NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
 
   bool performance_test = true;
   cudaProfilerStart();
   // for (int threads : {256, 512}) {
   //   for (int block_limit = 16; block_limit < 112; block_limit += 4) {
   //     run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
   //   }
   // }
 #ifdef USE_ROCM
   for (int sz = 512; sz <= (8 << 20); sz *= 2) {
     run<half>(myRank, nRanks, comm, 512, 16, sz + 8 * 47, performance_test);
   }
 #else
   for (int sz = 512; sz <= (8 << 20); sz *= 2) {
     run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
   }
 #endif
 
   cudaProfilerStop();
   MPICHECK(MPI_Finalize());
   return EXIT_SUCCESS;
 }