"src/targets/gpu/vscode:/vscode.git/clone" did not exist on "0bedc5e8e02fdd72381e9551c8fa7239eaad49c9"
Commit 3acbd087 authored by Khalique's avatar Khalique
Browse files

made additional functions to reorder axes

parent 38bde548
......@@ -36,7 +36,38 @@ struct tf_parser
std::unordered_map<std::string, op_func> ops;
void nhwc_to_nchw(std::size_t& dim)
std::vector<size_t> parse_axes(attribute_map& attributes, const std::string& s)
{
auto attrs = attributes.at(s).list().i();
std::vector<size_t> axes;
copy(attrs.begin(), attrs.end(), std::back_inserter(axes));
if (is_nhwc)
{
for(size_t i = 0; i < axes.size(); ++i)
{
parse_axis(axes.at(i));
}
}
return axes;
}
template <class T>
void reorder_data(std::vector<T>& prev_data)
{
std::vector<T> new_data(prev_data.size());
for(auto i = 0; i < new_data.size(); i++)
{
auto new_idx = i;
parse_axis(new_idx);
new_data.at(new_idx) = prev_data.at(i);
}
prev_data = new_data;
}
template <class T>
void parse_axis(T& dim)
{
if(is_nhwc)
{
switch(dim)
{
......@@ -48,6 +79,8 @@ struct tf_parser
}
}
}
tf_parser()
{
add_generic_op("Identity", op::identity{});
......@@ -125,14 +158,14 @@ struct tf_parser
// output_lens = (3,2,7,5)
//
// Get lengths for both arguments
const std::vector<std::size_t>* s0 = &arg0->get_shape().lens();
const std::vector<std::size_t>* s1 = &arg1->get_shape().lens();
const std::vector<size_t>* s0 = &arg0->get_shape().lens();
const std::vector<size_t>* s1 = &arg1->get_shape().lens();
// Make sure s0 is the smaller size
if(s0->size() > s1->size())
std::swap(s0, s1);
std::vector<std::size_t> output_lens(*s1);
std::vector<size_t> output_lens(*s1);
auto offset = s1->size() - s0->size();
std::transform(s0->begin(),
s0->end(),
......@@ -168,14 +201,6 @@ struct tf_parser
{
epsilon = attributes.at("epsilon").f();
}
auto l0 = args[0];
if(l0->name() == "@param")
{
if(is_nhwc)
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, l0);
};
args[0] = l0;
op::batch_norm_inference op{epsilon, momentum, bn_mode};
return prog.add_instruction(op, std::move(args));
}
......@@ -184,27 +209,17 @@ struct tf_parser
parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
auto l0 = args[0];
// otherwise, if the input is a parameter to the graph, then first insert transpose
if(l0->name() == "@param")
{
if(is_nhwc)
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, l0);
};
auto l1 = prog.add_instruction(op::broadcast{axis, l0->get_shape()}, args[1]);
return prog.add_instruction(op::add{}, l0, l1);
auto l0 = prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]);
return prog.add_instruction(op::add{}, args[0], l0);
}
instruction_ref
parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
// get index for axis within args
std::size_t axis_idx = attributes.at("N").i();
std::size_t axis = args[axis_idx]->eval().at<int64_t>();
if(is_nhwc and axis < 4)
{
nhwc_to_nchw(axis);
}
size_t axis_idx = attributes.at("N").i();
size_t axis = args[axis_idx]->eval().at<int64_t>();
parse_axis(axis);
op::concat op{axis};
// return only first N arguments (assuming last index is the axis value)
return prog.add_instruction(
......@@ -232,7 +247,7 @@ struct tf_parser
}
else if(pad_mode.find("EXPLICIT") != std::string::npos)
{
std::vector<std::size_t> padding;
std::vector<size_t> padding;
copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding));
if(padding.size() != 4)
{
......@@ -248,52 +263,37 @@ struct tf_parser
}
if(contains(attributes, "strides"))
{
std::vector<std::size_t> stride;
std::vector<size_t> stride;
copy(attributes.at("strides").list().i(), std::back_inserter(stride));
reorder_data(stride);
if(stride.size() != 4)
{
MIGRAPHX_THROW("strides should have 4 values");
}
if(is_nhwc)
{
op.stride[0] = stride[1];
op.stride[1] = stride[2];
}
else
{
op.stride[0] = stride[2];
op.stride[1] = stride[3];
}
}
if(contains(attributes, "dilations"))
{
std::vector<std::size_t> dilation;
std::vector<size_t> dilation;
copy(attributes.at("dilations").list().i(), std::back_inserter(dilation));
reorder_data(dilation);
if(dilation.size() != 4)
{
MIGRAPHX_THROW("dilation should have 4 values");
}
if(is_nhwc)
{
op.dilation[0] = dilation[1];
op.dilation[1] = dilation[2];
}
else
{
op.dilation[0] = dilation[2];
op.dilation[1] = dilation[3];
}
}
auto l0 = args[0];
if(l0->name() == "@param")
auto l0 = args[1];
if (l0->get_operator().name() == "transpose" and is_nhwc)
{
if(is_nhwc)
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, l0);
l0 = prog.add_instruction(op::transpose{{1, 3, 0, 2}}, args[1]);
}
auto l1 = args[1];
if(is_nhwc)
l1 = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
return prog.add_instruction(op, {l0, l1});
else
l0 = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
return prog.add_instruction(op, {args[0], l0});
}
instruction_ref parse_pooling(const std::string& name,
......@@ -316,49 +316,29 @@ struct tf_parser
}
if(contains(attributes, "strides"))
{
std::vector<std::size_t> stride;
std::vector<size_t> stride;
copy(attributes.at("strides").list().i(), std::back_inserter(stride));
reorder_data(stride);
if(stride.size() != 4)
{
MIGRAPHX_THROW("strides should have 4 values");
}
if(is_nhwc)
{
op.stride[0] = stride[1];
op.stride[1] = stride[2];
}
else
{
op.stride[0] = stride[2];
op.stride[1] = stride[3];
}
}
if(contains(attributes, "ksize"))
{
std::vector<std::size_t> ksize;
std::vector<size_t> ksize;
copy(attributes.at("ksize").list().i(), std::back_inserter(ksize));
reorder_data(ksize);
if(ksize.size() != 4)
{
MIGRAPHX_THROW("ksize should have 4 values");
}
if(is_nhwc)
{
op.lengths[0] = ksize[1];
op.lengths[1] = ksize[2];
}
else
{
op.lengths[0] = ksize[2];
op.lengths[1] = ksize[3];
}
}
auto l0 = args[0];
if(l0->name() == "@param")
{
if(is_nhwc)
l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, l0);
}
return prog.add_instruction(op, l0);
return prog.add_instruction(op, args[0]);
}
instruction_ref
......@@ -399,28 +379,21 @@ struct tf_parser
parse_squeeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
op::squeeze op;
auto axes = attributes.at("squeeze_dims").list().i();
auto temp = args[0]->get_shape().lens();
auto axes = parse_axes(attributes, "squeeze_dims");
copy(axes, std::back_inserter(op.axes));
auto l0 = args[0];
if(is_nhwc)
{
if(l0->name() != "@param")
// squeeze dims are represented for nhwc,
// but intermediate shapes are in nchw
l0 = prog.add_instruction(op::transpose{{0, 2, 3, 1}}, args[0]);
}
auto l0_dims = l0->get_shape().lens();
auto args0_dims = args[0]->get_shape().lens();
if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
{
for(size_t i = 0; i < l0_dims.size(); i++)
for(auto i = 0; i < args0_dims.size(); i++)
{
if(l0_dims.at(i) == 1)
if(args0_dims.at(i) == 1)
{
op.axes.push_back(i);
}
}
}
return prog.add_instruction(op, l0);
return prog.add_instruction(op, args[0]);
}
void parse_graph(const tensorflow::GraphDef& graph)
......@@ -433,7 +406,15 @@ struct tf_parser
shape::type_t shape_type = parse_type(input_attrs.at("dtype").type());
std::vector<size_t> dims = parse_dims(input_attrs.at("shape").shape());
shape s = shape{shape_type, dims};
instructions[name] = prog.add_parameter(name, s);
auto in_param = prog.add_parameter(name, s);
if(is_nhwc and dims.size() >= 4) // only transpose for NHWC tensors and larger
{
std::vector<int64_t> axes(dims.size());
std::iota(axes.begin(), axes.end(), 0);
reorder_data(axes);
in_param = prog.add_instruction(op::transpose{axes}, in_param);
}
instructions[name] = in_param;
}
for(auto&& p : nodes)
{
......@@ -726,7 +707,7 @@ struct tf_parser
template <class T>
static std::vector<T> get_data_vals(const google::protobuf::RepeatedField<T>& data,
const std::size_t& shape_size)
const size_t& shape_size)
{
std::vector<T> data_vals(shape_size);
// check if shape has enough data values given existing fields
......
......@@ -110,8 +110,9 @@ TEST_CASE(conv_test)
op.stride = {1, 1};
op.dilation = {1, 1};
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0);
auto l3 = p.add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
p.add_instruction(op, l2, l3);
auto l3 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l1);
auto l4 = p.add_instruction(migraphx::op::transpose{{1, 3, 0, 2}}, l3);
p.add_instruction(op, l2, l4);
auto prog = migraphx::parse_tf("conv_test.pb", true);
EXPECT(p == prog);
......@@ -141,8 +142,7 @@ TEST_CASE(pooling_test)
max_pool_op.lengths = {2, 2};
auto l1 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0);
p.add_instruction(max_pool_op, l1);
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l0);
p.add_instruction(avg_pool_op, l2);
p.add_instruction(avg_pool_op, l1);
auto prog = migraphx::parse_tf("pooling_test.pb", true);
EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment