"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "ab122dace70041349b3404bc2556e8f1c6f16c73"
Commit 4927d10e authored by yuguo's avatar yuguo
Browse files

[DCU] fix

parent 2e870ed9
...@@ -117,13 +117,17 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl ...@@ -117,13 +117,17 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
NVTE_CHECK(false, "comm_cu_nums must be 4,8,16,32"); NVTE_CHECK(false, "comm_cu_nums must be 4,8,16,32");
} }
const char *NVTE_UB_COMM_CU_NUMS = std::getenv("NVTE_UB_COMM_CU_NUMS");
static cudaStream_t compute_streams[NVTE_COMM_OVERLAP_MAX_STREAMS]; static cudaStream_t compute_streams[NVTE_COMM_OVERLAP_MAX_STREAMS];
for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { for (int i = 0; i < std::min(num_max_streams, num_splits); i++) {
if (compute_streams[i] == nullptr) { if (compute_streams[i] == nullptr) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, _gemm_priority)); if (NVTE_UB_COMM_CU_NUMS != nullptr && NVTE_UB_COMM_CU_NUMS[0] != '\0') {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(hipExtStreamCreateWithCUMask(&compute_streams[i], cuMaskSize, cuMask)); NVTE_CHECK_CUDA(hipExtStreamCreateWithCUMask(&compute_streams[i], cuMaskSize, cuMask));
#endif #endif
} else {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, _gemm_priority));
}
} }
_stream_compute.push_back(compute_streams[i]); _stream_compute.push_back(compute_streams[i]);
} }
...@@ -359,14 +363,18 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType ...@@ -359,14 +363,18 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
NVTE_CHECK(false, "comm_cu_nums must be 4,8,16,32"); NVTE_CHECK(false, "comm_cu_nums must be 4,8,16,32");
} }
const char *NVTE_UB_COMM_CU_NUMS = std::getenv("NVTE_UB_COMM_CU_NUMS");
static cudaStream_t comm_stream; static cudaStream_t comm_stream;
if (comm_stream == nullptr) { if (comm_stream == nullptr) {
NVTE_CHECK_CUDA( if (NVTE_UB_COMM_CU_NUMS != nullptr && NVTE_UB_COMM_CU_NUMS[0] != '\0') {
cudaStreamCreateWithPriority(&comm_stream, cudaStreamNonBlocking, _comm_priority));
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA( NVTE_CHECK_CUDA(
hipExtStreamCreateWithCUMask(&comm_stream, cuMaskSize, cuMask)); hipExtStreamCreateWithCUMask(&comm_stream, cuMaskSize, cuMask));
#endif #endif
} else {
NVTE_CHECK_CUDA(
cudaStreamCreateWithPriority(&comm_stream, cudaStreamNonBlocking, _comm_priority));
}
} }
_stream_comm = comm_stream; _stream_comm = comm_stream;
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment