Commit bf60a846 authored by Rick Ho's avatar Rick Ho
Browse files

limit max streams

parent ab153b37
...@@ -6,15 +6,18 @@ ...@@ -6,15 +6,18 @@
#include <helper_cuda.h> #include <helper_cuda.h>
#define MAX_STREAMS 16
struct CudaStreamManager { struct CudaStreamManager {
const size_t num_expert; const size_t num_expert;
cublasHandle_t* handles; cublasHandle_t* handles;
cudaStream_t* streams; cudaStream_t* streams;
CudaStreamManager(const size_t num_expert_) : num_expert(num_expert_) { CudaStreamManager(const size_t num_expert_) : num_expert(num_expert_) {
streams = new cudaStream_t[num_expert]; streams = new cudaStream_t[MAX_STREAMS];
handles = new cublasHandle_t[num_expert]; handles = new cublasHandle_t[MAX_STREAMS];
for (size_t i=0; i<num_expert; ++i) { for (size_t i=0; i<MAX_STREAMS; ++i) {
checkCudaErrors(cublasCreate(handles + i)); checkCudaErrors(cublasCreate(handles + i));
checkCudaErrors(cudaStreamCreate(streams + i)); checkCudaErrors(cudaStreamCreate(streams + i));
checkCudaErrors(cublasSetStream(handles[i], streams[i])); checkCudaErrors(cublasSetStream(handles[i], streams[i]));
...@@ -22,11 +25,20 @@ struct CudaStreamManager { ...@@ -22,11 +25,20 @@ struct CudaStreamManager {
} }
~CudaStreamManager() { ~CudaStreamManager() {
for (size_t i=0; i<num_expert; ++i) { for (size_t i=0; i<MAX_STREAMS; ++i) {
checkCudaErrors(cudaStreamDestroy(streams[i])); checkCudaErrors(cudaStreamDestroy(streams[i]));
checkCudaErrors(cublasDestroy(handles[i])); checkCudaErrors(cublasDestroy(handles[i]));
} }
} }
inline cudaStream_t& getStream(int idx) {
return streams[idx % MAX_STREAMS];
}
inline cublasHandle_t& getHandle(int idx) {
return handles[idx % MAX_STREAMS];
}
void sync();
}; };
CudaStreamManager* getCudaStreamManager(const size_t num_expert); CudaStreamManager* getCudaStreamManager(const size_t num_expert);
......
...@@ -70,7 +70,7 @@ void moe_cuda_forward_impl( ...@@ -70,7 +70,7 @@ void moe_cuda_forward_impl(
checkCudaErrors(cudaMemcpyAsync(input_buf + target_idx * in_feat, checkCudaErrors(cudaMemcpyAsync(input_buf + target_idx * in_feat,
input + i * in_feat, sizeof(scalar_t) * in_feat, input + i * in_feat, sizeof(scalar_t) * in_feat,
cudaMemcpyDeviceToDevice, cudaMemcpyDeviceToDevice,
h->streams[gate[i]])); h->getStream(gate[i])));
} }
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = 0;
...@@ -85,7 +85,7 @@ void moe_cuda_forward_impl( ...@@ -85,7 +85,7 @@ void moe_cuda_forward_impl(
in_feat); in_feat);
#endif #endif
// Use T(B) x T(A) = T(C) to produce row-major C // Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors(cublasXgemm(h->handles[i], checkCudaErrors(cublasXgemm(h->getHandle(i),
(transb == CUBLAS_OP_T) ? CUBLAS_OP_N : CUBLAS_OP_T, (transb == CUBLAS_OP_T) ? CUBLAS_OP_N : CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
out_feat, expert_count[i], in_feat, out_feat, expert_count[i], in_feat,
...@@ -108,12 +108,11 @@ void moe_cuda_forward_impl( ...@@ -108,12 +108,11 @@ void moe_cuda_forward_impl(
output_buf + target_idx * out_feat, output_buf + target_idx * out_feat,
sizeof(scalar_t) * out_feat, sizeof(scalar_t) * out_feat,
cudaMemcpyDeviceToDevice, cudaMemcpyDeviceToDevice,
h->streams[gate[i]])); h->getStream(gate[i])));
} }
for (int i = 0; i < num_expert; ++i) { h->sync();
cudaStreamSynchronize(h->streams[i]);
}
cudaFree(input_buf); cudaFree(input_buf);
cudaFree(output_buf); cudaFree(output_buf);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment