cuda_stream_manager.cpp 892 Bytes
Newer Older
Rick Ho's avatar
Rick Ho committed
1
#include <cuda_runtime.h>
Rick Ho's avatar
Rick Ho committed
2
#include <cassert>
3
#include <thread>
Rick Ho's avatar
Rick Ho committed
4
5
6

#include "cuda_stream_manager.h"

7
8
9
10
11
void CudaStreamManager::sync(int i) {
	if (i > -1) {
		cudaStreamSynchronize(streams[i]);
		return;
	}
Rick Ho's avatar
Rick Ho committed
12
	for (size_t i = 0; i < this->num_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
13
14
15
		cudaStreamSynchronize(streams[i]);
	}
}
Rick Ho's avatar
Rick Ho committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

void CudaStreamManager::setup(const size_t num_expert, const int device) {
#ifdef MOE_DEBUG
	printf("setup at device %d\n", device);
#endif
	this->num_expert = num_expert;
	if (device == -1) {
        checkCudaErrors(cudaGetDevice(&this->device));
	} else {
		this->device = device;
	}
	checkCudaErrors(cudaSetDevice(this->device));
	streams = new cudaStream_t[num_expert];
	handles = new cublasHandle_t[num_expert];
	for (size_t i=0; i<num_expert; ++i) {
		checkCudaErrors(cudaStreamCreate(streams+i));
		checkCudaErrors(cublasCreate(handles + i));
		cublasSetStream(handles[i], streams[i]);
	}
}