Commit 51e7bbd7 authored by Khalique's avatar Khalique
Browse files

manual merge

parents c6a36353 8623590a
...@@ -215,12 +215,9 @@ void simplify_reshapes::apply(program& p) const ...@@ -215,12 +215,9 @@ void simplify_reshapes::apply(program& p) const
// Skip possible dead instructions // Skip possible dead instructions
if(ins->outputs().empty() and ins != end) if(ins->outputs().empty() and ins != end)
continue; continue;
match::find_matches(p, match::find_matches(p, ins, find_nop_reshapes{}, find_reshaper{}, find_transpose{}
ins, // find_concat_transpose{}
find_nop_reshapes{}, );
find_reshaper{},
find_transpose{},
find_concat_transpose{});
} }
} }
......
...@@ -185,7 +185,7 @@ struct tf_parser ...@@ -185,7 +185,7 @@ struct tf_parser
add_mem_op("Slice", &tf_parser::parse_slice, false); add_mem_op("Slice", &tf_parser::parse_slice, false);
add_mem_op("Softmax", &tf_parser::parse_softmax); add_mem_op("Softmax", &tf_parser::parse_softmax);
add_mem_op("Squeeze", &tf_parser::parse_squeeze, false); add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice); add_mem_op("StridedSlice", &tf_parser::parse_stridedslice, false);
add_mem_op("Transpose", &tf_parser::parse_transpose, false); add_mem_op("Transpose", &tf_parser::parse_transpose, false);
} }
...@@ -823,19 +823,65 @@ struct tf_parser ...@@ -823,19 +823,65 @@ struct tf_parser
op::slice op; op::slice op;
auto starts = args[1]->eval().get<int32_t>().to_vector(); auto starts = args[1]->eval().get<int32_t>().to_vector();
auto ends = args[2]->eval().get<int32_t>().to_vector(); auto ends = args[2]->eval().get<int32_t>().to_vector();
size_t num_axes = args[0]->get_shape().lens().size(); auto l0 = args[0];
size_t num_axes = l0->get_shape().lens().size();
std::vector<size_t> axes = l0->get_shape().lens();
op.starts = std::vector<int64_t>(starts.begin(), starts.end()); op.starts = std::vector<int64_t>(starts.begin(), starts.end());
op.ends = std::vector<int64_t>(ends.begin(), ends.end()); op.ends = std::vector<int64_t>(ends.begin(), ends.end());
op.axes = std::vector<int64_t>(num_axes); op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0); std::iota(op.axes.begin(), op.axes.end(), 0);
uint32_t begin_mask = 0;
uint32_t end_mask = 0;
uint32_t shrink_axis_mask = 0; uint32_t shrink_axis_mask = 0;
uint32_t bitwise_compare = 1; uint32_t bitwise_compare = 1;
std::vector<int64_t> begin_axes;
std::vector<int64_t> end_axes;
std::vector<int64_t> squeeze_axes; std::vector<int64_t> squeeze_axes;
if(contains(attributes, "begin_mask"))
begin_mask = static_cast<uint32_t>(attributes.at("begin_mask").i());
if(contains(attributes, "end_mask"))
end_mask = static_cast<uint32_t>(attributes.at("end_mask").i());
if(contains(attributes, "shrink_axis_mask")) if(contains(attributes, "shrink_axis_mask"))
shrink_axis_mask = static_cast<uint32_t>(attributes.at("shrink_axis_mask").i()); shrink_axis_mask = static_cast<uint32_t>(attributes.at("shrink_axis_mask").i());
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to begin
if(((begin_mask >> i) & bitwise_compare) == 1)
begin_axes.push_back(1);
else
begin_axes.push_back(0);
}
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to end
if(((end_mask >> i) & bitwise_compare) == 1)
end_axes.push_back(1);
else
end_axes.push_back(0);
}
for(size_t i = 0; i < num_axes; i++)
{
if(begin_axes.at(i) == 1)
{
op.starts.at(i) = 0;
}
if(end_axes.at(i) == 1)
{
op.ends.at(i) = axes.at(i);
}
}
auto l1 = prog.add_instruction(op, l0);
if(shrink_axis_mask == 0)
return l1;
for(size_t i = 0; i < num_axes; i++) for(size_t i = 0; i < num_axes; i++)
{ {
// the LSB corresponds to axis 0 when determining which axes to squeeze // the LSB corresponds to axis 0 when determining which axes to squeeze
...@@ -843,8 +889,7 @@ struct tf_parser ...@@ -843,8 +889,7 @@ struct tf_parser
squeeze_axes.push_back(i); squeeze_axes.push_back(i);
} }
auto l0 = prog.add_instruction(op, make_contiguous(args[0])); return prog.add_instruction(op::squeeze{squeeze_axes}, l1);
return to_nhwc(prog.add_instruction(op::squeeze{squeeze_axes}, l0));
} }
instruction_ref instruction_ref
......
...@@ -516,6 +516,27 @@ TEST_CASE(tanh_test) ...@@ -516,6 +516,27 @@ TEST_CASE(tanh_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(stridedslice_masks_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}});
std::size_t num_axes = 4;
migraphx::op::slice op;
op.starts = {0, 0, 1, 1};
op.ends = {1, 10, 3, 3};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
// add literals for starts, ends, and strides in tf (NHWC format)
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0, 1, 1, 0});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0, 0, 0, 0});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 1});
p.add_instruction(op, l0);
auto prog = migraphx::parse_tf("stridedslice_masks_test.pb", true);
EXPECT(p == prog);
}
TEST_CASE(transpose_test) TEST_CASE(transpose_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment