#ifndef MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP #define MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP #include #include #include #include #include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace op { struct convolution { std::vector padding = {0, 0}; std::vector stride = {1, 1}; std::vector dilation = {1, 1}; int group = 1; padding_mode_t padding_mode = default_; template static auto reflect(Self& self, F f) { return pack(f(self.padding, "padding"), f(self.stride, "stride"), f(self.dilation, "dilation"), f(self.group, "group"), f(self.padding_mode, "padding_mode")); } std::string name() const { return "convolution"; } void check_attribute_size() const { if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and stride.size() == dilation.size())) { MIGRAPHX_THROW("CONVOLUTION: inconsistent attribute sizes"); } } value attributes() const { return {{"normalize_padding", "padding"}}; } shape normalize_compute_shape(std::vector inputs) const { check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3); check_attribute_size(); // dim num of input and attribute should match auto in_lens = inputs[0].lens(); auto input_size = in_lens.size(); auto padding_size = padding.size(); if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2)) { MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!"); } const shape& input = inputs.at(0); const shape& weights = inputs.at(1); size_t kdims = input_size - 2; if(kdims != this->kdims()) { MIGRAPHX_THROW("convolution: input k-dims does not match attribute size"); } if(input.lens().at(1) != (weights.lens().at(1) * group)) MIGRAPHX_THROW("CONVOLUTION: Mismatch channel numbers"); std::vector output_lens{in_lens[0], weights.lens()[0]}; for(size_t i = 0; i < kdims; i++) { auto padding_factor = 2 * padding[i]; if(padding_size == 2 * kdims) padding_factor = padding[i] + padding[i + kdims]; output_lens.push_back(std::size_t(std::max( 1, (in_lens[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) + padding_factor) / stride[i] + 1))); } return inputs[0].with_lens(output_lens); } size_t kdims() const { check_attribute_size(); return stride.size(); } }; } // namespace op } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx #endif