Commit d690c7b2 authored by Rick Ho's avatar Rick Ho
Browse files

manual batched scatter and gather kernels

parent 254ad118
......@@ -11,7 +11,11 @@ CudaStreamManager* getCudaStreamManager(const size_t num_expert) {
return smgr;
}
void CudaStreamManager::sync() {
void CudaStreamManager::sync(int i) {
if (i > -1) {
cudaStreamSynchronize(streams[i]);
return;
}
for (size_t i=0; i<MAX_STREAMS; ++i) {
cudaStreamSynchronize(streams[i]);
}
......
......@@ -38,7 +38,7 @@ struct CudaStreamManager {
return handles[idx % MAX_STREAMS];
}
void sync();
void sync(int=-1);
};
CudaStreamManager* getCudaStreamManager(const size_t num_expert);
......
......@@ -18,7 +18,7 @@
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
// #define MOE_BREAKDOWN
#define MOE_DEBUG
// #define MOE_DEBUG
template <typename scalar_t>
__global__
......@@ -31,6 +31,29 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
}
template <typename scalar_t>
__global__
void batch_scatter_kernel(int wid, int* pos,
const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * blockIdx.x;
oubuf += wid * pos[blockIdx.x];
for (int i = threadIdx.x; i < wid; i += blockDim.x) {
oubuf[i] = inbuf[i];
}
}
template <typename scalar_t>
__global__
void batch_gather_kernel(int wid, int* pos,
const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * pos[blockIdx.x];
oubuf += wid * blockIdx.x;
for (int i = threadIdx.x; i < wid; i += blockDim.x) {
oubuf[i] = inbuf[i];
}
}
template <typename scalar_t>
void moe_cuda_forward_impl(
const scalar_t* input,
......@@ -86,22 +109,26 @@ void moe_cuda_forward_impl(
expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
}
int *pos = new int[batch_size];
int *d_pos;
checkCudaErrors(cudaMalloc(&d_pos, sizeof(int) * batch_size));
for (int i = 0; i < batch_size; ++i) {
pos[i] = expert_ptr[gate[i]]++;
}
checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
cudaMemcpyHostToDevice));
#ifdef MOE_BREAKDOWN
timestamp(t_expert);
fprintf(stderr, "Expert asn time %.3lf us\n", getDuration(t_cpy, t_expert) *
1e6);
#endif
for (int i = 0; i < batch_size; ++i) {
int target_idx = expert_ptr[gate[i]]++;
#ifdef MOE_DEBUG_SCATTER
fprintf(stderr, "aln idx %d gate %d tgt %d\n", i, gate[i], target_idx);
#endif
checkCudaErrors(cudaMemcpyAsync(input_buf + target_idx * in_feat,
input + i * in_feat, sizeof(scalar_t) * in_feat,
cudaMemcpyDeviceToDevice,
h->getStream(gate[i])));
}
batch_scatter_kernel<scalar_t>
<<<batch_size, 256, 0, h->getStream(0)>>>(in_feat, d_pos, input,
input_buf);
h->sync(0);
#ifdef MOE_BREAKDOWN
h->sync();
......@@ -148,25 +175,16 @@ void moe_cuda_forward_impl(
}
#ifdef MOE_BREAKDOWN
h->sync();
timestamp(t_mm);
fprintf(stderr, "GeMM time %.3lf us\n", getDuration(t_scatter, t_mm) *
1e6);
#endif
for (int i = batch_size - 1; i >= 0; --i) {
int target_idx = --expert_ptr[gate[i]];
#ifdef MOE_DEBUG_SCATTER
fprintf(stderr, "cb idx %d gate %d tgt %d\n", i, gate[i], target_idx);
#endif
checkCudaErrors(cudaMemcpyAsync(output + i * out_feat,
output_buf + target_idx * out_feat,
sizeof(scalar_t) * out_feat,
cudaMemcpyDeviceToDevice,
h->getStream(gate[i])));
}
h->sync();
batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, h->getStream(0)>>>(out_feat, d_pos, output_buf,
output);
h->sync(0);
#ifdef MOE_BREAKDOWN
timestamp(t_gather);
......@@ -177,7 +195,11 @@ void moe_cuda_forward_impl(
#endif
cudaFree(input_buf);
cudaFree(hidden_buf);
cudaFree(output_buf);
cudaFree(d_pos);
delete [] pos;
delete [] gate;
}
template <typename scalar_t>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment