/************************************************************************* * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #ifndef TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_ #define TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_ #include "multi_stream.h" #include #include #include #include #include "cuda_runtime.h" #include "logging.h" static inline int getIntEnv(const char *name, int defval, int minval) { int val = defval; const char* env = std::getenv(name); if (env != nullptr && env[0] != '\0') { val = atoi(env); if (val < minval) { val = minval; } } return val; } namespace transformer_engine::detail { cudaStream_t get_compute_stream(int idx) { const size_t num_streams = nvte_get_num_compute_streams(); NVTE_CHECK(0 <= idx && idx < num_streams, "Invalid compute stream (requested idx ", idx, ", but there are ", num_streams, " streams)"); static std::vector streams(num_streams); static std::once_flag stream_init_flag; auto init = [&]() { int comm_cu_nums = getIntEnv("TORCH_COMM_CU_NUMS", 8, 4); unsigned int cuMask[4]; unsigned int cuMaskSize = 4; if (comm_cu_nums == 4) { cuMask[0] = 0xfffffff0; cuMask[1] = 0xffffffff; cuMask[2] = 0xffffffff; cuMask[3] = 0xffffffff; } else if (comm_cu_nums == 8) { cuMask[0] = 0xffffff00; cuMask[1] = 0xffffffff; cuMask[2] = 0xffffffff; cuMask[3] = 0xffffffff; } else if (comm_cu_nums == 16) { cuMask[0] = 0xffff0000; cuMask[1] = 0xffffffff; cuMask[2] = 0xffffffff; cuMask[3] = 0xffffffff; } else if (comm_cu_nums == 32) { cuMask[0] = 0x00000000; cuMask[1] = 0xffffffff; cuMask[2] = 0xffffffff; cuMask[3] = 0xffffffff; } else { NVTE_CHECK(false, "comm_cu_nums must be 4,8,16,32"); } const char *TORCH_COMM_CU_NUMS = std::getenv("TORCH_COMM_CU_NUMS"); for (size_t i = 0; i < num_streams; i++) { #ifdef __HIP_PLATFORM_AMD__ if (TORCH_COMM_CU_NUMS != nullptr && TORCH_COMM_CU_NUMS[0] != '\0') { NVTE_CHECK_CUDA(hipExtStreamCreateWithCUMask(&streams[i], cuMaskSize, cuMask)); } else { NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&streams[i], cudaStreamNonBlocking, -1)); } #else NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&streams[i], cudaStreamNonBlocking, -1)); #endif } }; std::call_once(stream_init_flag, init); return streams[idx]; } cudaEvent_t get_compute_stream_event(int idx) { const size_t num_streams = nvte_get_num_compute_streams(); NVTE_CHECK(0 <= idx && idx < num_streams, "Invalid compute stream (requested idx ", idx, ", but there are ", num_streams, " streams)"); static std::vector events(num_streams); static std::once_flag event_init_flag; auto init = [&]() { for (size_t i = 0; i < num_streams; i++) { NVTE_CHECK_CUDA(cudaEventCreate(&events[i])); } }; std::call_once(event_init_flag, init); return events[idx]; } int get_num_compute_streams() { #ifdef __HIP_PLATFORM_AMD__ static constexpr int num_compute_streams = compute_num_streams; #else static constexpr int num_compute_streams = 4; #endif return num_compute_streams; } } // namespace transformer_engine::detail int nvte_get_num_compute_streams() { return transformer_engine::detail::get_num_compute_streams(); } #endif // TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_