Commit bf60a846 authored by Rick Ho's avatar Rick Ho
Browse files

limit max streams

parent ab153b37
......@@ -6,15 +6,18 @@
#include <helper_cuda.h>
#define MAX_STREAMS 16
struct CudaStreamManager {
const size_t num_expert;
cublasHandle_t* handles;
cudaStream_t* streams;
CudaStreamManager(const size_t num_expert_) : num_expert(num_expert_) {
streams = new cudaStream_t[num_expert];
handles = new cublasHandle_t[num_expert];
for (size_t i=0; i<num_expert; ++i) {
streams = new cudaStream_t[MAX_STREAMS];
handles = new cublasHandle_t[MAX_STREAMS];
for (size_t i=0; i<MAX_STREAMS; ++i) {
checkCudaErrors(cublasCreate(handles + i));
checkCudaErrors(cudaStreamCreate(streams + i));
checkCudaErrors(cublasSetStream(handles[i], streams[i]));
......@@ -22,11 +25,20 @@ struct CudaStreamManager {
}
~CudaStreamManager() {
for (size_t i=0; i<num_expert; ++i) {
for (size_t i=0; i<MAX_STREAMS; ++i) {
checkCudaErrors(cudaStreamDestroy(streams[i]));
checkCudaErrors(cublasDestroy(handles[i]));
}
}
inline cudaStream_t& getStream(int idx) {
return streams[idx % MAX_STREAMS];
}
inline cublasHandle_t& getHandle(int idx) {
return handles[idx % MAX_STREAMS];
}
void sync();
};
CudaStreamManager* getCudaStreamManager(const size_t num_expert);
......
......@@ -70,7 +70,7 @@ void moe_cuda_forward_impl(
checkCudaErrors(cudaMemcpyAsync(input_buf + target_idx * in_feat,
input + i * in_feat, sizeof(scalar_t) * in_feat,
cudaMemcpyDeviceToDevice,
h->streams[gate[i]]));
h->getStream(gate[i])));
}
scalar_t alpha = 1, beta = 0;
......@@ -85,7 +85,7 @@ void moe_cuda_forward_impl(
in_feat);
#endif
// Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors(cublasXgemm(h->handles[i],
checkCudaErrors(cublasXgemm(h->getHandle(i),
(transb == CUBLAS_OP_T) ? CUBLAS_OP_N : CUBLAS_OP_T,
CUBLAS_OP_N,
out_feat, expert_count[i], in_feat,
......@@ -108,12 +108,11 @@ void moe_cuda_forward_impl(
output_buf + target_idx * out_feat,
sizeof(scalar_t) * out_feat,
cudaMemcpyDeviceToDevice,
h->streams[gate[i]]));
h->getStream(gate[i])));
}
for (int i = 0; i < num_expert; ++i) {
cudaStreamSynchronize(h->streams[i]);
}
h->sync();
cudaFree(input_buf);
cudaFree(output_buf);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment