Unverified Commit ee257d99 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Use double buffer for block_scan (#2436)

parent 19bd9c49
...@@ -43,24 +43,32 @@ template <index_int N, ...@@ -43,24 +43,32 @@ template <index_int N,
__device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input, Output output) __device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input, Output output)
{ {
using type = decltype(input(deduce_for_stride(fs))); using type = decltype(input(deduce_for_stride(fs)));
MIGRAPHX_DEVICE_SHARED type buffer[N]; MIGRAPHX_DEVICE_SHARED type buffer[2][N];
type x = init; type x = init;
fs([&](auto i) { fs([&](auto i) {
index_int iout = 0;
index_int iin = 1;
if(idx.local == 0) if(idx.local == 0)
buffer[idx.local] = op(input(i), x); buffer[iout][idx.local] = op(input(i), x);
else else
buffer[idx.local] = input(i); buffer[iout][idx.local] = input(i);
__syncthreads(); __syncthreads();
for(index_int s = 1; s < idx.nlocal(); s *= 2) for(index_int s = 1; s < idx.nlocal(); s *= 2)
{ {
if(idx.local + s < idx.nlocal()) iout = 1 - iout;
iin = 1 - iin;
if(idx.local >= s)
{ {
buffer[idx.local + s] = op(buffer[idx.local], buffer[idx.local + s]); buffer[iout][idx.local] = op(buffer[iin][idx.local], buffer[iin][idx.local - s]);
}
else
{
buffer[iout][idx.local] = buffer[iin][idx.local];
} }
__syncthreads(); __syncthreads();
} }
x = buffer[idx.nlocal() - 1]; x = buffer[iout][idx.nlocal() - 1];
output(i, buffer[idx.local]); output(i, buffer[iout][idx.local]);
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment