Commit d63dbddc authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into pass_manager

parents 6a141efa c2527321
...@@ -31,6 +31,8 @@ enum class rnn_direction ...@@ -31,6 +31,8 @@ enum class rnn_direction
bidirectional, bidirectional,
}; };
std::ostream& operator<<(std::ostream& os, rnn_direction v);
} // namespace op } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -19,6 +19,13 @@ namespace op { ...@@ -19,6 +19,13 @@ namespace op {
struct gather struct gather
{ {
int axis = 0; int axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "gather"; } std::string name() const { return "gather"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
......
...@@ -27,6 +27,16 @@ struct gru ...@@ -27,6 +27,16 @@ struct gru
float clip = 0.0f; float clip = 0.0f;
int linear_before_reset = 0; int linear_before_reset = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.clip, "clip"),
f(self.linear_before_reset, "linear_before_reset"));
}
std::string name() const { return "gru"; } std::string name() const { return "gru"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -19,6 +19,13 @@ namespace op { ...@@ -19,6 +19,13 @@ namespace op {
struct logsoftmax struct logsoftmax
{ {
int axis = 1; int axis = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "logsoftmax"; } std::string name() const { return "logsoftmax"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -25,6 +25,15 @@ struct lstm ...@@ -25,6 +25,15 @@ struct lstm
float clip = 0.0f; float clip = 0.0f;
int input_forget = 0; int input_forget = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.input_forget, "input_forget"));
}
std::string name() const { return "lstm"; } std::string name() const { return "lstm"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -25,6 +25,15 @@ struct rnn ...@@ -25,6 +25,15 @@ struct rnn
rnn_direction direction = rnn_direction::forward; rnn_direction direction = rnn_direction::forward;
float clip = 0.0f; float clip = 0.0f;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.clip, "clip"));
}
std::string name() const { return "rnn"; } std::string name() const { return "rnn"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -1166,5 +1167,14 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const ...@@ -1166,5 +1167,14 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
} }
} }
namespace op {
std::ostream& operator<<(std::ostream& os, rnn_direction v)
{
std::vector<std::string> rnn_direction_str = {"forward", "reverse", "bidirectional"};
os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
return os;
}
} // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -363,6 +363,16 @@ struct tf_parser ...@@ -363,6 +363,16 @@ struct tf_parser
int64_t axis = 0; int64_t axis = 0;
if(contains(attributes, "axis")) if(contains(attributes, "axis"))
axis = attributes.at("axis").i(); axis = attributes.at("axis").i();
size_t input_size = args.front()->get_shape().lens().size();
if(axis > input_size)
{
MIGRAPHX_THROW("TF_PARSER: axis value of " + to_string(axis) +
" must be smaller than input size " + to_string(input_size));
}
// check if input arg needs axis to be converted to NCHW
if(input_size >= 4)
axis = parse_axis(axis);
std::transform( std::transform(
args.begin(), args.begin(),
args.end(), args.end(),
......
...@@ -154,7 +154,7 @@ TEST_CASE(rnn_test_one_direction) ...@@ -154,7 +154,7 @@ TEST_CASE(rnn_test_one_direction)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -339,7 +339,7 @@ TEST_CASE(gru_test_args) ...@@ -339,7 +339,7 @@ TEST_CASE(gru_test_args)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::relu{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
...@@ -373,7 +373,10 @@ TEST_CASE(gru_test_args) ...@@ -373,7 +373,10 @@ TEST_CASE(gru_test_args)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::relu{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
...@@ -414,8 +417,14 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -414,8 +417,14 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::bidirectional, clip}, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq, seq,
w, w,
r, r,
...@@ -445,9 +454,14 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -445,9 +454,14 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::gru{ p.add_instruction(migraphx::op::gru{hs,
hs, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip}, {migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq, seq,
w, w,
r, r,
...@@ -479,7 +493,10 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -479,7 +493,10 @@ TEST_CASE(gru_test_actv_funcs)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
...@@ -511,9 +528,12 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -511,9 +528,12 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
...@@ -546,7 +566,10 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -546,7 +566,10 @@ TEST_CASE(gru_test_actv_funcs)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::forward, clip}, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq, seq,
w, w,
r, r,
...@@ -576,9 +599,11 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -576,9 +599,11 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::gru{ p.add_instruction(migraphx::op::gru{hs,
hs, {migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip}, {migraphx::op::relu{}, migraphx::op::relu{}},
migraphx::op::rnn_direction::reverse,
clip},
seq, seq,
w, w,
r, r,
...@@ -826,7 +851,12 @@ TEST_CASE(lstm_forward_actv_func) ...@@ -826,7 +851,12 @@ TEST_CASE(lstm_forward_actv_func)
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::lstm{hs, {}, migraphx::op::rnn_direction::forward, clip, input_forget}, migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq, seq,
w, w,
r, r,
...@@ -851,8 +881,10 @@ TEST_CASE(lstm_forward_actv_func) ...@@ -851,8 +881,10 @@ TEST_CASE(lstm_forward_actv_func)
auto bias = p.add_parameter("bias", bias_shape); auto bias = p.add_parameter("bias", bias_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(migraphx::op::lstm{hs, auto out_hs = p.add_instruction(
{migraphx::op::sigmoid{}}, migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward, migraphx::op::rnn_direction::forward,
clip, clip,
input_forget}, input_forget},
...@@ -881,9 +913,10 @@ TEST_CASE(lstm_forward_actv_func) ...@@ -881,9 +913,10 @@ TEST_CASE(lstm_forward_actv_func)
auto seq_len = p.add_parameter("seq_len", sl_shape); auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::lstm{hs, migraphx::op::lstm{
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward, migraphx::op::rnn_direction::forward,
clip, clip,
input_forget}, input_forget},
...@@ -993,7 +1026,12 @@ TEST_CASE(lstm_reverse) ...@@ -993,7 +1026,12 @@ TEST_CASE(lstm_reverse)
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::lstm{hs, {}, migraphx::op::rnn_direction::forward, clip, input_forget}, migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip,
input_forget},
seq, seq,
w, w,
r, r,
...@@ -1037,10 +1075,14 @@ TEST_CASE(lstm_bidirectional) ...@@ -1037,10 +1075,14 @@ TEST_CASE(lstm_bidirectional)
auto ic = p.add_parameter("c0", ih_shape); auto ic = p.add_parameter("c0", ih_shape);
auto pph = p.add_parameter("pph", pph_shape); auto pph = p.add_parameter("pph", pph_shape);
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::lstm{ p.add_instruction(migraphx::op::lstm{hs,
hs, {migraphx::op::sigmoid{},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
input_forget}, input_forget},
...@@ -1067,10 +1109,14 @@ TEST_CASE(lstm_bidirectional) ...@@ -1067,10 +1109,14 @@ TEST_CASE(lstm_bidirectional)
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::lstm{ p.add_instruction(migraphx::op::lstm{hs,
hs, {migraphx::op::sigmoid{},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
input_forget}, input_forget},
...@@ -1098,10 +1144,14 @@ TEST_CASE(lstm_bidirectional) ...@@ -1098,10 +1144,14 @@ TEST_CASE(lstm_bidirectional)
auto bias = p.add_parameter("bias", bias_shape); auto bias = p.add_parameter("bias", bias_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::lstm{ p.add_instruction(migraphx::op::lstm{hs,
hs, {migraphx::op::sigmoid{},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
input_forget}, input_forget},
...@@ -1130,10 +1180,14 @@ TEST_CASE(lstm_bidirectional) ...@@ -1130,10 +1180,14 @@ TEST_CASE(lstm_bidirectional)
auto seq_len = p.add_parameter("seq_len", sl_shape); auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::lstm{ p.add_instruction(migraphx::op::lstm{hs,
hs, {migraphx::op::sigmoid{},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
input_forget}, input_forget},
...@@ -1163,10 +1217,14 @@ TEST_CASE(lstm_bidirectional) ...@@ -1163,10 +1217,14 @@ TEST_CASE(lstm_bidirectional)
auto ih = p.add_parameter("h0", ih_shape); auto ih = p.add_parameter("h0", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::lstm{ p.add_instruction(migraphx::op::lstm{hs,
hs, {migraphx::op::sigmoid{},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
input_forget}, input_forget},
...@@ -1197,10 +1255,14 @@ TEST_CASE(lstm_bidirectional) ...@@ -1197,10 +1255,14 @@ TEST_CASE(lstm_bidirectional)
auto ic = p.add_parameter("c0", ih_shape); auto ic = p.add_parameter("c0", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::lstm{ p.add_instruction(migraphx::op::lstm{hs,
hs, {migraphx::op::sigmoid{},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
input_forget}, input_forget},
...@@ -1244,9 +1306,17 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1244,9 +1306,17 @@ TEST_CASE(lstm_bi_actv_funcs)
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::lstm{ p.add_instruction(migraphx::op::lstm{hs,
hs, {}, migraphx::op::rnn_direction::bidirectional, clip, input_forget}, {migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
seq, seq,
w, w,
r, r,
...@@ -1273,7 +1343,12 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1273,7 +1343,12 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::lstm{hs, p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}}, {migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
input_forget}, input_forget},
...@@ -1304,7 +1379,12 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1304,7 +1379,12 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::lstm{hs, p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
input_forget}, input_forget},
...@@ -1337,6 +1417,8 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1337,6 +1417,8 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::lstm{hs, p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}, {migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::tanh{}}, migraphx::op::tanh{}},
...@@ -1376,6 +1458,7 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1376,6 +1458,7 @@ TEST_CASE(lstm_bi_actv_funcs)
migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
......
:
0 Placeholder*
dtype0*
shape:
:
1 Placeholder*
dtype0*
shape:
:
2 Placeholder*
dtype0*
shape:
4
pack1Pack012*
T0*
axis*
N"
\ No newline at end of file
...@@ -151,6 +151,28 @@ TEST_CASE(pack_test) ...@@ -151,6 +151,28 @@ TEST_CASE(pack_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(pack_test_nhwc)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
std::vector<migraphx::instruction_ref> args{l0, l1, l2};
std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t nchw_axis = 1;
std::transform(args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](migraphx::instruction_ref arg) {
return p.add_instruction(migraphx::op::unsqueeze{{nchw_axis}}, arg);
});
p.add_instruction(migraphx::op::concat{static_cast<size_t>(nchw_axis)}, unsqueezed_args);
auto prog = migraphx::parse_tf("pack_test_nhwc.pb", true);
EXPECT(p == prog);
}
TEST_CASE(pooling_test) TEST_CASE(pooling_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment