Unverified Commit 6972ad26 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge pull request #170 from ROCmSoftwarePlatform/gru_operator

Gru operator
parents fb8fda8f bce629f1
...@@ -1164,20 +1164,20 @@ struct outline ...@@ -1164,20 +1164,20 @@ struct outline
argument compute(const shape&, const std::vector<argument>&) const { return {s, nullptr}; } argument compute(const shape&, const std::vector<argument>&) const { return {s, nullptr}; }
}; };
struct rnn // indicate rnn computation direction
enum class rnn_direction
{ {
forward,
reverse,
bidirectional,
};
enum rnn_direction_t struct rnn
{ {
forward,
reverse,
bidirectional,
};
std::size_t hidden_size = 1; std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{tanh{}, tanh{}}; std::vector<operation> actv_funcs{tanh{}, tanh{}};
rnn_direction_t direction = forward; rnn_direction direction = rnn_direction::forward;
float clip = 0.0f; float clip = 0.0f;
std::string name() const { return "rnn"; } std::string name() const { return "rnn"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
...@@ -1190,7 +1190,7 @@ struct rnn ...@@ -1190,7 +1190,7 @@ struct rnn
} }
std::size_t num_directions = 1; std::size_t num_directions = 1;
if(direction == bidirectional) if(direction == rnn_direction::bidirectional)
{ {
num_directions = 2; num_directions = 2;
} }
...@@ -1222,6 +1222,43 @@ struct rnn_last_output ...@@ -1222,6 +1222,43 @@ struct rnn_last_output
} }
}; };
struct gru
{
std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{sigmoid{}, tanh{}};
rnn_direction direction = rnn_direction::forward;
float clip = 0.0f;
int linear_before_reset = 0;
std::string name() const { return "gru"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto in_dims = inputs[0].lens();
auto hidden_dims = inputs[2].lens();
if(hidden_size != hidden_dims[2])
{
MIGRAPHX_THROW("GRU: hidden size mismatch in attribute and input");
}
std::size_t num_directions = 1;
if(direction == rnn_direction::bidirectional)
{
num_directions = 2;
}
if(num_directions != hidden_dims[0])
{
MIGRAPHX_THROW("GRU: num_direction does not match the direction attribute");
}
std::vector<std::size_t> out_dims(in_dims);
out_dims.insert(out_dims.begin() + 1, num_directions);
out_dims.back() = hidden_size;
return {inputs[0].type(), out_dims};
}
};
struct undefined struct undefined
{ {
std::string name() const { return "undefined"; } std::string name() const { return "undefined"; }
......
...@@ -21,17 +21,30 @@ struct rewrite_rnn ...@@ -21,17 +21,30 @@ struct rewrite_rnn
void apply(program& prog) const; void apply(program& prog) const;
private: private:
std::vector<instruction_ref> rnn_cell(bool is_forward, // for vanilla rnn operators
void apply_vanilla_rnn(program& prog, instruction_ref ins) const;
std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref w,
instruction_ref r,
instruction_ref bias,
instruction_ref ih,
operation& actv_func) const;
std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const;
// for gru operators
void apply_gru(program& prog, instruction_ref ins) const;
std::vector<instruction_ref> gru_cell(bool is_forward,
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
instruction_ref input, std::vector<instruction_ref> inputs,
instruction_ref w, int linear_before_reset,
instruction_ref r, const operation& actv_func1,
instruction_ref bias, const operation& actv_func2) const;
instruction_ref ih,
operation& actv_func) const; std::vector<operation> gru_actv_funcs(instruction_ref ins) const;
std::vector<operation> compute_actv_funcs(instruction_ref ins) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -88,6 +88,7 @@ struct onnx_parser ...@@ -88,6 +88,7 @@ struct onnx_parser
add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill); add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
add_mem_op("Transpose", &onnx_parser::parse_transpose); add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("RNN", &onnx_parser::parse_rnn); add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("Pad", &onnx_parser::parse_pad); add_mem_op("Pad", &onnx_parser::parse_pad);
// init the activation function map // init the activation function map
...@@ -715,8 +716,7 @@ struct onnx_parser ...@@ -715,8 +716,7 @@ struct onnx_parser
parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
migraphx::shape input_shape = args[0]->get_shape(); migraphx::shape input_shape = args[0]->get_shape();
migraphx::shape w_shape = args[1]->get_shape(); std::size_t hidden_size = args[1]->get_shape().lens()[1];
std::size_t hidden_size = w_shape.lens()[1];
if(contains(attributes, "hidden_size")) if(contains(attributes, "hidden_size"))
{ {
...@@ -734,14 +734,14 @@ struct onnx_parser ...@@ -734,14 +734,14 @@ struct onnx_parser
direction = attributes.at("direction").s(); direction = attributes.at("direction").s();
} }
op::rnn::rnn_direction_t dirct = op::rnn::forward; op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional") if(direction == "bidirectional")
{ {
dirct = op::rnn::bidirectional; dirct = op::rnn_direction::bidirectional;
} }
else if(direction == "reverse") else if(direction == "reverse")
{ {
dirct = op::rnn::reverse; dirct = op::rnn_direction::reverse;
} }
std::vector<std::string> vec_names{"tanh"}; std::vector<std::string> vec_names{"tanh"};
...@@ -763,7 +763,7 @@ struct onnx_parser ...@@ -763,7 +763,7 @@ struct onnx_parser
// one is for forward, and the other is for reverse. // one is for forward, and the other is for reverse.
// if only one actv function is provided, we use it in both // if only one actv function is provided, we use it in both
// forward and reverse direction // forward and reverse direction
if(dirct == op::rnn::bidirectional) if(dirct == op::rnn_direction::bidirectional)
{ {
if(vec_names.size() == 1) if(vec_names.size() == 1)
{ {
...@@ -801,6 +801,125 @@ struct onnx_parser ...@@ -801,6 +801,125 @@ struct onnx_parser
return {hidden_states, last_output}; return {hidden_states, last_output};
} }
std::vector<instruction_ref>
parse_gru(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2];
if(contains(attributes, "hidden_size"))
{
std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att)
{
MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
}
}
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(attributes, "direction"))
{
direction = attributes.at("direction").s();
}
op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional")
{
dirct = op::rnn_direction::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::rnn_direction::reverse;
}
std::vector<std::string> vec_names = {"sigmoid", "tanh"};
if(contains(attributes, "activations"))
{
auto names = attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::transform(
names.begin(), names.end(), vec_names.begin(), [](auto& str) { return str; });
}
// need 4 activation functions
if(dirct == op::rnn_direction::bidirectional)
{
// 4 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provided,
// repeat 1 four times. If 2 actv functins are provided,
// assume forward and reverse use the same pair of actv
// functions. For the case of 3 actv functions provided,
// assume the 3rd one is repeated once and used by the
// reverse direction.
// This may need change later
if(vec_names.size() == 1)
{
vec_names.insert(vec_names.end(), 3, vec_names.at(0));
}
else if(vec_names.size() == 2)
{
// repeat the activation functions
vec_names.push_back(vec_names.at(0));
vec_names.push_back(vec_names.at(1));
}
else if(vec_names.size() == 3)
{
vec_names.push_back(vec_names.at(2));
}
}
else
{
if(vec_names.size() == 1)
{
vec_names.push_back(vec_names.at(0));
}
}
for_each(vec_names.begin(), vec_names.end(), [&](auto& name) {
if(map_actv_funcs.count(name) == 0)
{
MIGRAPHX_THROW("GRU: activation function " + std::string(name) + " not supported");
}
});
std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
return map_actv_funcs[name];
});
float clip = 0.0;
if(contains(attributes, "clip"))
{
clip = parse_value(attributes.at("clip")).at<float>();
}
int linear_before_reset = 0;
if(contains(attributes, "linear_before_reset"))
{
linear_before_reset = parse_value(attributes.at("linear_before_reset")).at<int>();
}
// append undefined opeator to make 6 arguments
if(args.size() < 6)
{
auto ins = prog.add_instruction(op::undefined{});
args.insert(args.end(), 6 - args.size(), ins);
}
// first output for concatenation of hidden states
auto hidden_states = prog.add_instruction(
op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
std::move(args));
// second output for last gru output
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
return {hidden_states, last_output};
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
onnx::ModelProto model; onnx::ModelProto model;
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -491,7 +491,7 @@ TEST_CASE(rnn_test) ...@@ -491,7 +491,7 @@ TEST_CASE(rnn_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -523,7 +523,7 @@ TEST_CASE(rnn_test) ...@@ -523,7 +523,7 @@ TEST_CASE(rnn_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -555,7 +555,7 @@ TEST_CASE(rnn_test) ...@@ -555,7 +555,7 @@ TEST_CASE(rnn_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -583,7 +583,7 @@ TEST_CASE(rnn_test) ...@@ -583,7 +583,7 @@ TEST_CASE(rnn_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -615,7 +615,7 @@ TEST_CASE(rnn_test) ...@@ -615,7 +615,7 @@ TEST_CASE(rnn_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -630,6 +630,429 @@ TEST_CASE(rnn_test) ...@@ -630,6 +630,429 @@ TEST_CASE(rnn_test)
} }
} }
TEST_CASE(gru_test)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// forward
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip,
1},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward.onnx");
EXPECT(p == prog);
}
// reverse
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse.onnx");
EXPECT(p == prog);
}
// bidirectional
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::relu{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(gru_test_args)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// 3 arguments
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_3arg.onnx");
EXPECT(p == prog);
}
// 4 arguments
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_4arg.onnx");
EXPECT(p == prog);
}
// 5 arguments
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_5arg.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(gru_test_actv_funcs)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// bidirection, 0 actv function
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx");
EXPECT(p == prog);
}
// bidirection, 1 actv function
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{
hs, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx");
EXPECT(p == prog);
}
// bidirection, 2 actv functions
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_2.onnx");
EXPECT(p == prog);
}
// bidirection, 3 actv functions
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx");
EXPECT(p == prog);
}
// forward, 0 actv function
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward_0.onnx");
EXPECT(p == prog);
}
// reverse, 1 actv function
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{
hs, {migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(flatten_test) TEST_CASE(flatten_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment