cuda_stream_manager.h 1.01 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
#ifndef CUDA_STREAM_MANAGER_H
#define CUDA_STREAM_MANAGER_H

#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <helper_cuda.h> 

8
9
#include <cstdio>

Rick Ho's avatar
Rick Ho committed
10
11
12

class CudaStreamManager {
public:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
13
14
    CudaStreamManager(const size_t num_expert_, const int device_) : num_expert(num_expert_), device(device_) {
        checkCudaErrors(cudaSetDevice(device));
15
        printf("set device %d\n", device);
Rick Ho's avatar
Rick Ho committed
16
17
18
19
20
21
22
23
24
25
26
27
28
        streams = new cudaStream_t[num_expert];
        checkCudaErrors(cublasCreate(&handle));
        for (size_t i=0; i<num_expert; ++i) {
            checkCudaErrors(cudaStreamCreate(streams+i));
        }
    }
    ~CudaStreamManager() {
        for (size_t i=0; i<num_expert; ++i) {
            checkCudaErrors(cudaStreamDestroy(*(streams+i)));
        }
        checkCudaErrors(cublasDestroy(handle));
    }
    const size_t num_expert;
Jiezhong Qiu's avatar
Jiezhong Qiu committed
29
    const int device;
Rick Ho's avatar
Rick Ho committed
30
31
32
33
    cublasHandle_t handle;
    cudaStream_t* streams;
}; 

Jiezhong Qiu's avatar
Jiezhong Qiu committed
34
CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device);
Rick Ho's avatar
Rick Ho committed
35
36

#endif  // CUDA_STREAM_MANAGER