"vscode:/vscode.git/clone" did not exist on "589de109e06de0cf91fdbf6cc62c92298ffc4cee"
Commit 2565f2fa authored by Rick Ho's avatar Rick Ho
Browse files

stream manager fixed

parent b8a212ef
/* TODO: make it ke xue
#include <cuda_runtime.h>
#include <cassert>
#include <thread>
#include "cuda_stream_manager.h"
thread_local CudaStreamManager smgr;
CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device) {
if (!smgr) {
smgr = new CudaStreamManager(num_expert, device);
}
<<<<<<< HEAD
return smgr;
}
void CudaStreamManager::sync(int i) {
if (i > -1) {
cudaStreamSynchronize(streams[i]);
......@@ -25,5 +13,3 @@ void CudaStreamManager::sync(int i) {
cudaStreamSynchronize(streams[i]);
}
}
}
*/
......@@ -9,6 +9,12 @@
class CudaStreamManager {
public:
size_t num_expert;
int device;
cublasHandle_t* handles;
cudaStream_t* streams;
public:
CudaStreamManager() : num_expert(0), device(0), streams(NULL) {
int current_device;
......@@ -26,9 +32,11 @@ public:
this->device = device;
checkCudaErrors(cudaSetDevice(device));
streams = new cudaStream_t[num_expert];
checkCudaErrors(cublasCreate(&handle));
handles = new cublasHandle_t[num_expert];
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamCreate(streams+i));
checkCudaErrors(cublasCreate(handles + i));
cublasSetStream(handles[i], streams[i]);
}
}
......@@ -38,14 +46,12 @@ public:
#endif
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamDestroy(*(streams+i)));
checkCudaErrors(cublasDestroy(handles[i]));
}
checkCudaErrors(cublasDestroy(handle));
delete[] streams;
}
size_t num_expert;
int device;
cublasHandle_t handle;
cudaStream_t* streams;
void sync(int=-1);
};
// CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device);
......
......@@ -147,7 +147,7 @@ void moe_cuda_forward_impl(
in_feat);
#endif
// Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors(cublasXgemm(smgr.handle, // h->getHandle(i),
checkCudaErrors(cublasXgemm(smgr.handles[0], // h->getHandle(i),
CUBLAS_OP_T,
CUBLAS_OP_N,
out_feat, expert_count[i], in_feat,
......@@ -204,7 +204,7 @@ void moe_cuda_grad_weight(
checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
for (size_t i=0; i<batch_size; ++i) {
// checkCudaErrors(cublasSetStream);
checkCudaErrors(cublasXgemm(smgr.handle,
checkCudaErrors(cublasXgemm(smgr.handles[0],
CUBLAS_OP_N,
CUBLAS_OP_T,
out_feat,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment