stream_manager.cpp 2.31 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
#include <unordered_map>
#include <mutex>
#include <cassert>
#include <thread>
#include <iostream>
Rick Ho's avatar
Rick Ho committed
6
7
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
Rick Ho's avatar
Rick Ho committed
8

Rick Ho's avatar
Rick Ho committed
9
#include "fastermoe/status.h"
Rick Ho's avatar
Rick Ho committed
10
11
12
13
#include "stream_manager.h"

#define SMGR_N_STREAMS 16

Rick Ho's avatar
Rick Ho committed
14

Rick Ho's avatar
Rick Ho committed
15
cudaStream_t CudaStreamManager::stream(size_t idx) {
Rick Ho's avatar
Rick Ho committed
16
17
18
    if (this->use_default) {
        return c10::cuda::getCurrentCUDAStream().stream();
    }
Rick Ho's avatar
Rick Ho committed
19
20
21
22
    return this->streams[idx % SMGR_N_STREAMS];
}

cublasHandle_t CudaStreamManager::handle(size_t idx) {
Rick Ho's avatar
Rick Ho committed
23
24
25
    if (this->use_default) {
        return at::cuda::getCurrentCUDABlasHandle();
    }
Rick Ho's avatar
Rick Ho committed
26
27
28
29
30
    return this->handles[idx % SMGR_N_STREAMS];
}


void CudaStreamManager::sync(int idx) {
Rick Ho's avatar
Rick Ho committed
31
32
33
    if (this->use_default) {
        return;
    }
Rick Ho's avatar
Rick Ho committed
34
35
36
37
38
39
    for (int i = 0; i < idx && i < SMGR_N_STREAMS; ++i) {
        cudaStreamSynchronize(streams[i]);
    }
}

void CudaStreamManager::setup(const int device) {
Rick Ho's avatar
Rick Ho committed
40
#ifdef FMOE_USE_NCCL
Rick Ho's avatar
Rick Ho committed
41
42
43
44
45
46
47
    this->ncclgood = 0;
#endif
    this->device = device;
    checkCudaErrors(cudaSetDevice(device));
    streams = new cudaStream_t[SMGR_N_STREAMS];
    handles = new cublasHandle_t[SMGR_N_STREAMS];
    for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
48
49
50
51
52
        // SHOULD NOT USE: cudaStreamCreate(...)
        // more details in
        // https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html
        checkCudaErrors(cudaStreamCreateWithFlags(streams + i,
                        cudaStreamNonBlocking));
Rick Ho's avatar
Rick Ho committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        checkCudaErrors(cublasCreate(handles + i));
        cublasSetStream(handles[i], streams[i]);
    }
}

void CudaStreamManager::destroy() {
    for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
        checkCudaErrors(cudaStreamDestroy(streams[i]));
        checkCudaErrors(cublasDestroy(handles[i]));
    }
    delete[] streams;
    delete[] handles;
}

std::unordered_map<int, CudaStreamManager*> smgrs;
std::mutex smgr_mtx;

CudaStreamManager* getCudaStreamManager(const int device) {
    auto it = smgrs.find(device);
    if (it == smgrs.end()) {
        smgr_mtx.lock();
        it = smgrs.find(device);
        if (it == smgrs.end()) {
            auto smgr = new CudaStreamManager(device);
            smgrs.insert(std::pair<int, CudaStreamManager*>(device, smgr));
            smgr_mtx.unlock();
            return smgr;
        } else {
            smgr_mtx.unlock();
        }
    }
    return it->second;
}