cuda_stream_manager.h 974 Bytes
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
#ifndef CUDA_STREAM_MANAGER_H
#define CUDA_STREAM_MANAGER_H

#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <helper_cuda.h> 


class CudaStreamManager {
public:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
11
12
    CudaStreamManager(const size_t num_expert_, const int device_) : num_expert(num_expert_), device(device_) {
        checkCudaErrors(cudaSetDevice(device));
Rick Ho's avatar
Rick Ho committed
13
14
15
16
17
18
19
20
21
22
23
24
25
        streams = new cudaStream_t[num_expert];
        checkCudaErrors(cublasCreate(&handle));
        for (size_t i=0; i<num_expert; ++i) {
            checkCudaErrors(cudaStreamCreate(streams+i));
        }
    }
    ~CudaStreamManager() {
        for (size_t i=0; i<num_expert; ++i) {
            checkCudaErrors(cudaStreamDestroy(*(streams+i)));
        }
        checkCudaErrors(cublasDestroy(handle));
    }
    const size_t num_expert;
Jiezhong Qiu's avatar
Jiezhong Qiu committed
26
    const int device;
Rick Ho's avatar
Rick Ho committed
27
28
29
30
    cublasHandle_t handle;
    cudaStream_t* streams;
}; 

Jiezhong Qiu's avatar
Jiezhong Qiu committed
31
CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device);
Rick Ho's avatar
Rick Ho committed
32
33

#endif  // CUDA_STREAM_MANAGER