// hipcc ring_sendrecv.cu -o ring_sendrecv -I /opt/mpi/include -L /opt/mpi/lib/ -lmpi -L /opt/dtk/lib/ -lrccl // mpirun -np 8 --allow-run-as-root --oversubscribe --quiet ./ring_sendrecv #include #include #include #include #include #define MSG_SIZE 4 #define CUDACHECK(cmd) do { \ hipError_t e = cmd; \ if (e != hipSuccess) { \ printf("HIP error %s:%d: '%s'\n", __FILE__, __LINE__, hipGetErrorString(e)); \ exit(EXIT_FAILURE); \ } \ } while(0) #define NCCLCHECK(cmd) do { \ ncclResult_t r = cmd; \ if (r != ncclSuccess) { \ printf("NCCL error %s:%d: '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \ exit(EXIT_FAILURE); \ } \ } while(0) int main(int argc, char* argv[]) { int rank, world_size; ncclUniqueId id; ncclComm_t comm; hipStream_t stream; MPI_Init(&argc, &argv); MPI_Comm_rank(MPI_COMM_WORLD, &rank); MPI_Comm_size(MPI_COMM_WORLD, &world_size); int device_count = 0; CUDACHECK(hipGetDeviceCount(&device_count)); if (world_size > device_count) { if (rank == 0) printf("Error: More ranks (%d) than available GPUs (%d)\n", world_size, device_count); MPI_Finalize(); return -1; } CUDACHECK(hipSetDevice(rank)); printf("Rank %d using device %d\n", rank, rank); if (rank == 0) NCCLCHECK(ncclGetUniqueId(&id)); MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD); NCCLCHECK(ncclCommInitRank(&comm, world_size, id, rank)); CUDACHECK(hipStreamCreate(&stream)); int* d_send = nullptr; int* d_recv = nullptr; CUDACHECK(hipMalloc(&d_send, MSG_SIZE * sizeof(int))); CUDACHECK(hipMalloc(&d_recv, MSG_SIZE * sizeof(int))); CUDACHECK(hipMemset(d_recv, 0, MSG_SIZE * sizeof(int))); int h_send[MSG_SIZE]; for (int i = 0; i < MSG_SIZE; ++i) h_send[i] = rank * 100 + i; CUDACHECK(hipMemcpy(d_send, h_send, MSG_SIZE * sizeof(int), hipMemcpyHostToDevice)); int next = (rank + 1) % world_size; int prev = (rank - 1 + world_size) % world_size; printf("Rank %d sending to Rank %d: ", rank, next); for (int i = 0; i < MSG_SIZE; ++i) printf("%d ", h_send[i]); printf("\n"); NCCLCHECK(ncclGroupStart()); NCCLCHECK(ncclSend(d_send, MSG_SIZE, ncclInt, next, comm, stream)); NCCLCHECK(ncclRecv(d_recv, MSG_SIZE, ncclInt, prev, comm, stream)); NCCLCHECK(ncclGroupEnd()); CUDACHECK(hipStreamSynchronize(stream)); CUDACHECK(hipGetLastError()); MPI_Barrier(MPI_COMM_WORLD); // 打印 int h_recv[MSG_SIZE]; CUDACHECK(hipMemcpy(h_recv, d_recv, MSG_SIZE * sizeof(int), hipMemcpyDeviceToHost)); printf("Rank %d received from Rank %d: ", rank, prev); for (int i = 0; i < MSG_SIZE; ++i) printf("%d ", h_recv[i]); printf("\n"); CUDACHECK(hipFree(d_send)); CUDACHECK(hipFree(d_recv)); CUDACHECK(hipStreamDestroy(stream)); NCCLCHECK(ncclCommDestroy(comm)); MPI_Finalize(); return 0; }