// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "hip/hip_runtime.h"
#include <unordered_map>
#include <mutex>
#include <cassert>
#include <thread>
#include <iostream>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <ATen/hip/HIPContext.h>

#include "../hip/fastermoe/status.h"
#include "../hip/stream_manager.h"

#define SMGR_N_STREAMS 16


hipStream_t CudaStreamManager::stream(size_t idx) {
    if (this->use_default) {
        return c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
    }
    return this->streams[idx % SMGR_N_STREAMS];
}

hipStream_t CudaStreamManager::torchStream() {
    return c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
}

hipblasHandle_t CudaStreamManager::handle(size_t idx) {
    if (this->use_default) {
        return at::cuda::getCurrentCUDABlasHandle();
    }
    return this->handles[idx % SMGR_N_STREAMS];
}


void CudaStreamManager::syncTorch() {
    hipStreamSynchronize(this->torchStream());
}

void CudaStreamManager::sync(int idx) {
    if (this->use_default) {
        return;
    }
    for (int i = 0; i < idx && i < SMGR_N_STREAMS; ++i) {
        hipStreamSynchronize(streams[i]);
    }
}

void CudaStreamManager::setup(const int device) {
#ifdef FMOE_USE_NCCL
    this->ncclgood = 0;
#endif
    this->device = device;
    checkCudaErrors(hipSetDevice(device));
    streams = new hipStream_t[SMGR_N_STREAMS];
    handles = new hipblasHandle_t[SMGR_N_STREAMS];
    for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
        // SHOULD NOT USE: hipStreamCreate(...)
        // more details in
        // https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html
        checkCudaErrors(hipStreamCreateWithFlags(streams + i,
                        hipStreamNonBlocking));
        checkCudaErrors(hipblasCreate(handles + i));
        hipblasSetStream(handles[i], streams[i]);
    }
}

void CudaStreamManager::destroy() {
    for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
        checkCudaErrors(hipStreamDestroy(streams[i]));
        checkCudaErrors(hipblasDestroy(handles[i]));
    }
    delete[] streams;
    delete[] handles;
}

std::unordered_map<int, CudaStreamManager*> smgrs;
std::mutex smgr_mtx;

CudaStreamManager* getCudaStreamManager(const int device) {
    auto it = smgrs.find(device);
    if (it == smgrs.end()) {
        smgr_mtx.lock();
        it = smgrs.find(device);
        if (it == smgrs.end()) {
            auto smgr = new CudaStreamManager(device);
            smgrs.insert(std::pair<int, CudaStreamManager*>(device, smgr));
            smgr_mtx.unlock();
            return smgr;
        } else {
            smgr_mtx.unlock();
        }
    }
    return it->second;
}

