cuda_stream_manager.h 1.51 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
#ifndef CUDA_STREAM_MANAGER_H
#define CUDA_STREAM_MANAGER_H

#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <helper_cuda.h> 

8
9
#include <cstdio>

Rick Ho's avatar
Rick Ho committed
10
11

class CudaStreamManager {
Rick Ho's avatar
Rick Ho committed
12
13
14
15
16
17
public:
    size_t num_expert;
    int device;
    cublasHandle_t* handles;
    cudaStream_t* streams;

Rick Ho's avatar
Rick Ho committed
18
public:
19
    CudaStreamManager() : num_expert(0), device(0), streams(NULL) {
20
21
        int current_device;
        checkCudaErrors(cudaGetDevice(&current_device));
22
23
24
25
26
27
28
29
30
31
32
33
#ifdef MOE_DEBUG
        printf("constructor at device %d\n", current_device);
#endif
    }

    void setup(const size_t num_expert, const int device) {
#ifdef MOE_DEBUG
        printf("setup at device %d\n", device);
#endif
        this->num_expert = num_expert;
        this->device = device;
        checkCudaErrors(cudaSetDevice(device));        
Rick Ho's avatar
Rick Ho committed
34
        streams = new cudaStream_t[num_expert];
Rick Ho's avatar
Rick Ho committed
35
        handles = new cublasHandle_t[num_expert];
Rick Ho's avatar
Rick Ho committed
36
37
        for (size_t i=0; i<num_expert; ++i) {
            checkCudaErrors(cudaStreamCreate(streams+i));
Rick Ho's avatar
Rick Ho committed
38
39
40
			checkCudaErrors(cublasCreate(handles + i));
			cublasSetStream(handles[i], streams[i]);
		}
Rick Ho's avatar
Rick Ho committed
41
    }
42

Rick Ho's avatar
Rick Ho committed
43
    ~CudaStreamManager() {
44
45
46
#ifdef MOE_DEBUG
        printf("destructor at device %d\n", device);
#endif
Rick Ho's avatar
Rick Ho committed
47
48
        for (size_t i=0; i<num_expert; ++i) {
            checkCudaErrors(cudaStreamDestroy(*(streams+i)));
Rick Ho's avatar
Rick Ho committed
49
50
			checkCudaErrors(cublasDestroy(handles[i]));
		}
Jiezhong Qiu's avatar
Jiezhong Qiu committed
51
        delete[] streams;
Rick Ho's avatar
Rick Ho committed
52
    }
Rick Ho's avatar
Rick Ho committed
53
54

	void sync(int=-1);
Rick Ho's avatar
Rick Ho committed
55
56
}; 

57
// CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device);
Rick Ho's avatar
Rick Ho committed
58
59

#endif  // CUDA_STREAM_MANAGER