Commit b0586006 authored by Paul's avatar Paul
Browse files

Formatting

parent 9b32ae0c
...@@ -59,7 +59,9 @@ struct reshape ...@@ -59,7 +59,9 @@ struct reshape
shape s{inputs.front().type(), rdims}; shape s{inputs.front().type(), rdims};
if(s.elements() != inputs.front().elements()) if(s.elements() != inputs.front().elements())
MIGRAPHX_THROW("Wrong number of elements for reshape: reshape has " + std::to_string(s.elements()) + " elements whereas the input has " + std::to_string(inputs.front().elements())); MIGRAPHX_THROW("Wrong number of elements for reshape: reshape has " +
std::to_string(s.elements()) + " elements whereas the input has " +
std::to_string(inputs.front().elements()));
return s; return s;
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
......
...@@ -31,7 +31,8 @@ void rewrite_pooling::apply(program& prog) const ...@@ -31,7 +31,8 @@ void rewrite_pooling::apply(program& prog) const
continue; continue;
std::int64_t n = s.lens()[0]; std::int64_t n = s.lens()[0];
std::int64_t c = s.lens()[1]; std::int64_t c = s.lens()[1];
auto reshape = prog.insert_instruction(ins, op::reshape{{n*c, -1}}, ins->inputs().front()); auto reshape =
prog.insert_instruction(ins, op::reshape{{n * c, -1}}, ins->inputs().front());
auto pooling = prog.insert_instruction(ins, op::reduce_mean{{1}}, reshape); auto pooling = prog.insert_instruction(ins, op::reduce_mean{{1}}, reshape);
prog.replace_instruction(ins, op::reshape{{n, c, 1, 1}}, pooling); prog.replace_instruction(ins, op::reshape{{n, c, 1, 1}}, pooling);
} }
......
...@@ -210,13 +210,13 @@ constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_si ...@@ -210,13 +210,13 @@ constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_si
template <class Op, class T, class Input, class Output> template <class Op, class T, class Input, class Output>
void reduce_multi_impl(hipStream_t stream, void reduce_multi_impl(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg, const argument& arg,
Op op, Op op,
T init, T init,
Input read_input, Input read_input,
Output read_output, Output read_output,
const shape& reduce_slice) const shape& reduce_slice)
{ {
hip_visit_all(result, arg, reduce_slice)([&](auto output, auto input, auto reduce_shape) { hip_visit_all(result, arg, reduce_slice)([&](auto output, auto input, auto reduce_shape) {
auto nelements = result.get_shape().elements(); auto nelements = result.get_shape().elements();
...@@ -239,14 +239,14 @@ void reduce_multi_impl(hipStream_t stream, ...@@ -239,14 +239,14 @@ void reduce_multi_impl(hipStream_t stream,
template <class Op, class T, class Input, class Output> template <class Op, class T, class Input, class Output>
void reduce_standard_impl(hipStream_t stream, void reduce_standard_impl(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg, const argument& arg,
Op op, Op op,
T init, T init,
Input read_input, Input read_input,
Output read_output, Output read_output,
std::size_t relements, std::size_t relements,
std::size_t stride) std::size_t stride)
{ {
hip_visit_all(result, arg)([&](auto output, auto input) { hip_visit_all(result, arg)([&](auto output, auto input) {
auto nelements = result.get_shape().elements(); auto nelements = result.get_shape().elements();
...@@ -254,8 +254,8 @@ void reduce_standard_impl(hipStream_t stream, ...@@ -254,8 +254,8 @@ void reduce_standard_impl(hipStream_t stream,
const std::size_t max_block_size = 256; const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(relements, max_block_size); const std::size_t block_size = compute_block_size(relements, max_block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ { gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
const auto out_idx = i / block_size; const auto out_idx = i / block_size;
const auto base_idx = out_idx * stride; const auto base_idx = out_idx * stride;
auto r = block_reduce<max_block_size>(idx, op, init, relements, [&](auto j) __device__ { auto r = block_reduce<max_block_size>(idx, op, init, relements, [&](auto j) __device__ {
return read_input(input.data()[base_idx + j]); return read_input(input.data()[base_idx + j]);
}); });
...@@ -276,10 +276,25 @@ void reduce(hipStream_t stream, ...@@ -276,10 +276,25 @@ void reduce(hipStream_t stream,
{ {
auto&& output_shape = result.get_shape(); auto&& output_shape = result.get_shape();
auto&& input_shape = arg.get_shape(); auto&& input_shape = arg.get_shape();
if (input_shape.standard() and output_shape.standard() and output_shape.lens().back() != input_shape.lens().back() and std::equal(output_shape.lens().begin(), std::prev(output_shape.lens().end()), input_shape.lens().begin())) if(input_shape.standard() and output_shape.standard() and
output_shape.lens().back() != input_shape.lens().back() and
std::equal(output_shape.lens().begin(),
std::prev(output_shape.lens().end()),
input_shape.lens().begin()))
{ {
std::size_t stride = std::accumulate(input_shape.strides().begin(), input_shape.strides().end(), 1, std::multiplies<size_t>()); std::size_t stride = std::accumulate(input_shape.strides().begin(),
reduce_standard_impl(stream, result, arg, op, init, read_input, read_output, input_shape.lens().back(), stride); input_shape.strides().end(),
1,
std::multiplies<size_t>());
reduce_standard_impl(stream,
result,
arg,
op,
init,
read_input,
read_output,
input_shape.lens().back(),
stride);
} }
else else
{ {
...@@ -296,7 +311,6 @@ void reduce(hipStream_t stream, ...@@ -296,7 +311,6 @@ void reduce(hipStream_t stream,
}); });
shape reduce_slice{output_shape.type(), reduce_lens}; shape reduce_slice{output_shape.type(), reduce_lens};
reduce_multi_impl(stream, result, arg, op, init, read_input, read_output, reduce_slice); reduce_multi_impl(stream, result, arg, op, init, read_input, read_output, reduce_slice);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment