Commit 7f728f6b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix test examples with more pack information added in onnx operator implementation.

parent af00eea8
......@@ -31,6 +31,8 @@ enum class rnn_direction
bidirectional,
};
std::ostream& operator << (std::ostream& os, rnn_direction v);
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -19,6 +19,13 @@ namespace op {
struct gather
{
int axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "gather"; }
shape compute_shape(std::vector<shape> inputs) const
......
......@@ -27,6 +27,16 @@ struct gru
float clip = 0.0f;
int linear_before_reset = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden size"),
f(self.actv_funcs, "activation functions"),
f(self.direction, "direction"),
f(self.clip, "clip"),
f(self.linear_before_reset, "linear_before_reset"));
}
std::string name() const { return "gru"; }
shape compute_shape(std::vector<shape> inputs) const
{
......
......@@ -19,6 +19,13 @@ namespace op {
struct logsoftmax
{
int axis = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "logsoftmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
......
......@@ -25,6 +25,15 @@ struct lstm
float clip = 0.0f;
int input_forget = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden size"),
f(self.actv_funcs, "activation functions"),
f(self.direction, "direction"),
f(self.input_forget, "input_forget"));
}
std::string name() const { return "lstm"; }
shape compute_shape(std::vector<shape> inputs) const
{
......
......@@ -25,6 +25,15 @@ struct rnn
rnn_direction direction = rnn_direction::forward;
float clip = 0.0f;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden size"),
f(self.actv_funcs, "activation functions"),
f(self.direction, "direction"),
f(self.clip, "clip"));
}
std::string name() const { return "rnn"; }
shape compute_shape(std::vector<shape> inputs) const
{
......
......@@ -4,6 +4,7 @@
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -1166,5 +1167,13 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
}
}
namespace op {
std::ostream& operator << (std::ostream& os, rnn_direction v)
{
os << static_cast<std::underlying_type<rnn_direction>::type>(v);
return os;
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -154,7 +154,7 @@ TEST_CASE(rnn_test_one_direction)
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -339,7 +339,7 @@ TEST_CASE(gru_test_args)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
{migraphx::op::relu{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
......@@ -373,7 +373,8 @@ TEST_CASE(gru_test_args)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
{migraphx::op::tanh{}, migraphx::op::sigmoid{},
migraphx::op::relu{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
......@@ -415,7 +416,8 @@ TEST_CASE(gru_test_actv_funcs)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::bidirectional, clip},
migraphx::op::gru{hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
......@@ -447,7 +449,8 @@ TEST_CASE(gru_test_actv_funcs)
auto out_hs = p.add_instruction(
migraphx::op::gru{
hs, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
hs, {migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
......@@ -479,7 +482,8 @@ TEST_CASE(gru_test_actv_funcs)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
{migraphx::op::tanh{}, migraphx::op::sigmoid{},
migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
......@@ -513,7 +517,8 @@ TEST_CASE(gru_test_actv_funcs)
auto out_hs = p.add_instruction(
migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}},
{migraphx::op::tanh{}, migraphx::op::sigmoid{},
migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
......@@ -546,7 +551,8 @@ TEST_CASE(gru_test_actv_funcs)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::forward, clip},
p.add_instruction(migraphx::op::gru{hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward, clip},
seq,
w,
r,
......@@ -578,7 +584,7 @@ TEST_CASE(gru_test_actv_funcs)
auto out_hs = p.add_instruction(
migraphx::op::gru{
hs, {migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip},
hs, {migraphx::op::relu{}, migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip},
seq,
w,
r,
......@@ -826,7 +832,8 @@ TEST_CASE(lstm_forward_actv_func)
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{hs, {}, migraphx::op::rnn_direction::forward, clip, input_forget},
migraphx::op::lstm{hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward, clip, input_forget},
seq,
w,
r,
......@@ -852,7 +859,7 @@ TEST_CASE(lstm_forward_actv_func)
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}},
{migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
......@@ -883,7 +890,7 @@ TEST_CASE(lstm_forward_actv_func)
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
......@@ -993,7 +1000,8 @@ TEST_CASE(lstm_reverse)
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{hs, {}, migraphx::op::rnn_direction::forward, clip, input_forget},
migraphx::op::lstm{hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse, clip, input_forget},
seq,
w,
r,
......@@ -1040,7 +1048,8 @@ TEST_CASE(lstm_bidirectional)
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
......@@ -1070,7 +1079,8 @@ TEST_CASE(lstm_bidirectional)
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
......@@ -1101,7 +1111,8 @@ TEST_CASE(lstm_bidirectional)
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
......@@ -1133,7 +1144,8 @@ TEST_CASE(lstm_bidirectional)
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
......@@ -1166,7 +1178,8 @@ TEST_CASE(lstm_bidirectional)
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
......@@ -1200,7 +1213,8 @@ TEST_CASE(lstm_bidirectional)
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
......@@ -1246,7 +1260,9 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs, {}, migraphx::op::rnn_direction::bidirectional, clip, input_forget},
hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, clip, input_forget},
seq,
w,
r,
......@@ -1273,7 +1289,8 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}},
{migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{},
migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
......@@ -1304,7 +1321,8 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
......@@ -1337,6 +1355,8 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
......@@ -1376,6 +1396,7 @@ TEST_CASE(lstm_bi_actv_funcs)
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment