Commit f165928a authored by TiagoMAntunes's avatar TiagoMAntunes
Browse files

New simplified column reduction kernel

parent 69151519
...@@ -50,36 +50,40 @@ void column_reduce(const scalar_t * matrix, scalar_t * result, ...@@ -50,36 +50,40 @@ void column_reduce(const scalar_t * matrix, scalar_t * result,
extern __shared__ __align__(sizeof(scalar_t)) unsigned char my_smem[]; extern __shared__ __align__(sizeof(scalar_t)) unsigned char my_smem[];
scalar_t *sdata = reinterpret_cast<scalar_t *>(my_smem); scalar_t *sdata = reinterpret_cast<scalar_t *>(my_smem);
unsigned int tid = threadIdx.x + threadIdx.y * blockDim.x;
unsigned int real_x = threadIdx.x + blockDim.x * blockIdx.x;
unsigned int real_y = n * threadIdx.y;
unsigned int i = real_x + real_y;
unsigned int it = n*blockDim.y;
unsigned int offset = it;
unsigned int tid = threadIdx.x + threadIdx.y * blockDim.x; // line
unsigned int i = threadIdx.x * n + threadIdx.y + blockIdx.y * blockDim.y; // get to idx th line
unsigned int offset = 0;
unsigned int it = n * blockDim.x; // advance blockDim.x threads vertically
unsigned int real_y = blockIdx.y * blockDim.y + threadIdx.y;
// sum all the values from that column to fit in one single block
sdata[tid] = 0; sdata[tid] = 0;
if (real_y < n && threadIdx.x < m) // remember we only have one x block if (threadIdx.y < m && real_x < n) {
// can load memory
// printf("tid=%d loading %d\n", tid, i);
sdata[tid] = matrix[i];
while (i + offset < n*m) { while (i + offset < n*m) {
// printf("tid=%d loading %d\n", tid, i+offset);
sdata[tid] += matrix[i + offset]; sdata[tid] += matrix[i + offset];
offset += it; offset += it;
} }
__syncthreads(); }
unsigned int lowest = blockDim.x > m ? m : blockDim.x;
if (real_y < n && threadIdx.x < m)
for (unsigned int s = 1; threadIdx.x + s < lowest; s *= 2) {
if (threadIdx.x % (2*s) == 0) {
sdata[tid] += sdata[tid + s];
}
__syncthreads(); __syncthreads();
for (unsigned int s = blockDim.y / 2; s > 0; s>>=1) {
if (threadIdx.y < s) {
// printf("tid=%d adding %d\n", tid, tid + blockDim.x *s);
sdata[tid] += sdata[tid + blockDim.x * s];
} }
__syncthreads();
}
if (threadIdx.x == 0 && real_y < n) { if (threadIdx.y == 0 && real_x < n) {
result[real_y] = sdata[tid]; result[real_x] = sdata[tid];
} }
} }
void moe_cuda_expert_count_impl( void moe_cuda_expert_count_impl(
...@@ -211,7 +215,7 @@ void moe_cuda_backward_impl( ...@@ -211,7 +215,7 @@ void moe_cuda_backward_impl(
// bias // bias
dim3 block_threads(32, 32); dim3 block_threads(32, 32);
dim3 grid_threads(1, out_feat / 32 + (out_feat % 32 ? 1 : 0)); dim3 grid_threads(out_feat / 32 + (out_feat % 32 ? 1 : 0), 1);
for (int i = 0, ptr = 0; i < num_expert; ++i) { for (int i = 0, ptr = 0; i < num_expert; ++i) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment