"git@developer.sourcefind.cn:OpenDAS/autoawq.git" did not exist on "d62aebfed7667e138f4421c97bcee5ecbd4758c4"
Commit fb4b631d authored by Khalique's avatar Khalique
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into scalar_parsing

parents 9cdeab11 1966b2c1
...@@ -110,6 +110,7 @@ struct tf_parser ...@@ -110,6 +110,7 @@ struct tf_parser
add_generic_op("Relu", op::relu{}); add_generic_op("Relu", op::relu{});
add_binary_op("Add", op::add{}); add_binary_op("Add", op::add{});
add_binary_op("Mul", op::mul{});
add_mem_op("AvgPool", &tf_parser::parse_pooling); add_mem_op("AvgPool", &tf_parser::parse_pooling);
add_mem_op("BiasAdd", &tf_parser::parse_biasadd); add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
...@@ -117,6 +118,7 @@ struct tf_parser ...@@ -117,6 +118,7 @@ struct tf_parser
add_mem_op("Const", &tf_parser::parse_constant); add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv); add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm); add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MatMul", &tf_parser::parse_matmul);
add_mem_op("MaxPool", &tf_parser::parse_pooling); add_mem_op("MaxPool", &tf_parser::parse_pooling);
add_mem_op("Mean", &tf_parser::parse_mean); add_mem_op("Mean", &tf_parser::parse_mean);
add_mem_op("Pack", &tf_parser::parse_pack); add_mem_op("Pack", &tf_parser::parse_pack);
...@@ -337,6 +339,32 @@ struct tf_parser ...@@ -337,6 +339,32 @@ struct tf_parser
return prog.add_instruction(op, {args[0], weights}); return prog.add_instruction(op, {args[0], weights});
} }
instruction_ref
parse_matmul(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
bool transa = false;
bool transb = false;
if(contains(attributes, "transpose_a"))
{
transa = attributes.at("transpose_a").b();
}
if(contains(attributes, "transpose_b"))
{
transb = attributes.at("transpose_a").b();
}
std::vector<int64_t> perm(args[0]->get_shape().lens().size());
std::iota(perm.begin(), perm.end(), int64_t{0});
// swap the last two elements
std::iter_swap(perm.end() - 1, perm.end() - 2);
auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
return prog.add_instruction(op::dot{}, l1, l2);
}
instruction_ref instruction_ref
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
......
2
0 Placeholder*
dtype0*
shape
:
2
1 Placeholder*
dtype0*
shape
:
F
matmul1MatMul01*
T0*
transpose_a(*
transpose_b("
\ No newline at end of file
:
0 Placeholder*
shape:*
dtype0
:
1 Placeholder*
dtype0*
shape:

mul1Mul01*
T0"
\ No newline at end of file
...@@ -129,6 +129,33 @@ TEST_CASE(identity_test) ...@@ -129,6 +129,33 @@ TEST_CASE(identity_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(matmul_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {8, 4}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 8}});
auto trans_l0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0);
auto trans_l1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
p.add_instruction(migraphx::op::dot{}, trans_l0, trans_l1);
auto prog = migraphx::parse_tf("matmul_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(mul_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});
p.add_instruction(migraphx::op::mul{}, l0, l1);
auto prog = migraphx::parse_tf("mul_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(pack_test) TEST_CASE(pack_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment