#ifndef CUDA_STREAM_MANAGER_H #define CUDA_STREAM_MANAGER_H #include #include #include #include class CudaStreamManager { public: CudaStreamManager(const size_t num_expert_, const int device_) : num_expert(num_expert_), device(device_) { checkCudaErrors(cudaSetDevice(device)); printf("set device %d\n", device); streams = new cudaStream_t[num_expert]; checkCudaErrors(cublasCreate(&handle)); for (size_t i=0; i