Commit 3809fcb4 authored by Paul's avatar Paul
Browse files

Reduce block size for reductions

parent 1e2ef8fa
......@@ -50,6 +50,14 @@ struct hip_array
result[i] = x[i] * y[i];
return result;
}
friend MIGRAPHX_DEVICE_CONSTEXPR hip_array operator+(const hip_array& x, const hip_array& y)
{
hip_array result;
for(std::size_t i = 0; i < N; i++)
result[i] = x[i] + y[i];
return result;
}
};
} // namespace device
......
......@@ -14,6 +14,36 @@ struct index
std::size_t global;
std::size_t local;
std::size_t group;
__device__ std::size_t nglobal() const
{
return blockDim.x * gridDim.x;
}
__device__ std::size_t nlocal() const
{
return blockDim.x;
}
template<class F>
__device__ void global_stride(std::size_t n, F f) const
{
const auto stride = nglobal();
for(std::size_t i = global; i < n; i += stride)
{
f(i);
}
}
template<class F>
__device__ void local_stride(std::size_t n, F f) const
{
const auto stride = nlocal();
for(std::size_t i = local; i < n; i += stride)
{
f(i);
}
}
};
template <class F>
......@@ -54,10 +84,9 @@ inline auto gs_launch(hipStream_t stream, std::size_t n, std::size_t local = 102
return [=](auto f) {
launch(stream, nglobal, local)([=](auto idx) {
for(size_t i = idx.global; i < n; i += nglobal)
{
idx.global_stride(n, [&](auto i) {
gs_invoke(f, i, idx);
}
});
});
};
}
......
......@@ -22,17 +22,16 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
using type = decltype(f(idx.local));
MIGRAPHX_DEVICE_SHARED type buffer[N];
type x = init;
for(size_t i = idx.local; i < n; i += N)
{
x = op(x, f(i));
}
idx.local_stride(n, [&](auto i) {
x = op(x, f(i));
});
buffer[idx.local] = x;
__syncthreads();
for(std::size_t s = 1; s < N; s *= 2)
for(std::size_t s = 1; s < idx.nlocal(); s *= 2)
{
const std::size_t index = 2 * s * idx.local;
if(index < N)
if(index < idx.nlocal())
{
buffer[index] = op(buffer[index], buffer[index + s]);
}
......@@ -41,6 +40,14 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
return buffer[0];
}
constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_size)
{
size_t block_size = 1;
while(block_size < max_block_size and block_size < n)
block_size *= 2;
return block_size;
}
void reduce_sum(hipStream_t stream, const argument& result, const argument& arg)
{
auto&& output_shape = result.get_shape();
......@@ -61,11 +68,12 @@ void reduce_sum(hipStream_t stream, const argument& result, const argument& arg)
auto nelements = result.get_shape().elements();
auto relements = reduce_slice.elements();
const std::size_t block_size = 1024;
const std::size_t max_block_size = 1024;
const std::size_t block_size = compute_block_size(relements, max_block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
auto base_idx = output.get_shape().multi(i / block_size);
auto offset = input.get_shape().index(base_idx);
auto r = block_reduce<block_size>(idx, sum{}, 0, relements, [&](auto j) __device__ {
auto r = block_reduce<max_block_size>(idx, sum{}, 0, relements, [&](auto j) __device__ {
return input.data()[reduce_shape.index(j) + offset];
});
if(idx.local == 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment