Commit dcbc9255 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into hcc26

parents e4d9de80 40b0c973
......@@ -164,6 +164,7 @@ struct tf_parser
add_binary_op("Sub", op::sub{});
add_mem_op("AvgPool", &tf_parser::parse_pooling);
add_mem_op("BatchMatMul", &tf_parser::parse_matmul, false);
add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
add_mem_op("ConcatV2", &tf_parser::parse_concat, false);
add_mem_op("Const", &tf_parser::parse_constant);
......@@ -530,6 +531,15 @@ struct tf_parser
transb = attributes.at("transpose_a").b();
}
if(contains(attributes, "adj_x"))
{
transa = attributes.at("adj_x").b();
}
if(contains(attributes, "adj_y"))
{
transb = attributes.at("adj_y").b();
}
std::vector<int64_t> perm(args[0]->get_shape().lens().size());
std::iota(perm.begin(), perm.end(), int64_t{0});
// swap the last two elements
......
:
0 Placeholder*
shape:*
dtype0
:
1 Placeholder*
dtype0*
shape:
D
batchmatmul1 BatchMatMul01*
adj_x(*
adj_y(*
T0"
\ No newline at end of file
......@@ -48,6 +48,21 @@ TEST_CASE(add_bcast_test)
EXPECT(p == prog);
}
TEST_CASE(batchmatmul_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 4}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 4, 8}});
auto trans_l0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0);
auto trans_l1 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1);
p.add_instruction(migraphx::op::dot{}, trans_l0, trans_l1);
auto prog = optimize_tf("batchmatmul_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(batchnorm_test)
{
float epsilon = 1.001e-5f;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment