Commit 9b32ae0c authored by Paul's avatar Paul
Browse files

Reduce dims in pooling

parent d1e0225d
......@@ -59,7 +59,7 @@ struct reshape
shape s{inputs.front().type(), rdims};
if(s.elements() != inputs.front().elements())
MIGRAPHX_THROW("Wrong number of elements for reshape");
MIGRAPHX_THROW("Wrong number of elements for reshape: reshape has " + std::to_string(s.elements()) + " elements whereas the input has " + std::to_string(inputs.front().elements()));
return s;
}
argument compute(shape output_shape, std::vector<argument> args) const
......
......@@ -2,6 +2,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/program.hpp>
......@@ -28,7 +29,11 @@ void rewrite_pooling::apply(program& prog) const
continue;
if(s.lens()[2] != op.lengths[0] and s.lens()[3] != op.lengths[1])
continue;
prog.replace_instruction(ins, op::reduce_mean{{2, 3}}, ins->inputs().front());
std::int64_t n = s.lens()[0];
std::int64_t c = s.lens()[1];
auto reshape = prog.insert_instruction(ins, op::reshape{{n*c, -1}}, ins->inputs().front());
auto pooling = prog.insert_instruction(ins, op::reduce_mean{{1}}, reshape);
prog.replace_instruction(ins, op::reshape{{n, c, 1, 1}}, pooling);
}
}
......
......@@ -16,6 +16,12 @@ struct hip_array
MIGRAPHX_DEVICE_CONSTEXPR T& operator[](std::size_t i) { return d[i]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& operator[](std::size_t i) const { return d[i]; }
MIGRAPHX_DEVICE_CONSTEXPR T& front() { return d[0]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& front() const { return d[0]; }
MIGRAPHX_DEVICE_CONSTEXPR T& back() { return d[N - 1]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& back() const { return d[N - 1]; }
MIGRAPHX_DEVICE_CONSTEXPR T* data() { return d; }
MIGRAPHX_DEVICE_CONSTEXPR const T* data() const { return d; }
......
......@@ -209,28 +209,15 @@ constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_si
}
template <class Op, class T, class Input, class Output>
void reduce(hipStream_t stream,
void reduce_multi_impl(hipStream_t stream,
const argument& result,
const argument& arg,
Op op,
T init,
Input read_input,
Output read_output)
Output read_output,
const shape& reduce_slice)
{
auto&& output_shape = result.get_shape();
auto&& input_shape = arg.get_shape();
std::vector<std::size_t> reduce_lens;
std::transform(output_shape.lens().begin(),
output_shape.lens().end(),
input_shape.lens().begin(),
std::back_inserter(reduce_lens),
[](auto x, auto y) -> std::size_t {
if(x == y)
return 1;
else
return y;
});
shape reduce_slice{output_shape.type(), reduce_lens};
hip_visit_all(result, arg, reduce_slice)([&](auto output, auto input, auto reduce_shape) {
auto nelements = result.get_shape().elements();
auto relements = reduce_slice.elements();
......@@ -250,6 +237,69 @@ void reduce(hipStream_t stream,
});
}
template <class Op, class T, class Input, class Output>
void reduce_standard_impl(hipStream_t stream,
const argument& result,
const argument& arg,
Op op,
T init,
Input read_input,
Output read_output,
std::size_t relements,
std::size_t stride)
{
hip_visit_all(result, arg)([&](auto output, auto input) {
auto nelements = result.get_shape().elements();
const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(relements, max_block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
const auto out_idx = i / block_size;
const auto base_idx = out_idx * stride;
auto r = block_reduce<max_block_size>(idx, op, init, relements, [&](auto j) __device__ {
return read_input(input.data()[base_idx + j]);
});
if(idx.local == 0)
output.data()[out_idx] = read_output(r);
});
});
}
template <class Op, class T, class Input, class Output>
void reduce(hipStream_t stream,
const argument& result,
const argument& arg,
Op op,
T init,
Input read_input,
Output read_output)
{
auto&& output_shape = result.get_shape();
auto&& input_shape = arg.get_shape();
if (input_shape.standard() and output_shape.standard() and output_shape.lens().back() != input_shape.lens().back() and std::equal(output_shape.lens().begin(), std::prev(output_shape.lens().end()), input_shape.lens().begin()))
{
std::size_t stride = std::accumulate(input_shape.strides().begin(), input_shape.strides().end(), 1, std::multiplies<size_t>());
reduce_standard_impl(stream, result, arg, op, init, read_input, read_output, input_shape.lens().back(), stride);
}
else
{
std::vector<std::size_t> reduce_lens;
std::transform(output_shape.lens().begin(),
output_shape.lens().end(),
input_shape.lens().begin(),
std::back_inserter(reduce_lens),
[](auto x, auto y) -> std::size_t {
if(x == y)
return 1;
else
return y;
});
shape reduce_slice{output_shape.type(), reduce_lens};
reduce_multi_impl(stream, result, arg, op, init, read_input, read_output, reduce_slice);
}
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment