Commit 9cdeab11 authored by Khalique's avatar Khalique
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into scalar_parsing

parents 433f2cdb a27d2dc5
......@@ -31,6 +31,8 @@ enum class rnn_direction
bidirectional,
};
std::ostream& operator<<(std::ostream& os, rnn_direction v);
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -19,6 +19,13 @@ namespace op {
struct gather
{
int axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "gather"; }
shape compute_shape(std::vector<shape> inputs) const
......
......@@ -27,6 +27,16 @@ struct gru
float clip = 0.0f;
int linear_before_reset = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.clip, "clip"),
f(self.linear_before_reset, "linear_before_reset"));
}
std::string name() const { return "gru"; }
shape compute_shape(std::vector<shape> inputs) const
{
......
......@@ -19,6 +19,13 @@ namespace op {
struct logsoftmax
{
int axis = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "logsoftmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
......
......@@ -25,6 +25,15 @@ struct lstm
float clip = 0.0f;
int input_forget = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.input_forget, "input_forget"));
}
std::string name() const { return "lstm"; }
shape compute_shape(std::vector<shape> inputs) const
{
......
......@@ -25,6 +25,15 @@ struct rnn
rnn_direction direction = rnn_direction::forward;
float clip = 0.0f;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.clip, "clip"));
}
std::string name() const { return "rnn"; }
shape compute_shape(std::vector<shape> inputs) const
{
......
......@@ -4,6 +4,7 @@
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -1166,5 +1167,14 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
}
}
namespace op {
std::ostream& operator<<(std::ostream& os, rnn_direction v)
{
std::vector<std::string> rnn_direction_str = {"forward", "reverse", "bidirectional"};
os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
return os;
}
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -124,6 +124,7 @@ struct tf_parser
add_mem_op("Reshape", &tf_parser::parse_reshape);
add_mem_op("Softmax", &tf_parser::parse_softmax);
add_mem_op("Squeeze", &tf_parser::parse_squeeze);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice);
}
template <class F>
......@@ -508,6 +509,46 @@ struct tf_parser
return prog.add_instruction(op, args[0]);
}
instruction_ref parse_stridedslice(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
op::slice op;
auto starts = args[1]->eval().get<int32_t>().to_vector();
auto ends = args[2]->eval().get<int32_t>().to_vector();
size_t num_axes = args[0]->get_shape().lens().size();
if(num_axes >= 4)
{
reorder_data(starts);
reorder_data(ends);
}
op.starts = std::vector<int64_t>(starts.begin(), starts.end());
op.ends = std::vector<int64_t>(ends.begin(), ends.end());
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
uint32_t shrink_axis_mask = 0;
uint32_t bitwise_compare = 1;
std::vector<int64_t> squeeze_axes;
if(contains(attributes, "shrink_axis_mask"))
shrink_axis_mask = static_cast<uint32_t>(attributes.at("shrink_axis_mask").i());
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to squeeze
if(((shrink_axis_mask >> i) & bitwise_compare) == 1)
squeeze_axes.push_back(i);
}
if(num_axes >= 4)
{
squeeze_axes = parse_axes(squeeze_axes);
}
auto l0 = prog.add_instruction(op, args[0]);
return prog.add_instruction(op::squeeze{squeeze_axes}, l0);
}
void parse_graph(const tensorflow::GraphDef& graph)
{
nodes = get_nodes(graph, input_nodes);
......
......@@ -154,7 +154,7 @@ TEST_CASE(rnn_test_one_direction)
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -339,7 +339,7 @@ TEST_CASE(gru_test_args)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
{migraphx::op::relu{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
......@@ -373,7 +373,10 @@ TEST_CASE(gru_test_args)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::relu{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
......@@ -414,14 +417,20 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
bias,
seq_len,
ih);
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx");
......@@ -445,15 +454,20 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{
hs, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
bias,
seq_len,
ih);
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx");
......@@ -479,7 +493,10 @@ TEST_CASE(gru_test_actv_funcs)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
......@@ -511,17 +528,20 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx");
......@@ -546,7 +566,10 @@ TEST_CASE(gru_test_actv_funcs)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::forward, clip},
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
......@@ -576,15 +599,17 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{
hs, {migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip},
seq,
w,
r,
bias,
seq_len,
ih);
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::relu{}, migraphx::op::relu{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx");
......@@ -826,7 +851,12 @@ TEST_CASE(lstm_forward_actv_func)
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{hs, {}, migraphx::op::rnn_direction::forward, clip, input_forget},
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
......@@ -851,19 +881,21 @@ TEST_CASE(lstm_forward_actv_func)
auto bias = p.add_parameter("bias", bias_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
bias,
und,
und,
und,
und);
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
bias,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f1af.onnx");
......@@ -881,20 +913,21 @@ TEST_CASE(lstm_forward_actv_func)
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
und,
und,
und);
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f2af.onnx");
......@@ -993,7 +1026,12 @@ TEST_CASE(lstm_reverse)
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{hs, {}, migraphx::op::rnn_direction::forward, clip, input_forget},
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip,
input_forget},
seq,
w,
r,
......@@ -1037,21 +1075,25 @@ TEST_CASE(lstm_bidirectional)
auto ic = p.add_parameter("c0", ih_shape);
auto pph = p.add_parameter("pph", pph_shape);
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi.onnx");
......@@ -1067,21 +1109,25 @@ TEST_CASE(lstm_bidirectional)
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi3args.onnx");
......@@ -1098,21 +1144,25 @@ TEST_CASE(lstm_bidirectional)
auto bias = p.add_parameter("bias", bias_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
und,
und,
und,
und);
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi4args.onnx");
......@@ -1130,21 +1180,25 @@ TEST_CASE(lstm_bidirectional)
auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
und,
und,
und);
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi5args.onnx");
......@@ -1163,21 +1217,25 @@ TEST_CASE(lstm_bidirectional)
auto ih = p.add_parameter("h0", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
und,
und);
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi6args.onnx");
......@@ -1197,21 +1255,25 @@ TEST_CASE(lstm_bidirectional)
auto ic = p.add_parameter("c0", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
ic,
und);
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
bias,
seq_len,
ih,
ic,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi7args.onnx");
......@@ -1244,17 +1306,25 @@ TEST_CASE(lstm_bi_actv_funcs)
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs, {}, migraphx::op::rnn_direction::bidirectional, clip, input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq,
w,
r,
und,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi0af.onnx");
......@@ -1273,7 +1343,12 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}},
{migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
......@@ -1304,7 +1379,12 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
......@@ -1337,6 +1417,8 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
......@@ -1376,6 +1458,7 @@ TEST_CASE(lstm_bi_actv_funcs)
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip,
......
......@@ -237,4 +237,28 @@ TEST_CASE(squeeze_test)
EXPECT(p == prog);
}
TEST_CASE(stridedslice_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
std::size_t num_axes = 4;
migraphx::op::slice op;
op.starts = {0, 0, 0, 0};
op.ends = {1, 5, 1, 1};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
// add literals for starts, ends, and strides in tf (NHWC format)
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0, 0, 0, 0});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 5});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 1});
auto l1 = p.add_instruction(op, l0);
auto shrink_axis = 2;
p.add_instruction(migraphx::op::squeeze{{shrink_axis}}, l1);
auto prog = migraphx::parse_tf("stridedslice_test.pb", true);
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment