Commit ffbef335 authored by yuguo's avatar yuguo
Browse files

[DCU] fix TORCH_COMM_CU_NUMS conflicts in 2.5

parent 067c2b3d
......@@ -27,8 +27,43 @@ cudaStream_t get_compute_stream(int idx) {
static std::vector<cudaStream_t> streams(num_streams);
static std::once_flag stream_init_flag;
auto init = [&]() {
int comm_cu_nums = getIntEnv("TORCH_COMM_CU_NUMS", 8, 4);
unsigned int cuMask[4];
unsigned int cuMaskSize = 4;
if (comm_cu_nums == 4) {
cuMask[0] = 0xfffffff0;
cuMask[1] = 0xffffffff;
cuMask[2] = 0xffffffff;
cuMask[3] = 0xffffffff;
} else if (comm_cu_nums == 8) {
cuMask[0] = 0xffffff00;
cuMask[1] = 0xffffffff;
cuMask[2] = 0xffffffff;
cuMask[3] = 0xffffffff;
} else if (comm_cu_nums == 16) {
cuMask[0] = 0xffff0000;
cuMask[1] = 0xffffffff;
cuMask[2] = 0xffffffff;
cuMask[3] = 0xffffffff;
} else if (comm_cu_nums == 32) {
cuMask[0] = 0x00000000;
cuMask[1] = 0xffffffff;
cuMask[2] = 0xffffffff;
cuMask[3] = 0xffffffff;
} else {
NVTE_CHECK(false, "comm_cu_nums must be 4,8,16,32");
}
const char *TORCH_COMM_CU_NUMS = std::getenv("TORCH_COMM_CU_NUMS");
for (size_t i = 0; i < num_streams; i++) {
#ifdef __HIP_PLATFORM_AMD__
if (TORCH_COMM_CU_NUMS != nullptr && TORCH_COMM_CU_NUMS[0] != '\0') {
NVTE_CHECK_CUDA(hipExtStreamCreateWithCUMask(&streams[i], cuMaskSize, cuMask));
} else {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&streams[i], cudaStreamNonBlocking, -1));
}
#else
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&streams[i], cudaStreamNonBlocking, -1));
#endif
}
};
std::call_once(stream_init_flag, init);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment