#ifndef CUDA_STREAM_MANAGER_H #define CUDA_STREAM_MANAGER_H #include #include #include #include class CudaStreamManager { public: size_t num_expert; int device; cublasHandle_t* handles; cudaStream_t* streams; public: CudaStreamManager() : num_expert(0), streams(NULL) { int current_device; checkCudaErrors(cudaGetDevice(¤t_device)); #ifdef MOE_DEBUG printf("constructor at device %d\n", current_device); #endif } void setup(const size_t num_expert, const int device=-1); cudaStream_t stream(size_t=0); ~CudaStreamManager() { #ifdef MOE_DEBUG printf("destructor at device %d\n", device); #endif for (size_t i=0; i