"vscode:/vscode.git/clone" did not exist on "101ac9a760237aa3aa541f278e43d94b7faf7dd9"
cuda_stream_manager.cpp 2.33 KB
Newer Older
1
2
#include <unordered_map>
#include <mutex>
Rick Ho's avatar
Rick Ho committed
3
#include <cassert>
4
#include <thread>
Rick Ho's avatar
Rick Ho committed
5
6
7
8
9
#include <iostream>

#ifdef MOE_USE_NCCL
#include <c10d/ProcessGroupNCCL.hpp>
#endif  // MOE_USE_NCCL
Rick Ho's avatar
Rick Ho committed
10
11

#include "cuda_stream_manager.h"
12
13
#include <helper_cuda.h> 

Rick Ho's avatar
Rick Ho committed
14
#define SMGR_N_STREAMS 16
Rick Ho's avatar
Rick Ho committed
15

Rick Ho's avatar
Rick Ho committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#ifdef MOE_USE_NCCL
class HackNCCLGroup: public c10d::ProcessGroupNCCL {
public:
	ncclComm_t getcomm(at::Device dev) {
		auto key = std::to_string(dev.index());
		auto v = getNCCLComm(key, {dev});
		if (v.size() == 0) {
			std::cerr << "PyTorch has nothing\n";
			return 0;
		}
		return v[0]->getNcclComm();
	}
};

void CudaStreamManager::ensure(void* torchp, at::Device dev) {
	if (this->ncclgood) {
		return;
	}
	HackNCCLGroup* h = (HackNCCLGroup*)torchp;
	this->ncclcomm = h->getcomm(dev);
	if (this->ncclcomm != 0) {
		this->ncclgood = 1;
	} else {
		std::cerr << "Nccl initialization failed\n";
	}
}
#endif  // MOE_USE_NCCL

Rick Ho's avatar
Rick Ho committed
44
cudaStream_t CudaStreamManager::stream(size_t idx) {
45
	return this->streams[idx % SMGR_N_STREAMS];
Rick Ho's avatar
Rick Ho committed
46
47
}

48
49
50
51
52
53
54
cublasHandle_t CudaStreamManager::handle(size_t idx) {
	return this->handles[idx % SMGR_N_STREAMS];
}


void CudaStreamManager::sync(int idx) {
	for (int i = 0; i < idx && i < SMGR_N_STREAMS; ++i) {
Rick Ho's avatar
Rick Ho committed
55
56
57
		cudaStreamSynchronize(streams[i]);
	}
}
Rick Ho's avatar
Rick Ho committed
58

59
void CudaStreamManager::setup(const int device) {
Rick Ho's avatar
Rick Ho committed
60
61
62
63
#ifdef MOE_USE_NCCL
	this->ncclgood = 0;
#endif
	this->device = device;
64
65
66
67
	checkCudaErrors(cudaSetDevice(device));
	streams = new cudaStream_t[SMGR_N_STREAMS];
	handles = new cublasHandle_t[SMGR_N_STREAMS];
	for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
Rick Ho's avatar
Rick Ho committed
68
		checkCudaErrors(cudaStreamCreate(streams + i));
Rick Ho's avatar
Rick Ho committed
69
70
71
72
73
		checkCudaErrors(cublasCreate(handles + i));
		cublasSetStream(handles[i], streams[i]);
	}
}

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
void CudaStreamManager::destroy() {
	for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
		checkCudaErrors(cudaStreamDestroy(streams[i]));
		checkCudaErrors(cublasDestroy(handles[i]));
	}
	delete[] streams;
	delete[] handles;
}

std::unordered_map<int, CudaStreamManager*> smgrs;
std::mutex smgr_mtx;

CudaStreamManager* getCudaStreamManager(const int device) {
	auto it = smgrs.find(device);
	if (it == smgrs.end()) {
		smgr_mtx.lock();
		it = smgrs.find(device);
		if (it == smgrs.end()) {
			auto smgr = new CudaStreamManager(device);
			smgrs.insert(std::pair<int, CudaStreamManager*>(device, smgr));
			smgr_mtx.unlock();
			return smgr;
		} else {
			smgr_mtx.unlock();
		}
	}
	return it->second;
}