Commit 1e2ef8fa authored by Paul's avatar Paul
Browse files

Formatting

parent 97056d33
......@@ -35,13 +35,13 @@ inline auto launch(hipStream_t stream, std::size_t global, std::size_t local)
};
}
template<class F>
template <class F>
__host__ __device__ auto gs_invoke(F&& f, std::size_t i, index idx) -> decltype(f(i, idx))
{
return f(i, idx);
}
template<class F>
template <class F>
__host__ __device__ auto gs_invoke(F&& f, std::size_t i, index) -> decltype(f(i))
{
return f(i);
......
......@@ -9,14 +9,14 @@ namespace device {
struct sum
{
template<class T>
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T operator()(T x, T y) const
{
return x + y;
}
};
template<std::size_t N, class Op, class T, class F>
template <std::size_t N, class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
{
using type = decltype(f(idx.local));
......@@ -32,7 +32,7 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
for(std::size_t s = 1; s < N; s *= 2)
{
const std::size_t index = 2 * s * idx.local;
if (index < N)
if(index < N)
{
buffer[index] = op(buffer[index], buffer[index + s]);
}
......@@ -44,28 +44,32 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
void reduce_sum(hipStream_t stream, const argument& result, const argument& arg)
{
auto&& output_shape = result.get_shape();
auto&& input_shape = arg.get_shape();
auto&& input_shape = arg.get_shape();
std::vector<std::size_t> reduce_lens;
std::transform(output_shape.lens().begin(), output_shape.lens().end(), input_shape.lens().begin(), std::back_inserter(reduce_lens), [](auto x, auto y) -> std::size_t {
if (x == y)
return 1;
else
return y;
});
std::transform(output_shape.lens().begin(),
output_shape.lens().end(),
input_shape.lens().begin(),
std::back_inserter(reduce_lens),
[](auto x, auto y) -> std::size_t {
if(x == y)
return 1;
else
return y;
});
shape reduce_slice{output_shape.type(), reduce_lens, input_shape.strides()};
hip_visit_all(result, arg, reduce_slice)([&](auto output, auto input, auto reduce_shape) {
auto nelements = result.get_shape().elements();
auto relements = reduce_slice.elements();
const std::size_t block_size = 1024;
gs_launch(stream, nelements*block_size, block_size)([=](auto i, auto idx) __device__ {
auto base_idx = output.get_shape().multi(i/block_size);
auto offset = input.get_shape().index(base_idx);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
auto base_idx = output.get_shape().multi(i / block_size);
auto offset = input.get_shape().index(base_idx);
auto r = block_reduce<block_size>(idx, sum{}, 0, relements, [&](auto j) __device__ {
return input.data()[reduce_shape.index(j) + offset];
return input.data()[reduce_shape.index(j) + offset];
});
if (idx.local == 0)
output.data()[i/block_size] = r;
if(idx.local == 0)
output.data()[i / block_size] = r;
});
});
}
......
......@@ -3451,7 +3451,7 @@ struct test_reduce_sum : verify_program<test_reduce_sum>
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 8, 8}};
auto x = p.add_parameter("x", s);
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_sum{{1}}, x);
return p;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment