Commit 1e2ef8fa authored by Paul's avatar Paul
Browse files

Formatting

parent 97056d33
......@@ -35,13 +35,13 @@ inline auto launch(hipStream_t stream, std::size_t global, std::size_t local)
};
}
template<class F>
template <class F>
__host__ __device__ auto gs_invoke(F&& f, std::size_t i, index idx) -> decltype(f(i, idx))
{
return f(i, idx);
}
template<class F>
template <class F>
__host__ __device__ auto gs_invoke(F&& f, std::size_t i, index) -> decltype(f(i))
{
return f(i);
......
......@@ -9,14 +9,14 @@ namespace device {
struct sum
{
template<class T>
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T operator()(T x, T y) const
{
return x + y;
}
};
template<std::size_t N, class Op, class T, class F>
template <std::size_t N, class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
{
using type = decltype(f(idx.local));
......@@ -32,7 +32,7 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
for(std::size_t s = 1; s < N; s *= 2)
{
const std::size_t index = 2 * s * idx.local;
if (index < N)
if(index < N)
{
buffer[index] = op(buffer[index], buffer[index + s]);
}
......@@ -46,8 +46,12 @@ void reduce_sum(hipStream_t stream, const argument& result, const argument& arg)
auto&& output_shape = result.get_shape();
auto&& input_shape = arg.get_shape();
std::vector<std::size_t> reduce_lens;
std::transform(output_shape.lens().begin(), output_shape.lens().end(), input_shape.lens().begin(), std::back_inserter(reduce_lens), [](auto x, auto y) -> std::size_t {
if (x == y)
std::transform(output_shape.lens().begin(),
output_shape.lens().end(),
input_shape.lens().begin(),
std::back_inserter(reduce_lens),
[](auto x, auto y) -> std::size_t {
if(x == y)
return 1;
else
return y;
......@@ -58,14 +62,14 @@ void reduce_sum(hipStream_t stream, const argument& result, const argument& arg)
auto relements = reduce_slice.elements();
const std::size_t block_size = 1024;
gs_launch(stream, nelements*block_size, block_size)([=](auto i, auto idx) __device__ {
auto base_idx = output.get_shape().multi(i/block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
auto base_idx = output.get_shape().multi(i / block_size);
auto offset = input.get_shape().index(base_idx);
auto r = block_reduce<block_size>(idx, sum{}, 0, relements, [&](auto j) __device__ {
return input.data()[reduce_shape.index(j) + offset];
});
if (idx.local == 0)
output.data()[i/block_size] = r;
if(idx.local == 0)
output.data()[i / block_size] = r;
});
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment