ring_sendrecv.cu 2.91 KB
Newer Older
yuguo's avatar
yuguo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
// hipcc ring_sendrecv.cu -o ring_sendrecv -I /opt/mpi/include  -L /opt/mpi/lib/ -lmpi -L /opt/dtk/lib/ -lrccl
// mpirun -np 8 --allow-run-as-root  --oversubscribe --quiet ./ring_sendrecv
#include <cstdio>
#include <cstdlib>
#include <mpi.h>
#include <rccl.h>
#include <hip/hip_runtime.h>

#define MSG_SIZE 4

#define CUDACHECK(cmd) do { \
  hipError_t e = cmd; \
  if (e != hipSuccess) { \
    printf("HIP error %s:%d: '%s'\n", __FILE__, __LINE__, hipGetErrorString(e)); \
    exit(EXIT_FAILURE); \
  } \
} while(0)

#define NCCLCHECK(cmd) do { \
  ncclResult_t r = cmd; \
  if (r != ncclSuccess) { \
    printf("NCCL error %s:%d: '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \
    exit(EXIT_FAILURE); \
  } \
} while(0)

int main(int argc, char* argv[]) {
  int rank, world_size;
  ncclUniqueId id;
  ncclComm_t comm;
  hipStream_t stream;

  MPI_Init(&argc, &argv);
  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  MPI_Comm_size(MPI_COMM_WORLD, &world_size);

  int device_count = 0;
  CUDACHECK(hipGetDeviceCount(&device_count));
  if (world_size > device_count) {
    if (rank == 0)
      printf("Error: More ranks (%d) than available GPUs (%d)\n", world_size, device_count);
    MPI_Finalize();
    return -1;
  }

  CUDACHECK(hipSetDevice(rank));
  printf("Rank %d using device %d\n", rank, rank);

  if (rank == 0)
    NCCLCHECK(ncclGetUniqueId(&id));
  MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD);

  NCCLCHECK(ncclCommInitRank(&comm, world_size, id, rank));
  CUDACHECK(hipStreamCreate(&stream));

  int* d_send = nullptr;
  int* d_recv = nullptr;
  CUDACHECK(hipMalloc(&d_send, MSG_SIZE * sizeof(int)));
  CUDACHECK(hipMalloc(&d_recv, MSG_SIZE * sizeof(int)));
  CUDACHECK(hipMemset(d_recv, 0, MSG_SIZE * sizeof(int)));

  int h_send[MSG_SIZE];
  for (int i = 0; i < MSG_SIZE; ++i)
    h_send[i] = rank * 100 + i;
  CUDACHECK(hipMemcpy(d_send, h_send, MSG_SIZE * sizeof(int), hipMemcpyHostToDevice));

  int next = (rank + 1) % world_size;
  int prev = (rank - 1 + world_size) % world_size;

  printf("Rank %d sending to Rank %d: ", rank, next);
  for (int i = 0; i < MSG_SIZE; ++i) printf("%d ", h_send[i]);
  printf("\n");

  NCCLCHECK(ncclGroupStart());
  NCCLCHECK(ncclSend(d_send, MSG_SIZE, ncclInt, next, comm, stream));
  NCCLCHECK(ncclRecv(d_recv, MSG_SIZE, ncclInt, prev, comm, stream));
  NCCLCHECK(ncclGroupEnd());

  CUDACHECK(hipStreamSynchronize(stream));
  CUDACHECK(hipGetLastError());
  MPI_Barrier(MPI_COMM_WORLD); // 打印

  int h_recv[MSG_SIZE];
  CUDACHECK(hipMemcpy(h_recv, d_recv, MSG_SIZE * sizeof(int), hipMemcpyDeviceToHost));
  printf("Rank %d received from Rank %d: ", rank, prev);
  for (int i = 0; i < MSG_SIZE; ++i) printf("%d ", h_recv[i]);
  printf("\n");

  CUDACHECK(hipFree(d_send));
  CUDACHECK(hipFree(d_recv));
  CUDACHECK(hipStreamDestroy(stream));
  NCCLCHECK(ncclCommDestroy(comm));
  MPI_Finalize();

  return 0;
}