custom_all_reduce_test.cu 12.3 KB
Newer Older
1
2
3
/**
 * This is a standalone test for custom allreduce.
 * To compile, make sure you have MPI and NCCL installed in your system.
4
 * export MPI_HOME=xxx
5
 * nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
6
 * custom_all_reduce_test -lnccl -I${MPI_HOME} -lmpi
7
8
9
10
11
 *
 * Warning: this C++ test is not designed to be very readable and was used
 * during the rapid prototyping process.
 *
 * To run:
12
 * mpirun --allow-run-as-root -np 8 ./custom_all_reduce_test
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
 */
#include <cuda.h>
#include <curand_kernel.h>
#include <stdio.h>
#include <stdlib.h>

#include <limits>
#include <vector>

#include "cuda_profiler_api.h"
#include "custom_all_reduce.cuh"
#include "mpi.h"
#include "nccl.h"

#define MPICHECK(cmd)                                                  \
  do {                                                                 \
    int e = cmd;                                                       \
    if (e != MPI_SUCCESS) {                                            \
      printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \
      exit(EXIT_FAILURE);                                              \
    }                                                                  \
  } while (0)

#define NCCLCHECK(cmd)                                              \
  do {                                                              \
    ncclResult_t r = cmd;                                           \
    if (r != ncclSuccess) {                                         \
      printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
             ncclGetErrorString(r));                                \
      exit(EXIT_FAILURE);                                           \
    }                                                               \
  } while (0)

__global__ void dummy_kernel() {
47
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
48
  for (int i = 0; i < 100; i++) __nanosleep(1000000);  // 100ms
49
50
51
52
53
54
#else
  for (int i = 0; i < 100; i++) {
    long long int start = clock64();
    while (clock64() - start < 150000000);  // approximately 98.4ms on P40
  }
#endif
55
56
57
}

template <typename T>
58
__global__ void set_data(T* data, int size, int myRank) {
59
60
61
62
63
64
65
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
       idx += gridDim.x * blockDim.x) {
    data[idx] = myRank * 0.11f;
  }
}

template <typename T>
66
67
__global__ void convert_data(const T* data1, const T* data2, double* fdata1,
                             double* fdata2, int size) {
68
69
70
71
72
73
74
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
       idx += gridDim.x * blockDim.x) {
    fdata1[idx] = data1[idx];
    fdata2[idx] = data2[idx];
  }
}

75
__global__ void init_rand(curandState_t* state, int size, int nRanks) {
76
77
78
79
80
81
82
83
84
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
       idx += gridDim.x * blockDim.x) {
    for (int i = 0; i < nRanks; i++) {
      curand_init(i + 1, idx, 0, &state[idx * nRanks + i]);
    }
  }
}

template <typename T>
85
__global__ void gen_data(curandState_t* state, T* data, double* ground_truth,
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
                         int myRank, int nRanks, int size) {
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
       idx += gridDim.x * blockDim.x) {
    double sum = 0.0;
    for (int i = 0; i < nRanks; i++) {
      double val = curand_uniform_double(&state[idx * nRanks + i]) * 4;
      T hval = val;  // downcast first
      sum += static_cast<double>(hval);
      if (i == myRank) data[idx] = hval;
    }
    ground_truth[idx] = sum;
  }
}

template <typename T>
101
void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
102
         int data_size, bool performance_test) {
103
  T* result;
104
105
106
107
108
109
110
  cudaStream_t stream;
  CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
  CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
  CUDACHECK(cudaMemset(result, 0, data_size * sizeof(T)));

  cudaIpcMemHandle_t self_data_handle;
  cudaIpcMemHandle_t data_handles[8];
111
112
  vllm::Signal* buffer;
  T* self_data_copy;
113
114
115
116
117
118
119
120
121
122
123
124
  /**
   * Allocate IPC buffer
   *
   * The first section is a temporary buffer for storing intermediate allreduce
   * results, if a particular algorithm requires it. The second section is for
   * the input to the allreduce. The actual API takes the input pointer as an
   * argument (that is, they can and usually should be allocated separately).
   * But since the input pointers and the temporary buffer all require IPC
   * registration, they are allocated and registered together in the test for
   * convenience.
   */
  CUDACHECK(
125
126
127
      cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
  CUDACHECK(
      cudaMemset(buffer, 0, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
128
129
130
131
132
133
134
  CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T)));
  CUDACHECK(cudaIpcGetMemHandle(&self_data_handle, buffer));

  MPICHECK(MPI_Allgather(&self_data_handle, sizeof(cudaIpcMemHandle_t),
                         MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t),
                         MPI_BYTE, MPI_COMM_WORLD));

135
  void* rank_data;
136
137
138
139
140
  size_t rank_data_sz = 16 * 1024 * 1024;
  CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
  std::vector<int64_t> offsets(nRanks, 0);
  vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
                           offsets, myRank);
141
142
143
  auto* self_data =
      reinterpret_cast<T*>(reinterpret_cast<char*>(buffer) +
                           sizeof(vllm::Signal) + data_size * sizeof(T));
144
145
146
147
148
  // hack buffer registration
  {
    std::vector<std::string> handles;
    handles.reserve(nRanks);
    for (int i = 0; i < nRanks; i++) {
149
150
      char* begin = (char*)&data_handles[i];
      char* end = (char*)&data_handles[i + 1];
151
152
      handles.emplace_back(begin, end);
    }
153
154
    std::vector<int64_t> offsets(nRanks,
                                 sizeof(vllm::Signal) + data_size * sizeof(T));
155
156
157
    fa.register_buffer(handles, offsets, self_data);
  }

158
  double* ground_truth;
159
  CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double)));
160
  curandState_t* states;
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
  CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
  init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
  gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
                                        nRanks, data_size);
  CUDACHECK(cudaMemcpyAsync(self_data_copy, self_data, data_size * sizeof(T),
                            cudaMemcpyDeviceToDevice, stream));
  cudaEvent_t start, stop;
  CUDACHECK(cudaEventCreate(&start));
  CUDACHECK(cudaEventCreate(&stop));

  ncclDataType_t ncclDtype;
  if (std::is_same<T, half>::value) {
    ncclDtype = ncclFloat16;
  } else if (std::is_same<T, nv_bfloat16>::value) {
    ncclDtype = ncclBfloat16;
  } else {
    ncclDtype = ncclFloat;
  }
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
  double *nccl_result, *my_result;
  CUDACHECK(cudaMallocHost(&nccl_result, data_size * sizeof(double)));
  CUDACHECK(cudaMallocHost(&my_result, data_size * sizeof(double)));
  if (performance_test) {
    dummy_kernel<<<1, 1, 0, stream>>>();
    constexpr int warmup_iters = 5;
    constexpr int num_iters = 100;
    // warmup
    for (int i = 0; i < warmup_iters; i++) {
      NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum,
                              comm, stream));
    }
    CUDACHECK(cudaEventRecord(start, stream));
    for (int i = 0; i < num_iters; i++) {
      NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum,
                              comm, stream));
    }
    CUDACHECK(cudaEventRecord(stop, stream));
    CUDACHECK(cudaStreamSynchronize(stream));
    float allreduce_ms = 0;
    cudaEventElapsedTime(&allreduce_ms, start, stop);
200

201
202
203
204
205
206
207
208
209
210
211
212
213
    dummy_kernel<<<1, 1, 0, stream>>>();
    // warm up
    for (int i = 0; i < warmup_iters; i++) {
      fa.allreduce<T>(stream, self_data, result, data_size, threads,
                      block_limit);
    }
    CUDACHECK(cudaEventRecord(start, stream));
    for (int i = 0; i < num_iters; i++) {
      fa.allreduce<T>(stream, self_data, result, data_size, threads,
                      block_limit);
    }
    CUDACHECK(cudaEventRecord(stop, stream));
    CUDACHECK(cudaStreamSynchronize(stream));
214

215
216
217
218
219
220
221
222
    float duration_ms = 0;
    cudaEventElapsedTime(&duration_ms, start, stop);
    if (myRank == 0)
      printf(
          "Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl "
          "time:%.2fus\n",
          myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit,
          duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters);
223

224
225
    // And wait for all the queued up work to complete
    CUDACHECK(cudaStreamSynchronize(stream));
226

227
228
    NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype,
                            ncclSum, comm, stream));
229

230
231
232
    convert_data<T><<<108, 1024, 0, stream>>>(self_data, result, nccl_result,
                                              my_result, data_size);
    CUDACHECK(cudaStreamSynchronize(stream));
233

234
235
236
237
238
239
240
    for (unsigned long j = 0; j < data_size; j++) {
      auto diff = abs(nccl_result[j] - my_result[j]);
      if (diff >= 4e-2) {
        printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
               myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
        break;
      }
241
    }
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    long double nccl_diffs = 0.0;
    long double my_diffs = 0.0;
    for (int j = 0; j < data_size; j++) {
      nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
      my_diffs += abs(my_result[j] - ground_truth[j]);
    }
    if (myRank == 0)
      std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
                << " me: " << my_diffs / data_size << std::endl;
  } else {
    for (int i = 0; i < 100; i++) {
      fa.allreduce<T>(stream, self_data, result, data_size, threads,
                      block_limit);
      CUDACHECK(cudaStreamSynchronize(stream));
      NCCLCHECK(ncclAllReduce(self_data, self_data_copy, data_size, ncclDtype,
                              ncclSum, comm, stream));
      convert_data<T><<<108, 1024, 0, stream>>>(
          self_data_copy, result, nccl_result, my_result, data_size);
      CUDACHECK(cudaStreamSynchronize(stream));
261

262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
      for (unsigned long j = 0; j < data_size; j++) {
        auto diff = abs(nccl_result[j] - my_result[j]);
        if (diff >= 4e-2) {
          printf(
              "Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
              myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
          break;
        }
      }
    }
    if (myRank == 0)
      printf("Test passed: nGPUs:%d, sz (kb): %d, %d, %d\n", nRanks,
             data_size * sizeof(T) / 1024, threads, block_limit);
    // long double nccl_diffs = 0.0;
    // long double my_diffs = 0.0;
    // for (int j = 0; j < data_size; j++) {
    //   nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
    //   my_diffs += abs(my_result[j] - ground_truth[j]);
    // }
    // if (myRank == 0)
    //   std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
    //             << " me: " << my_diffs / data_size << std::endl;
284
285
286
287
288
289
290
291
292
293
294
295
296
  }

  CUDACHECK(cudaFree(result));
  CUDACHECK(cudaFree(self_data_copy));
  CUDACHECK(cudaFree(rank_data));
  CUDACHECK(cudaFree(buffer));
  CUDACHECK(cudaFree(states));
  CUDACHECK(cudaFreeHost(ground_truth));
  CUDACHECK(cudaFreeHost(nccl_result));
  CUDACHECK(cudaFreeHost(my_result));
  CUDACHECK(cudaStreamDestroy(stream));
}

297
int main(int argc, char** argv) {
298
299
300
301
302
303
304
305
  int nRanks, myRank;
  MPICHECK(MPI_Init(&argc, &argv));
  MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
  MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks));
  CUDACHECK(cudaSetDevice(myRank));
  ncclUniqueId id;
  ncclComm_t comm;
  if (myRank == 0) ncclGetUniqueId(&id);
306
  MPICHECK(MPI_Bcast(static_cast<void*>(&id), sizeof(id), MPI_BYTE, 0,
307
308
309
                     MPI_COMM_WORLD));
  NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));

310
  bool performance_test = true;
311
  cudaProfilerStart();
312
313
  // Uncomment to scan through different block size configs.
  // for (int threads : {256, 512, 1024}) {
314
  //   for (int block_limit = 16; block_limit < 112; block_limit += 4) {
315
316
  //     run<half>(myRank, nRanks, comm, threads, block_limit, 1024 * 1024,
  //     performance_test);
317
318
  //   }
  // }
319
  // Scan through different sizes to test performance.
320
321
  for (int sz = 512; sz <= (8 << 20); sz *= 2) {
    run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
322
323
324
  }

  cudaProfilerStop();
325
  MPICHECK(MPI_Finalize());
326
327
  return EXIT_SUCCESS;
}