#ifndef MIGRAPHX_GUARD_OPERATORS_POOLING_HPP #define MIGRAPHX_GUARD_OPERATORS_POOLING_HPP #include #include #include #include #include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace op { struct pooling { std::string mode = "average"; std::array padding = {{0, 0}}; std::array stride = {{1, 1}}; std::array lengths = {{1, 1}}; padding_mode_t padding_mode = default_; template static auto reflect(Self& self, F f) { return pack(f(self.mode, "mode"), f(self.padding, "padding"), f(self.padding, "padding_mode"), f(self.stride, "stride"), f(self.lengths, "lengths")); } std::string name() const { return "pooling"; } shape compute_shape(std::vector inputs) const { check_shapes{inputs, *this}.has(1).only_dims(4); const shape& input = inputs.at(0); auto t = input.type(); assert(lengths[0] <= (input.lens()[2] + 2 * padding[0])); assert(lengths[1] <= (input.lens()[3] + 2 * padding[1])); if(padding_mode == default_) { return {t, { input.lens()[0], input.lens()[1], std::size_t(std::max( 1, floor_divide( input.lens()[2] + 2 * padding[0] - lengths[0], stride[0]) + 1)), std::size_t(std::max( 1, floor_divide( input.lens()[3] + 2 * padding[1] - lengths[1], stride[1]) + 1)), }}; } else if(padding_mode == same) { return {t, {input.lens()[0], input.lens()[1], ceil_divide(input.lens()[2], stride[0]), ceil_divide(input.lens()[3], stride[1])}}; } else if(padding_mode == valid) { return { t, { input.lens()[0], input.lens()[1], std::size_t(std::max( 1, floor_divide(input.lens()[2] - lengths[0], stride[0]) + 1)), std::size_t(std::max( 1, floor_divide(input.lens()[3] - lengths[1], stride[1]) + 1)), }}; } else { MIGRAPHX_THROW("Invalid padding mode"); } } }; } // namespace op } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx #endif