Commit 3d1987d1 authored by TiagoMAntunes's avatar TiagoMAntunes
Browse files

New compute kernel for column reduction

parent f957c299
......@@ -46,29 +46,35 @@ __global__
void column_reduce(const scalar_t * matrix, scalar_t * result,
int m /* lines */, int n /* columns*/) {
extern __shared__ float sdata[];
unsigned int tid = threadIdx.x; // line
unsigned int i = blockIdx.x + threadIdx.x * n; // get to idx th line
unsigned int tid = threadIdx.x + threadIdx.y * blockDim.x; // line
unsigned int i = threadIdx.x * n + threadIdx.y + blockIdx.y * blockDim.y; // get to idx th line
unsigned int offset = 0;
unsigned int it = n * blockDim.x; // advanced blockDim.x threads vertically
unsigned int it = n * blockDim.x; // advance blockDim.x threads vertically
unsigned int real_y = blockIdx.y * blockDim.y + threadIdx.y;
// sum all the values from that column to fit in one single block
sdata[tid] = 0;
while (i + offset < n*m) {
sdata[tid] += matrix[i + offset];
offset += it;
}
if (real_y < n && threadIdx.x < m) // remember we only have one x block
while (i + offset < n*m) {
sdata[tid] += matrix[i + offset];
offset += it;
}
__syncthreads();
for (unsigned int s = 1; tid + s < blockDim.x; s *= 2) {
if (tid % (2*s) == 0) {
sdata[tid] += sdata[tid + s];
unsigned int lowest = blockDim.x > m ? m : blockDim.x;
if (real_y < n && threadIdx.x < m)
for (unsigned int s = 1; threadIdx.x + s < lowest; s *= 2) {
if (threadIdx.x % (2*s) == 0) {
sdata[tid] += sdata[tid + s];
}
__syncthreads();
}
__syncthreads();
if (threadIdx.x == 0 && real_y < n) {
result[real_y] = sdata[tid];
}
if (tid == 0) {result[blockIdx.x] = sdata[0];}
}
void moe_cuda_expert_count_impl(
......@@ -198,6 +204,11 @@ void moe_cuda_backward_impl(
CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = 0;
// bias
dim3 block_threads(32, 32);
dim3 grid_threads(1, out_feat / 32 + (out_feat % 32 ? 1 : 0));
for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) {
cudaMemset(grad_weight + i * in_feat * out_feat, 0,
......@@ -235,7 +246,7 @@ void moe_cuda_backward_impl(
if (has_bias) {
column_reduce
<<<out_feat, 1024, sizeof(scalar_t)*1024, smgr->stream(0)>>>
<<<grid_threads, block_threads, sizeof(scalar_t)*1024, smgr->stream(0)>>>
(
grad_output_buf + ptr * out_feat,
grad_bias + i * out_feat,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment