Commit 4a244975 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 7f728f6b
...@@ -31,7 +31,7 @@ enum class rnn_direction ...@@ -31,7 +31,7 @@ enum class rnn_direction
bidirectional, bidirectional,
}; };
std::ostream& operator << (std::ostream& os, rnn_direction v); std::ostream& operator<<(std::ostream& os, rnn_direction v);
} // namespace op } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -1168,12 +1168,12 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const ...@@ -1168,12 +1168,12 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
} }
namespace op { namespace op {
std::ostream& operator << (std::ostream& os, rnn_direction v) std::ostream& operator<<(std::ostream& os, rnn_direction v)
{ {
os << static_cast<std::underlying_type<rnn_direction>::type>(v); os << static_cast<std::underlying_type<rnn_direction>::type>(v);
return os; return os;
} }
} } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -373,8 +373,10 @@ TEST_CASE(gru_test_args) ...@@ -373,8 +373,10 @@ TEST_CASE(gru_test_args)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, {migraphx::op::tanh{},
migraphx::op::relu{}, migraphx::op::tanh{}}, migraphx::op::sigmoid{},
migraphx::op::relu{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
...@@ -415,15 +417,20 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -415,15 +417,20 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::gru{hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, p.add_instruction(migraphx::op::gru{hs,
migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip}, {migraphx::op::sigmoid{},
seq, migraphx::op::tanh{},
w, migraphx::op::sigmoid{},
r, migraphx::op::tanh{}},
bias, migraphx::op::rnn_direction::bidirectional,
seq_len, clip},
ih); seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx");
...@@ -447,16 +454,20 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -447,16 +454,20 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::gru{ p.add_instruction(migraphx::op::gru{hs,
hs, {migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}}, {migraphx::op::sigmoid{},
migraphx::op::rnn_direction::bidirectional, clip}, migraphx::op::sigmoid{},
seq, migraphx::op::sigmoid{},
w, migraphx::op::sigmoid{}},
r, migraphx::op::rnn_direction::bidirectional,
bias, clip},
seq_len, seq,
ih); w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx");
...@@ -482,8 +493,10 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -482,8 +493,10 @@ TEST_CASE(gru_test_actv_funcs)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, {migraphx::op::tanh{},
migraphx::op::tanh{}, migraphx::op::sigmoid{}}, migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
...@@ -515,18 +528,20 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -515,18 +528,20 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, {migraphx::op::tanh{},
migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::sigmoid{},
migraphx::op::rnn_direction::bidirectional, migraphx::op::tanh{},
clip}, migraphx::op::tanh{}},
seq, migraphx::op::rnn_direction::bidirectional,
w, clip},
r, seq,
bias, w,
seq_len, r,
ih); bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx");
...@@ -551,8 +566,10 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -551,8 +566,10 @@ TEST_CASE(gru_test_actv_funcs)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, p.add_instruction(migraphx::op::gru{hs,
migraphx::op::rnn_direction::forward, clip}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq, seq,
w, w,
r, r,
...@@ -582,15 +599,17 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -582,15 +599,17 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::gru{ p.add_instruction(migraphx::op::gru{hs,
hs, {migraphx::op::relu{}, migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip}, {migraphx::op::relu{}, migraphx::op::relu{}},
seq, migraphx::op::rnn_direction::reverse,
w, clip},
r, seq,
bias, w,
seq_len, r,
ih); bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx");
...@@ -832,8 +851,12 @@ TEST_CASE(lstm_forward_actv_func) ...@@ -832,8 +851,12 @@ TEST_CASE(lstm_forward_actv_func)
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::lstm{hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::lstm{
migraphx::op::rnn_direction::forward, clip, input_forget}, hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
seq, seq,
w, w,
r, r,
...@@ -858,19 +881,21 @@ TEST_CASE(lstm_forward_actv_func) ...@@ -858,19 +881,21 @@ TEST_CASE(lstm_forward_actv_func)
auto bias = p.add_parameter("bias", bias_shape); auto bias = p.add_parameter("bias", bias_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(migraphx::op::lstm{hs, auto out_hs = p.add_instruction(
{migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}}, migraphx::op::lstm{
migraphx::op::rnn_direction::forward, hs,
clip, {migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
input_forget}, migraphx::op::rnn_direction::forward,
seq, clip,
w, input_forget},
r, seq,
bias, w,
und, r,
und, bias,
und, und,
und); und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f1af.onnx"); auto prog = migraphx::parse_onnx("onnx_lstm_f1af.onnx");
...@@ -888,20 +913,21 @@ TEST_CASE(lstm_forward_actv_func) ...@@ -888,20 +913,21 @@ TEST_CASE(lstm_forward_actv_func)
auto seq_len = p.add_parameter("seq_len", sl_shape); auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::lstm{hs, migraphx::op::lstm{
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}}, hs,
migraphx::op::rnn_direction::forward, {migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
clip, migraphx::op::rnn_direction::forward,
input_forget}, clip,
seq, input_forget},
w, seq,
r, w,
bias, r,
seq_len, bias,
und, seq_len,
und, und,
und); und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_f2af.onnx"); auto prog = migraphx::parse_onnx("onnx_lstm_f2af.onnx");
...@@ -1000,8 +1026,12 @@ TEST_CASE(lstm_reverse) ...@@ -1000,8 +1026,12 @@ TEST_CASE(lstm_reverse)
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::lstm{hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::lstm{
migraphx::op::rnn_direction::reverse, clip, input_forget}, hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip,
input_forget},
seq, seq,
w, w,
r, r,
...@@ -1045,22 +1075,25 @@ TEST_CASE(lstm_bidirectional) ...@@ -1045,22 +1075,25 @@ TEST_CASE(lstm_bidirectional)
auto ic = p.add_parameter("c0", ih_shape); auto ic = p.add_parameter("c0", ih_shape);
auto pph = p.add_parameter("pph", pph_shape); auto pph = p.add_parameter("pph", pph_shape);
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::lstm{ p.add_instruction(migraphx::op::lstm{hs,
hs, {migraphx::op::sigmoid{},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::tanh{},
migraphx::op::rnn_direction::bidirectional, migraphx::op::sigmoid{},
clip, migraphx::op::tanh{},
input_forget}, migraphx::op::tanh{}},
seq, migraphx::op::rnn_direction::bidirectional,
w, clip,
r, input_forget},
bias, seq,
seq_len, w,
ih, r,
ic, bias,
pph); seq_len,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi.onnx"); auto prog = migraphx::parse_onnx("onnx_lstm_bi.onnx");
...@@ -1076,22 +1109,25 @@ TEST_CASE(lstm_bidirectional) ...@@ -1076,22 +1109,25 @@ TEST_CASE(lstm_bidirectional)
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::lstm{ p.add_instruction(migraphx::op::lstm{hs,
hs, {migraphx::op::sigmoid{},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::tanh{},
migraphx::op::rnn_direction::bidirectional, migraphx::op::sigmoid{},
clip, migraphx::op::tanh{},
input_forget}, migraphx::op::tanh{}},
seq, migraphx::op::rnn_direction::bidirectional,
w, clip,
r, input_forget},
und, seq,
und, w,
und, r,
und, und,
und); und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi3args.onnx"); auto prog = migraphx::parse_onnx("onnx_lstm_bi3args.onnx");
...@@ -1108,22 +1144,25 @@ TEST_CASE(lstm_bidirectional) ...@@ -1108,22 +1144,25 @@ TEST_CASE(lstm_bidirectional)
auto bias = p.add_parameter("bias", bias_shape); auto bias = p.add_parameter("bias", bias_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::lstm{ p.add_instruction(migraphx::op::lstm{hs,
hs, {migraphx::op::sigmoid{},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::tanh{},
migraphx::op::rnn_direction::bidirectional, migraphx::op::sigmoid{},
clip, migraphx::op::tanh{},
input_forget}, migraphx::op::tanh{}},
seq, migraphx::op::rnn_direction::bidirectional,
w, clip,
r, input_forget},
bias, seq,
und, w,
und, r,
und, bias,
und); und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi4args.onnx"); auto prog = migraphx::parse_onnx("onnx_lstm_bi4args.onnx");
...@@ -1141,22 +1180,25 @@ TEST_CASE(lstm_bidirectional) ...@@ -1141,22 +1180,25 @@ TEST_CASE(lstm_bidirectional)
auto seq_len = p.add_parameter("seq_len", sl_shape); auto seq_len = p.add_parameter("seq_len", sl_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::lstm{ p.add_instruction(migraphx::op::lstm{hs,
hs, {migraphx::op::sigmoid{},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::tanh{},
migraphx::op::rnn_direction::bidirectional, migraphx::op::sigmoid{},
clip, migraphx::op::tanh{},
input_forget}, migraphx::op::tanh{}},
seq, migraphx::op::rnn_direction::bidirectional,
w, clip,
r, input_forget},
bias, seq,
seq_len, w,
und, r,
und, bias,
und); seq_len,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi5args.onnx"); auto prog = migraphx::parse_onnx("onnx_lstm_bi5args.onnx");
...@@ -1175,22 +1217,25 @@ TEST_CASE(lstm_bidirectional) ...@@ -1175,22 +1217,25 @@ TEST_CASE(lstm_bidirectional)
auto ih = p.add_parameter("h0", ih_shape); auto ih = p.add_parameter("h0", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::lstm{ p.add_instruction(migraphx::op::lstm{hs,
hs, {migraphx::op::sigmoid{},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::tanh{},
migraphx::op::rnn_direction::bidirectional, migraphx::op::sigmoid{},
clip, migraphx::op::tanh{},
input_forget}, migraphx::op::tanh{}},
seq, migraphx::op::rnn_direction::bidirectional,
w, clip,
r, input_forget},
bias, seq,
seq_len, w,
ih, r,
und, bias,
und); seq_len,
ih,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi6args.onnx"); auto prog = migraphx::parse_onnx("onnx_lstm_bi6args.onnx");
...@@ -1210,22 +1255,25 @@ TEST_CASE(lstm_bidirectional) ...@@ -1210,22 +1255,25 @@ TEST_CASE(lstm_bidirectional)
auto ic = p.add_parameter("c0", ih_shape); auto ic = p.add_parameter("c0", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::lstm{ p.add_instruction(migraphx::op::lstm{hs,
hs, {migraphx::op::sigmoid{},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::tanh{},
migraphx::op::rnn_direction::bidirectional, migraphx::op::sigmoid{},
clip, migraphx::op::tanh{},
input_forget}, migraphx::op::tanh{}},
seq, migraphx::op::rnn_direction::bidirectional,
w, clip,
r, input_forget},
bias, seq,
seq_len, w,
ih, r,
ic, bias,
und); seq_len,
ih,
ic,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi7args.onnx"); auto prog = migraphx::parse_onnx("onnx_lstm_bi7args.onnx");
...@@ -1258,19 +1306,25 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1258,19 +1306,25 @@ TEST_CASE(lstm_bi_actv_funcs)
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::lstm{ p.add_instruction(migraphx::op::lstm{hs,
hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}, {migraphx::op::sigmoid{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::tanh{},
migraphx::op::rnn_direction::bidirectional, clip, input_forget}, migraphx::op::tanh{},
seq, migraphx::op::sigmoid{},
w, migraphx::op::tanh{},
r, migraphx::op::tanh{}},
und, migraphx::op::rnn_direction::bidirectional,
und, clip,
und, input_forget},
und, seq,
und); w,
r,
und,
und,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_lstm_bi0af.onnx"); auto prog = migraphx::parse_onnx("onnx_lstm_bi0af.onnx");
...@@ -1289,8 +1343,12 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1289,8 +1343,12 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::lstm{hs, p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, {migraphx::op::sigmoid{},
migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}}, migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
input_forget}, input_forget},
...@@ -1321,8 +1379,12 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1321,8 +1379,12 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::lstm{hs, p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}, {migraphx::op::sigmoid{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
input_forget}, input_forget},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment