#include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace onnx { struct parse_dropout : op_parser { std::vector operators() const { return {{"Dropout"}}; } std::vector parse(const op_desc& /*opd*/, const onnx_parser& /*parser*/, const onnx_parser::node_info& info, std::vector args) const { auto out = info.add_instruction(make_op("identity"), args[0]); auto s = args[0]->get_shape(); std::vector vec(s.elements(), 1); shape mask_s{shape::bool_type, s.lens()}; auto mask = info.add_literal(literal(mask_s, vec)); return {out, mask}; } }; } // namespace onnx } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx