Unverified Commit a671ef89 authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge branch 'develop' into parse_softmax_using_new_impl

parents 63a0b166 5944d0c7
......@@ -166,6 +166,7 @@ struct tf_parser
add_mem_op("AvgPool", &tf_parser::parse_pooling);
add_mem_op("BatchMatMul", &tf_parser::parse_matmul, false);
add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
add_mem_op("Cast", &tf_parser::parse_cast, false);
add_mem_op("ConcatV2", &tf_parser::parse_concat, false);
add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv);
......@@ -307,6 +308,13 @@ struct tf_parser
return prog.add_instruction(op::add{}, args[0], l0);
}
instruction_ref
parse_cast(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
shape::type_t type = parse_type(attributes.at("DstT").type());
return prog.add_instruction(op::convert{type}, std::move(args));
}
instruction_ref
parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
......
......@@ -118,6 +118,16 @@ TEST_CASE(concat_test)
EXPECT(p == prog);
}
TEST_CASE(cast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::convert{migraphx::shape::int32_type}, l0);
auto prog = optimize_tf("cast_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(const_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment