Commit 7f728f6b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix test examples with more pack information added in onnx operator implementation.

parent af00eea8
...@@ -31,6 +31,8 @@ enum class rnn_direction ...@@ -31,6 +31,8 @@ enum class rnn_direction
bidirectional, bidirectional,
}; };
std::ostream& operator << (std::ostream& os, rnn_direction v);
} // namespace op } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -19,6 +19,13 @@ namespace op { ...@@ -19,6 +19,13 @@ namespace op {
struct gather struct gather
{ {
int axis = 0; int axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "gather"; } std::string name() const { return "gather"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
......
...@@ -27,6 +27,16 @@ struct gru ...@@ -27,6 +27,16 @@ struct gru
float clip = 0.0f; float clip = 0.0f;
int linear_before_reset = 0; int linear_before_reset = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden size"),
f(self.actv_funcs, "activation functions"),
f(self.direction, "direction"),
f(self.clip, "clip"),
f(self.linear_before_reset, "linear_before_reset"));
}
std::string name() const { return "gru"; } std::string name() const { return "gru"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -19,6 +19,13 @@ namespace op { ...@@ -19,6 +19,13 @@ namespace op {
struct logsoftmax struct logsoftmax
{ {
int axis = 1; int axis = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "logsoftmax"; } std::string name() const { return "logsoftmax"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -25,6 +25,15 @@ struct lstm ...@@ -25,6 +25,15 @@ struct lstm
float clip = 0.0f; float clip = 0.0f;
int input_forget = 0; int input_forget = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden size"),
f(self.actv_funcs, "activation functions"),
f(self.direction, "direction"),
f(self.input_forget, "input_forget"));
}
std::string name() const { return "lstm"; } std::string name() const { return "lstm"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -25,6 +25,15 @@ struct rnn ...@@ -25,6 +25,15 @@ struct rnn
rnn_direction direction = rnn_direction::forward; rnn_direction direction = rnn_direction::forward;
float clip = 0.0f; float clip = 0.0f;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden size"),
f(self.actv_funcs, "activation functions"),
f(self.direction, "direction"),
f(self.clip, "clip"));
}
std::string name() const { return "rnn"; } std::string name() const { return "rnn"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -1166,5 +1167,13 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const ...@@ -1166,5 +1167,13 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
} }
} }
namespace op {
std::ostream& operator << (std::ostream& os, rnn_direction v)
{
os << static_cast<std::underlying_type<rnn_direction>::type>(v);
return os;
}
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -154,7 +154,7 @@ TEST_CASE(rnn_test_one_direction) ...@@ -154,7 +154,7 @@ TEST_CASE(rnn_test_one_direction)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -339,7 +339,7 @@ TEST_CASE(gru_test_args) ...@@ -339,7 +339,7 @@ TEST_CASE(gru_test_args)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::relu{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
...@@ -373,7 +373,8 @@ TEST_CASE(gru_test_args) ...@@ -373,7 +373,8 @@ TEST_CASE(gru_test_args)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{},
migraphx::op::relu{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
...@@ -415,7 +416,8 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -415,7 +416,8 @@ TEST_CASE(gru_test_actv_funcs)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::bidirectional, clip}, migraphx::op::gru{hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
seq, seq,
w, w,
r, r,
...@@ -447,7 +449,8 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -447,7 +449,8 @@ TEST_CASE(gru_test_actv_funcs)
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::gru{ migraphx::op::gru{
hs, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip}, hs, {migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional, clip},
seq, seq,
w, w,
r, r,
...@@ -479,7 +482,8 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -479,7 +482,8 @@ TEST_CASE(gru_test_actv_funcs)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{},
migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
...@@ -513,7 +517,8 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -513,7 +517,8 @@ TEST_CASE(gru_test_actv_funcs)
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::gru{hs, migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{},
migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
...@@ -546,7 +551,8 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -546,7 +551,8 @@ TEST_CASE(gru_test_actv_funcs)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::forward, clip}, p.add_instruction(migraphx::op::gru{hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward, clip},
seq, seq,
w, w,
r, r,
...@@ -578,7 +584,7 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -578,7 +584,7 @@ TEST_CASE(gru_test_actv_funcs)
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::gru{ migraphx::op::gru{
hs, {migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip}, hs, {migraphx::op::relu{}, migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip},
seq, seq,
w, w,
r, r,
...@@ -826,7 +832,8 @@ TEST_CASE(lstm_forward_actv_func) ...@@ -826,7 +832,8 @@ TEST_CASE(lstm_forward_actv_func)
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::lstm{hs, {}, migraphx::op::rnn_direction::forward, clip, input_forget}, migraphx::op::lstm{hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward, clip, input_forget},
seq, seq,
w, w,
r, r,
...@@ -852,7 +859,7 @@ TEST_CASE(lstm_forward_actv_func) ...@@ -852,7 +859,7 @@ TEST_CASE(lstm_forward_actv_func)
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(migraphx::op::lstm{hs, auto out_hs = p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}}, {migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward, migraphx::op::rnn_direction::forward,
clip, clip,
input_forget}, input_forget},
...@@ -883,7 +890,7 @@ TEST_CASE(lstm_forward_actv_func) ...@@ -883,7 +890,7 @@ TEST_CASE(lstm_forward_actv_func)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::lstm{hs, p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward, migraphx::op::rnn_direction::forward,
clip, clip,
input_forget}, input_forget},
...@@ -993,7 +1000,8 @@ TEST_CASE(lstm_reverse) ...@@ -993,7 +1000,8 @@ TEST_CASE(lstm_reverse)
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::lstm{hs, {}, migraphx::op::rnn_direction::forward, clip, input_forget}, migraphx::op::lstm{hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse, clip, input_forget},
seq, seq,
w, w,
r, r,
...@@ -1040,7 +1048,8 @@ TEST_CASE(lstm_bidirectional) ...@@ -1040,7 +1048,8 @@ TEST_CASE(lstm_bidirectional)
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::lstm{ migraphx::op::lstm{
hs, hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
input_forget}, input_forget},
...@@ -1070,7 +1079,8 @@ TEST_CASE(lstm_bidirectional) ...@@ -1070,7 +1079,8 @@ TEST_CASE(lstm_bidirectional)
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::lstm{ migraphx::op::lstm{
hs, hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
input_forget}, input_forget},
...@@ -1101,7 +1111,8 @@ TEST_CASE(lstm_bidirectional) ...@@ -1101,7 +1111,8 @@ TEST_CASE(lstm_bidirectional)
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::lstm{ migraphx::op::lstm{
hs, hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
input_forget}, input_forget},
...@@ -1133,7 +1144,8 @@ TEST_CASE(lstm_bidirectional) ...@@ -1133,7 +1144,8 @@ TEST_CASE(lstm_bidirectional)
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::lstm{ migraphx::op::lstm{
hs, hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
input_forget}, input_forget},
...@@ -1166,7 +1178,8 @@ TEST_CASE(lstm_bidirectional) ...@@ -1166,7 +1178,8 @@ TEST_CASE(lstm_bidirectional)
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::lstm{ migraphx::op::lstm{
hs, hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
input_forget}, input_forget},
...@@ -1200,7 +1213,8 @@ TEST_CASE(lstm_bidirectional) ...@@ -1200,7 +1213,8 @@ TEST_CASE(lstm_bidirectional)
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::lstm{ migraphx::op::lstm{
hs, hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
input_forget}, input_forget},
...@@ -1246,7 +1260,9 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1246,7 +1260,9 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::lstm{ migraphx::op::lstm{
hs, {}, migraphx::op::rnn_direction::bidirectional, clip, input_forget}, hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, clip, input_forget},
seq, seq,
w, w,
r, r,
...@@ -1273,7 +1289,8 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1273,7 +1289,8 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::lstm{hs, p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}}, {migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{},
migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
input_forget}, input_forget},
...@@ -1304,7 +1321,8 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1304,7 +1321,8 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::lstm{hs, p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
input_forget}, input_forget},
...@@ -1337,6 +1355,8 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1337,6 +1355,8 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::lstm{hs, p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}, {migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::tanh{}}, migraphx::op::tanh{}},
...@@ -1376,6 +1396,7 @@ TEST_CASE(lstm_bi_actv_funcs) ...@@ -1376,6 +1396,7 @@ TEST_CASE(lstm_bi_actv_funcs)
migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment