Unverified Commit ee257d99 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Use double buffer for block_scan (#2436)

parent 19bd9c49
......@@ -43,24 +43,32 @@ template <index_int N,
__device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input, Output output)
{
using type = decltype(input(deduce_for_stride(fs)));
MIGRAPHX_DEVICE_SHARED type buffer[N];
MIGRAPHX_DEVICE_SHARED type buffer[2][N];
type x = init;
fs([&](auto i) {
index_int iout = 0;
index_int iin = 1;
if(idx.local == 0)
buffer[idx.local] = op(input(i), x);
buffer[iout][idx.local] = op(input(i), x);
else
buffer[idx.local] = input(i);
buffer[iout][idx.local] = input(i);
__syncthreads();
for(index_int s = 1; s < idx.nlocal(); s *= 2)
{
if(idx.local + s < idx.nlocal())
iout = 1 - iout;
iin = 1 - iin;
if(idx.local >= s)
{
buffer[idx.local + s] = op(buffer[idx.local], buffer[idx.local + s]);
buffer[iout][idx.local] = op(buffer[iin][idx.local], buffer[iin][idx.local - s]);
}
else
{
buffer[iout][idx.local] = buffer[iin][idx.local];
}
__syncthreads();
}
x = buffer[idx.nlocal() - 1];
output(i, buffer[idx.local]);
x = buffer[iout][idx.nlocal() - 1];
output(i, buffer[iout][idx.local]);
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment