Commit be7467a8 authored by Paul's avatar Paul
Browse files

Formatting

parent c60998b3
...@@ -48,7 +48,7 @@ struct min ...@@ -48,7 +48,7 @@ struct min
struct lowest struct lowest
{ {
template<class T> template <class T>
operator T() const operator T() const
{ {
return device_cast(std::numeric_limits<host_type<T>>::lowest()); return device_cast(std::numeric_limits<host_type<T>>::lowest());
...@@ -57,7 +57,7 @@ struct lowest ...@@ -57,7 +57,7 @@ struct lowest
struct highest struct highest
{ {
template<class T> template <class T>
operator T() const operator T() const
{ {
return device_cast(std::numeric_limits<host_type<T>>::max()); return device_cast(std::numeric_limits<host_type<T>>::max());
...@@ -164,7 +164,7 @@ __device__ void dpp_reduce(float& x, sum) ...@@ -164,7 +164,7 @@ __device__ void dpp_reduce(float& x, sum)
template <std::size_t N, class Op, class T, class F> template <std::size_t N, class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f) __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
{ {
using type = decltype(f(idx.local)); using type = decltype(f(idx.local));
MIGRAPHX_DEVICE_SHARED type buffer[N / 64]; MIGRAPHX_DEVICE_SHARED type buffer[N / 64];
type x = init; type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); }); idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
...@@ -193,8 +193,14 @@ constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_si ...@@ -193,8 +193,14 @@ constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_si
return block_size; return block_size;
} }
template<class Op, class T, class Input, class Output> template <class Op, class T, class Input, class Output>
void reduce(hipStream_t stream, const argument& result, const argument& arg, Op op, T init, Input read_input, Output read_output) void reduce(hipStream_t stream,
const argument& result,
const argument& arg,
Op op,
T init,
Input read_input,
Output read_output)
{ {
auto&& output_shape = result.get_shape(); auto&& output_shape = result.get_shape();
auto&& input_shape = arg.get_shape(); auto&& input_shape = arg.get_shape();
......
...@@ -6,7 +6,6 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -6,7 +6,6 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void reduce_sum(hipStream_t stream, const argument& result, const argument& arg) void reduce_sum(hipStream_t stream, const argument& result, const argument& arg)
{ {
reduce(stream, result, arg, sum{}, 0, id{}, id{}); reduce(stream, result, arg, sum{}, 0, id{}, id{});
......
...@@ -12,7 +12,8 @@ shape hip_reduce_sum::compute_shape(std::vector<shape> inputs) const ...@@ -12,7 +12,8 @@ shape hip_reduce_sum::compute_shape(std::vector<shape> inputs) const
return op.compute_shape(inputs); return op.compute_shape(inputs);
} }
argument hip_reduce_sum::compute(context& ctx, const shape&, const std::vector<argument>& args) const argument
hip_reduce_sum::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
device::reduce_sum(ctx.get_stream().get(), args.back(), args.front()); device::reduce_sum(ctx.get_stream().get(), args.back(), args.front());
return args.back(); return args.back();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment