sendrecv_compute_overlap.cu 3.86 KB
Newer Older
yuguo's avatar
yuguo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
// hipcc sendrecv_compute_overlap.cu -o sendrecv_compute_overlap -I /opt/mpi/include  -L /opt/mpi/lib/ -lmpi -L /opt/dtk/lib/ -lrccl
// mpirun -np 8 --allow-run-as-root  --oversubscribe --quiet ./sendrecv_compute_overlap
#include "hip/hip_runtime.h"
#include <cstdio>
#include <cstdlib>
#include <mpi.h>
#include <rccl.h>
#include <hip/hip_runtime.h>

#define MSG_SIZE 1024 * 1024  // 1M int = 4MB
#define CUDACHECK(cmd) do { \
  hipError_t e = cmd; \
  if (e != hipSuccess) { \
    printf("CUDA error %s:%d: '%s'\n", __FILE__, __LINE__, hipGetErrorString(e)); \
    exit(EXIT_FAILURE); \
  } \
} while(0)

#define NCCLCHECK(cmd) do { \
  ncclResult_t r = cmd; \
  if (r != ncclSuccess) { \
    printf("NCCL error %s:%d: '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \
    exit(EXIT_FAILURE); \
  } \
} while(0)

__global__ void compute_kernel(int* data, int size) {
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx < size) {
    data[idx] = data[idx] * 2 + 1;
  }
}

int main(int argc, char* argv[]) {
  int rank, size;
  ncclUniqueId id;
  ncclComm_t comm;
  hipStream_t stream_comm, stream_comp;

  MPI_Init(&argc, &argv);
  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  MPI_Comm_size(MPI_COMM_WORLD, &size);

  if (size < 2) {
    if (rank == 0) printf("This demo needs at least 2 processes.\n");
    MPI_Finalize();
    return 0;
  }

  CUDACHECK(hipSetDevice(rank));
  if (rank == 0)
    NCCLCHECK(ncclGetUniqueId(&id));
  MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);
  NCCLCHECK(ncclCommInitRank(&comm, size, id, rank));

  CUDACHECK(hipStreamCreate(&stream_comm));
  CUDACHECK(hipStreamCreate(&stream_comp));

  int* d_buf;
  CUDACHECK(hipMalloc(&d_buf, MSG_SIZE * sizeof(int)));

  // 初始化数据
  if (rank == 0) {
    int* h_data = new int[MSG_SIZE];
    for (int i = 0; i < MSG_SIZE; ++i)
      h_data[i] = i;
    CUDACHECK(hipMemcpy(d_buf, h_data, MSG_SIZE * sizeof(int), hipMemcpyHostToDevice));
    delete[] h_data;
  }

  // 创建 cuda events 计时
  hipEvent_t start, stop, comp_done, comm_done;
  CUDACHECK(hipEventCreate(&start));
  CUDACHECK(hipEventCreate(&stop));
  CUDACHECK(hipEventCreate(&comp_done));
  CUDACHECK(hipEventCreate(&comm_done));

  CUDACHECK(hipEventRecord(start));

  if (rank == 0) {
    // 启动计算(模拟 workload)+ send(通信)
    compute_kernel<<<(MSG_SIZE+255)/256, 256, 0, stream_comp>>>(d_buf, MSG_SIZE);
    CUDACHECK(hipEventRecord(comp_done, stream_comp));  // 2. 计算完成时记录 event
    CUDACHECK(hipStreamWaitEvent(stream_comm, comp_done, 0));
    NCCLCHECK(ncclSend(d_buf, MSG_SIZE, ncclInt, 1, comm, stream_comm));
  } else if (rank == 1) {
    // 启动接收 + 本地计算
    NCCLCHECK(ncclRecv(d_buf, MSG_SIZE, ncclInt, 0, comm, stream_comm));
    CUDACHECK(hipEventRecord(comm_done, stream_comm));  // 2. 通信完成时记录 event
    CUDACHECK(hipStreamWaitEvent(stream_comp, comm_done, 0));
    compute_kernel<<<(MSG_SIZE+255)/256, 256, 0, stream_comp>>>(d_buf, MSG_SIZE);
  }

  // 同步所有 stream
  CUDACHECK(hipStreamSynchronize(stream_comm));
  CUDACHECK(hipStreamSynchronize(stream_comp));
  CUDACHECK(hipEventRecord(stop));
  CUDACHECK(hipEventSynchronize(stop));

  float ms;
  CUDACHECK(hipEventElapsedTime(&ms, start, stop));
  printf("Rank %d: Total time = %.2f ms\n", rank, ms);

  // 查看部分数据内容
  if (rank == 1) {
    int* h_out = new int[5];
    CUDACHECK(hipMemcpy(h_out, d_buf, 5 * sizeof(int), hipMemcpyDeviceToHost));
    printf("Rank 1: first 5 received and computed values: ");
    for (int i = 0; i < 5; ++i) printf("%d ", h_out[i]);
    printf("\n");
    delete[] h_out;
  }

  // 清理
  CUDACHECK(hipFree(d_buf));
  CUDACHECK(hipStreamDestroy(stream_comm));
  CUDACHECK(hipStreamDestroy(stream_comp));
  NCCLCHECK(ncclCommDestroy(comm));
  MPI_Finalize();

  return 0;
}