Commit 143e21cc authored by Rick Ho's avatar Rick Ho
Browse files

update stream manager

parent 2565f2fa
...@@ -9,7 +9,7 @@ void CudaStreamManager::sync(int i) { ...@@ -9,7 +9,7 @@ void CudaStreamManager::sync(int i) {
cudaStreamSynchronize(streams[i]); cudaStreamSynchronize(streams[i]);
return; return;
} }
for (size_t i=0; i<MAX_STREAMS; ++i) { for (size_t i = 0; i < this->num_expert; ++i) {
cudaStreamSynchronize(streams[i]); cudaStreamSynchronize(streams[i]);
} }
} }
...@@ -69,10 +69,6 @@ void moe_cuda_forward_impl( ...@@ -69,10 +69,6 @@ void moe_cuda_forward_impl(
const size_t num_expert, const size_t num_expert,
cublasOperation_t transb) { cublasOperation_t transb) {
#ifdef MOE_BREAKDOWN
timestamp(t_init);
#endif
scalar_t *input_buf, *output_buf; scalar_t *input_buf, *output_buf;
checkCudaErrors(cudaMalloc(&input_buf, sizeof(scalar_t) * batch_size * checkCudaErrors(cudaMalloc(&input_buf, sizeof(scalar_t) * batch_size *
...@@ -80,12 +76,6 @@ void moe_cuda_forward_impl( ...@@ -80,12 +76,6 @@ void moe_cuda_forward_impl(
checkCudaErrors(cudaMalloc(&output_buf, sizeof(scalar_t) * batch_size * checkCudaErrors(cudaMalloc(&output_buf, sizeof(scalar_t) * batch_size *
out_feat)); out_feat));
#ifdef MOE_BREAKDOWN
timestamp(t_malloc);
fprintf(stderr, "Malloc time %.3lf us\n", getDuration(t_init, t_malloc) *
1e6);
#endif
int *gate = new int[batch_size]; int *gate = new int[batch_size];
int *expert_count = new int[num_expert], *expert_ptr = new int[num_expert]; int *expert_count = new int[num_expert], *expert_ptr = new int[num_expert];
memset(expert_count, 0, sizeof(int) * num_expert); memset(expert_count, 0, sizeof(int) * num_expert);
...@@ -93,12 +83,6 @@ void moe_cuda_forward_impl( ...@@ -93,12 +83,6 @@ void moe_cuda_forward_impl(
checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size, checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
cudaMemcpyDeviceToHost)); cudaMemcpyDeviceToHost));
#ifdef MOE_BREAKDOWN
timestamp(t_cpy);
fprintf(stderr, "Copy time %.3lf us\n", getDuration(t_malloc, t_cpy) *
1e6);
#endif
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
++expert_count[gate[i]]; ++expert_count[gate[i]];
} }
...@@ -117,23 +101,10 @@ void moe_cuda_forward_impl( ...@@ -117,23 +101,10 @@ void moe_cuda_forward_impl(
checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size, checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
cudaMemcpyHostToDevice)); cudaMemcpyHostToDevice));
#ifdef MOE_BREAKDOWN
timestamp(t_expert);
fprintf(stderr, "Expert asn time %.3lf us\n", getDuration(t_cpy, t_expert) *
1e6);
#endif
batch_scatter_kernel<scalar_t> batch_scatter_kernel<scalar_t>
<<<batch_size, 256, 0, smgr.streams[0]>>>(in_feat, d_pos, input, <<<batch_size, 256, 0, smgr.streams[0]>>>(in_feat, d_pos, input,
input_buf); input_buf);
// smgr.sync(0); smgr.sync(0);
#ifdef MOE_BREAKDOWN
// h->sync();
timestamp(t_scatter);
fprintf(stderr, "Scatter time %.3lf us\n", getDuration(t_expert, t_scatter) *
1e6);
#endif
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = 0;
...@@ -141,13 +112,8 @@ void moe_cuda_forward_impl( ...@@ -141,13 +112,8 @@ void moe_cuda_forward_impl(
if (expert_count[i] == 0) { if (expert_count[i] == 0) {
continue; continue;
} }
#ifdef MOE_DEBUG_SCATTER
fprintf(stderr, "gemm %d sz %d\n", i, expert_count[i]);
fprintf(stderr, "GeMM %d x %d x %d\n", out_feat, expert_count[i],
in_feat);
#endif
// Use T(B) x T(A) = T(C) to produce row-major C // Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors(cublasXgemm(smgr.handles[0], // h->getHandle(i), checkCudaErrors(cublasXgemm(smgr.handles[i],
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
out_feat, expert_count[i], in_feat, out_feat, expert_count[i], in_feat,
...@@ -161,25 +127,10 @@ void moe_cuda_forward_impl( ...@@ -161,25 +127,10 @@ void moe_cuda_forward_impl(
ptr += expert_count[i]; ptr += expert_count[i];
} }
#ifdef MOE_BREAKDOWN
timestamp(t_mm);
fprintf(stderr, "GeMM time %.3lf us\n", getDuration(t_scatter, t_mm) *
1e6);
#endif
// h->sync();
batch_gather_kernel<scalar_t> batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, smgr.streams[0]>>>(out_feat, d_pos, output_buf, <<<batch_size, 256, 0, smgr.streams[0]>>>(out_feat, d_pos, output_buf,
output); output);
// h->sync(0); smgr.sync(0);
#ifdef MOE_BREAKDOWN
timestamp(t_gather);
fprintf(stderr, "Gather time %.3lf us\n", getDuration(t_mm, t_gather) *
1e6);
fprintf(stderr, "Overall time %.3lf us\n", getDuration(t_init, t_gather) *
1e6);
#endif
cudaFree(input_buf); cudaFree(input_buf);
cudaFree(output_buf); cudaFree(output_buf);
......
...@@ -6,8 +6,8 @@ import sys ...@@ -6,8 +6,8 @@ import sys
def perf(): def perf():
batch_size = int(sys.argv[1]) batch_size = int(sys.argv[1])
io_feat = int(sys.argv[2]) in_feat = int(sys.argv[2])
hidden_feat = int(sys.argv[3]) out_feat = int(sys.argv[3])
num_expert = int(sys.argv[4]) num_expert = int(sys.argv[4])
...@@ -36,7 +36,7 @@ def perf(): ...@@ -36,7 +36,7 @@ def perf():
sqtot += (te - ts)**2 sqtot += (te - ts)**2
maxt = max(maxt, te - ts) maxt = max(maxt, te - ts)
gflops = 2e-9 * n_runs * io_feat * hidden_feat * 2 * batch_size / tott gflops = 2e-9 * n_runs * in_feat * out_feat * batch_size / tott
print('Time mean/max/stdev {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format( print('Time mean/max/stdev {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format(
tott * 1e3 / n_runs, maxt * 1e3, tott * 1e3 / n_runs, maxt * 1e3,
(sqtot / n_runs - (tott / n_runs)**2) * 1e3 / n_runs, gflops)) (sqtot / n_runs - (tott / n_runs)**2) * 1e3 / n_runs, gflops))
......
...@@ -11,7 +11,7 @@ setup( ...@@ -11,7 +11,7 @@ setup(
name='moe_cuda', name='moe_cuda',
sources=[ sources=[
'moe.cpp', 'moe.cpp',
# 'cuda_stream_manager.cpp', 'cuda_stream_manager.cpp',
'moe_cuda_kernel.cu', 'moe_cuda_kernel.cu',
], ],
extra_compile_args={'cxx': ['-I{}'.format(CUDA_HELPER)], extra_compile_args={'cxx': ['-I{}'.format(CUDA_HELPER)],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment