Commit f12243cf authored by Khalique's avatar Khalique
Browse files

initial testing

parent a713a6d3
......@@ -544,7 +544,8 @@ struct tf_parser
op::slice op;
auto starts = args[1]->eval().get<int32_t>().to_vector();
auto ends = args[2]->eval().get<int32_t>().to_vector();
size_t num_axes = args[0]->get_shape().lens().size();
std::vector<size_t> axes = args[0]->get_shape().lens();
size_t num_axes = axes.size();
if(num_axes >= 4)
{
reorder_data(starts);
......@@ -555,25 +556,75 @@ struct tf_parser
op.ends = std::vector<int64_t>(ends.begin(), ends.end());
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
uint32_t begin_mask = 0;
uint32_t end_mask = 0;
uint32_t shrink_axis_mask = 0;
uint32_t bitwise_compare = 1;
std::vector<int64_t> begin_axes;
std::vector<int64_t> end_axes;
std::vector<int64_t> squeeze_axes;
if(contains(attributes, "begin_mask"))
begin_mask = static_cast<uint32_t>(attributes.at("begin_mask").i());
if(contains(attributes, "end_mask"))
end_mask = static_cast<uint32_t>(attributes.at("end_mask").i());
if(contains(attributes, "shrink_axis_mask"))
shrink_axis_mask = static_cast<uint32_t>(attributes.at("shrink_axis_mask").i());
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to begin
if(((begin_mask >> i) & bitwise_compare) == 1)
begin_axes.push_back(1);
else
begin_axes.push_back(0);
}
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to end
if(((end_mask >> i) & bitwise_compare) == 1)
end_axes.push_back(1);
else
end_axes.push_back(0);
}
if (num_axes >= 4)
{
reorder_data(begin_axes);
reorder_data(end_axes);
}
for(size_t i = 0; i < num_axes; i++)
{
if(begin_axes.at(i) == 1)
{
op.starts.at(i) = 0;
}
if(end_axes.at(i) == 1)
{
op.ends.at(i) = axes.at(i);
}
}
auto l0 = prog.add_instruction(op, args[0]);
if(shrink_axis_mask == 0)
return l0;
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to squeeze
if(((shrink_axis_mask >> i) & bitwise_compare) == 1)
squeeze_axes.push_back(i);
}
if(num_axes >= 4)
{
squeeze_axes = parse_axes(squeeze_axes);
}
auto l0 = prog.add_instruction(op, args[0]);
return prog.add_instruction(op::squeeze{squeeze_axes}, l0);
}
......
......@@ -288,4 +288,25 @@ TEST_CASE(stridedslice_test)
EXPECT(p == prog);
}
TEST_CASE(stridedslice_masks_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}});
std::size_t num_axes = 4;
migraphx::op::slice op;
op.starts = {0, 0, 1, 1};
op.ends = {1, 10, 3, 3};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
// add literals for starts, ends, and strides in tf (NHWC format)
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0, 1, 1, 0});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0, 0, 0, 0});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 1});
p.add_instruction(op, l0);
auto prog = migraphx::parse_tf("stridedslice_masks_test.pb", true);
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment