cuda_stream_manager.h 1.12 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
#ifndef CUDA_STREAM_MANAGER_H
#define CUDA_STREAM_MANAGER_H

#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <helper_cuda.h> 


Rick Ho's avatar
Rick Ho committed
9
10
11
#define MAX_STREAMS 16


Rick Ho's avatar
Rick Ho committed
12
13
14
15
16
struct CudaStreamManager {
    const size_t num_expert;
    cublasHandle_t* handles;
    cudaStream_t* streams;

Rick Ho's avatar
Rick Ho committed
17
    CudaStreamManager(const size_t num_expert_) : num_expert(num_expert_) {
Rick Ho's avatar
Rick Ho committed
18
19
20
        streams = new cudaStream_t[MAX_STREAMS];
		handles = new cublasHandle_t[MAX_STREAMS];
        for (size_t i=0; i<MAX_STREAMS; ++i) {
Rick Ho's avatar
Rick Ho committed
21
22
23
24
			checkCudaErrors(cublasCreate(handles + i));
			checkCudaErrors(cudaStreamCreate(streams + i));
			checkCudaErrors(cublasSetStream(handles[i], streams[i]));
		}
Rick Ho's avatar
Rick Ho committed
25
    }
Rick Ho's avatar
Rick Ho committed
26

Rick Ho's avatar
Rick Ho committed
27
    ~CudaStreamManager() {
Rick Ho's avatar
Rick Ho committed
28
        for (size_t i=0; i<MAX_STREAMS; ++i) {
Rick Ho's avatar
Rick Ho committed
29
30
31
            checkCudaErrors(cudaStreamDestroy(streams[i]));
			checkCudaErrors(cublasDestroy(handles[i]));
		}
Rick Ho's avatar
Rick Ho committed
32
    }
Rick Ho's avatar
Rick Ho committed
33
34
35
36
37
38
39
40

	inline cudaStream_t& getStream(int idx) {
		return streams[idx % MAX_STREAMS];
	}
	inline cublasHandle_t& getHandle(int idx) {
		return handles[idx % MAX_STREAMS];
	}

41
	void sync(int=-1);
Rick Ho's avatar
Rick Ho committed
42
43
44
45
46
}; 

CudaStreamManager* getCudaStreamManager(const size_t num_expert);

#endif  // CUDA_STREAM_MANAGER