Commit 01d888bf authored by TiagoMAntunes's avatar TiagoMAntunes
Browse files

Added CUDA kernel for column reduction

parent c5f73a0f
......@@ -37,6 +37,40 @@ void batch_scatter_kernel(size_t wid, const long* pos,
}
}
/*
This function is to be called with one block per each column
*/
template <typename scalar_t>
__global__
void column_reduce(const scalar_t * matrix, scalar_t * result,
int m /* lines */, int n /* columns*/) {
extern __shared__ float sdata[];
unsigned int tid = threadIdx.x; // line
unsigned int i = blockIdx.x + threadIdx.x * n; // get to idx th line
unsigned int offset = 0;
unsigned int it = n * blockDim.x; // advanced blockDim.x threads vertically
// sum all the values from that column to fit in one single block
sdata[tid] = 0;
while (i + offset < n*m) {
sdata[tid] += matrix[i + offset];
offset += it;
}
__syncthreads();
for (unsigned int s = 1; tid + s < blockDim.x; s *= 2) {
if (tid % (2*s) == 0) {
sdata[tid] += sdata[tid + s];
}
__syncthreads();
}
if (tid == 0) {result[blockIdx.x] = sdata[0];}
}
void moe_cuda_expert_count_impl(
const int* d_gate,
int* expert_count,
......@@ -168,6 +202,7 @@ void moe_cuda_backward_impl(
if (expert_count[i] == 0) {
cudaMemset(grad_weight + i * in_feat * out_feat, 0,
sizeof(scalar_t) * in_feat * out_feat);
cudaMemset(grad_bias + i * out_feat, 0, sizeof(scalar_t) * out_feat);
continue;
}
// Use T(B) x T(A) = T(C) to produce row-major C
......@@ -199,7 +234,14 @@ void moe_cuda_backward_impl(
));
if (has_bias) {
// call bias kernel here
column_reduce
<<<out_feat, 1024, sizeof(scalar_t)*1024, smgr->stream(0)>>>
(
grad_output_buf + ptr * out_feat,
grad_bias + i * out_feat,
expert_count[i],
out_feat
);
}
ptr += expert_count[i];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment