"...git@developer.sourcefind.cn:lacacy/qwen_lmdeploy.git" did not exist on "d592fbea9f1fd3ed15b6d7836d217ecfb5711b5a"
Commit db6038e3 authored by Khalique's avatar Khalique
Browse files

add identity onnx parsing

parent 80203608
......@@ -53,6 +53,7 @@ struct onnx_parser
add_generic_op("Relu", op::relu{});
// disable dropout for inference
add_generic_op("Dropout", op::identity{});
add_generic_op("Identity", op::identity{});
add_broadcastable_binary_op("Add", op::add{});
add_broadcastable_binary_op("Div", op::div{});
......
......@@ -1091,4 +1091,18 @@ TEST_CASE(contiguous_test)
EXPECT(migraph::verify_range(results_vector, gold));
}
TEST_CASE(identity_test)
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {2,2}};
std::vector<int> data{1, 2, 3, 4};
auto l = p.add_literal(migraph::literal{s, data});
p.add_instruction(migraph::op::identity{}, l);
p.compile(migraph::cpu::target{});
auto result = p.eval({});
std::vector<int> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(std::equal(data.begin(), data.end(), results_vector.begin()));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment