"docs/en_US/Tutorial/InstallationLinux.md" did not exist on "a1f926661689b6db3bba72f7cd381aac118e5c0a"
Unverified Commit 66da8dd1 authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge branch 'develop' into slice_op

parents fe08345c 7f28f161
...@@ -160,12 +160,14 @@ struct tf_parser ...@@ -160,12 +160,14 @@ struct tf_parser
add_binary_op("Add", op::add{}); add_binary_op("Add", op::add{});
add_binary_op("Mul", op::mul{}); add_binary_op("Mul", op::mul{});
add_binary_op("Pow", op::pow{});
add_binary_op("SquaredDifference", op::sqdiff{}); add_binary_op("SquaredDifference", op::sqdiff{});
add_binary_op("Sub", op::sub{}); add_binary_op("Sub", op::sub{});
add_mem_op("AvgPool", &tf_parser::parse_pooling); add_mem_op("AvgPool", &tf_parser::parse_pooling);
add_mem_op("BatchMatMul", &tf_parser::parse_matmul, false); add_mem_op("BatchMatMul", &tf_parser::parse_matmul, false);
add_mem_op("BiasAdd", &tf_parser::parse_biasadd); add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
add_mem_op("Cast", &tf_parser::parse_cast, false);
add_mem_op("ConcatV2", &tf_parser::parse_concat, false); add_mem_op("ConcatV2", &tf_parser::parse_concat, false);
add_mem_op("Const", &tf_parser::parse_constant); add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv); add_mem_op("Conv2D", &tf_parser::parse_conv);
...@@ -308,6 +310,13 @@ struct tf_parser ...@@ -308,6 +310,13 @@ struct tf_parser
return prog.add_instruction(op::add{}, args[0], l0); return prog.add_instruction(op::add{}, args[0], l0);
} }
instruction_ref
parse_cast(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
shape::type_t type = parse_type(attributes.at("DstT").type());
return prog.add_instruction(op::convert{type}, std::move(args));
}
instruction_ref instruction_ref
parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
......
:
0 Placeholder*
dtype0*
shape:
:
1 Placeholder*
dtype0*
shape:

pow1Pow01*
T0"
\ No newline at end of file
...@@ -118,6 +118,16 @@ TEST_CASE(concat_test) ...@@ -118,6 +118,16 @@ TEST_CASE(concat_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(cast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::convert{migraphx::shape::int32_type}, l0);
auto prog = optimize_tf("cast_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(const_test) TEST_CASE(const_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -327,12 +337,22 @@ TEST_CASE(pooling_test) ...@@ -327,12 +337,22 @@ TEST_CASE(pooling_test)
avg_pool_op.lengths = {2, 2}; avg_pool_op.lengths = {2, 2};
max_pool_op.lengths = {2, 2}; max_pool_op.lengths = {2, 2};
p.add_instruction(max_pool_op, l0); p.add_instruction(max_pool_op, l0);
// p.add_instruction(avg_pool_op, l0);
auto prog = optimize_tf("pooling_test.pb", true); auto prog = optimize_tf("pooling_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(pow_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
p.add_instruction(migraphx::op::pow{}, l0, l1);
auto prog = optimize_tf("pow_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(relu_test) TEST_CASE(relu_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment