Commit 3809fcb4 authored by Paul's avatar Paul
Browse files

Reduce block size for reductions

parent 1e2ef8fa
...@@ -50,6 +50,14 @@ struct hip_array ...@@ -50,6 +50,14 @@ struct hip_array
result[i] = x[i] * y[i]; result[i] = x[i] * y[i];
return result; return result;
} }
friend MIGRAPHX_DEVICE_CONSTEXPR hip_array operator+(const hip_array& x, const hip_array& y)
{
hip_array result;
for(std::size_t i = 0; i < N; i++)
result[i] = x[i] + y[i];
return result;
}
}; };
} // namespace device } // namespace device
......
...@@ -14,6 +14,36 @@ struct index ...@@ -14,6 +14,36 @@ struct index
std::size_t global; std::size_t global;
std::size_t local; std::size_t local;
std::size_t group; std::size_t group;
__device__ std::size_t nglobal() const
{
return blockDim.x * gridDim.x;
}
__device__ std::size_t nlocal() const
{
return blockDim.x;
}
template<class F>
__device__ void global_stride(std::size_t n, F f) const
{
const auto stride = nglobal();
for(std::size_t i = global; i < n; i += stride)
{
f(i);
}
}
template<class F>
__device__ void local_stride(std::size_t n, F f) const
{
const auto stride = nlocal();
for(std::size_t i = local; i < n; i += stride)
{
f(i);
}
}
}; };
template <class F> template <class F>
...@@ -54,10 +84,9 @@ inline auto gs_launch(hipStream_t stream, std::size_t n, std::size_t local = 102 ...@@ -54,10 +84,9 @@ inline auto gs_launch(hipStream_t stream, std::size_t n, std::size_t local = 102
return [=](auto f) { return [=](auto f) {
launch(stream, nglobal, local)([=](auto idx) { launch(stream, nglobal, local)([=](auto idx) {
for(size_t i = idx.global; i < n; i += nglobal) idx.global_stride(n, [&](auto i) {
{
gs_invoke(f, i, idx); gs_invoke(f, i, idx);
} });
}); });
}; };
} }
......
...@@ -22,17 +22,16 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f) ...@@ -22,17 +22,16 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
using type = decltype(f(idx.local)); using type = decltype(f(idx.local));
MIGRAPHX_DEVICE_SHARED type buffer[N]; MIGRAPHX_DEVICE_SHARED type buffer[N];
type x = init; type x = init;
for(size_t i = idx.local; i < n; i += N) idx.local_stride(n, [&](auto i) {
{ x = op(x, f(i));
x = op(x, f(i)); });
}
buffer[idx.local] = x; buffer[idx.local] = x;
__syncthreads(); __syncthreads();
for(std::size_t s = 1; s < N; s *= 2) for(std::size_t s = 1; s < idx.nlocal(); s *= 2)
{ {
const std::size_t index = 2 * s * idx.local; const std::size_t index = 2 * s * idx.local;
if(index < N) if(index < idx.nlocal())
{ {
buffer[index] = op(buffer[index], buffer[index + s]); buffer[index] = op(buffer[index], buffer[index + s]);
} }
...@@ -41,6 +40,14 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f) ...@@ -41,6 +40,14 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
return buffer[0]; return buffer[0];
} }
constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_size)
{
size_t block_size = 1;
while(block_size < max_block_size and block_size < n)
block_size *= 2;
return block_size;
}
void reduce_sum(hipStream_t stream, const argument& result, const argument& arg) void reduce_sum(hipStream_t stream, const argument& result, const argument& arg)
{ {
auto&& output_shape = result.get_shape(); auto&& output_shape = result.get_shape();
...@@ -61,11 +68,12 @@ void reduce_sum(hipStream_t stream, const argument& result, const argument& arg) ...@@ -61,11 +68,12 @@ void reduce_sum(hipStream_t stream, const argument& result, const argument& arg)
auto nelements = result.get_shape().elements(); auto nelements = result.get_shape().elements();
auto relements = reduce_slice.elements(); auto relements = reduce_slice.elements();
const std::size_t block_size = 1024; const std::size_t max_block_size = 1024;
const std::size_t block_size = compute_block_size(relements, max_block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ { gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
auto base_idx = output.get_shape().multi(i / block_size); auto base_idx = output.get_shape().multi(i / block_size);
auto offset = input.get_shape().index(base_idx); auto offset = input.get_shape().index(base_idx);
auto r = block_reduce<block_size>(idx, sum{}, 0, relements, [&](auto j) __device__ { auto r = block_reduce<max_block_size>(idx, sum{}, 0, relements, [&](auto j) __device__ {
return input.data()[reduce_shape.index(j) + offset]; return input.data()[reduce_shape.index(j) + offset];
}); });
if(idx.local == 0) if(idx.local == 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment