#ifndef MIGRAPHX_GUARD_OPERATORS_UNSQUEEZE_HPP #define MIGRAPHX_GUARD_OPERATORS_UNSQUEEZE_HPP #include #include #include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace op { struct unsqueeze { std::vector axes; template static auto reflect(Self& self, F f) { return pack(f(self.axes, "axes")); } std::string name() const { return "unsqueeze"; } shape compute_shape(std::vector inputs) const { auto input_shape = inputs[0]; auto type = input_shape.type(); auto old_lens = input_shape.lens(); std::size_t new_size = old_lens.size() + axes.size(); std::vector new_lens(new_size); std::size_t p = 0; for(std::size_t i = 0; i < new_size; i++) { if(std::find(axes.begin(), axes.end(), i) != axes.end()) { new_lens[i] = 1; } else { new_lens[i] = old_lens[p++]; } } return shape{type, new_lens}; } argument compute(shape output_shape, std::vector args) const { return {std::move(output_shape), std::move(args.front().data)}; } int output_alias(const std::vector&) const { return 0; } }; } // namespace op } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx #endif