#include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace tf { struct parse_expanddims : op_parser { std::vector operators() const { return {{"ExpandDims"}}; } instruction_ref parse(const op_desc& /*opd*/, const tf_parser& /*parser*/, const tf_parser::node_info& info, std::vector args) const { std::vector input_dims = args[0]->get_shape().lens(); std::vector new_dims(input_dims.begin(), input_dims.end()); int num_dims = input_dims.size(); int32_t dim = args[1]->eval().at(); if(dim < 0) { new_dims.insert(new_dims.begin() + (num_dims + dim + 1), 1); } else { new_dims.insert(new_dims.begin() + dim, 1); } return info.add_instruction(make_op("reshape", {{"dims", new_dims}}), args[0]); } }; } // namespace tf } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx