stream_manager.cpp 2.5 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
#include <unordered_map>
#include <mutex>
#include <cassert>
#include <thread>
#include <iostream>
Rick Ho's avatar
Rick Ho committed
6
7
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
Rick Ho's avatar
Rick Ho committed
8

Rick Ho's avatar
Rick Ho committed
9
#include "fastermoe/status.h"
Rick Ho's avatar
Rick Ho committed
10
11
12
13
#include "stream_manager.h"

#define SMGR_N_STREAMS 16

Rick Ho's avatar
Rick Ho committed
14

Rick Ho's avatar
Rick Ho committed
15
cudaStream_t CudaStreamManager::stream(size_t idx) {
Rick Ho's avatar
Rick Ho committed
16
17
18
    if (this->use_default) {
        return c10::cuda::getCurrentCUDAStream().stream();
    }
Rick Ho's avatar
Rick Ho committed
19
20
21
    return this->streams[idx % SMGR_N_STREAMS];
}

Rick Ho's avatar
Rick Ho committed
22
23
24
25
cudaStream_t CudaStreamManager::torchStream() {
    return c10::cuda::getCurrentCUDAStream().stream();
}

Rick Ho's avatar
Rick Ho committed
26
cublasHandle_t CudaStreamManager::handle(size_t idx) {
Rick Ho's avatar
Rick Ho committed
27
28
29
    if (this->use_default) {
        return at::cuda::getCurrentCUDABlasHandle();
    }
Rick Ho's avatar
Rick Ho committed
30
31
32
33
    return this->handles[idx % SMGR_N_STREAMS];
}


Rick Ho's avatar
Rick Ho committed
34
35
36
37
void CudaStreamManager::syncTorch() {
    cudaStreamSynchronize(this->torchStream());
}

Rick Ho's avatar
Rick Ho committed
38
void CudaStreamManager::sync(int idx) {
Rick Ho's avatar
Rick Ho committed
39
40
41
    if (this->use_default) {
        return;
    }
Rick Ho's avatar
Rick Ho committed
42
43
44
45
46
47
    for (int i = 0; i < idx && i < SMGR_N_STREAMS; ++i) {
        cudaStreamSynchronize(streams[i]);
    }
}

void CudaStreamManager::setup(const int device) {
Rick Ho's avatar
Rick Ho committed
48
#ifdef FMOE_USE_NCCL
Rick Ho's avatar
Rick Ho committed
49
50
51
52
53
54
55
    this->ncclgood = 0;
#endif
    this->device = device;
    checkCudaErrors(cudaSetDevice(device));
    streams = new cudaStream_t[SMGR_N_STREAMS];
    handles = new cublasHandle_t[SMGR_N_STREAMS];
    for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
56
57
58
59
60
        // SHOULD NOT USE: cudaStreamCreate(...)
        // more details in
        // https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html
        checkCudaErrors(cudaStreamCreateWithFlags(streams + i,
                        cudaStreamNonBlocking));
Rick Ho's avatar
Rick Ho committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        checkCudaErrors(cublasCreate(handles + i));
        cublasSetStream(handles[i], streams[i]);
    }
}

void CudaStreamManager::destroy() {
    for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
        checkCudaErrors(cudaStreamDestroy(streams[i]));
        checkCudaErrors(cublasDestroy(handles[i]));
    }
    delete[] streams;
    delete[] handles;
}

std::unordered_map<int, CudaStreamManager*> smgrs;
std::mutex smgr_mtx;

CudaStreamManager* getCudaStreamManager(const int device) {
    auto it = smgrs.find(device);
    if (it == smgrs.end()) {
        smgr_mtx.lock();
        it = smgrs.find(device);
        if (it == smgrs.end()) {
            auto smgr = new CudaStreamManager(device);
            smgrs.insert(std::pair<int, CudaStreamManager*>(device, smgr));
            smgr_mtx.unlock();
            return smgr;
        } else {
            smgr_mtx.unlock();
        }
    }
    return it->second;
}