Commit 461a68cc authored by Khalique's avatar Khalique
Browse files

transpose literals, adjust shapes for params instead of adding transpose

parent 26df5406
...@@ -36,26 +36,26 @@ struct tf_parser ...@@ -36,26 +36,26 @@ struct tf_parser
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
std::vector<size_t> parse_axes(attribute_map& attributes, const std::string& s) std::vector<size_t> parse_axes(attribute_map& attributes, const std::string& s) const
{ {
auto attrs = attributes.at(s).list().i(); auto attrs = attributes.at(s).list().i();
std::vector<size_t> axes; std::vector<size_t> axes;
copy(attrs.begin(), attrs.end(), std::back_inserter(axes)); copy(attrs.begin(), attrs.end(), std::back_inserter(axes));
if(is_nhwc) if(is_nhwc)
{ {
for(size_t i = 0; i < axes.size(); ++i) for(size_t& axis: axes)
{ {
parse_axis(axes.at(i)); parse_axis(axis);
} }
} }
return axes; return axes;
} }
template <class T> template <class T>
void reorder_data(std::vector<T>& prev_data) void reorder_data(std::vector<T>& prev_data) const
{ {
std::vector<T> new_data(prev_data.size()); std::vector<T> new_data(prev_data.size());
for(auto i = 0; i < new_data.size(); i++) for(size_t i = 0; i < new_data.size(); i++)
{ {
auto new_idx = i; auto new_idx = i;
parse_axis(new_idx); parse_axis(new_idx);
...@@ -65,7 +65,7 @@ struct tf_parser ...@@ -65,7 +65,7 @@ struct tf_parser
} }
template <class T> template <class T>
void parse_axis(T& dim) void parse_axis(T& dim) const
{ {
if(is_nhwc) if(is_nhwc)
{ {
...@@ -80,6 +80,13 @@ struct tf_parser ...@@ -80,6 +80,13 @@ struct tf_parser
} }
} }
std::vector<int64_t> get_axes(size_t num_axes) const
{
std::vector<int64_t> axes(num_axes);
std::iota(axes.begin(), axes.end(), 0);
return axes;
}
tf_parser() tf_parser()
{ {
add_generic_op("Identity", op::identity{}); add_generic_op("Identity", op::identity{});
...@@ -230,7 +237,15 @@ struct tf_parser ...@@ -230,7 +237,15 @@ struct tf_parser
const std::vector<instruction_ref>&) const std::vector<instruction_ref>&)
{ {
literal v = parse_tensor(attributes.at("value").tensor()); literal v = parse_tensor(attributes.at("value").tensor());
return prog.add_literal(v); auto l0 = prog.add_literal(v);
size_t num_axes = l0->get_shape().lens().size();
if(num_axes >= 4)
{
std::vector<int64_t> transpose_axes = get_axes(num_axes);
reorder_data(transpose_axes);
l0 = prog.add_instruction(op::transpose{transpose_axes}, l0);
}
return l0;
} }
instruction_ref instruction_ref
...@@ -248,6 +263,7 @@ struct tf_parser ...@@ -248,6 +263,7 @@ struct tf_parser
{ {
std::vector<size_t> padding; std::vector<size_t> padding;
copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding)); copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding));
reorder_data(padding);
if(padding.size() != 4) if(padding.size() != 4)
{ {
MIGRAPHX_THROW("padding should have 4 values"); MIGRAPHX_THROW("padding should have 4 values");
...@@ -285,12 +301,13 @@ struct tf_parser ...@@ -285,12 +301,13 @@ struct tf_parser
op.dilation[1] = dilation[3]; op.dilation[1] = dilation[3];
} }
auto l0 = args[1]; auto l0 = args[1];
if(l0->get_operator().name() == "transpose" and is_nhwc) // check if weights are from a constant
if(l0->inputs().at(0)->name() == "@literal" and is_nhwc)
{ {
l0 = prog.add_instruction(op::transpose{{1, 3, 0, 2}}, args[1]); l0 = prog.add_instruction(op::transpose{{1, 3, 0, 2}}, args[1]);
} }
else else if (l0->name() != "@param")
l0 = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]); MIGRAPHX_THROW("cannot infer data format for weights");
return prog.add_instruction(op, {args[0], l0}); return prog.add_instruction(op, {args[0], l0});
} }
...@@ -378,13 +395,12 @@ struct tf_parser ...@@ -378,13 +395,12 @@ struct tf_parser
parse_squeeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_squeeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
op::squeeze op; op::squeeze op;
auto temp = args[0]->get_shape().lens();
auto axes = parse_axes(attributes, "squeeze_dims"); auto axes = parse_axes(attributes, "squeeze_dims");
copy(axes, std::back_inserter(op.axes)); copy(axes, std::back_inserter(op.axes));
auto args0_dims = args[0]->get_shape().lens(); auto args0_dims = args[0]->get_shape().lens();
if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1 if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
{ {
for(auto i = 0; i < args0_dims.size(); i++) for(size_t i = 0; i < args0_dims.size(); i++)
{ {
if(args0_dims.at(i) == 1) if(args0_dims.at(i) == 1)
{ {
...@@ -404,16 +420,12 @@ struct tf_parser ...@@ -404,16 +420,12 @@ struct tf_parser
attribute_map input_attrs = get_attributes(input); attribute_map input_attrs = get_attributes(input);
shape::type_t shape_type = parse_type(input_attrs.at("dtype").type()); shape::type_t shape_type = parse_type(input_attrs.at("dtype").type());
std::vector<size_t> dims = parse_dims(input_attrs.at("shape").shape()); std::vector<size_t> dims = parse_dims(input_attrs.at("shape").shape());
shape s = shape{shape_type, dims}; if(is_nhwc and dims.size() >= 4)
auto in_param = prog.add_parameter(name, s);
if(is_nhwc and dims.size() >= 4) // only transpose for NHWC tensors and larger
{ {
std::vector<int64_t> axes(dims.size()); reorder_data(dims);
std::iota(axes.begin(), axes.end(), 0);
reorder_data(axes);
in_param = prog.add_instruction(op::transpose{axes}, in_param);
} }
instructions[name] = in_param; shape s = shape{shape_type, dims};
instructions[name] = prog.add_parameter(name, s);
} }
for(auto&& p : nodes) for(auto&& p : nodes)
{ {
...@@ -427,7 +439,6 @@ struct tf_parser ...@@ -427,7 +439,6 @@ struct tf_parser
{ {
auto&& node = nodes.at(name); auto&& node = nodes.at(name);
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
// std::cout << name << std::endl;
for(auto&& input : node.input()) for(auto&& input : node.input())
{ {
......
No preview for this file type
...@@ -42,8 +42,7 @@ TEST_CASE(batchnorm_test) ...@@ -42,8 +42,7 @@ TEST_CASE(batchnorm_test)
migraphx::op::batch_norm_inference op{ migraphx::op::batch_norm_inference op{
epsilon, momentum, migraphx::op::batch_norm_inference::spatial}; epsilon, momentum, migraphx::op::batch_norm_inference::spatial};
migraphx::shape s0{migraphx::shape::float_type, {32}}; migraphx::shape s0{migraphx::shape::float_type, {32}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 16, 16, 32}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 32, 16, 16}});
l0 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0);
std::vector<float> const_vals(32); std::vector<float> const_vals(32);
std::fill(const_vals.begin(), const_vals.end(), 1.0f); std::fill(const_vals.begin(), const_vals.end(), 1.0f);
...@@ -60,13 +59,12 @@ TEST_CASE(batchnorm_test) ...@@ -60,13 +59,12 @@ TEST_CASE(batchnorm_test)
TEST_CASE(biasadd_test) TEST_CASE(biasadd_test)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s0{migraphx::shape::float_type, {1, 1, 1, 500}}; migraphx::shape s0{migraphx::shape::float_type, {1, 500, 1, 1}};
uint64_t axis = 1; uint64_t axis = 1;
auto l0 = p.add_parameter("0", s0); auto l0 = p.add_parameter("0", s0);
auto l1 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
auto l2 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}}); auto l2 = p.add_instruction(migraphx::op::broadcast{axis, l0->get_shape()}, l1);
auto l3 = p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape()}, l2); p.add_instruction(migraphx::op::add{}, l0, l2);
p.add_instruction(migraphx::op::add{}, l1, l3);
auto prog = migraphx::parse_tf("biasadd_test.pb", true); auto prog = migraphx::parse_tf("biasadd_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -102,17 +100,18 @@ TEST_CASE(conv_test) ...@@ -102,17 +100,18 @@ TEST_CASE(conv_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 16, 16, 3}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 32}}); std::vector<float> weight_data(3*3*3*32);
std::fill(weight_data.begin(), weight_data.end(), 1.0f);
auto l1 = p.add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 32}}, weight_data);
migraphx::op::convolution op; migraphx::op::convolution op;
op.padding_mode = migraphx::op::padding_mode_t::same; op.padding_mode = migraphx::op::padding_mode_t::same;
op.stride = {1, 1}; op.stride = {1, 1};
op.dilation = {1, 1}; op.dilation = {1, 1};
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0); auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l1);
auto l3 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l1); auto l3 = p.add_instruction(migraphx::op::transpose{{1, 3, 0, 2}}, l2);
auto l4 = p.add_instruction(migraphx::op::transpose{{1, 3, 0, 2}}, l3); p.add_instruction(op, l0, l3);
p.add_instruction(op, l2, l4);
auto prog = migraphx::parse_tf("conv_test.pb", true); auto prog = migraphx::parse_tf("conv_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -131,7 +130,7 @@ TEST_CASE(identity_test) ...@@ -131,7 +130,7 @@ TEST_CASE(identity_test)
TEST_CASE(pooling_test) TEST_CASE(pooling_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 16, 16, 3}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
migraphx::op::pooling avg_pool_op{"average"}; migraphx::op::pooling avg_pool_op{"average"};
migraphx::op::pooling max_pool_op{"max"}; migraphx::op::pooling max_pool_op{"max"};
avg_pool_op.padding_mode = migraphx::op::padding_mode_t::valid; avg_pool_op.padding_mode = migraphx::op::padding_mode_t::valid;
...@@ -140,9 +139,8 @@ TEST_CASE(pooling_test) ...@@ -140,9 +139,8 @@ TEST_CASE(pooling_test)
max_pool_op.stride = {2, 2}; max_pool_op.stride = {2, 2};
avg_pool_op.lengths = {2, 2}; avg_pool_op.lengths = {2, 2};
max_pool_op.lengths = {2, 2}; max_pool_op.lengths = {2, 2};
auto l1 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0); p.add_instruction(max_pool_op, l0);
p.add_instruction(max_pool_op, l1); p.add_instruction(avg_pool_op, l0);
p.add_instruction(avg_pool_op, l1);
auto prog = migraphx::parse_tf("pooling_test.pb", true); auto prog = migraphx::parse_tf("pooling_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment