Commit cdf96caf authored by Paul's avatar Paul
Browse files

Transpose every instruction

parent 0d796941
...@@ -36,6 +36,49 @@ struct tf_parser ...@@ -36,6 +36,49 @@ struct tf_parser
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
bool should_transpose(instruction_ref ins)
{
return is_nhwc and ins->get_shape().lens().size() == 4;
}
instruction_ref to_nhwc(instruction_ref ins)
{
if (should_transpose(ins))
return prog.add_instruction(op::transpose{{0, 2, 3, 1}}, ins);
return ins;
}
instruction_ref to_nchw(instruction_ref ins)
{
if (should_transpose(ins))
return prog.add_instruction(op::transpose{{0, 3, 1, 2}}, ins);
return ins;
}
instruction_ref to_kcxy(instruction_ref ins)
{
if (should_transpose(ins))
return prog.add_instruction(op::transpose{{3, 2, 0, 1}}, ins);
return ins;
}
instruction_ref make_contiguous(instruction_ref ins)
{
if (ins->get_shape().standard())
return ins;
else
return prog.add_instruction(op::contiguous{}, ins);
}
std::vector<instruction_ref> to_nchw(const std::vector<instruction_ref>& args)
{
std::vector<instruction_ref> result(args.size());
std::transform(args.begin(), args.end(), result.begin(), [&](auto ins) {
return to_nchw(ins);
});
return result;
}
std::vector<size_t> parse_axes(const attribute_map& attributes, const std::string& s) const std::vector<size_t> parse_axes(const attribute_map& attributes, const std::string& s) const
{ {
auto attrs = attributes.at(s).list().i(); auto attrs = attributes.at(s).list().i();
...@@ -116,41 +159,43 @@ struct tf_parser ...@@ -116,41 +159,43 @@ struct tf_parser
add_mem_op("AvgPool", &tf_parser::parse_pooling); add_mem_op("AvgPool", &tf_parser::parse_pooling);
add_mem_op("BiasAdd", &tf_parser::parse_biasadd); add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
add_mem_op("ConcatV2", &tf_parser::parse_concat); add_mem_op("ConcatV2", &tf_parser::parse_concat, false);
add_mem_op("Const", &tf_parser::parse_constant); add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv); add_mem_op("Conv2D", &tf_parser::parse_conv, false);
add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv); add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv, false);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm); add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MatMul", &tf_parser::parse_matmul); add_mem_op("MatMul", &tf_parser::parse_matmul, false);
add_mem_op("MaxPool", &tf_parser::parse_pooling); add_mem_op("MaxPool", &tf_parser::parse_pooling);
add_mem_op("Mean", &tf_parser::parse_mean); add_mem_op("Mean", &tf_parser::parse_mean);
add_mem_op("Pack", &tf_parser::parse_pack); add_mem_op("Pack", &tf_parser::parse_pack, false);
add_mem_op("Pad", &tf_parser::parse_pad); add_mem_op("Pad", &tf_parser::parse_pad, false);
add_mem_op("Reshape", &tf_parser::parse_reshape); add_mem_op("Reshape", &tf_parser::parse_reshape, false);
add_mem_op("Softmax", &tf_parser::parse_softmax); add_mem_op("Softmax", &tf_parser::parse_softmax);
add_mem_op("Squeeze", &tf_parser::parse_squeeze); add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice); add_mem_op("StridedSlice", &tf_parser::parse_stridedslice);
} }
template <class F> template <class F>
void add_op(std::string name, F f) void add_op(std::string name, F f, bool transpose=true)
{
ops.emplace(name, f);
}
// Multi output op
template <class F>
void add_multi_op(std::string name, F f)
{ {
ops.emplace(name, f); if (transpose)
{
ops.emplace(name, op_func{[=](const attribute_map& attributes, std::vector<instruction_ref> args) -> instruction_ref {
return to_nhwc(f(attributes, to_nchw(args)));
}});
}
else
{
ops.emplace(name, f);
}
} }
template <class F> template <class F>
void add_mem_op(std::string name, F f) void add_mem_op(std::string name, F f, bool transpose=true)
{ {
add_op(name, [=](auto&&... xs) { add_op(name, [=](auto&&... xs) {
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...); return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
}); }, transpose);
} }
template <class T> template <class T>
...@@ -159,15 +204,15 @@ struct tf_parser ...@@ -159,15 +204,15 @@ struct tf_parser
add_op(name, [this, x](const attribute_map& attributes, std::vector<instruction_ref> args) { add_op(name, [this, x](const attribute_map& attributes, std::vector<instruction_ref> args) {
if(args.size() != 2) if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands"); MIGRAPHX_THROW("binary operators should have 2 operands");
auto l0 = args[1];
if(contains(attributes, "data_format")) if(contains(attributes, "data_format"))
{ {
if(is_nhwc) // TODO
{ // if(is_nhwc)
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]); // {
} // l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]);
// }
} }
return add_broadcastable_binary_op(args[0], l0, x); return add_broadcastable_binary_op(args[0], args[1], x);
}); });
} }
...@@ -207,20 +252,20 @@ struct tf_parser ...@@ -207,20 +252,20 @@ struct tf_parser
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0); auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1); auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1);
return prog.add_instruction(x, l0, l1); return to_nhwc(prog.add_instruction(x, to_nchw(l0), to_nchw(l1)));
} }
else else
{ {
return prog.add_instruction(x, {arg0, arg1}); return to_nhwc(prog.add_instruction(x, {to_nchw(arg0), to_nchw(arg1)}));
} }
} }
template <class T> template <class T>
void add_generic_op(std::string name, T x) void add_generic_op(std::string name, T x, bool transpose=true)
{ {
add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) { add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args); return prog.add_instruction(x, args);
}); }, transpose);
} }
instruction_ref instruction_ref
...@@ -250,11 +295,11 @@ struct tf_parser ...@@ -250,11 +295,11 @@ struct tf_parser
{ {
// get index for axis within args // get index for axis within args
size_t axis_idx = attributes.at("N").i(); size_t axis_idx = attributes.at("N").i();
size_t axis = parse_axis(args[axis_idx]->eval().at<int64_t>()); size_t axis = args[axis_idx]->eval().at<int64_t>();
op::concat op{axis}; op::concat op{axis};
// return only first N arguments (assuming last index is the axis value) // return only first N arguments (assuming last index is the axis value)
return prog.add_instruction( return to_nhwc(prog.add_instruction(
op, std::vector<instruction_ref>(args.begin(), args.begin() + args.size() - 1)); op, std::vector<instruction_ref>(args.begin(), args.begin() + args.size() - 1)));
} }
instruction_ref parse_constant(const std::string&, instruction_ref parse_constant(const std::string&,
...@@ -324,22 +369,7 @@ struct tf_parser ...@@ -324,22 +369,7 @@ struct tf_parser
op.dilation[0] = dilation[2]; op.dilation[0] = dilation[2];
op.dilation[1] = dilation[3]; op.dilation[1] = dilation[3];
} }
auto weights = args[1]; return prog.add_instruction(op, {to_nchw(args[0]), to_kcxy(to_nchw(args[1]))});
// check if weights are from a constant
if(weights->name() != "@param")
{
if(is_nhwc)
{
weights = prog.add_instruction(op::transpose{{1, 3, 0, 2}}, args[1]);
}
else
{
weights = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
}
}
return prog.add_instruction(op, {args[0], weights});
} }
instruction_ref parse_depthwiseconv(const std::string&, instruction_ref parse_depthwiseconv(const std::string&,
...@@ -369,19 +399,7 @@ struct tf_parser ...@@ -369,19 +399,7 @@ struct tf_parser
op.stride[0] = stride[2]; op.stride[0] = stride[2];
op.stride[1] = stride[3]; op.stride[1] = stride[3];
} }
auto weights = args[1]; auto weights = to_kcxy(to_nchw(args[1]));
// check if weights are from a constant
if(weights->name() != "@param")
{
if(is_nhwc)
{
weights = prog.add_instruction(op::transpose{{1, 3, 0, 2}}, args[1]);
}
else
{
weights = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
}
}
std::vector<int64_t> new_weights_shape; std::vector<int64_t> new_weights_shape;
copy(weights->get_shape().lens(), std::back_inserter(new_weights_shape)); copy(weights->get_shape().lens(), std::back_inserter(new_weights_shape));
...@@ -429,7 +447,7 @@ struct tf_parser ...@@ -429,7 +447,7 @@ struct tf_parser
instruction_ref instruction_ref
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector()); auto axes = args[1]->eval().get<int32_t>().to_vector();
bool keep_dims = attributes.at("keep_dims").b(); bool keep_dims = attributes.at("keep_dims").b();
std::vector<int32_t> hw_axes{2, 3}; std::vector<int32_t> hw_axes{2, 3};
// check if conditions for GlobalAvgPool are met // check if conditions for GlobalAvgPool are met
...@@ -463,16 +481,13 @@ struct tf_parser ...@@ -463,16 +481,13 @@ struct tf_parser
MIGRAPHX_THROW("TF_PARSER: axis value of " + to_string(axis) + MIGRAPHX_THROW("TF_PARSER: axis value of " + to_string(axis) +
" must be smaller than input size " + to_string(input_size)); " must be smaller than input size " + to_string(input_size));
} }
// check if input arg needs axis to be converted to NCHW
if(input_size >= 4)
axis = parse_axis(axis);
std::transform( std::transform(
args.begin(), args.begin(),
args.end(), args.end(),
std::back_inserter(unsqueezed_args), std::back_inserter(unsqueezed_args),
[&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); }); [&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); });
return prog.add_instruction(op::concat{static_cast<size_t>(axis)}, unsqueezed_args); return to_nhwc(prog.add_instruction(op::concat{static_cast<size_t>(axis)}, unsqueezed_args));
} }
instruction_ref instruction_ref
...@@ -555,7 +570,7 @@ struct tf_parser ...@@ -555,7 +570,7 @@ struct tf_parser
MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)"); MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)");
auto s = args[1]->eval(); auto s = args[1]->eval();
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
return prog.add_instruction(op, args[0]); return prog.add_instruction(op, make_contiguous(args[0]));
} }
void parse_from(std::istream& is) void parse_from(std::istream& is)
...@@ -586,7 +601,7 @@ struct tf_parser ...@@ -586,7 +601,7 @@ struct tf_parser
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
op::squeeze op; op::squeeze op;
auto axes = parse_axes(attributes, "squeeze_dims"); auto axes = attributes.at("squeeze_dims").list().i();
copy(axes, std::back_inserter(op.axes)); copy(axes, std::back_inserter(op.axes));
auto args0_dims = args[0]->get_shape().lens(); auto args0_dims = args[0]->get_shape().lens();
if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1 if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
...@@ -599,7 +614,7 @@ struct tf_parser ...@@ -599,7 +614,7 @@ struct tf_parser
} }
} }
} }
return prog.add_instruction(op, args[0]); return prog.add_instruction(op, make_contiguous(args[0]));
} }
instruction_ref parse_stridedslice(const std::string&, instruction_ref parse_stridedslice(const std::string&,
...@@ -610,11 +625,6 @@ struct tf_parser ...@@ -610,11 +625,6 @@ struct tf_parser
auto starts = args[1]->eval().get<int32_t>().to_vector(); auto starts = args[1]->eval().get<int32_t>().to_vector();
auto ends = args[2]->eval().get<int32_t>().to_vector(); auto ends = args[2]->eval().get<int32_t>().to_vector();
size_t num_axes = args[0]->get_shape().lens().size(); size_t num_axes = args[0]->get_shape().lens().size();
if(num_axes >= 4)
{
reorder_data(starts);
reorder_data(ends);
}
op.starts = std::vector<int64_t>(starts.begin(), starts.end()); op.starts = std::vector<int64_t>(starts.begin(), starts.end());
op.ends = std::vector<int64_t>(ends.begin(), ends.end()); op.ends = std::vector<int64_t>(ends.begin(), ends.end());
...@@ -633,13 +643,9 @@ struct tf_parser ...@@ -633,13 +643,9 @@ struct tf_parser
if(((shrink_axis_mask >> i) & bitwise_compare) == 1) if(((shrink_axis_mask >> i) & bitwise_compare) == 1)
squeeze_axes.push_back(i); squeeze_axes.push_back(i);
} }
if(num_axes >= 4)
{
squeeze_axes = parse_axes(squeeze_axes);
}
auto l0 = prog.add_instruction(op, args[0]); auto l0 = prog.add_instruction(op, make_contiguous(args[0]));
return prog.add_instruction(op::squeeze{squeeze_axes}, l0); return to_nhwc(prog.add_instruction(op::squeeze{squeeze_axes}, l0));
} }
void parse_graph(const tensorflow::GraphDef& graph) void parse_graph(const tensorflow::GraphDef& graph)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment