// hipcc sendrecv_compute_overlap.cu -o sendrecv_compute_overlap -I /opt/mpi/include -L /opt/mpi/lib/ -lmpi -L /opt/dtk/lib/ -lrccl // mpirun -np 8 --allow-run-as-root --oversubscribe --quiet ./sendrecv_compute_overlap #include "hip/hip_runtime.h" #include #include #include #include #include #define MSG_SIZE 1024 * 1024 // 1M int = 4MB #define CUDACHECK(cmd) do { \ hipError_t e = cmd; \ if (e != hipSuccess) { \ printf("CUDA error %s:%d: '%s'\n", __FILE__, __LINE__, hipGetErrorString(e)); \ exit(EXIT_FAILURE); \ } \ } while(0) #define NCCLCHECK(cmd) do { \ ncclResult_t r = cmd; \ if (r != ncclSuccess) { \ printf("NCCL error %s:%d: '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \ exit(EXIT_FAILURE); \ } \ } while(0) __global__ void compute_kernel(int* data, int size) { int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < size) { data[idx] = data[idx] * 2 + 1; } } int main(int argc, char* argv[]) { int rank, size; ncclUniqueId id; ncclComm_t comm; hipStream_t stream_comm, stream_comp; MPI_Init(&argc, &argv); MPI_Comm_rank(MPI_COMM_WORLD, &rank); MPI_Comm_size(MPI_COMM_WORLD, &size); if (size < 2) { if (rank == 0) printf("This demo needs at least 2 processes.\n"); MPI_Finalize(); return 0; } CUDACHECK(hipSetDevice(rank)); if (rank == 0) NCCLCHECK(ncclGetUniqueId(&id)); MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD); NCCLCHECK(ncclCommInitRank(&comm, size, id, rank)); CUDACHECK(hipStreamCreate(&stream_comm)); CUDACHECK(hipStreamCreate(&stream_comp)); int* d_buf; CUDACHECK(hipMalloc(&d_buf, MSG_SIZE * sizeof(int))); // 初始化数据 if (rank == 0) { int* h_data = new int[MSG_SIZE]; for (int i = 0; i < MSG_SIZE; ++i) h_data[i] = i; CUDACHECK(hipMemcpy(d_buf, h_data, MSG_SIZE * sizeof(int), hipMemcpyHostToDevice)); delete[] h_data; } // 创建 cuda events 计时 hipEvent_t start, stop, comp_done, comm_done; CUDACHECK(hipEventCreate(&start)); CUDACHECK(hipEventCreate(&stop)); CUDACHECK(hipEventCreate(&comp_done)); CUDACHECK(hipEventCreate(&comm_done)); CUDACHECK(hipEventRecord(start)); if (rank == 0) { // 启动计算(模拟 workload)+ send(通信) compute_kernel<<<(MSG_SIZE+255)/256, 256, 0, stream_comp>>>(d_buf, MSG_SIZE); CUDACHECK(hipEventRecord(comp_done, stream_comp)); // 2. 计算完成时记录 event CUDACHECK(hipStreamWaitEvent(stream_comm, comp_done, 0)); NCCLCHECK(ncclSend(d_buf, MSG_SIZE, ncclInt, 1, comm, stream_comm)); } else if (rank == 1) { // 启动接收 + 本地计算 NCCLCHECK(ncclRecv(d_buf, MSG_SIZE, ncclInt, 0, comm, stream_comm)); CUDACHECK(hipEventRecord(comm_done, stream_comm)); // 2. 通信完成时记录 event CUDACHECK(hipStreamWaitEvent(stream_comp, comm_done, 0)); compute_kernel<<<(MSG_SIZE+255)/256, 256, 0, stream_comp>>>(d_buf, MSG_SIZE); } // 同步所有 stream CUDACHECK(hipStreamSynchronize(stream_comm)); CUDACHECK(hipStreamSynchronize(stream_comp)); CUDACHECK(hipEventRecord(stop)); CUDACHECK(hipEventSynchronize(stop)); float ms; CUDACHECK(hipEventElapsedTime(&ms, start, stop)); printf("Rank %d: Total time = %.2f ms\n", rank, ms); // 查看部分数据内容 if (rank == 1) { int* h_out = new int[5]; CUDACHECK(hipMemcpy(h_out, d_buf, 5 * sizeof(int), hipMemcpyDeviceToHost)); printf("Rank 1: first 5 received and computed values: "); for (int i = 0; i < 5; ++i) printf("%d ", h_out[i]); printf("\n"); delete[] h_out; } // 清理 CUDACHECK(hipFree(d_buf)); CUDACHECK(hipStreamDestroy(stream_comm)); CUDACHECK(hipStreamDestroy(stream_comp)); NCCLCHECK(ncclCommDestroy(comm)); MPI_Finalize(); return 0; }