Unverified Commit 1fe84f2a authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge pull request #324 from ROCmSoftwarePlatform/slice_op

Slice op
parents a960abad 752d3239
...@@ -181,6 +181,7 @@ struct tf_parser ...@@ -181,6 +181,7 @@ struct tf_parser
add_mem_op("Pack", &tf_parser::parse_pack, false); add_mem_op("Pack", &tf_parser::parse_pack, false);
add_mem_op("Pad", &tf_parser::parse_pad); add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Reshape", &tf_parser::parse_reshape, false); add_mem_op("Reshape", &tf_parser::parse_reshape, false);
add_mem_op("Slice", &tf_parser::parse_slice, false);
add_mem_op("Softmax", &tf_parser::parse_softmax<op::softmax>); add_mem_op("Softmax", &tf_parser::parse_softmax<op::softmax>);
add_mem_op("Squeeze", &tf_parser::parse_squeeze, false); add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice); add_mem_op("StridedSlice", &tf_parser::parse_stridedslice);
...@@ -733,6 +734,29 @@ struct tf_parser ...@@ -733,6 +734,29 @@ struct tf_parser
} }
} }
instruction_ref
parse_slice(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
op::slice op;
auto starts = args[1]->eval().get<int32_t>().to_vector();
auto size = args[2]->eval().get<int32_t>().to_vector();
auto axes = args[0]->get_shape().lens();
size_t num_axes = axes.size();
op.starts = std::vector<int64_t>(starts.begin(), starts.end());
op.ends = std::vector<int64_t>(num_axes);
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
for(size_t i = 0; i < num_axes; i++)
{
if(size[i] == -1)
op.ends[i] = axes[i];
else
op.ends[i] = starts[i] + size[i];
}
return prog.add_instruction(op, make_contiguous(args[0]));
}
// template to facilitate the logsoftmax later // template to facilitate the logsoftmax later
template <class Op> template <class Op>
instruction_ref parse_softmax(const std::string&, instruction_ref parse_softmax(const std::string&,
......
...@@ -412,6 +412,26 @@ TEST_CASE(rsqrt_test) ...@@ -412,6 +412,26 @@ TEST_CASE(rsqrt_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(slice_test)
{
migraphx::program p;
std::size_t num_axes = 2;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 10}});
migraphx::shape s0{migraphx::shape::int32_type, {num_axes}};
p.add_literal(migraphx::literal{s0, {1, 0}});
p.add_literal(migraphx::literal{s0, {2, -1}});
migraphx::op::slice op;
op.starts = {1, 0};
op.ends = {3, 10};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
p.add_instruction(op, l0);
auto prog = optimize_tf("slice_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(softmax_test) TEST_CASE(softmax_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment