// !!! This is a file automatically generated by hipify!!! #include #include "hip/hip_runtime.h" #include #include #include #include #include #include #include #include "../hip/fastermoe/status.h" #include "../hip/stream_manager.h" #define SMGR_N_STREAMS 16 hipStream_t CudaStreamManager::stream(size_t idx) { if (this->use_default) { return c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); } return this->streams[idx % SMGR_N_STREAMS]; } hipStream_t CudaStreamManager::torchStream() { return c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); } hipblasHandle_t CudaStreamManager::handle(size_t idx) { if (this->use_default) { return at::cuda::getCurrentCUDABlasHandle(); } return this->handles[idx % SMGR_N_STREAMS]; } void CudaStreamManager::syncTorch() { hipStreamSynchronize(this->torchStream()); } void CudaStreamManager::sync(int idx) { if (this->use_default) { return; } for (int i = 0; i < idx && i < SMGR_N_STREAMS; ++i) { hipStreamSynchronize(streams[i]); } } void CudaStreamManager::setup(const int device) { #ifdef FMOE_USE_NCCL this->ncclgood = 0; #endif this->device = device; checkCudaErrors(hipSetDevice(device)); streams = new hipStream_t[SMGR_N_STREAMS]; handles = new hipblasHandle_t[SMGR_N_STREAMS]; for (size_t i = 0; i < SMGR_N_STREAMS; ++i) { // SHOULD NOT USE: hipStreamCreate(...) // more details in // https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html checkCudaErrors(hipStreamCreateWithFlags(streams + i, hipStreamNonBlocking)); checkCudaErrors(hipblasCreate(handles + i)); hipblasSetStream(handles[i], streams[i]); } } void CudaStreamManager::destroy() { for (size_t i = 0; i < SMGR_N_STREAMS; ++i) { checkCudaErrors(hipStreamDestroy(streams[i])); checkCudaErrors(hipblasDestroy(handles[i])); } delete[] streams; delete[] handles; } std::unordered_map smgrs; std::mutex smgr_mtx; CudaStreamManager* getCudaStreamManager(const int device) { auto it = smgrs.find(device); if (it == smgrs.end()) { smgr_mtx.lock(); it = smgrs.find(device); if (it == smgrs.end()) { auto smgr = new CudaStreamManager(device); smgrs.insert(std::pair(device, smgr)); smgr_mtx.unlock(); return smgr; } else { smgr_mtx.unlock(); } } return it->second; }