"vscode:/vscode.git/clone" did not exist on "cd6e312c2d300625b2ec4c7b63f29c160bb4ed56"
Commit 7f728f6b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix test examples with more pack information added in onnx operator implementation.

parent af00eea8
......@@ -31,6 +31,8 @@ enum class rnn_direction
bidirectional,
};
std::ostream& operator << (std::ostream& os, rnn_direction v);
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -19,6 +19,13 @@ namespace op {
struct gather
{
int axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "gather"; }
shape compute_shape(std::vector<shape> inputs) const
......
......@@ -27,6 +27,16 @@ struct gru
float clip = 0.0f;
int linear_before_reset = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden size"),
f(self.actv_funcs, "activation functions"),
f(self.direction, "direction"),
f(self.clip, "clip"),
f(self.linear_before_reset, "linear_before_reset"));
}
std::string name() const { return "gru"; }
shape compute_shape(std::vector<shape> inputs) const
{
......
......@@ -19,6 +19,13 @@ namespace op {
struct logsoftmax
{
int axis = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "logsoftmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
......
......@@ -25,6 +25,15 @@ struct lstm
float clip = 0.0f;
int input_forget = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden size"),
f(self.actv_funcs, "activation functions"),
f(self.direction, "direction"),
f(self.input_forget, "input_forget"));
}
std::string name() const { return "lstm"; }
shape compute_shape(std::vector<shape> inputs) const
{
......
......@@ -25,6 +25,15 @@ struct rnn
rnn_direction direction = rnn_direction::forward;
float clip = 0.0f;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden size"),
f(self.actv_funcs, "activation functions"),
f(self.direction, "direction"),
f(self.clip, "clip"));
}
std::string name() const { return "rnn"; }
shape compute_shape(std::vector<shape> inputs) const
{
......
......@@ -4,6 +4,7 @@
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -1166,5 +1167,13 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
}
}
namespace op {
std::ostream& operator << (std::ostream& os, rnn_direction v)
{
os << static_cast<std::underlying_type<rnn_direction>::type>(v);
return os;
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -154,7 +154,7 @@ TEST_CASE(rnn_test_one_direction)
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -339,7 +339,7 @@ TEST_CASE(gru_test_args)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
{migraphx::op::relu{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
......@@ -373,7 +373,8 @@ TEST_CASE(gru_test_args)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
{migraphx::op::tanh{}, migraphx::op::sigmoid{},
migraphx::op::relu{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
......@@ -415,7 +416,8 @@ TEST_CASE(gru_test_actv_funcs)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::bidirectional, clip},
migraphx::op::gru{hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
......@@ -447,7 +449,8 @@ TEST_CASE(gru_test_actv_funcs)
auto out_hs = p.add_instruction(
migraphx::op::gru{
hs, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
hs, {migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
......@@ -479,7 +482,8 @@ TEST_CASE(gru_test_actv_funcs)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
{migraphx::op::tanh{}, migraphx::op::sigmoid{},
migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
......@@ -513,7 +517,8 @@ TEST_CASE(gru_test_actv_funcs)
auto out_hs = p.add_instruction(
migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}},
{migraphx::op::tanh{}, migraphx::op::sigmoid{},
migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
......@@ -546,7 +551,8 @@ TEST_CASE(gru_test_actv_funcs)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::forward, clip},
p.add_instruction(migraphx::op::gru{hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward, clip},
seq,
w,
r,
......@@ -578,7 +584,7 @@ TEST_CASE(gru_test_actv_funcs)
auto out_hs = p.add_instruction(
migraphx::op::gru{
hs, {migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip},
hs, {migraphx::op::relu{}, migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip},
seq,
w,
r,
......@@ -826,7 +832,8 @@ TEST_CASE(lstm_forward_actv_func)
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{hs, {}, migraphx::op::rnn_direction::forward, clip, input_forget},
migraphx::op::lstm{hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward, clip, input_forget},
seq,
w,
r,
......@@ -852,7 +859,7 @@ TEST_CASE(lstm_forward_actv_func)
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}},
{migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
......@@ -883,7 +890,7 @@ TEST_CASE(lstm_forward_actv_func)
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip,
input_forget},
......@@ -993,7 +1000,8 @@ TEST_CASE(lstm_reverse)
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::lstm{hs, {}, migraphx::op::rnn_direction::forward, clip, input_forget},
migraphx::op::lstm{hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse, clip, input_forget},
seq,
w,
r,
......@@ -1040,7 +1048,8 @@ TEST_CASE(lstm_bidirectional)
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
......@@ -1070,7 +1079,8 @@ TEST_CASE(lstm_bidirectional)
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
......@@ -1101,7 +1111,8 @@ TEST_CASE(lstm_bidirectional)
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
......@@ -1133,7 +1144,8 @@ TEST_CASE(lstm_bidirectional)
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
......@@ -1166,7 +1178,8 @@ TEST_CASE(lstm_bidirectional)
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
......@@ -1200,7 +1213,8 @@ TEST_CASE(lstm_bidirectional)
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
......@@ -1246,7 +1260,9 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs = p.add_instruction(
migraphx::op::lstm{
hs, {}, migraphx::op::rnn_direction::bidirectional, clip, input_forget},
hs, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional, clip, input_forget},
seq,
w,
r,
......@@ -1273,7 +1289,8 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}},
{migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{},
migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
......@@ -1304,7 +1321,8 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
input_forget},
......@@ -1337,6 +1355,8 @@ TEST_CASE(lstm_bi_actv_funcs)
auto out_hs =
p.add_instruction(migraphx::op::lstm{hs,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{}},
......@@ -1376,6 +1396,7 @@ TEST_CASE(lstm_bi_actv_funcs)
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment