cuda_stream_manager.h 1.41 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
#ifndef CUDA_STREAM_MANAGER_H
#define CUDA_STREAM_MANAGER_H

#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <helper_cuda.h> 

8
9
#include <cstdio>

Rick Ho's avatar
Rick Ho committed
10
11
12

class CudaStreamManager {
public:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
13
    CudaStreamManager(const size_t num_expert_, const int device_) : num_expert(num_expert_), device(device_) {
14
15
16
17
18
        /* 
        Actually, we will see current_device == device,  
        which means pytorch always sets the correct device for us.
        But for safety, we still manually set device to the desired one.
        */
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
19
        /*
20
21
22
        int current_device;
        checkCudaErrors(cudaGetDevice(&current_device));
        printf("CudaStreamManager construnctor called, get device %d, set device %d\n", current_device, device);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
23
        */
Jiezhong Qiu's avatar
Jiezhong Qiu committed
24
        checkCudaErrors(cudaSetDevice(device));
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
25
        
Rick Ho's avatar
Rick Ho committed
26
27
28
29
30
31
32
33
34
35
36
37
38
        streams = new cudaStream_t[num_expert];
        checkCudaErrors(cublasCreate(&handle));
        for (size_t i=0; i<num_expert; ++i) {
            checkCudaErrors(cudaStreamCreate(streams+i));
        }
    }
    ~CudaStreamManager() {
        for (size_t i=0; i<num_expert; ++i) {
            checkCudaErrors(cudaStreamDestroy(*(streams+i)));
        }
        checkCudaErrors(cublasDestroy(handle));
    }
    const size_t num_expert;
Jiezhong Qiu's avatar
Jiezhong Qiu committed
39
    const int device;
Rick Ho's avatar
Rick Ho committed
40
41
42
43
    cublasHandle_t handle;
    cudaStream_t* streams;
}; 

Jiezhong Qiu's avatar
Jiezhong Qiu committed
44
CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device);
Rick Ho's avatar
Rick Ho committed
45
46

#endif  // CUDA_STREAM_MANAGER