cuda_stream_manager.cpp 1.01 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
#include <cuda_runtime.h>
Rick Ho's avatar
Rick Ho committed
2
#include <cassert>
3
#include <thread>
Rick Ho's avatar
Rick Ho committed
4
5
6

#include "cuda_stream_manager.h"

Rick Ho's avatar
Rick Ho committed
7
8
9
10
11
12
13
cudaStream_t CudaStreamManager::stream(size_t idx) {
	if (num_expert <= idx) {
		this->setup(idx + 1);
	}
	return this->streams[idx];
}

14
15
16
17
18
void CudaStreamManager::sync(int i) {
	if (i > -1) {
		cudaStreamSynchronize(streams[i]);
		return;
	}
Rick Ho's avatar
Rick Ho committed
19
	for (size_t i = 0; i < this->num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
20
21
22
		cudaStreamSynchronize(streams[i]);
	}
}
Rick Ho's avatar
Rick Ho committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37

void CudaStreamManager::setup(const size_t num_expert, const int device) {
#ifdef MOE_DEBUG
	printf("setup at device %d\n", device);
#endif
	this->num_expert = num_expert;
	if (device == -1) {
        checkCudaErrors(cudaGetDevice(&this->device));
	} else {
		this->device = device;
	}
	checkCudaErrors(cudaSetDevice(this->device));
	streams = new cudaStream_t[num_expert];
	handles = new cublasHandle_t[num_expert];
	for (size_t i=0; i<num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
38
		checkCudaErrors(cudaStreamCreate(streams + i));
Rick Ho's avatar
Rick Ho committed
39
40
41
42
43
		checkCudaErrors(cublasCreate(handles + i));
		cublasSetStream(handles[i], streams[i]);
	}
}