Commit d690c7b2 authored by Rick Ho's avatar Rick Ho
Browse files

manual batched scatter and gather kernels

parent 254ad118
...@@ -11,7 +11,11 @@ CudaStreamManager* getCudaStreamManager(const size_t num_expert) { ...@@ -11,7 +11,11 @@ CudaStreamManager* getCudaStreamManager(const size_t num_expert) {
return smgr; return smgr;
} }
void CudaStreamManager::sync() { void CudaStreamManager::sync(int i) {
if (i > -1) {
cudaStreamSynchronize(streams[i]);
return;
}
for (size_t i=0; i<MAX_STREAMS; ++i) { for (size_t i=0; i<MAX_STREAMS; ++i) {
cudaStreamSynchronize(streams[i]); cudaStreamSynchronize(streams[i]);
} }
......
...@@ -38,7 +38,7 @@ struct CudaStreamManager { ...@@ -38,7 +38,7 @@ struct CudaStreamManager {
return handles[idx % MAX_STREAMS]; return handles[idx % MAX_STREAMS];
} }
void sync(); void sync(int=-1);
}; };
CudaStreamManager* getCudaStreamManager(const size_t num_expert); CudaStreamManager* getCudaStreamManager(const size_t num_expert);
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1) #define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
// #define MOE_BREAKDOWN // #define MOE_BREAKDOWN
#define MOE_DEBUG // #define MOE_DEBUG
template <typename scalar_t> template <typename scalar_t>
__global__ __global__
...@@ -31,6 +31,29 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, ...@@ -31,6 +31,29 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
} }
template <typename scalar_t>
__global__
void batch_scatter_kernel(int wid, int* pos,
const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * blockIdx.x;
oubuf += wid * pos[blockIdx.x];
for (int i = threadIdx.x; i < wid; i += blockDim.x) {
oubuf[i] = inbuf[i];
}
}
template <typename scalar_t>
__global__
void batch_gather_kernel(int wid, int* pos,
const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * pos[blockIdx.x];
oubuf += wid * blockIdx.x;
for (int i = threadIdx.x; i < wid; i += blockDim.x) {
oubuf[i] = inbuf[i];
}
}
template <typename scalar_t> template <typename scalar_t>
void moe_cuda_forward_impl( void moe_cuda_forward_impl(
const scalar_t* input, const scalar_t* input,
...@@ -86,22 +109,26 @@ void moe_cuda_forward_impl( ...@@ -86,22 +109,26 @@ void moe_cuda_forward_impl(
expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1]; expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
} }
int *pos = new int[batch_size];
int *d_pos;
checkCudaErrors(cudaMalloc(&d_pos, sizeof(int) * batch_size));
for (int i = 0; i < batch_size; ++i) {
pos[i] = expert_ptr[gate[i]]++;
}
checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
cudaMemcpyHostToDevice));
#ifdef MOE_BREAKDOWN #ifdef MOE_BREAKDOWN
timestamp(t_expert); timestamp(t_expert);
fprintf(stderr, "Expert asn time %.3lf us\n", getDuration(t_cpy, t_expert) * fprintf(stderr, "Expert asn time %.3lf us\n", getDuration(t_cpy, t_expert) *
1e6); 1e6);
#endif #endif
for (int i = 0; i < batch_size; ++i) { batch_scatter_kernel<scalar_t>
int target_idx = expert_ptr[gate[i]]++; <<<batch_size, 256, 0, h->getStream(0)>>>(in_feat, d_pos, input,
#ifdef MOE_DEBUG_SCATTER input_buf);
fprintf(stderr, "aln idx %d gate %d tgt %d\n", i, gate[i], target_idx); h->sync(0);
#endif
checkCudaErrors(cudaMemcpyAsync(input_buf + target_idx * in_feat,
input + i * in_feat, sizeof(scalar_t) * in_feat,
cudaMemcpyDeviceToDevice,
h->getStream(gate[i])));
}
#ifdef MOE_BREAKDOWN #ifdef MOE_BREAKDOWN
h->sync(); h->sync();
...@@ -148,25 +175,16 @@ void moe_cuda_forward_impl( ...@@ -148,25 +175,16 @@ void moe_cuda_forward_impl(
} }
#ifdef MOE_BREAKDOWN #ifdef MOE_BREAKDOWN
h->sync();
timestamp(t_mm); timestamp(t_mm);
fprintf(stderr, "GeMM time %.3lf us\n", getDuration(t_scatter, t_mm) * fprintf(stderr, "GeMM time %.3lf us\n", getDuration(t_scatter, t_mm) *
1e6); 1e6);
#endif #endif
for (int i = batch_size - 1; i >= 0; --i) {
int target_idx = --expert_ptr[gate[i]];
#ifdef MOE_DEBUG_SCATTER
fprintf(stderr, "cb idx %d gate %d tgt %d\n", i, gate[i], target_idx);
#endif
checkCudaErrors(cudaMemcpyAsync(output + i * out_feat,
output_buf + target_idx * out_feat,
sizeof(scalar_t) * out_feat,
cudaMemcpyDeviceToDevice,
h->getStream(gate[i])));
}
h->sync(); h->sync();
batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, h->getStream(0)>>>(out_feat, d_pos, output_buf,
output);
h->sync(0);
#ifdef MOE_BREAKDOWN #ifdef MOE_BREAKDOWN
timestamp(t_gather); timestamp(t_gather);
...@@ -177,7 +195,11 @@ void moe_cuda_forward_impl( ...@@ -177,7 +195,11 @@ void moe_cuda_forward_impl(
#endif #endif
cudaFree(input_buf); cudaFree(input_buf);
cudaFree(hidden_buf);
cudaFree(output_buf); cudaFree(output_buf);
cudaFree(d_pos);
delete [] pos;
delete [] gate;
} }
template <typename scalar_t> template <typename scalar_t>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment