#ifndef MIGRAPHX_GUARD_OPERATORS_RESHAPE_HPP #define MIGRAPHX_GUARD_OPERATORS_RESHAPE_HPP #include #include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace op { struct reshape { std::vector dims; template static auto reflect(Self& self, F f) { return pack(f(self.dims, "dims")); } std::string name() const { return "reshape"; } shape compute_shape(std::vector inputs) const { check_shapes{inputs, *this}.has(1).standard(); auto&& idims = inputs.front().lens(); std::vector rdims(dims.begin(), dims.end()); auto n_neg_dims = std::count(dims.begin(), dims.end(), -1); if(n_neg_dims > 1) MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim"); for(std::size_t i = 0; i < dims.size(); i++) { if(dims[i] == 0) rdims[i] = idims[i]; // since rdims using size_t type, -1 is the max value // is size_t that cause later compuation incorrect if(dims[i] == -1) rdims[i] = 1; } if(n_neg_dims > 0) { size_t missing_dim = inputs.front().elements() / std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies()); for(std::size_t i = 0; i < rdims.size(); i++) { if(dims[i] == -1) rdims[i] = missing_dim; } } shape s{inputs.front().type(), rdims}; if(s.elements() != inputs.front().elements()) MIGRAPHX_THROW("Reshape: Wrong number of elements for reshape: reshape has " + std::to_string(s.elements()) + " elements whereas the input has " + std::to_string(inputs.front().elements())); return s; } argument compute(shape output_shape, std::vector args) const { return {std::move(output_shape), std::move(args.front().data)}; } std::ptrdiff_t output_alias(const std::vector&) const { return 0; } }; } // namespace op } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx #endif