"vscode:/vscode.git/clone" did not exist on "6ba2231d7227825dfb76e9df161824d04234e69f"
Commit b0704f1d authored by TiagoMAntunes's avatar TiagoMAntunes
Browse files

New data placement structure, speedup

parent f165928a
......@@ -47,45 +47,48 @@ void column_reduce(const scalar_t * matrix, scalar_t * result,
int m /* lines */, int n /* columns*/) {
// https://stackoverflow.com/questions/27570552/templated-cuda-kernel-with-dynamic-shared-memory
extern __shared__ __align__(sizeof(scalar_t)) unsigned char my_smem[];
extern __shared__ unsigned char my_smem[];
scalar_t *sdata = reinterpret_cast<scalar_t *>(my_smem);
unsigned int tid = threadIdx.x + threadIdx.y * blockDim.x;
unsigned int real_x = threadIdx.x + blockDim.x * blockIdx.x;
unsigned int real_y = n * threadIdx.y;
// normal tid
int tid = threadIdx.x + threadIdx.y * blockDim.x;
unsigned int i = real_x + real_y;
unsigned int it = n*blockDim.y;
unsigned int offset = it;
// transposed tid for shared memory
int new_tid = threadIdx.y + threadIdx.x * blockDim.y;
// true x value in the matrix
int real_x = threadIdx.x + blockDim.x * blockIdx.x;
int i = real_x + n * threadIdx.y;
const int it = n*blockDim.y;
int offset = it;
float accumulator = 0;
sdata[tid] = 0;
if (threadIdx.y < m && real_x < n) {
// can load memory
// printf("tid=%d loading %d\n", tid, i);
sdata[tid] = matrix[i];
// store all the values from this column in a warped way
accumulator = matrix[i];
while (i + offset < n*m) {
// printf("tid=%d loading %d\n", tid, i+offset);
sdata[tid] += matrix[i + offset];
accumulator += matrix[i + offset];
offset += it;
}
}
// save column reduction data in a transposed way
sdata[new_tid] = accumulator;
__syncthreads();
for (unsigned int s = blockDim.y / 2; s > 0; s>>=1) {
if (threadIdx.y < s) {
// printf("tid=%d adding %d\n", tid, tid + blockDim.x *s);
sdata[tid] += sdata[tid + blockDim.x * s];
}
for (size_t t= 16; t > 0; t>>=1) {
if (tid < 32 * 32 - 16)
sdata[tid] += sdata[tid + t];
__syncthreads();
}
if (threadIdx.y == 0 && real_x < n) {
result[real_x] = sdata[tid];
}
if (threadIdx.y == 0 && real_x < n)
result[real_x] = sdata[new_tid];
}
void moe_cuda_expert_count_impl(
const int* d_gate,
int* expert_count,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment