Commit fb5e2817 authored by Khalique's avatar Khalique
Browse files

added parsing test

parent 635788d1
......@@ -504,27 +504,35 @@ struct tf_parser
std::vector<instruction_ref> args)
{
op::slice op;
auto begin = args[1]->eval().get<int64_t>().to_vector();
;
auto end = args[2]->eval().get<int64_t>().to_vector();
;
op.starts = begin;
op.ends = end;
auto starts = args[1]->eval().get<int32_t>().to_vector();
auto ends = args[2]->eval().get<int32_t>().to_vector();
size_t num_axes = args[0]->get_shape().lens().size();
if(num_axes >= 4)
{
reorder_data(starts);
reorder_data(ends);
}
op.starts = std::vector<int64_t>(starts.begin(), starts.end());
op.ends = std::vector<int64_t>(ends.begin(), ends.end());
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
int shrink_axis_mask = 0;
std::vector<int64_t> squeeze_axes;
if(contains(attributes, "shrink_axis_mask"))
shrink_axis_mask = attributes.at("shrink_axis_mask").i();
size_t num_axes = args[0]->get_shape().lens().size();
for(size_t i = 0; i < num_axes; i++)
{
if((shrink_axis_mask >> i) & 1)
squeeze_axes.push_back(i);
}
if(num_axes >= 4)
{
squeeze_axes = parse_axes(squeeze_axes);
}
auto l0 = prog.add_instruction(op, args[0]);
return prog.add_instruction(op::squeeze{squeeze_axes}, l0);
}
......
......@@ -215,4 +215,28 @@ TEST_CASE(squeeze_test)
EXPECT(p == prog);
}
TEST_CASE(stridedslice_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
std::size_t num_axes = 4;
migraphx::op::slice op;
op.starts = {0,0,0,0};
op.ends = {1,5,1,1};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
// add literals for starts, ends, and strides in tf (NHWC format)
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0,0,0,0});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1,1,1,5});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1,1,1,1});
auto l1 = p.add_instruction(op, l0);
auto shrink_axis = 2;
p.add_instruction(migraphx::op::squeeze{{shrink_axis}}, l1);
auto prog = migraphx::parse_tf("stridedslice_test.pb", true);
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment