cuda_stream_manager.h 929 Bytes
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
#ifndef CUDA_STREAM_MANAGER_H
#define CUDA_STREAM_MANAGER_H

#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <helper_cuda.h> 


Rick Ho's avatar
Rick Ho committed
9
10
11
12
13
struct CudaStreamManager {
    const size_t num_expert;
    cublasHandle_t* handles;
    cudaStream_t* streams;

Rick Ho's avatar
Rick Ho committed
14
15
    CudaStreamManager(const size_t num_expert_) : num_expert(num_expert_) {
        streams = new cudaStream_t[num_expert];
Rick Ho's avatar
Rick Ho committed
16
		handles = new cublasHandle_t[num_expert];
Rick Ho's avatar
Rick Ho committed
17
        for (size_t i=0; i<num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
18
19
20
21
			checkCudaErrors(cublasCreate(handles + i));
			checkCudaErrors(cudaStreamCreate(streams + i));
			checkCudaErrors(cublasSetStream(handles[i], streams[i]));
		}
Rick Ho's avatar
Rick Ho committed
22
    }
Rick Ho's avatar
Rick Ho committed
23

Rick Ho's avatar
Rick Ho committed
24
25
    ~CudaStreamManager() {
        for (size_t i=0; i<num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
26
27
28
            checkCudaErrors(cudaStreamDestroy(streams[i]));
			checkCudaErrors(cublasDestroy(handles[i]));
		}
Rick Ho's avatar
Rick Ho committed
29
30
31
32
33
34
    }
}; 

CudaStreamManager* getCudaStreamManager(const size_t num_expert);

#endif  // CUDA_STREAM_MANAGER