Commit 2565f2fa authored by Rick Ho's avatar Rick Ho
Browse files

stream manager fixed

parent b8a212ef
/* TODO: make it ke xue
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cassert> #include <cassert>
#include <thread> #include <thread>
#include "cuda_stream_manager.h" #include "cuda_stream_manager.h"
thread_local CudaStreamManager smgr;
CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device) {
if (!smgr) {
smgr = new CudaStreamManager(num_expert, device);
}
<<<<<<< HEAD
return smgr;
}
void CudaStreamManager::sync(int i) { void CudaStreamManager::sync(int i) {
if (i > -1) { if (i > -1) {
cudaStreamSynchronize(streams[i]); cudaStreamSynchronize(streams[i]);
...@@ -25,5 +13,3 @@ void CudaStreamManager::sync(int i) { ...@@ -25,5 +13,3 @@ void CudaStreamManager::sync(int i) {
cudaStreamSynchronize(streams[i]); cudaStreamSynchronize(streams[i]);
} }
} }
}
*/
...@@ -9,6 +9,12 @@ ...@@ -9,6 +9,12 @@
class CudaStreamManager { class CudaStreamManager {
public:
size_t num_expert;
int device;
cublasHandle_t* handles;
cudaStream_t* streams;
public: public:
CudaStreamManager() : num_expert(0), device(0), streams(NULL) { CudaStreamManager() : num_expert(0), device(0), streams(NULL) {
int current_device; int current_device;
...@@ -26,10 +32,12 @@ public: ...@@ -26,10 +32,12 @@ public:
this->device = device; this->device = device;
checkCudaErrors(cudaSetDevice(device)); checkCudaErrors(cudaSetDevice(device));
streams = new cudaStream_t[num_expert]; streams = new cudaStream_t[num_expert];
checkCudaErrors(cublasCreate(&handle)); handles = new cublasHandle_t[num_expert];
for (size_t i=0; i<num_expert; ++i) { for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamCreate(streams+i)); checkCudaErrors(cudaStreamCreate(streams+i));
} checkCudaErrors(cublasCreate(handles + i));
cublasSetStream(handles[i], streams[i]);
}
} }
~CudaStreamManager() { ~CudaStreamManager() {
...@@ -38,14 +46,12 @@ public: ...@@ -38,14 +46,12 @@ public:
#endif #endif
for (size_t i=0; i<num_expert; ++i) { for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamDestroy(*(streams+i))); checkCudaErrors(cudaStreamDestroy(*(streams+i)));
} checkCudaErrors(cublasDestroy(handles[i]));
checkCudaErrors(cublasDestroy(handle)); }
delete[] streams; delete[] streams;
} }
size_t num_expert;
int device; void sync(int=-1);
cublasHandle_t handle;
cudaStream_t* streams;
}; };
// CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device); // CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device);
......
...@@ -147,7 +147,7 @@ void moe_cuda_forward_impl( ...@@ -147,7 +147,7 @@ void moe_cuda_forward_impl(
in_feat); in_feat);
#endif #endif
// Use T(B) x T(A) = T(C) to produce row-major C // Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors(cublasXgemm(smgr.handle, // h->getHandle(i), checkCudaErrors(cublasXgemm(smgr.handles[0], // h->getHandle(i),
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
out_feat, expert_count[i], in_feat, out_feat, expert_count[i], in_feat,
...@@ -204,7 +204,7 @@ void moe_cuda_grad_weight( ...@@ -204,7 +204,7 @@ void moe_cuda_grad_weight(
checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost)); checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
for (size_t i=0; i<batch_size; ++i) { for (size_t i=0; i<batch_size; ++i) {
// checkCudaErrors(cublasSetStream); // checkCudaErrors(cublasSetStream);
checkCudaErrors(cublasXgemm(smgr.handle, checkCudaErrors(cublasXgemm(smgr.handles[0],
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
out_feat, out_feat,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment