"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "35eb3bd29b12169630eb9af6cc4df1d4621ec213"
Commit d0afbcaf authored by Khalique's avatar Khalique
Browse files

continued testing tf operators, changed biasadd parsing

parent e20d1399
...@@ -176,9 +176,16 @@ struct tf_parser ...@@ -176,9 +176,16 @@ struct tf_parser
instruction_ref instruction_ref
parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{ {
uint64_t axis = 1; uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
auto l0 = prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]); auto l0 = args[0];
return prog.add_instruction(op::add{}, args[0], l0); // otherwise, if the input is a parameter to the graph, then first insert transpose
if(l0->name() == "@param")
{
if(is_nhwc)
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, l0);
};
auto l1 = prog.add_instruction(op::broadcast{axis, l0->get_shape()}, args[1]);
return prog.add_instruction(op::add{}, l0, l1);
} }
instruction_ref instruction_ref
......
2
0 Placeholder*
shape
:*
dtype0
2
1 Placeholder*
dtype0*
shape
:
add_bcast1Add01*
T0"
\ No newline at end of file
:
0 Placeholder*
shape:*
dtype0
:
1 Placeholder*
dtype0*
shape:

add1Add01*
T0"
\ No newline at end of file
;
0 Placeholder*
shape:*
dtype0
/
1 Placeholder*
dtype0*
shape:
:
bias_add1BiasAdd01*
T0*
data_formatNHWC"
\ No newline at end of file
:
0 Placeholder*
shape:*
dtype0
identityIdentity0*
T0"
\ No newline at end of file
:
0 Placeholder*
dtype0*
shape:
u
avg_poolingAvgPool0*
ksize
*
paddingVALID*
T0*
data_formatNHWC*
strides

u
max_poolingMaxPool0*
data_formatNHWC*
strides
*
ksize
*
paddingVALID*
T0"
\ No newline at end of file
...@@ -7,6 +7,78 @@ ...@@ -7,6 +7,78 @@
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include "test.hpp" #include "test.hpp"
TEST_CASE(add_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
p.add_instruction(migraphx::op::add{}, l0, l1);
auto prog = migraphx::parse_tf("add_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(add_bcast_test)
{
migraphx::program p;
migraphx::shape s0{migraphx::shape::float_type, {2, 3}};
auto l0 = p.add_parameter("0", s0);
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 1}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1);
p.add_instruction(migraphx::op::add{}, l2, l3);
auto prog = migraphx::parse_tf("add_bcast_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(biasadd_test)
{
migraphx::program p;
migraphx::shape s0{migraphx::shape::float_type, {1, 1, 1, 500}};
uint64_t axis = 1;
auto l0 = p.add_parameter("0", s0);
auto l1 = p.add_instruction(migraphx::op::transpose{{0,3,1,2}}, l0);
auto l2 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
auto l3 = p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape()}, l2);
p.add_instruction(migraphx::op::add{}, l1, l3);
auto prog = migraphx::parse_tf("biasadd_test.pb", true);
EXPECT(p == prog);
}
TEST_CASE(identity_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::identity{}, l0);
auto prog = migraphx::parse_tf("identity_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(pooling_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 16, 16, 3}});
migraphx::op::pooling avg_pool_op{"average"};
migraphx::op::pooling max_pool_op{"max"};
avg_pool_op.padding_mode = migraphx::op::padding_mode_t::valid;
max_pool_op.padding_mode = migraphx::op::padding_mode_t::valid;
avg_pool_op.stride = {2, 2};
max_pool_op.stride = {2, 2};
avg_pool_op.lengths = {2, 2};
max_pool_op.lengths = {2, 2};
auto l1 = p.add_instruction(migraphx::op::transpose{{0,3,1,2}}, l0);
p.add_instruction(max_pool_op, l1);
auto l2 = p.add_instruction(migraphx::op::transpose{{0,3,1,2}}, l0);
p.add_instruction(avg_pool_op, l2);
auto prog = migraphx::parse_tf("pooling_test.pb", true);
EXPECT(p == prog);
}
TEST_CASE(relu_test) TEST_CASE(relu_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment