// hipcc reducescatter.cu -o reducescatter -I /opt/mpi/include -L /opt/mpi/lib/ -lmpi -L /opt/dtk/lib/ -lrccl // mpirun -np 8 --allow-run-as-root --oversubscribe --quiet ./reducescatter #include #include #include #include #include #define CHUNK_SIZE 4 // 每个 rank 接收 CHUNK_SIZE 个元素 #define TOTAL_SIZE (CHUNK_SIZE * world_size) #define CUDACHECK(cmd) do { \ hipError_t e = cmd; \ if (e != hipSuccess) { \ printf("CUDA error %s:%d: '%s'\n", __FILE__, __LINE__, \ hipGetErrorString(e)); \ exit(EXIT_FAILURE); \ } \ } while(0) #define NCCLCHECK(cmd) do { \ ncclResult_t r = cmd; \ if (r != ncclSuccess) { \ printf("NCCL error %s:%d: '%s'\n", __FILE__, __LINE__, \ ncclGetErrorString(r)); \ exit(EXIT_FAILURE); \ } \ } while(0) int main(int argc, char* argv[]) { int rank, world_size; ncclComm_t comm; ncclUniqueId id; hipStream_t stream; MPI_Init(&argc, &argv); MPI_Comm_rank(MPI_COMM_WORLD, &rank); MPI_Comm_size(MPI_COMM_WORLD, &world_size); CUDACHECK(hipSetDevice(rank)); if (rank == 0) NCCLCHECK(ncclGetUniqueId(&id)); MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD); NCCLCHECK(ncclCommInitRank(&comm, world_size, id, rank)); CUDACHECK(hipStreamCreate(&stream)); // 每个 rank 分配 TOTAL_SIZE 的发送缓冲区(每 rank 发送全部数据) int* sendbuff; int* recvbuff; CUDACHECK(hipMalloc(&sendbuff, TOTAL_SIZE * sizeof(int))); CUDACHECK(hipMalloc(&recvbuff, CHUNK_SIZE * sizeof(int))); // 初始化发送数据(例如每 rank 的数据是 rank*100 + index) int* h_send = new int[TOTAL_SIZE]; for (int i = 0; i < TOTAL_SIZE; ++i) h_send[i] = rank * 100 + i; CUDACHECK(hipMemcpy(sendbuff, h_send, TOTAL_SIZE * sizeof(int), hipMemcpyHostToDevice)); // 打印发送数据 printf("Rank %d original data: ", rank); for (int i = 0; i < TOTAL_SIZE; ++i) printf("%d ", h_send[i]); printf("\n"); delete[] h_send; // 执行 reduce-scatter(sum) NCCLCHECK(ncclReduceScatter( sendbuff, recvbuff, CHUNK_SIZE, ncclInt, ncclSum, comm, stream)); CUDACHECK(hipStreamSynchronize(stream)); // 打印每个 rank 接收到的结果 int* h_recv = new int[CHUNK_SIZE]; CUDACHECK(hipMemcpy(h_recv, recvbuff, CHUNK_SIZE * sizeof(int), hipMemcpyDeviceToHost)); printf("Rank %d received reduced chunk: ", rank); for (int i = 0; i < CHUNK_SIZE; ++i) printf("%d ", h_recv[i]); printf("\n"); delete[] h_recv; // 清理资源 CUDACHECK(hipFree(sendbuff)); CUDACHECK(hipFree(recvbuff)); CUDACHECK(hipStreamDestroy(stream)); NCCLCHECK(ncclCommDestroy(comm)); MPI_Finalize(); return 0; }