Commit 461a68cc authored by Khalique's avatar Khalique
Browse files

transpose literals, adjust shapes for params instead of adding transpose

parent 26df5406
......@@ -36,26 +36,26 @@ struct tf_parser
std::unordered_map<std::string, op_func> ops;
std::vector<size_t> parse_axes(attribute_map& attributes, const std::string& s)
std::vector<size_t> parse_axes(attribute_map& attributes, const std::string& s) const
{
auto attrs = attributes.at(s).list().i();
std::vector<size_t> axes;
copy(attrs.begin(), attrs.end(), std::back_inserter(axes));
if(is_nhwc)
{
for(size_t i = 0; i < axes.size(); ++i)
for(size_t& axis: axes)
{
parse_axis(axes.at(i));
parse_axis(axis);
}
}
return axes;
}
template <class T>
void reorder_data(std::vector<T>& prev_data)
void reorder_data(std::vector<T>& prev_data) const
{
std::vector<T> new_data(prev_data.size());
for(auto i = 0; i < new_data.size(); i++)
for(size_t i = 0; i < new_data.size(); i++)
{
auto new_idx = i;
parse_axis(new_idx);
......@@ -65,7 +65,7 @@ struct tf_parser
}
template <class T>
void parse_axis(T& dim)
void parse_axis(T& dim) const
{
if(is_nhwc)
{
......@@ -80,6 +80,13 @@ struct tf_parser
}
}
std::vector<int64_t> get_axes(size_t num_axes) const
{
std::vector<int64_t> axes(num_axes);
std::iota(axes.begin(), axes.end(), 0);
return axes;
}
tf_parser()
{
add_generic_op("Identity", op::identity{});
......@@ -230,7 +237,15 @@ struct tf_parser
const std::vector<instruction_ref>&)
{
literal v = parse_tensor(attributes.at("value").tensor());
return prog.add_literal(v);
auto l0 = prog.add_literal(v);
size_t num_axes = l0->get_shape().lens().size();
if(num_axes >= 4)
{
std::vector<int64_t> transpose_axes = get_axes(num_axes);
reorder_data(transpose_axes);
l0 = prog.add_instruction(op::transpose{transpose_axes}, l0);
}
return l0;
}
instruction_ref
......@@ -248,6 +263,7 @@ struct tf_parser
{
std::vector<size_t> padding;
copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding));
reorder_data(padding);
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
......@@ -285,12 +301,13 @@ struct tf_parser
op.dilation[1] = dilation[3];
}
auto l0 = args[1];
if(l0->get_operator().name() == "transpose" and is_nhwc)
// check if weights are from a constant
if(l0->inputs().at(0)->name() == "@literal" and is_nhwc)
{
l0 = prog.add_instruction(op::transpose{{1, 3, 0, 2}}, args[1]);
}
else
l0 = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
else if (l0->name() != "@param")
MIGRAPHX_THROW("cannot infer data format for weights");
return prog.add_instruction(op, {args[0], l0});
}
......@@ -378,13 +395,12 @@ struct tf_parser
parse_squeeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
op::squeeze op;
auto temp = args[0]->get_shape().lens();
auto axes = parse_axes(attributes, "squeeze_dims");
copy(axes, std::back_inserter(op.axes));
auto args0_dims = args[0]->get_shape().lens();
if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
{
for(auto i = 0; i < args0_dims.size(); i++)
for(size_t i = 0; i < args0_dims.size(); i++)
{
if(args0_dims.at(i) == 1)
{
......@@ -404,16 +420,12 @@ struct tf_parser
attribute_map input_attrs = get_attributes(input);
shape::type_t shape_type = parse_type(input_attrs.at("dtype").type());
std::vector<size_t> dims = parse_dims(input_attrs.at("shape").shape());
shape s = shape{shape_type, dims};
auto in_param = prog.add_parameter(name, s);
if(is_nhwc and dims.size() >= 4) // only transpose for NHWC tensors and larger
if(is_nhwc and dims.size() >= 4)
{
std::vector<int64_t> axes(dims.size());
std::iota(axes.begin(), axes.end(), 0);
reorder_data(axes);
in_param = prog.add_instruction(op::transpose{axes}, in_param);
reorder_data(dims);
}
instructions[name] = in_param;
shape s = shape{shape_type, dims};
instructions[name] = prog.add_parameter(name, s);
}
for(auto&& p : nodes)
{
......@@ -427,7 +439,6 @@ struct tf_parser
{
auto&& node = nodes.at(name);
std::vector<instruction_ref> args;
// std::cout << name << std::endl;
for(auto&& input : node.input())
{
......
No preview for this file type
......@@ -42,8 +42,7 @@ TEST_CASE(batchnorm_test)
migraphx::op::batch_norm_inference op{
epsilon, momentum, migraphx::op::batch_norm_inference::spatial};
migraphx::shape s0{migraphx::shape::float_type, {32}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 16, 16, 32}});
l0 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0);
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 32, 16, 16}});
std::vector<float> const_vals(32);
std::fill(const_vals.begin(), const_vals.end(), 1.0f);
......@@ -60,13 +59,12 @@ TEST_CASE(batchnorm_test)
TEST_CASE(biasadd_test)
{
migraphx::program p;
migraphx::shape s0{migraphx::shape::float_type, {1, 1, 1, 500}};
migraphx::shape s0{migraphx::shape::float_type, {1, 500, 1, 1}};
uint64_t axis = 1;
auto l0 = p.add_parameter("0", s0);
auto l1 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0);
auto l2 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
auto l3 = p.add_instruction(migraphx::op::broadcast{axis, l1->get_shape()}, l2);
p.add_instruction(migraphx::op::add{}, l1, l3);
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
auto l2 = p.add_instruction(migraphx::op::broadcast{axis, l0->get_shape()}, l1);
p.add_instruction(migraphx::op::add{}, l0, l2);
auto prog = migraphx::parse_tf("biasadd_test.pb", true);
EXPECT(p == prog);
......@@ -102,17 +100,18 @@ TEST_CASE(conv_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 16, 16, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 32}});
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
std::vector<float> weight_data(3*3*3*32);
std::fill(weight_data.begin(), weight_data.end(), 1.0f);
auto l1 = p.add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 32}}, weight_data);
migraphx::op::convolution op;
op.padding_mode = migraphx::op::padding_mode_t::same;
op.stride = {1, 1};
op.dilation = {1, 1};
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0);
auto l3 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l1);
auto l4 = p.add_instruction(migraphx::op::transpose{{1, 3, 0, 2}}, l3);
p.add_instruction(op, l2, l4);
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l1);
auto l3 = p.add_instruction(migraphx::op::transpose{{1, 3, 0, 2}}, l2);
p.add_instruction(op, l0, l3);
auto prog = migraphx::parse_tf("conv_test.pb", true);
EXPECT(p == prog);
......@@ -131,7 +130,7 @@ TEST_CASE(identity_test)
TEST_CASE(pooling_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 16, 16, 3}});
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
migraphx::op::pooling avg_pool_op{"average"};
migraphx::op::pooling max_pool_op{"max"};
avg_pool_op.padding_mode = migraphx::op::padding_mode_t::valid;
......@@ -140,9 +139,8 @@ TEST_CASE(pooling_test)
max_pool_op.stride = {2, 2};
avg_pool_op.lengths = {2, 2};
max_pool_op.lengths = {2, 2};
auto l1 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0);
p.add_instruction(max_pool_op, l1);
p.add_instruction(avg_pool_op, l1);
p.add_instruction(max_pool_op, l0);
p.add_instruction(avg_pool_op, l0);
auto prog = migraphx::parse_tf("pooling_test.pb", true);
EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment