Commit 2f11bd2e authored by yuguo's avatar yuguo
Browse files
parents 9d26d942 4927d10e
......@@ -117,13 +117,17 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
NVTE_CHECK(false, "comm_cu_nums must be 4,8,16,32");
}
const char *NVTE_UB_COMM_CU_NUMS = std::getenv("NVTE_UB_COMM_CU_NUMS");
static cudaStream_t compute_streams[NVTE_COMM_OVERLAP_MAX_STREAMS];
for (int i = 0; i < std::min(num_max_streams, num_splits); i++) {
if (compute_streams[i] == nullptr) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, _gemm_priority));
if (NVTE_UB_COMM_CU_NUMS != nullptr && NVTE_UB_COMM_CU_NUMS[0] != '\0') {
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(hipExtStreamCreateWithCUMask(&compute_streams[i], cuMaskSize, cuMask));
NVTE_CHECK_CUDA(hipExtStreamCreateWithCUMask(&compute_streams[i], cuMaskSize, cuMask));
#endif
} else {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, _gemm_priority));
}
}
_stream_compute.push_back(compute_streams[i]);
}
......@@ -359,14 +363,18 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
NVTE_CHECK(false, "comm_cu_nums must be 4,8,16,32");
}
const char *NVTE_UB_COMM_CU_NUMS = std::getenv("NVTE_UB_COMM_CU_NUMS");
static cudaStream_t comm_stream;
if (comm_stream == nullptr) {
NVTE_CHECK_CUDA(
cudaStreamCreateWithPriority(&comm_stream, cudaStreamNonBlocking, _comm_priority));
if (NVTE_UB_COMM_CU_NUMS != nullptr && NVTE_UB_COMM_CU_NUMS[0] != '\0') {
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(
hipExtStreamCreateWithCUMask(&comm_stream, cuMaskSize, cuMask));
NVTE_CHECK_CUDA(
hipExtStreamCreateWithCUMask(&comm_stream, cuMaskSize, cuMask));
#endif
} else {
NVTE_CHECK_CUDA(
cudaStreamCreateWithPriority(&comm_stream, cudaStreamNonBlocking, _comm_priority));
}
}
_stream_comm = comm_stream;
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment