"awq_cuda/vscode:/vscode.git/clone" did not exist on "63e3fd83fd7aec85ae836a75070dbb3208cf7bca"
Commit 3acbd087 authored by Khalique's avatar Khalique
Browse files

made additional functions to reorder axes

parent 38bde548
...@@ -36,7 +36,38 @@ struct tf_parser ...@@ -36,7 +36,38 @@ struct tf_parser
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
void nhwc_to_nchw(std::size_t& dim) std::vector<size_t> parse_axes(attribute_map& attributes, const std::string& s)
{
auto attrs = attributes.at(s).list().i();
std::vector<size_t> axes;
copy(attrs.begin(), attrs.end(), std::back_inserter(axes));
if (is_nhwc)
{
for(size_t i = 0; i < axes.size(); ++i)
{
parse_axis(axes.at(i));
}
}
return axes;
}
template <class T>
void reorder_data(std::vector<T>& prev_data)
{
std::vector<T> new_data(prev_data.size());
for(auto i = 0; i < new_data.size(); i++)
{
auto new_idx = i;
parse_axis(new_idx);
new_data.at(new_idx) = prev_data.at(i);
}
prev_data = new_data;
}
template <class T>
void parse_axis(T& dim)
{
if(is_nhwc)
{ {
switch(dim) switch(dim)
{ {
...@@ -48,6 +79,8 @@ struct tf_parser ...@@ -48,6 +79,8 @@ struct tf_parser
} }
} }
}
tf_parser() tf_parser()
{ {
add_generic_op("Identity", op::identity{}); add_generic_op("Identity", op::identity{});
...@@ -125,14 +158,14 @@ struct tf_parser ...@@ -125,14 +158,14 @@ struct tf_parser
// output_lens = (3,2,7,5) // output_lens = (3,2,7,5)
// //
// Get lengths for both arguments // Get lengths for both arguments
const std::vector<std::size_t>* s0 = &arg0->get_shape().lens(); const std::vector<size_t>* s0 = &arg0->get_shape().lens();
const std::vector<std::size_t>* s1 = &arg1->get_shape().lens(); const std::vector<size_t>* s1 = &arg1->get_shape().lens();
// Make sure s0 is the smaller size // Make sure s0 is the smaller size
if(s0->size() > s1->size()) if(s0->size() > s1->size())
std::swap(s0, s1); std::swap(s0, s1);
std::vector<std::size_t> output_lens(*s1); std::vector<size_t> output_lens(*s1);
auto offset = s1->size() - s0->size(); auto offset = s1->size() - s0->size();
std::transform(s0->begin(), std::transform(s0->begin(),
s0->end(), s0->end(),
...@@ -168,14 +201,6 @@ struct tf_parser ...@@ -168,14 +201,6 @@ struct tf_parser
{ {
epsilon = attributes.at("epsilon").f(); epsilon = attributes.at("epsilon").f();
} }
auto l0 = args[0];
if(l0->name() == "@param")
{
if(is_nhwc)
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, l0);
};
args[0] = l0;
op::batch_norm_inference op{epsilon, momentum, bn_mode}; op::batch_norm_inference op{epsilon, momentum, bn_mode};
return prog.add_instruction(op, std::move(args)); return prog.add_instruction(op, std::move(args));
} }
...@@ -184,27 +209,17 @@ struct tf_parser ...@@ -184,27 +209,17 @@ struct tf_parser
parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{ {
uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel) uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
auto l0 = args[0]; auto l0 = prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]);
// otherwise, if the input is a parameter to the graph, then first insert transpose return prog.add_instruction(op::add{}, args[0], l0);
if(l0->name() == "@param")
{
if(is_nhwc)
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, l0);
};
auto l1 = prog.add_instruction(op::broadcast{axis, l0->get_shape()}, args[1]);
return prog.add_instruction(op::add{}, l0, l1);
} }
instruction_ref instruction_ref
parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
// get index for axis within args // get index for axis within args
std::size_t axis_idx = attributes.at("N").i(); size_t axis_idx = attributes.at("N").i();
std::size_t axis = args[axis_idx]->eval().at<int64_t>(); size_t axis = args[axis_idx]->eval().at<int64_t>();
if(is_nhwc and axis < 4) parse_axis(axis);
{
nhwc_to_nchw(axis);
}
op::concat op{axis}; op::concat op{axis};
// return only first N arguments (assuming last index is the axis value) // return only first N arguments (assuming last index is the axis value)
return prog.add_instruction( return prog.add_instruction(
...@@ -232,7 +247,7 @@ struct tf_parser ...@@ -232,7 +247,7 @@ struct tf_parser
} }
else if(pad_mode.find("EXPLICIT") != std::string::npos) else if(pad_mode.find("EXPLICIT") != std::string::npos)
{ {
std::vector<std::size_t> padding; std::vector<size_t> padding;
copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding)); copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding));
if(padding.size() != 4) if(padding.size() != 4)
{ {
...@@ -248,52 +263,37 @@ struct tf_parser ...@@ -248,52 +263,37 @@ struct tf_parser
} }
if(contains(attributes, "strides")) if(contains(attributes, "strides"))
{ {
std::vector<std::size_t> stride; std::vector<size_t> stride;
copy(attributes.at("strides").list().i(), std::back_inserter(stride)); copy(attributes.at("strides").list().i(), std::back_inserter(stride));
reorder_data(stride);
if(stride.size() != 4) if(stride.size() != 4)
{ {
MIGRAPHX_THROW("strides should have 4 values"); MIGRAPHX_THROW("strides should have 4 values");
} }
if(is_nhwc)
{
op.stride[0] = stride[1];
op.stride[1] = stride[2];
}
else
{
op.stride[0] = stride[2]; op.stride[0] = stride[2];
op.stride[1] = stride[3]; op.stride[1] = stride[3];
} }
}
if(contains(attributes, "dilations")) if(contains(attributes, "dilations"))
{ {
std::vector<std::size_t> dilation; std::vector<size_t> dilation;
copy(attributes.at("dilations").list().i(), std::back_inserter(dilation)); copy(attributes.at("dilations").list().i(), std::back_inserter(dilation));
reorder_data(dilation);
if(dilation.size() != 4) if(dilation.size() != 4)
{ {
MIGRAPHX_THROW("dilation should have 4 values"); MIGRAPHX_THROW("dilation should have 4 values");
} }
if(is_nhwc)
{
op.dilation[0] = dilation[1];
op.dilation[1] = dilation[2];
}
else
{
op.dilation[0] = dilation[2]; op.dilation[0] = dilation[2];
op.dilation[1] = dilation[3]; op.dilation[1] = dilation[3];
} }
} auto l0 = args[1];
auto l0 = args[0]; if (l0->get_operator().name() == "transpose" and is_nhwc)
if(l0->name() == "@param")
{ {
if(is_nhwc) l0 = prog.add_instruction(op::transpose{{1, 3, 0, 2}}, args[1]);
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, l0);
} }
auto l1 = args[1]; else
if(is_nhwc) l0 = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
l1 = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
return prog.add_instruction(op, {l0, l1}); return prog.add_instruction(op, {args[0], l0});
} }
instruction_ref parse_pooling(const std::string& name, instruction_ref parse_pooling(const std::string& name,
...@@ -316,49 +316,29 @@ struct tf_parser ...@@ -316,49 +316,29 @@ struct tf_parser
} }
if(contains(attributes, "strides")) if(contains(attributes, "strides"))
{ {
std::vector<std::size_t> stride; std::vector<size_t> stride;
copy(attributes.at("strides").list().i(), std::back_inserter(stride)); copy(attributes.at("strides").list().i(), std::back_inserter(stride));
reorder_data(stride);
if(stride.size() != 4) if(stride.size() != 4)
{ {
MIGRAPHX_THROW("strides should have 4 values"); MIGRAPHX_THROW("strides should have 4 values");
} }
if(is_nhwc)
{
op.stride[0] = stride[1];
op.stride[1] = stride[2];
}
else
{
op.stride[0] = stride[2]; op.stride[0] = stride[2];
op.stride[1] = stride[3]; op.stride[1] = stride[3];
} }
}
if(contains(attributes, "ksize")) if(contains(attributes, "ksize"))
{ {
std::vector<std::size_t> ksize; std::vector<size_t> ksize;
copy(attributes.at("ksize").list().i(), std::back_inserter(ksize)); copy(attributes.at("ksize").list().i(), std::back_inserter(ksize));
reorder_data(ksize);
if(ksize.size() != 4) if(ksize.size() != 4)
{ {
MIGRAPHX_THROW("ksize should have 4 values"); MIGRAPHX_THROW("ksize should have 4 values");
} }
if(is_nhwc)
{
op.lengths[0] = ksize[1];
op.lengths[1] = ksize[2];
}
else
{
op.lengths[0] = ksize[2]; op.lengths[0] = ksize[2];
op.lengths[1] = ksize[3]; op.lengths[1] = ksize[3];
} }
} return prog.add_instruction(op, args[0]);
auto l0 = args[0];
if(l0->name() == "@param")
{
if(is_nhwc)
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, l0);
}
return prog.add_instruction(op, l0);
} }
instruction_ref instruction_ref
...@@ -399,28 +379,21 @@ struct tf_parser ...@@ -399,28 +379,21 @@ struct tf_parser
parse_squeeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_squeeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
op::squeeze op; op::squeeze op;
auto axes = attributes.at("squeeze_dims").list().i(); auto temp = args[0]->get_shape().lens();
auto axes = parse_axes(attributes, "squeeze_dims");
copy(axes, std::back_inserter(op.axes)); copy(axes, std::back_inserter(op.axes));
auto l0 = args[0]; auto args0_dims = args[0]->get_shape().lens();
if(is_nhwc)
{
if(l0->name() != "@param")
// squeeze dims are represented for nhwc,
// but intermediate shapes are in nchw
l0 = prog.add_instruction(op::transpose{{0, 2, 3, 1}}, args[0]);
}
auto l0_dims = l0->get_shape().lens();
if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1 if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
{ {
for(size_t i = 0; i < l0_dims.size(); i++) for(auto i = 0; i < args0_dims.size(); i++)
{ {
if(l0_dims.at(i) == 1) if(args0_dims.at(i) == 1)
{ {
op.axes.push_back(i); op.axes.push_back(i);
} }
} }
} }
return prog.add_instruction(op, l0); return prog.add_instruction(op, args[0]);
} }
void parse_graph(const tensorflow::GraphDef& graph) void parse_graph(const tensorflow::GraphDef& graph)
...@@ -433,7 +406,15 @@ struct tf_parser ...@@ -433,7 +406,15 @@ struct tf_parser
shape::type_t shape_type = parse_type(input_attrs.at("dtype").type()); shape::type_t shape_type = parse_type(input_attrs.at("dtype").type());
std::vector<size_t> dims = parse_dims(input_attrs.at("shape").shape()); std::vector<size_t> dims = parse_dims(input_attrs.at("shape").shape());
shape s = shape{shape_type, dims}; shape s = shape{shape_type, dims};
instructions[name] = prog.add_parameter(name, s); auto in_param = prog.add_parameter(name, s);
if(is_nhwc and dims.size() >= 4) // only transpose for NHWC tensors and larger
{
std::vector<int64_t> axes(dims.size());
std::iota(axes.begin(), axes.end(), 0);
reorder_data(axes);
in_param = prog.add_instruction(op::transpose{axes}, in_param);
}
instructions[name] = in_param;
} }
for(auto&& p : nodes) for(auto&& p : nodes)
{ {
...@@ -726,7 +707,7 @@ struct tf_parser ...@@ -726,7 +707,7 @@ struct tf_parser
template <class T> template <class T>
static std::vector<T> get_data_vals(const google::protobuf::RepeatedField<T>& data, static std::vector<T> get_data_vals(const google::protobuf::RepeatedField<T>& data,
const std::size_t& shape_size) const size_t& shape_size)
{ {
std::vector<T> data_vals(shape_size); std::vector<T> data_vals(shape_size);
// check if shape has enough data values given existing fields // check if shape has enough data values given existing fields
......
...@@ -110,8 +110,9 @@ TEST_CASE(conv_test) ...@@ -110,8 +110,9 @@ TEST_CASE(conv_test)
op.stride = {1, 1}; op.stride = {1, 1};
op.dilation = {1, 1}; op.dilation = {1, 1};
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0); auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0);
auto l3 = p.add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1); auto l3 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l1);
p.add_instruction(op, l2, l3); auto l4 = p.add_instruction(migraphx::op::transpose{{1, 3, 0, 2}}, l3);
p.add_instruction(op, l2, l4);
auto prog = migraphx::parse_tf("conv_test.pb", true); auto prog = migraphx::parse_tf("conv_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -141,8 +142,7 @@ TEST_CASE(pooling_test) ...@@ -141,8 +142,7 @@ TEST_CASE(pooling_test)
max_pool_op.lengths = {2, 2}; max_pool_op.lengths = {2, 2};
auto l1 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0); auto l1 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0);
p.add_instruction(max_pool_op, l1); p.add_instruction(max_pool_op, l1);
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0); p.add_instruction(avg_pool_op, l1);
p.add_instruction(avg_pool_op, l2);
auto prog = migraphx::parse_tf("pooling_test.pb", true); auto prog = migraphx::parse_tf("pooling_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment