"magic_pdf/vscode:/vscode.git/clone" did not exist on "ad0d06b6a0bea18c267c8a3cd34d51e2a681e1dc"
Commit 2679a9b6 authored by Khalique's avatar Khalique
Browse files

add cast and test

parent 03f5c679
......@@ -164,6 +164,7 @@ struct tf_parser
add_mem_op("AvgPool", &tf_parser::parse_pooling);
add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
add_mem_op("Cast", &tf_parser::parse_cast, false);
add_mem_op("ConcatV2", &tf_parser::parse_concat, false);
add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv);
......@@ -304,6 +305,13 @@ struct tf_parser
return prog.add_instruction(op::add{}, args[0], l0);
}
instruction_ref
parse_cast(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
shape::type_t type = parse_type(attributes.at("DstT").type());
return prog.add_instruction(op::convert{type}, std::move(args));
}
instruction_ref
parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
......
......@@ -103,6 +103,16 @@ TEST_CASE(concat_test)
EXPECT(p == prog);
}
TEST_CASE(cast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::convert{migraphx::shape::int32_type}, l0);
auto prog = optimize_tf("cast_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(const_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment