stream_manager.h 1.13 KB
Newer Older
zhanggzh's avatar
zhanggzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "hip/hip_runtime.h"
#ifndef CUDA_STREAM_MANAGER_H
#define CUDA_STREAM_MANAGER_H

#include "../hip/utils/helper_cuda.h"

#ifdef FMOE_USE_NCCL
#include <rccl/rccl.h>

#define NCCL_SAFE_CALL(__fn__) { \
    auto __res__ = __fn__; \
    if (__res__ != ncclSuccess) { \
        fprintf(stderr, "NCCL Error at %s:%d value %d\n", __FILE__, __LINE__, __res__); \
        exit(-1); \
    } \
}

#endif

class CudaStreamManager {
public:
    int device;
    hipblasHandle_t* handles;
    hipStream_t* streams;
    bool use_default;
#ifdef FMOE_USE_NCCL
    char ncclgood;
    ncclComm_t ncclcomm;
#endif

public:
    CudaStreamManager(int device_): device(device_), use_default(false) {
        this->setup(device);
    }

    void setup(int);
    void sync(int=0);
    void syncTorch();
    void destroy();

    hipStream_t torchStream();
    hipStream_t stream(size_t=0);
    hipblasHandle_t handle(size_t=0);

    ~CudaStreamManager() {
        this->destroy();
    }
}; 

CudaStreamManager* getCudaStreamManager(const int device);

#endif  // CUDA_STREAM_MANAGER