Unverified Commit 6972ad26 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge pull request #170 from ROCmSoftwarePlatform/gru_operator

Gru operator
parents fb8fda8f bce629f1
......@@ -1164,20 +1164,20 @@ struct outline
argument compute(const shape&, const std::vector<argument>&) const { return {s, nullptr}; }
};
struct rnn
// indicate rnn computation direction
enum class rnn_direction
{
forward,
reverse,
bidirectional,
};
enum rnn_direction_t
{
forward,
reverse,
bidirectional,
};
struct rnn
{
std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{tanh{}, tanh{}};
rnn_direction_t direction = forward;
float clip = 0.0f;
rnn_direction direction = rnn_direction::forward;
float clip = 0.0f;
std::string name() const { return "rnn"; }
shape compute_shape(std::vector<shape> inputs) const
......@@ -1190,7 +1190,7 @@ struct rnn
}
std::size_t num_directions = 1;
if(direction == bidirectional)
if(direction == rnn_direction::bidirectional)
{
num_directions = 2;
}
......@@ -1222,6 +1222,43 @@ struct rnn_last_output
}
};
struct gru
{
std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{sigmoid{}, tanh{}};
rnn_direction direction = rnn_direction::forward;
float clip = 0.0f;
int linear_before_reset = 0;
std::string name() const { return "gru"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto in_dims = inputs[0].lens();
auto hidden_dims = inputs[2].lens();
if(hidden_size != hidden_dims[2])
{
MIGRAPHX_THROW("GRU: hidden size mismatch in attribute and input");
}
std::size_t num_directions = 1;
if(direction == rnn_direction::bidirectional)
{
num_directions = 2;
}
if(num_directions != hidden_dims[0])
{
MIGRAPHX_THROW("GRU: num_direction does not match the direction attribute");
}
std::vector<std::size_t> out_dims(in_dims);
out_dims.insert(out_dims.begin() + 1, num_directions);
out_dims.back() = hidden_size;
return {inputs[0].type(), out_dims};
}
};
struct undefined
{
std::string name() const { return "undefined"; }
......
......@@ -21,17 +21,30 @@ struct rewrite_rnn
void apply(program& prog) const;
private:
std::vector<instruction_ref> rnn_cell(bool is_forward,
// for vanilla rnn operators
void apply_vanilla_rnn(program& prog, instruction_ref ins) const;
std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref w,
instruction_ref r,
instruction_ref bias,
instruction_ref ih,
operation& actv_func) const;
std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const;
// for gru operators
void apply_gru(program& prog, instruction_ref ins) const;
std::vector<instruction_ref> gru_cell(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref w,
instruction_ref r,
instruction_ref bias,
instruction_ref ih,
operation& actv_func) const;
std::vector<operation> compute_actv_funcs(instruction_ref ins) const;
std::vector<instruction_ref> inputs,
int linear_before_reset,
const operation& actv_func1,
const operation& actv_func2) const;
std::vector<operation> gru_actv_funcs(instruction_ref ins) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -88,6 +88,7 @@ struct onnx_parser
add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("Pad", &onnx_parser::parse_pad);
// init the activation function map
......@@ -715,8 +716,7 @@ struct onnx_parser
parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
migraphx::shape input_shape = args[0]->get_shape();
migraphx::shape w_shape = args[1]->get_shape();
std::size_t hidden_size = w_shape.lens()[1];
std::size_t hidden_size = args[1]->get_shape().lens()[1];
if(contains(attributes, "hidden_size"))
{
......@@ -734,14 +734,14 @@ struct onnx_parser
direction = attributes.at("direction").s();
}
op::rnn::rnn_direction_t dirct = op::rnn::forward;
op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional")
{
dirct = op::rnn::bidirectional;
dirct = op::rnn_direction::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::rnn::reverse;
dirct = op::rnn_direction::reverse;
}
std::vector<std::string> vec_names{"tanh"};
......@@ -763,7 +763,7 @@ struct onnx_parser
// one is for forward, and the other is for reverse.
// if only one actv function is provided, we use it in both
// forward and reverse direction
if(dirct == op::rnn::bidirectional)
if(dirct == op::rnn_direction::bidirectional)
{
if(vec_names.size() == 1)
{
......@@ -801,6 +801,125 @@ struct onnx_parser
return {hidden_states, last_output};
}
std::vector<instruction_ref>
parse_gru(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2];
if(contains(attributes, "hidden_size"))
{
std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att)
{
MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
}
}
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(attributes, "direction"))
{
direction = attributes.at("direction").s();
}
op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional")
{
dirct = op::rnn_direction::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::rnn_direction::reverse;
}
std::vector<std::string> vec_names = {"sigmoid", "tanh"};
if(contains(attributes, "activations"))
{
auto names = attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::transform(
names.begin(), names.end(), vec_names.begin(), [](auto& str) { return str; });
}
// need 4 activation functions
if(dirct == op::rnn_direction::bidirectional)
{
// 4 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provided,
// repeat 1 four times. If 2 actv functins are provided,
// assume forward and reverse use the same pair of actv
// functions. For the case of 3 actv functions provided,
// assume the 3rd one is repeated once and used by the
// reverse direction.
// This may need change later
if(vec_names.size() == 1)
{
vec_names.insert(vec_names.end(), 3, vec_names.at(0));
}
else if(vec_names.size() == 2)
{
// repeat the activation functions
vec_names.push_back(vec_names.at(0));
vec_names.push_back(vec_names.at(1));
}
else if(vec_names.size() == 3)
{
vec_names.push_back(vec_names.at(2));
}
}
else
{
if(vec_names.size() == 1)
{
vec_names.push_back(vec_names.at(0));
}
}
for_each(vec_names.begin(), vec_names.end(), [&](auto& name) {
if(map_actv_funcs.count(name) == 0)
{
MIGRAPHX_THROW("GRU: activation function " + std::string(name) + " not supported");
}
});
std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
return map_actv_funcs[name];
});
float clip = 0.0;
if(contains(attributes, "clip"))
{
clip = parse_value(attributes.at("clip")).at<float>();
}
int linear_before_reset = 0;
if(contains(attributes, "linear_before_reset"))
{
linear_before_reset = parse_value(attributes.at("linear_before_reset")).at<int>();
}
// append undefined opeator to make 6 arguments
if(args.size() < 6)
{
auto ins = prog.add_instruction(op::undefined{});
args.insert(args.end(), 6 - args.size(), ins);
}
// first output for concatenation of hidden states
auto hidden_states = prog.add_instruction(
op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
std::move(args));
// second output for last gru output
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
return {hidden_states, last_output};
}
void parse_from(std::istream& is)
{
onnx::ModelProto model;
......
......@@ -10,63 +10,75 @@ inline namespace MIGRAPHX_INLINE_NS {
void rewrite_rnn::apply(program& prog) const
{
std::unordered_map<instruction_ref, instruction_ref> map_last_output;
for(auto ins : iterator_for(prog))
{
// rewrite rnn operator
if(ins->name() == "rnn")
{
// could be 3 to 6 inputs, but the parse_rnn function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have only 3 arguments
// when writing their program.
auto args = ins->inputs();
shape seq_shape = args[0]->get_shape();
std::size_t hidden_size = args[1]->get_shape().lens()[1];
std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type();
migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<float> data(ih_shape.elements(), 0);
auto actv_funcs = compute_actv_funcs(ins);
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn::rnn_direction_t dicrt = rnn_op.direction;
if(dicrt == op::rnn::bidirectional)
{
// input weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
// hidden state weight matrix
auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
// process bias
instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
{
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
}
// process intial hidden state, it could be the 6th argument
// or the 5th one (if the sequence len argument is ignored)
instruction_ref ih_forward{};
instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
{
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
}
else
{
ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret_forward = rnn_cell(true,
apply_vanilla_rnn(prog, ins);
}
if(ins->name() == "gru")
{
apply_gru(prog, ins);
}
}
}
void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
{
assert(ins->name() == "rnn");
// could be 3 to 6 inputs, but the parse_rnn function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have num of arguments
// when writing their program.
auto args = ins->inputs();
shape seq_shape = args[0]->get_shape();
std::size_t hidden_size = args[1]->get_shape().lens()[1];
std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type();
migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<float> data(ih_shape.elements(), 0);
auto actv_funcs = vanilla_rnn_actv_funcs(ins);
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn_direction dicrt = rnn_op.direction;
instruction_ref last_output{};
if(dicrt == op::rnn_direction::bidirectional)
{
// input weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
// hidden state weight matrix
auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
// process bias
instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
}
// process intial hidden state, it could be the 6th argument
// or the 5th one (if the sequence len argument is ignored)
instruction_ref ih_forward{};
instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined")
{
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
}
else
{
ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret_forward = vanilla_rnn_cell(true,
prog,
ins,
args[0],
......@@ -75,7 +87,7 @@ void rewrite_rnn::apply(program& prog) const
bias_forward,
ih_forward,
actv_funcs.at(0));
auto ret_reverse = rnn_cell(false,
auto ret_reverse = vanilla_rnn_cell(false,
prog,
ins,
args[0],
......@@ -85,108 +97,98 @@ void rewrite_rnn::apply(program& prog) const
ih_reverse,
actv_funcs.at(1));
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction
// sequence len is 1
instruction_ref hidden_output{};
if(ret_forward[0] == prog.end())
{
hidden_output = prog.replace_instruction(
ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
{
ret_forward[0] =
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
hidden_output = prog.replace_instruction(
ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
map_last_output[hidden_output] = last_output;
}
else
{
bool is_forward = (dicrt == op::rnn::forward);
// input weight matrix
auto w = args[1];
// hidden state weight matrix
auto r = args[2];
// process bias and initial hidden state
instruction_ref bias = prog.end();
if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
{
bias = args[3];
}
// process intial hidden state
instruction_ref ih;
if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
{
ih = args[5];
}
else
{
ih = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret =
rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// following logic is to ensure the last instruction is a
// concat instruction
// sequence len is 1
instruction_ref hidden_output{};
if(ret[0] == prog.end())
{
hidden_output = prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
hidden_output =
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
map_last_output[hidden_output] = last_output;
}
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction
// sequence len is 1
if(ret_forward[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
{
ret_forward[0] =
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
}
else
{
bool is_forward = (dicrt == op::rnn_direction::forward);
// input weight matrix
auto w = args[1];
// hidden state weight matrix
auto r = args[2];
// rewrite the rnn_last_output operator that right after the rnn
// operator. Intuitively, we can do a slice on the input to get
// the last output, but it is already existed in the rnn operator,
// so we can just use it as the output here
if(ins->name() == "rnn_last_output")
// process bias and initial hidden state
instruction_ref bias = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
auto inputs = ins->inputs();
assert(inputs.size() == 1);
auto arg = inputs[0];
if(map_last_output.count(arg) == 0)
{
MIGRAPHX_THROW("RNN_LAST_OUTPUT: no related rnn operator as its input");
}
bias = args[3];
}
// process intial hidden state
instruction_ref ih;
if(args.size() == 6 && args[5]->name() != "undefined")
{
ih = args[5];
}
else
{
ih = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret =
vanilla_rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// following logic is to ensure the last instruction is a
// concat instruction
// sequence len is 1
if(ret[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
}
prog.replace_instruction(ins, map_last_output[arg]);
// search its output to find if there are rnn_last_output operator
// while loop to handle case of multiple rnn_last_output operators
auto last_output_it = ins->outputs().begin();
while(last_output_it != ins->outputs().end())
{
last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_output";
});
if(last_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_output_it, last_output);
last_output_it++;
}
}
}
std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref w,
instruction_ref r,
instruction_ref bias,
instruction_ref ih,
operation& actv_func) const
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref w,
instruction_ref r,
instruction_ref bias,
instruction_ref ih,
operation& actv_func) const
{
// squeeze and transpose w
std::vector<int64_t> perm{1, 0};
......@@ -266,13 +268,14 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
return {hidden_out, last_out};
}
std::vector<operation> rewrite_rnn::compute_actv_funcs(instruction_ref ins) const
std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) const
{
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
// before rewrite the rnn operator, need to ensure
// we have 2 actv funcs. If less than 2, use the
// algorithm in parse_rnn to make 2 actv functions
if(rnn_op.direction == op::rnn::bidirectional)
// could be 3 to 6 inputs, but the parse_gru function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have any num of arguments
// when writing their program.
if(rnn_op.direction == op::rnn_direction::bidirectional)
{
if(rnn_op.actv_funcs.empty())
{
......@@ -302,5 +305,364 @@ std::vector<operation> rewrite_rnn::compute_actv_funcs(instruction_ref ins) cons
}
}
void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
{
assert(ins->name() == "gru");
const auto actv_funcs = gru_actv_funcs(ins);
// could be 3 to 6 inputs, but the parse_gru function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have num of arguments
// when writing their program.
auto args = ins->inputs();
shape seq_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2];
std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type();
migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<float> data(ih_shape.elements(), 0.0);
auto gru_op = any_cast<op::gru>(ins->get_operator());
op::rnn_direction dicrt = gru_op.direction;
instruction_ref last_output{};
if(dicrt == op::rnn_direction::bidirectional)
{
// w weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
// r weight matrix
auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
// bias
instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
}
// intial hidden state
instruction_ref ih_forward{};
instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined")
{
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
}
else
{
ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret_forward = gru_cell(true,
prog,
ins,
{args[0], w_forward, r_forward, bias_forward, ih_forward},
gru_op.linear_before_reset,
actv_funcs.at(0),
actv_funcs.at(1));
auto ret_reverse = gru_cell(false,
prog,
ins,
{args[0], w_reverse, r_reverse, bias_reverse, ih_reverse},
gru_op.linear_before_reset,
actv_funcs.at(2),
actv_funcs.at(3));
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// The following logic is to ensure the last instruction rewritten
// from gru operator is a concat
if(ret_forward[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
{
ret_forward[0] =
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
}
else
{
bool is_forward = (dicrt == op::rnn_direction::forward);
// weight matrix
auto w = args[1];
auto r = args[2];
// bias
instruction_ref bias = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias = args[3];
}
// intial hidden state
instruction_ref ih{};
if(args.size() == 6 && args[5]->name() != "undefined")
{
ih = args[5];
}
else
{
ih = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret = gru_cell(is_forward,
prog,
ins,
{args[0], w, r, bias, ih},
gru_op.linear_before_reset,
actv_funcs.at(0),
actv_funcs.at(1));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
if(ret[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
}
// replace the corresponding rnn_last_output instruction
// with the last_output, if rnn_last_output exists
// while loop to handle case of multiple rnn_last_output operators
auto last_output_it = ins->outputs().begin();
while(last_output_it != ins->outputs().end())
{
last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "rnn_last_output";
});
if(last_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_output_it, last_output);
last_output_it++;
}
}
}
std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
program& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
const operation& actv_func1,
const operation& actv_func2) const
{
assert(inputs.size() == 5);
auto seq = inputs.at(0);
auto w = inputs.at(1);
auto r = inputs.at(2);
auto bias = inputs.at(3);
auto ih = inputs.at(4);
instruction_ref hidden_states = prog.end();
instruction_ref last_output{};
migraphx::shape seq_shape = seq->get_shape();
migraphx::shape r_shape = r->get_shape();
long seq_len = static_cast<long>(seq_shape.lens()[0]);
long hs = static_cast<long>(r_shape.lens()[2]);
migraphx::shape s(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
std::vector<int> data(s.elements(), 1);
auto l1 = prog.add_literal(migraphx::literal{s, data});
// weight matrix
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto wz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw);
auto tran_wz = prog.insert_instruction(ins, op::transpose{perm}, wz);
auto wr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw);
auto tran_wr = prog.insert_instruction(ins, op::transpose{perm}, wr);
auto wh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw);
auto tran_wh = prog.insert_instruction(ins, op::transpose{perm}, wh);
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto rz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr);
auto tran_rz = prog.insert_instruction(ins, op::transpose{perm}, rz);
auto rr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr);
auto tran_rr = prog.insert_instruction(ins, op::transpose{perm}, rr);
auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
auto tran_rh = prog.insert_instruction(ins, op::transpose{perm}, rh);
// initial states
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
// bias
instruction_ref brcst_bz{};
instruction_ref brcst_br{};
instruction_ref brcst_wbh{};
instruction_ref brcst_rbh{};
instruction_ref brcst_bh{};
if(bias != prog.end())
{
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
brcst_wbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wbh);
auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
brcst_rbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rbh);
auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz);
brcst_bz = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bz);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
brcst_br = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, br);
auto bh = prog.insert_instruction(ins, op::add{}, wbh, rbh);
brcst_bh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bh);
}
for(long i = 0; i < seq_len; i++)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz);
auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz);
auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz);
if(bias != prog.end())
{
xht_z = prog.insert_instruction(ins, op::add{}, xht_z, brcst_bz);
}
auto zt = prog.insert_instruction(ins, actv_func1, xht_z);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr);
auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr);
auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
if(bias != prog.end())
{
xht_r = prog.insert_instruction(ins, op::add{}, xht_r, brcst_br);
}
auto rt = prog.insert_instruction(ins, actv_func1, xht_r);
instruction_ref xht_h;
if(linear_before_reset == 0)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end())
{
xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_bh);
}
}
else
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh);
if(bias != prog.end())
{
ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brcst_rbh);
}
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end())
{
xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_wbh);
}
}
auto ht = prog.insert_instruction(ins, actv_func2, xht_h);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt);
auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht);
auto zt_ht1 = prog.insert_instruction(ins, op::mul{}, zt, sih);
sih = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1);
last_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
if(i < seq_len - 1)
{
if(is_forward)
{
hidden_states =
(seq_index == 0)
? last_output
: prog.insert_instruction(ins, op::concat{0}, hidden_states, last_output);
}
else
{
hidden_states =
(seq_index == seq_len - 1)
? last_output
: prog.insert_instruction(ins, op::concat{0}, last_output, hidden_states);
}
}
}
return {hidden_states, last_output};
}
std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
{
auto gru_op = any_cast<op::gru>(ins->get_operator());
// before rewrite the gru operator, need to ensure
// we have 4 actv funcs, even though a user does not
// specifiy any actv func. If less than 4, use the
// algorithm in parse_gru to make 4 actv functions
if(gru_op.direction == op::rnn_direction::bidirectional)
{
if(gru_op.actv_funcs.empty())
return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}};
else if(gru_op.actv_funcs.size() == 1)
return {gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(0)};
else if(gru_op.actv_funcs.size() == 2)
return {gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1),
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1)};
else if(gru_op.actv_funcs.size() == 3)
return {gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1),
gru_op.actv_funcs.at(2),
gru_op.actv_funcs.at(0)};
else
return gru_op.actv_funcs;
}
else
{
if(gru_op.actv_funcs.empty())
return {op::sigmoid{}, op::tanh{}};
else if(gru_op.actv_funcs.size() == 1)
return {gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0)};
else
return gru_op.actv_funcs;
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -1361,521 +1361,6 @@ TEST_CASE(min_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(rnn_forward)
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> wf_data{0.4691,
0.3185,
-0.2227,
0.4423,
-0.0609,
-0.2803,
0.1744,
0.3146,
0.4049,
-0.3973,
-0.0890,
-0.1636};
std::vector<float> rf_data{-0.0456,
0.1061,
0.1574,
-0.4928,
-0.4300,
-0.1909,
-0.0225,
-0.2668,
0.1840,
-0.4453,
-0.4896,
0.1302,
-0.0929,
0.3545,
-0.4981,
0.0616};
std::vector<float> biasf_data{
-0.4938, 0.4355, -0.3186, 0.2094, 0.1037, -0.1071, 0.4504, -0.3990};
std::vector<float> input(seq_len * batch_size * input_size, 0);
input[0] = input[1] = 1.0;
float clip = 0.0f;
{
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
auto seq = p.add_literal(migraphx::literal{in_shape, input});
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
auto w = p.add_literal(migraphx::literal{w_shape, wf_data});
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto r = p.add_literal(migraphx::literal{r_shape, rf_data});
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, biasf_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward,
clip},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{0.37780784,
0.61055139,
0.55168478,
-0.5888475,
-0.37144644,
0.31708236,
0.13104209,
-0.18736027,
0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
{
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
auto seq = p.add_literal(migraphx::literal{in_shape, input});
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
auto w = p.add_literal(migraphx::literal{w_shape, wf_data});
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto r = p.add_literal(migraphx::literal{r_shape, rf_data});
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, biasf_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn::forward, clip},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({});
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
}
TEST_CASE(rnn_reverse)
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> wr_data{-0.0296,
-0.1341,
0.1761,
-0.2325,
-0.0717,
0.1852,
0.2720,
0.1471,
-0.1097,
0.3363,
-0.0587,
-0.2302};
std::vector<float> rr_data{0.2528,
-0.2333,
0.3973,
0.1593,
-0.0388,
0.1702,
0.3829,
-0.0712,
-0.1668,
0.3074,
-0.2854,
0.4049,
-0.3737,
-0.1051,
0.4482,
-0.2841};
std::vector<float> biasr_data{-0.3188, 0.1341, -0.4446, 0.1389, 0.3117, 0.3664, 0.2352, 0.2552};
std::vector<float> input(seq_len * batch_size * input_size, 0);
input[0] = input[1] = 1.0;
float clip = 0.0f;
{
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
auto seq = p.add_literal(migraphx::literal{in_shape, input});
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
auto w = p.add_literal(migraphx::literal{w_shape, wr_data});
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto r = p.add_literal(migraphx::literal{r_shape, rr_data});
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, biasr_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse,
clip},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.29385301,
0.16796815,
0.51075965,
0.40258689,
-0.13818839,
0.44124447,
0.14365635,
0.14803654,
-0.0070999,
0.46251031,
-0.20639211,
0.37488942,
-0.0070999,
0.46251031,
-0.20639211,
0.37488942};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
{
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
auto seq = p.add_literal(migraphx::literal{in_shape, input});
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
auto w = p.add_literal(migraphx::literal{w_shape, wr_data});
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto r = p.add_literal(migraphx::literal{r_shape, rr_data});
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, biasr_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn::reverse, clip},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({});
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{-0.29385301,
0.16796815,
0.51075965,
0.40258689,
-0.13818839,
0.44124447,
0.14365635,
0.14803654};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
}
TEST_CASE(rnn_bidirectional)
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
std::vector<float> wf_data{0.4691,
0.3185,
-0.2227,
0.4423,
-0.0609,
-0.2803,
0.1744,
0.3146,
0.4049,
-0.3973,
-0.0890,
-0.1636};
std::vector<float> wr_data{-0.0296,
-0.1341,
0.1761,
-0.2325,
-0.0717,
0.1852,
0.2720,
0.1471,
-0.1097,
0.3363,
-0.0587,
-0.2302};
std::vector<float> rf_data{-0.0456,
0.1061,
0.1574,
-0.4928,
-0.4300,
-0.1909,
-0.0225,
-0.2668,
0.1840,
-0.4453,
-0.4896,
0.1302,
-0.0929,
0.3545,
-0.4981,
0.0616};
std::vector<float> rr_data{0.2528,
-0.2333,
0.3973,
0.1593,
-0.0388,
0.1702,
0.3829,
-0.0712,
-0.1668,
0.3074,
-0.2854,
0.4049,
-0.3737,
-0.1051,
0.4482,
-0.2841};
std::vector<float> biasf_data{
-0.4938, 0.4355, -0.3186, 0.2094, 0.1037, -0.1071, 0.4504, -0.3990};
std::vector<float> biasr_data{-0.3188, 0.1341, -0.4446, 0.1389, 0.3117, 0.3664, 0.2352, 0.2552};
std::vector<float> input(seq_len * batch_size * input_size, 0);
input[0] = input[1] = 1.0;
float clip = 0.0f;
{
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
auto seq = p.add_literal(migraphx::literal{in_shape, input});
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w_data = wf_data;
w_data.insert(w_data.end(), wr_data.begin(), wr_data.end());
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r_data = rf_data;
r_data.insert(r_data.end(), rr_data.begin(), rr_data.end());
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias_data = biasf_data;
bias_data.insert(bias_data.end(), biasr_data.begin(), biasr_data.end());
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn::bidirectional, clip},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.37780784, 0.61055139, 0.55168478, -0.5888475, -0.37144644, 0.31708236,
0.13104209, -0.18736027, -0.29385301, 0.16796815, 0.51075965, 0.40258689,
-0.13818839, 0.44124447, 0.14365635, 0.14803654, 0.03445704, 0.19167931,
-0.3946827, -0.30889652, -0.22276389, 0.44193283, -0.16477929, -0.11893477,
-0.0070999, 0.46251031, -0.20639211, 0.37488942, -0.0070999, 0.46251031,
-0.20639211, 0.37488942};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
{
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
auto seq = p.add_literal(migraphx::literal{in_shape, input});
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w_data = wf_data;
w_data.insert(w_data.end(), wr_data.begin(), wr_data.end());
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r_data = rf_data;
r_data.insert(r_data.end(), rr_data.begin(), rr_data.end());
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias_data = biasf_data;
bias_data.insert(bias_data.end(), biasr_data.begin(), biasr_data.end());
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::bidirectional, clip},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({});
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477,
-0.29385301,
0.16796815,
0.51075965,
0.40258689,
-0.13818839,
0.44124447,
0.14365635,
0.14803654};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
{
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w_data = wf_data;
w_data.insert(w_data.end(), wr_data.begin(), wr_data.end());
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r_data = rf_data;
r_data.insert(r_data.end(), rr_data.begin(), rr_data.end());
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias_data = biasf_data;
bias_data.insert(bias_data.end(), biasr_data.begin(), biasr_data.end());
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional,
clip},
seq,
w,
r,
bias);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({});
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477,
-0.29385301,
0.16796815,
0.51075965,
0.40258689,
-0.13818839,
0.44124447,
0.14365635,
0.14803654};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
}
TEST_CASE(pad_test)
{
migraphx::program p;
......
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp>
#include "test.hpp"
TEST_CASE(rnn_forward)
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{0.4691,
0.3185,
-0.2227,
0.4423,
-0.0609,
-0.2803,
0.1744,
0.3146,
0.4049,
-0.3973,
-0.0890,
-0.1636};
std::vector<float> r_data{-0.0456,
0.1061,
0.1574,
-0.4928,
-0.4300,
-0.1909,
-0.0225,
-0.2668,
0.1840,
-0.4453,
-0.4896,
0.1302,
-0.0929,
0.3545,
-0.4981,
0.0616};
std::vector<float> bias_data{
-0.4938, 0.4355, -0.3186, 0.2094, 0.1037, -0.1071, 0.4504, -0.3990};
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
std::vector<float> input(seq_len * batch_size * input_size, 0);
input[0] = input[1] = 1.0;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
float clip = 0.0f;
// concatenation of hidden states as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{0.37780784,
0.61055139,
0.55168478,
-0.5888475,
-0.37144644,
0.31708236,
0.13104209,
-0.18736027,
0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// rnn last output as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({});
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
// multiple rnn_last_output operators
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({});
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
// 3 args
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto out_hs = p.add_instruction(
migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({});
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{
0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
// seq_len = 1
{
seq_len = 1;
std::vector<float> input_1(seq_len * batch_size * input_size, 0);
input_1[0] = input_1[1] = 1.0;
migraphx::shape in_shape_1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape_1, input_1});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{0.37780784,
0.61055139,
0.55168478,
-0.5888475,
-0.37144644,
0.31708236,
0.13104209,
-0.18736027};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
}
TEST_CASE(rnn_reverse)
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{-0.0296,
-0.1341,
0.1761,
-0.2325,
-0.0717,
0.1852,
0.2720,
0.1471,
-0.1097,
0.3363,
-0.0587,
-0.2302};
std::vector<float> r_data{0.2528,
-0.2333,
0.3973,
0.1593,
-0.0388,
0.1702,
0.3829,
-0.0712,
-0.1668,
0.3074,
-0.2854,
0.4049,
-0.3737,
-0.1051,
0.4482,
-0.2841};
std::vector<float> bias_data{-0.3188, 0.1341, -0.4446, 0.1389, 0.3117, 0.3664, 0.2352, 0.2552};
std::vector<float> input(seq_len * batch_size * input_size, 0);
input[0] = input[1] = 1.0;
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
// concatenation of hidden states as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.29385301,
0.16796815,
0.51075965,
0.40258689,
-0.13818839,
0.44124447,
0.14365635,
0.14803654,
-0.0070999,
0.46251031,
-0.20639211,
0.37488942,
-0.0070999,
0.46251031,
-0.20639211,
0.37488942};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// rnn last output as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({});
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{-0.29385301,
0.16796815,
0.51075965,
0.40258689,
-0.13818839,
0.44124447,
0.14365635,
0.14803654};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
}
TEST_CASE(rnn_bidirectional)
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
std::vector<float> w_data{0.4691, 0.3185, -0.2227, 0.4423, -0.0609, -0.2803,
0.1744, 0.3146, 0.4049, -0.3973, -0.0890, -0.1636,
-0.0296, -0.1341, 0.1761, -0.2325, -0.0717, 0.1852,
0.2720, 0.1471, -0.1097, 0.3363, -0.0587, -0.2302};
std::vector<float> r_data{-0.0456, 0.1061, 0.1574, -0.4928, -0.4300, -0.1909, -0.0225,
-0.2668, 0.1840, -0.4453, -0.4896, 0.1302, -0.0929, 0.3545,
-0.4981, 0.0616, 0.2528, -0.2333, 0.3973, 0.1593, -0.0388,
0.1702, 0.3829, -0.0712, -0.1668, 0.3074, -0.2854, 0.4049,
-0.3737, -0.1051, 0.4482, -0.2841};
std::vector<float> bias_data{-0.4938,
0.4355,
-0.3186,
0.2094,
0.1037,
-0.1071,
0.4504,
-0.3990,
-0.3188,
0.1341,
-0.4446,
0.1389,
0.3117,
0.3664,
0.2352,
0.2552};
std::vector<float> input(seq_len * batch_size * input_size, 0);
input[0] = input[1] = 1.0;
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
float clip = 0.0f;
// concatenation of hidden state for program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.37780784, 0.61055139, 0.55168478, -0.5888475, -0.37144644, 0.31708236,
0.13104209, -0.18736027, -0.29385301, 0.16796815, 0.51075965, 0.40258689,
-0.13818839, 0.44124447, 0.14365635, 0.14803654, 0.03445704, 0.19167931,
-0.3946827, -0.30889652, -0.22276389, 0.44193283, -0.16477929, -0.11893477,
-0.0070999, 0.46251031, -0.20639211, 0.37488942, -0.0070999, 0.46251031,
-0.20639211, 0.37488942};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// last rnn output for program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({});
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477,
-0.29385301,
0.16796815,
0.51075965,
0.40258689,
-0.13818839,
0.44124447,
0.14365635,
0.14803654};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
// 4 args
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({});
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477,
-0.29385301,
0.16796815,
0.51075965,
0.40258689,
-0.13818839,
0.44124447,
0.14365635,
0.14803654};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
// 3 args
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({});
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{
0.6570473, 0.36392266, 0.45342238, -0.45127486, 0., 0., 0., 0.,
-0.16225325, -0.29515147, 0.39617197, 0.27068236, 0., 0., 0., 0.,
0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0.};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
// concatenation of hidden state for program output
{
seq_len = 1;
std::vector<float> input_1(seq_len * batch_size * input_size, 0);
input_1[0] = input_1[1] = 1.0;
migraphx::shape in_shape_1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape_1, input_1});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{0.37780784,
0.61055139,
0.55168478,
-0.5888475,
-0.37144644,
0.31708236,
0.13104209,
-0.18736027,
-0.16915828,
0.1938169,
0.20667936,
0.58609703,
-0.0070999,
0.46251031,
-0.20639211,
0.37488942};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
}
TEST_CASE(gru_forward)
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}};
std::vector<float> w_data{
0.3485, -0.0378, -0.1782, 0.1416, -0.3096, -0.2212, -0.3883, 0.1983, -0.2418,
0.1480, -0.3255, 0.1359, -0.3551, -0.3605, -0.3482, -0.1424, -0.0495, -0.1640,
-0.1979, -0.2577, -0.4097, -0.1211, -0.0412, 0.1801, 0.1721, -0.4327, -0.0498,
0.2628, -0.1573, -0.1577, 0.2759, -0.2023, -0.1185, -0.2136, 0.1294, -0.2331,
0.0701, 0.4316, 0.0480, 0.0247, -0.0166, -0.2729, 0.1712, -0.3984, -0.3905};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}};
std::vector<float> r_data{
0.2848, -0.2851, -0.3466, -0.1718, -0.1492, -0.0082, 0.2452, -0.0401, 0.3399, 0.2529,
-0.0953, -0.0903, -0.1518, -0.1373, 0.3848, -0.0130, -0.4339, 0.0406, -0.1926, -0.1131,
0.4285, -0.0013, 0.2243, 0.2752, 0.1776, -0.1720, 0.0822, -0.0295, 0.1062, -0.2721,
-0.2736, -0.1826, 0.3541, -0.4259, 0.2188, 0.0706, 0.3650, 0.3947, 0.2522, 0.2179,
-0.0744, 0.2122, -0.4346, 0.2760, 0.4076, 0.1183, -0.1500, -0.1704, 0.3090, -0.0706,
-0.2442, 0.3021, 0.1680, 0.0783, -0.3754, -0.3469, -0.2972, -0.0170, 0.4143, 0.3801,
0.3852, -0.1170, -0.2937, 0.2979, -0.1357, 0.4257, 0.3884, -0.2916, 0.1071, 0.0934,
0.3645, -0.4310, -0.3480, 0.0702, -0.1558};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
std::vector<float> bias_data{
0.0560, 0.0310, -0.1669, -0.0781, 0.1793, -0.1758, 0.3173, -0.1650, -0.3732, 0.2946,
-0.0912, 0.3118, 0.1391, 0.2755, 0.2695, -0.1059, -0.2357, 0.3629, -0.2534, -0.0494,
0.0556, 0.0881, -0.2592, -0.2213, 0.2310, -0.4044, 0.1801, 0.1438, 0.3108, -0.3607};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input{-0.8432,
-0.9887,
1.3041,
-2.6430,
-0.3306,
-0.8504,
-0.3933,
0.5151,
-0.2951,
0.0093,
-1.1948,
-0.1239,
0.0373,
1.3211,
0.7854,
-0.4838,
-1.0536,
-0.2529};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
std::vector<float> ih_data{
-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212};
float clip = 0.0f;
// concatenation of hidden states for output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
1},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
-0.27298412, 0.42363745, -0.09368783, 0.4823072, -0.02183238, -0.6873896,
0.16144305, 0.31932795, 0.6104771, 0.79759157, -0.31791314, 0.5249062,
0.08800987, 0.46404213, -0.11872687, -0.26210734, 0.34448293, -0.0176422,
0.48523626, 0.60002893, -0.3969709, 0.43360898, 0.35775262, 0.23280787,
-0.52179873, -0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// last output for output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto concat_hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
1},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.3969709,
0.43360898,
0.35775262,
0.23280787,
-0.52179873,
-0.21944991,
0.4535257,
-0.13735442,
0.51757574,
0.50380427};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// two rnn_last_output operators after gru
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto concat_hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
1},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.3969709,
0.43360898,
0.35775262,
0.23280787,
-0.52179873,
-0.21944991,
0.4535257,
-0.13735442,
0.51757574,
0.50380427};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// last output for output, linear_before_reset = 0
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto concat_hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
0},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.53291196,
0.50160867,
0.39010462,
0.39292926,
-0.5960838,
-0.38451535,
0.454239,
-0.10620412,
0.6014447,
0.43445644};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
}
TEST_CASE(gru_forward_args)
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}};
std::vector<float> w_data{
0.3485, -0.0378, -0.1782, 0.1416, -0.3096, -0.2212, -0.3883, 0.1983, -0.2418,
0.1480, -0.3255, 0.1359, -0.3551, -0.3605, -0.3482, -0.1424, -0.0495, -0.1640,
-0.1979, -0.2577, -0.4097, -0.1211, -0.0412, 0.1801, 0.1721, -0.4327, -0.0498,
0.2628, -0.1573, -0.1577, 0.2759, -0.2023, -0.1185, -0.2136, 0.1294, -0.2331,
0.0701, 0.4316, 0.0480, 0.0247, -0.0166, -0.2729, 0.1712, -0.3984, -0.3905};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}};
std::vector<float> r_data{
0.2848, -0.2851, -0.3466, -0.1718, -0.1492, -0.0082, 0.2452, -0.0401, 0.3399, 0.2529,
-0.0953, -0.0903, -0.1518, -0.1373, 0.3848, -0.0130, -0.4339, 0.0406, -0.1926, -0.1131,
0.4285, -0.0013, 0.2243, 0.2752, 0.1776, -0.1720, 0.0822, -0.0295, 0.1062, -0.2721,
-0.2736, -0.1826, 0.3541, -0.4259, 0.2188, 0.0706, 0.3650, 0.3947, 0.2522, 0.2179,
-0.0744, 0.2122, -0.4346, 0.2760, 0.4076, 0.1183, -0.1500, -0.1704, 0.3090, -0.0706,
-0.2442, 0.3021, 0.1680, 0.0783, -0.3754, -0.3469, -0.2972, -0.0170, 0.4143, 0.3801,
0.3852, -0.1170, -0.2937, 0.2979, -0.1357, 0.4257, 0.3884, -0.2916, 0.1071, 0.0934,
0.3645, -0.4310, -0.3480, 0.0702, -0.1558};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
std::vector<float> bias_data{
0.0560, 0.0310, -0.1669, -0.0781, 0.1793, -0.1758, 0.3173, -0.1650, -0.3732, 0.2946,
-0.0912, 0.3118, 0.1391, 0.2755, 0.2695, -0.1059, -0.2357, 0.3629, -0.2534, -0.0494,
0.0556, 0.0881, -0.2592, -0.2213, 0.2310, -0.4044, 0.1801, 0.1438, 0.3108, -0.3607};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input{-0.8432,
-0.9887,
1.3041,
-2.6430,
-0.3306,
-0.8504,
-0.3933,
0.5151,
-0.2951,
0.0093,
-1.1948,
-0.1239,
0.0373,
1.3211,
0.7854,
-0.4838,
-1.0536,
-0.2529};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
std::vector<float> ih_data{
-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212};
float clip = 0.0f;
// 3 args
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
1},
seq,
w,
r);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.114674, -0.129581, -0.218156, -0.140788, -0.114242,
-0.346569, 0.321367, -0.0838253, 0.102097, 0.00232137,
-0.149055, 0.0590743, -0.0533094, -0.0446122, -0.112588,
0.0153261, 0.168883, -0.326836, 0.0843562, 0.160872,
-0.232523, 0.00214573, 0.231693, -0.160475, -0.518952,
0.0467166, 0.12327, -0.374162, 0.137778, 0.251976};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// 4 args (bias is used)
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
1},
seq,
w,
r,
bias);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.273619, 0.0931375, -0.104717, 0.0203752, -0.0797887,
-0.493948, 0.472118, -0.0336318, 0.332706, 0.0182268,
-0.341684, 0.38063, 0.0589275, 0.2644, -0.115737,
-0.152324, 0.442277, -0.201626, 0.408909, 0.12905,
-0.416866, 0.377186, 0.32922, 0.162214, -0.519973,
-0.140072, 0.465076, -0.229563, 0.500164, 0.195166};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// 4 args (ih is used)
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
1},
seq,
w,
r,
und,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.0801064, 0.27025, -0.20704, 0.333579, -0.0452438,
-0.56265, 0.061061, 0.262172, 0.405193, 0.775226,
-0.100683, 0.258729, -0.0187297, 0.215815, -0.108936,
-0.0941018, 0.129665, -0.159421, 0.190636, 0.597412,
-0.197, 0.0885705, 0.269396, -0.0414511, -0.515137,
-0.03075, 0.158326, -0.296488, 0.177983, 0.519498};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
}
TEST_CASE(gru_forward_actv_funcs)
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}};
std::vector<float> w_data{
0.3485, -0.0378, -0.1782, 0.1416, -0.3096, -0.2212, -0.3883, 0.1983, -0.2418,
0.1480, -0.3255, 0.1359, -0.3551, -0.3605, -0.3482, -0.1424, -0.0495, -0.1640,
-0.1979, -0.2577, -0.4097, -0.1211, -0.0412, 0.1801, 0.1721, -0.4327, -0.0498,
0.2628, -0.1573, -0.1577, 0.2759, -0.2023, -0.1185, -0.2136, 0.1294, -0.2331,
0.0701, 0.4316, 0.0480, 0.0247, -0.0166, -0.2729, 0.1712, -0.3984, -0.3905};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}};
std::vector<float> r_data{
0.2848, -0.2851, -0.3466, -0.1718, -0.1492, -0.0082, 0.2452, -0.0401, 0.3399, 0.2529,
-0.0953, -0.0903, -0.1518, -0.1373, 0.3848, -0.0130, -0.4339, 0.0406, -0.1926, -0.1131,
0.4285, -0.0013, 0.2243, 0.2752, 0.1776, -0.1720, 0.0822, -0.0295, 0.1062, -0.2721,
-0.2736, -0.1826, 0.3541, -0.4259, 0.2188, 0.0706, 0.3650, 0.3947, 0.2522, 0.2179,
-0.0744, 0.2122, -0.4346, 0.2760, 0.4076, 0.1183, -0.1500, -0.1704, 0.3090, -0.0706,
-0.2442, 0.3021, 0.1680, 0.0783, -0.3754, -0.3469, -0.2972, -0.0170, 0.4143, 0.3801,
0.3852, -0.1170, -0.2937, 0.2979, -0.1357, 0.4257, 0.3884, -0.2916, 0.1071, 0.0934,
0.3645, -0.4310, -0.3480, 0.0702, -0.1558};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
std::vector<float> bias_data{
0.0560, 0.0310, -0.1669, -0.0781, 0.1793, -0.1758, 0.3173, -0.1650, -0.3732, 0.2946,
-0.0912, 0.3118, 0.1391, 0.2755, 0.2695, -0.1059, -0.2357, 0.3629, -0.2534, -0.0494,
0.0556, 0.0881, -0.2592, -0.2213, 0.2310, -0.4044, 0.1801, 0.1438, 0.3108, -0.3607};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input{-0.8432,
-0.9887,
1.3041,
-2.6430,
-0.3306,
-0.8504,
-0.3933,
0.5151,
-0.2951,
0.0093,
-1.1948,
-0.1239,
0.0373,
1.3211,
0.7854,
-0.4838,
-1.0536,
-0.2529};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
std::vector<float> ih_data{
-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212};
float clip = 0.0f;
// no activation function specified, so default is used.
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto concat_hs = p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::forward, clip, 1},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.3969709,
0.43360898,
0.35775262,
0.23280787,
-0.52179873,
-0.21944991,
0.4535257,
-0.13735442,
0.51757574,
0.50380427};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// 1 activation function (sigmoid) specified
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip,
1},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{0.26905832, 0.5669211, 0.20464146, 0.67195725, 0.24752215,
0.11411376, 0.12353572, 0.4245067, 0.73908687, 0.8644615,
0.34754312, 0.61424744, 0.36769435, 0.6499579, 0.3168031,
0.3296533, 0.3055136, 0.42514813, 0.6851256, 0.7967266,
0.35652235, 0.6033026, 0.52634895, 0.5815402, 0.3001663,
0.39814138, 0.4354002, 0.4310627, 0.6708563, 0.7509278};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// 1 activation function (tanh) specified
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto concat_hs = p.add_instruction(
migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, 1},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.49333298,
-0.06104589,
0.5629142,
-0.97955984,
-0.9314696,
-0.03033514,
0.5280315,
-0.27354342,
0.65615714,
0.53612584};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// seq length of 1
{
migraphx::program p;
seq_len = 1;
migraphx::shape in_shape_one{migraphx::shape::float_type,
{seq_len, batch_size, input_size}};
std::vector<float> input_one{-0.8432, -0.9887, 1.3041, -2.6430, -0.3306, -0.8504};
auto seq = p.add_literal(migraphx::literal{in_shape_one, input_one});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip,
1},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.27298412,
0.42363745,
-0.09368783,
0.4823072,
-0.02183238,
-0.6873896,
0.16144305,
0.31932795,
0.6104771,
0.79759157};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
}
TEST_CASE(gru_reverse)
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}};
std::vector<float> w_data{
0.3485, -0.0378, -0.1782, 0.1416, -0.3096, -0.2212, -0.3883, 0.1983, -0.2418,
0.1480, -0.3255, 0.1359, -0.3551, -0.3605, -0.3482, -0.1424, -0.0495, -0.1640,
-0.1979, -0.2577, -0.4097, -0.1211, -0.0412, 0.1801, 0.1721, -0.4327, -0.0498,
0.2628, -0.1573, -0.1577, 0.2759, -0.2023, -0.1185, -0.2136, 0.1294, -0.2331,
0.0701, 0.4316, 0.0480, 0.0247, -0.0166, -0.2729, 0.1712, -0.3984, -0.3905};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}};
std::vector<float> r_data{
0.2848, -0.2851, -0.3466, -0.1718, -0.1492, -0.0082, 0.2452, -0.0401, 0.3399, 0.2529,
-0.0953, -0.0903, -0.1518, -0.1373, 0.3848, -0.0130, -0.4339, 0.0406, -0.1926, -0.1131,
0.4285, -0.0013, 0.2243, 0.2752, 0.1776, -0.1720, 0.0822, -0.0295, 0.1062, -0.2721,
-0.2736, -0.1826, 0.3541, -0.4259, 0.2188, 0.0706, 0.3650, 0.3947, 0.2522, 0.2179,
-0.0744, 0.2122, -0.4346, 0.2760, 0.4076, 0.1183, -0.1500, -0.1704, 0.3090, -0.0706,
-0.2442, 0.3021, 0.1680, 0.0783, -0.3754, -0.3469, -0.2972, -0.0170, 0.4143, 0.3801,
0.3852, -0.1170, -0.2937, 0.2979, -0.1357, 0.4257, 0.3884, -0.2916, 0.1071, 0.0934,
0.3645, -0.4310, -0.3480, 0.0702, -0.1558};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
std::vector<float> bias_data{
0.0560, 0.0310, -0.1669, -0.0781, 0.1793, -0.1758, 0.3173, -0.1650, -0.3732, 0.2946,
-0.0912, 0.3118, 0.1391, 0.2755, 0.2695, -0.1059, -0.2357, 0.3629, -0.2534, -0.0494,
0.0556, 0.0881, -0.2592, -0.2213, 0.2310, -0.4044, 0.1801, 0.1438, 0.3108, -0.3607};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input{-0.8432,
-0.9887,
1.3041,
-2.6430,
-0.3306,
-0.8504,
-0.3933,
0.5151,
-0.2951,
0.0093,
-1.1948,
-0.1239,
0.0373,
1.3211,
0.7854,
-0.4838,
-1.0536,
-0.2529};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
std::vector<float> ih_data{
-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212};
float clip = 0.0f;
// concatenation of hidden states for output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip,
1},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.263403, 0.317655, -0.00634162, 0.200443, -0.349125,
-0.600874, 0.542386, -0.0856531, 0.55703, 0.54711,
-0.276245, 0.521348, 0.302874, 0.394353, -0.334369,
-0.187861, 0.213553, -0.0708377, 0.545435, 0.654301,
-0.329512, 0.476095, 0.284044, 0.392077, -0.369226,
-0.3275, -0.027301, 0.143774, 0.655686, 0.782831};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// last output for output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto concat_hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip,
1},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.263403,
0.317655,
-0.00634162,
0.200443,
-0.349125,
-0.600874,
0.542386,
-0.0856531,
0.55703,
0.54711};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// last output for output, linear_before_reset = 0
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto concat_hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip,
0},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.388654,
0.384975,
0.0179455,
0.350101,
-0.456872,
-0.690085,
0.534512,
-0.0558191,
0.646604,
0.463943};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// no activation function specified, so default is used.
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip, 1},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.263403, 0.317655, -0.00634162, 0.200443, -0.349125,
-0.600874, 0.542386, -0.0856531, 0.55703, 0.54711,
-0.276245, 0.521348, 0.302874, 0.394353, -0.334369,
-0.187861, 0.213553, -0.0708377, 0.545435, 0.654301,
-0.329512, 0.476095, 0.284044, 0.392077, -0.369226,
-0.3275, -0.027301, 0.143774, 0.655686, 0.782831};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// seq length of 1
{
migraphx::program p;
seq_len = 1;
migraphx::shape in_shape_one{migraphx::shape::float_type,
{seq_len, batch_size, input_size}};
std::vector<float> input_one{-0.8432, -0.9887, 1.3041, -2.6430, -0.3306, -0.8504};
auto seq = p.add_literal(migraphx::literal{in_shape_one, input_one});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip,
1},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.272984,
0.423637,
-0.0936878,
0.482307,
-0.0218324,
-0.68739,
0.161443,
0.319328,
0.610477,
0.797592};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
}
TEST_CASE(gru_bidirectional)
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}};
std::vector<float> w_data{
0.3809, 0.4283, 0.2294, -0.1018, -0.1226, -0.0037, 0.2449, -0.2712, -0.1418,
0.1363, -0.3453, -0.0693, -0.2281, 0.2699, -0.2024, -0.3085, -0.3338, 0.4109,
0.2605, -0.1019, -0.2813, 0.3323, -0.1590, 0.0788, -0.3535, 0.0397, 0.2732,
0.2906, 0.0519, 0.3617, -0.2664, 0.1441, 0.0464, -0.1057, 0.2204, -0.3294,
0.3670, 0.1411, 0.3852, 0.3572, 0.3918, 0.0483, -0.3906, -0.2841, -0.2778,
-0.4272, 0.2335, -0.1811, -0.3885, -0.1279, 0.1000, 0.0206, -0.3284, -0.0353,
0.1197, 0.1190, 0.3862, 0.0965, -0.0492, 0.2657, -0.1430, 0.0597, 0.1408,
-0.0315, 0.1248, 0.0751, 0.3838, 0.3020, 0.0515, 0.2375, -0.4255, 0.1714,
-0.0432, 0.3447, -0.2441, -0.3989, -0.3428, -0.4204, -0.4080, -0.2683, -0.0996,
-0.1685, -0.0532, -0.1258, 0.1663, -0.3526, -0.3915, -0.1721, 0.1292, -0.2279};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}};
std::vector<float> r_data{
-0.2683, 0.0699, -0.4021, -0.1379, 0.0042, -0.2447, 0.4006, 0.0270, -0.0446, 0.1063,
0.1381, 0.1310, -0.3596, 0.3869, 0.3929, 0.2750, 0.0890, 0.3069, -0.1691, -0.2194,
-0.1066, 0.3187, -0.4369, -0.0603, -0.0834, -0.1182, -0.2047, 0.3253, -0.2931, 0.2082,
0.0424, 0.1111, -0.2773, -0.0279, -0.0869, 0.1413, -0.4227, -0.3672, 0.4137, 0.0609,
0.4223, -0.4032, 0.2945, 0.3600, 0.3345, -0.3880, -0.0192, -0.0090, -0.2648, 0.4339,
-0.0155, 0.4437, -0.1766, 0.1957, 0.2475, 0.3773, -0.2710, 0.3289, -0.2077, -0.2534,
-0.0832, -0.1632, 0.0728, 0.2520, 0.4153, 0.1659, -0.4342, 0.0541, 0.1812, -0.2305,
0.4440, 0.0946, 0.0410, -0.4381, -0.3161, 0.3906, -0.3958, -0.4238, 0.1975, 0.3440,
0.1437, -0.0568, 0.1492, -0.4248, -0.3304, 0.2786, -0.1328, -0.3740, -0.3566, 0.3074,
0.0924, 0.2684, -0.1527, 0.1826, 0.2424, 0.2002, 0.3479, -0.1089, 0.3472, -0.3677,
-0.4231, -0.0798, -0.3709, 0.3924, 0.2774, -0.3690, -0.0233, 0.2845, 0.1969, 0.1618,
-0.3742, -0.3619, 0.2925, -0.1838, -0.1495, -0.3747, 0.0341, -0.4243, -0.0732, -0.3997,
0.2139, 0.2425, 0.4171, -0.3358, 0.3534, 0.0938, -0.0582, -0.2681, -0.4293, 0.1027,
0.4101, 0.2641, -0.4110, -0.1681, 0.3582, -0.2089, 0.0852, 0.0963, 0.3866, 0.1955,
-0.2174, 0.1996, -0.2252, 0.1748, 0.1833, -0.3155, 0.2567, -0.4387, 0.3402, 0.0599};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
std::vector<float> bias_data{
-0.1582, -0.0826, 0.4008, 0.0118, 0.2511, 0.1900, -0.2838, 0.2549, -0.2484, 0.2363,
-0.4083, -0.0295, -0.1161, 0.1211, 0.2509, -0.1414, -0.2628, -0.2992, 0.1517, 0.1817,
-0.2783, 0.3183, -0.1629, -0.3108, -0.3418, 0.0411, 0.2203, 0.2187, -0.2990, -0.0416,
0.0209, -0.1024, 0.4443, -0.4420, -0.0330, -0.3591, -0.2990, 0.2167, 0.1395, 0.2317,
0.1318, 0.1909, -0.3615, 0.1953, -0.2582, -0.2217, 0.3723, 0.1458, 0.2630, -0.0377,
0.1754, 0.0800, -0.3964, -0.3247, 0.4219, -0.0900, 0.3553, 0.2614, -0.1298, -0.1124};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input{-0.8432,
-0.9887,
1.3041,
-2.6430,
-0.3306,
-0.8504,
-0.3933,
0.5151,
-0.2951,
0.0093,
-1.1948,
-0.1239,
0.0373,
1.3211,
0.7854,
-0.4838,
-1.0536,
-0.2529};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
std::vector<float> ih_data{-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348,
0.4178, 1.0175, 0.9212, -0.0468, 0.5691, -0.0882, 0.8340,
0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212};
float clip = 0.0f;
// concatenation of hidden states for output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
1},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.0352243, 0.0146756, 0.00570925, 0.152446, 0.208683, 0.214342, -0.0454273,
-0.135177, -0.0800739, 0.903659, 0.0248217, 0.435231, -0.144448, 0.101531,
-0.111305, 0.381317, 0.468983, 0.230557, 0.348021, 0.180229, -0.0930435,
0.174108, -0.063834, 0.0909285, 0.22759, -0.221983, -0.139656, -0.0938906,
-0.247681, 0.69647, -0.159396, 0.299061, -0.116652, 0.238649, 0.109945,
0.192866, 0.307073, 0.191113, 0.658287, -0.0340374, -0.0959787, 0.0794681,
0.241526, 0.321104, 0.00693533, -0.311839, -0.12802, -0.16643, -0.393849,
0.648851, -0.395918, 0.231694, -0.160503, 0.383289, 0.0879262, -0.0254665,
0.079043, 0.322652, 0.752701, 0.243775};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// last output for output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto concat_hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
1},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.0959787, 0.0794681, 0.241526, 0.321104, 0.00693533,
-0.311839, -0.12802, -0.16643, -0.393849, 0.648851,
0.0248217, 0.435231, -0.144448, 0.101531, -0.111305,
0.381317, 0.468983, 0.230557, 0.348021, 0.180229};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// last output for output, linear_before_reset = 0
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto concat_hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
-0.09280921, 0.18506107, 0.32247013, 0.17034212, -0.00115255, -0.29865006, -0.04513004,
-0.10688055, -0.4767866, 0.6317833, 0.00286336, 0.53692746, -0.00617076, 0.04564289,
-0.18030001, 0.39584228, 0.53879917, 0.384983, 0.2759448, 0.11611474};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
}
TEST_CASE(gru_bidirectional_args)
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}};
std::vector<float> w_data{
0.3809, 0.4283, 0.2294, -0.1018, -0.1226, -0.0037, 0.2449, -0.2712, -0.1418,
0.1363, -0.3453, -0.0693, -0.2281, 0.2699, -0.2024, -0.3085, -0.3338, 0.4109,
0.2605, -0.1019, -0.2813, 0.3323, -0.1590, 0.0788, -0.3535, 0.0397, 0.2732,
0.2906, 0.0519, 0.3617, -0.2664, 0.1441, 0.0464, -0.1057, 0.2204, -0.3294,
0.3670, 0.1411, 0.3852, 0.3572, 0.3918, 0.0483, -0.3906, -0.2841, -0.2778,
-0.4272, 0.2335, -0.1811, -0.3885, -0.1279, 0.1000, 0.0206, -0.3284, -0.0353,
0.1197, 0.1190, 0.3862, 0.0965, -0.0492, 0.2657, -0.1430, 0.0597, 0.1408,
-0.0315, 0.1248, 0.0751, 0.3838, 0.3020, 0.0515, 0.2375, -0.4255, 0.1714,
-0.0432, 0.3447, -0.2441, -0.3989, -0.3428, -0.4204, -0.4080, -0.2683, -0.0996,
-0.1685, -0.0532, -0.1258, 0.1663, -0.3526, -0.3915, -0.1721, 0.1292, -0.2279};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}};
std::vector<float> r_data{
-0.2683, 0.0699, -0.4021, -0.1379, 0.0042, -0.2447, 0.4006, 0.0270, -0.0446, 0.1063,
0.1381, 0.1310, -0.3596, 0.3869, 0.3929, 0.2750, 0.0890, 0.3069, -0.1691, -0.2194,
-0.1066, 0.3187, -0.4369, -0.0603, -0.0834, -0.1182, -0.2047, 0.3253, -0.2931, 0.2082,
0.0424, 0.1111, -0.2773, -0.0279, -0.0869, 0.1413, -0.4227, -0.3672, 0.4137, 0.0609,
0.4223, -0.4032, 0.2945, 0.3600, 0.3345, -0.3880, -0.0192, -0.0090, -0.2648, 0.4339,
-0.0155, 0.4437, -0.1766, 0.1957, 0.2475, 0.3773, -0.2710, 0.3289, -0.2077, -0.2534,
-0.0832, -0.1632, 0.0728, 0.2520, 0.4153, 0.1659, -0.4342, 0.0541, 0.1812, -0.2305,
0.4440, 0.0946, 0.0410, -0.4381, -0.3161, 0.3906, -0.3958, -0.4238, 0.1975, 0.3440,
0.1437, -0.0568, 0.1492, -0.4248, -0.3304, 0.2786, -0.1328, -0.3740, -0.3566, 0.3074,
0.0924, 0.2684, -0.1527, 0.1826, 0.2424, 0.2002, 0.3479, -0.1089, 0.3472, -0.3677,
-0.4231, -0.0798, -0.3709, 0.3924, 0.2774, -0.3690, -0.0233, 0.2845, 0.1969, 0.1618,
-0.3742, -0.3619, 0.2925, -0.1838, -0.1495, -0.3747, 0.0341, -0.4243, -0.0732, -0.3997,
0.2139, 0.2425, 0.4171, -0.3358, 0.3534, 0.0938, -0.0582, -0.2681, -0.4293, 0.1027,
0.4101, 0.2641, -0.4110, -0.1681, 0.3582, -0.2089, 0.0852, 0.0963, 0.3866, 0.1955,
-0.2174, 0.1996, -0.2252, 0.1748, 0.1833, -0.3155, 0.2567, -0.4387, 0.3402, 0.0599};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
std::vector<float> bias_data{
-0.1582, -0.0826, 0.4008, 0.0118, 0.2511, 0.1900, -0.2838, 0.2549, -0.2484, 0.2363,
-0.4083, -0.0295, -0.1161, 0.1211, 0.2509, -0.1414, -0.2628, -0.2992, 0.1517, 0.1817,
-0.2783, 0.3183, -0.1629, -0.3108, -0.3418, 0.0411, 0.2203, 0.2187, -0.2990, -0.0416,
0.0209, -0.1024, 0.4443, -0.4420, -0.0330, -0.3591, -0.2990, 0.2167, 0.1395, 0.2317,
0.1318, 0.1909, -0.3615, 0.1953, -0.2582, -0.2217, 0.3723, 0.1458, 0.2630, -0.0377,
0.1754, 0.0800, -0.3964, -0.3247, 0.4219, -0.0900, 0.3553, 0.2614, -0.1298, -0.1124};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input{-0.8432,
-0.9887,
1.3041,
-2.6430,
-0.3306,
-0.8504,
-0.3933,
0.5151,
-0.2951,
0.0093,
-1.1948,
-0.1239,
0.0373,
1.3211,
0.7854,
-0.4838,
-1.0536,
-0.2529};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
std::vector<float> ih_data{-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348,
0.4178, 1.0175, 0.9212, -0.0468, 0.5691, -0.0882, 0.8340,
0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212};
float clip = 0.0f;
// 3 args
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.0863793, -0.227845, 0.0283059, -0.258645, 0.14187, 0.43541, 0.190748,
-0.530196, -0.440444, 0.293767, 0.0402142, 0.0788687, -0.013, -0.233298,
-0.0739615, 0.467104, 0.446285, 0.306097, 0.125636, 0.272524, 0.0949838,
0.0522264, -0.0872712, -0.084203, 0.140013, 0.12739, -0.0111171, -0.431119,
-0.468382, 0.388067, -0.109174, -0.119064, -0.0242958, -0.180555, 0.118983,
0.341578, 0.275472, 0.0853083, 0.332205, -0.0498387, 0.140338, 0.0319435,
0.247019, 0.275848, -0.158223, 0.0495464, -0.0681034, -0.418158, -0.523234,
0.469122, -0.306578, -0.221095, -0.106449, -0.248934, -0.00682121, 0.288407,
0.198708, 0.0695644, 0.211621, 0.00246037};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// 4 args (bias is used)
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
1},
seq,
w,
r,
bias);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
-0.156667, -0.248473, 0.0255282, -0.24566, 0.211589, 0.192707, 0.253025,
-0.515283, -0.414174, 0.227127, 0.124773, 0.284532, -0.203929, -0.120517,
-0.2794, 0.547635, 0.518549, 0.0447674, 0.258461, 0.0502881, -0.219516,
0.0927382, -0.0760062, -0.0906231, 0.237615, -0.215638, 0.0128074, -0.425813,
-0.433378, 0.375383, -0.0381738, 0.117793, -0.180851, -0.0841245, -0.116649,
0.419469, 0.393515, -0.076395, 0.427436, -0.264071, -0.185829, 0.0483585,
0.242955, 0.25233, 0.0148512, -0.304127, -0.0616653, -0.411568, -0.491748,
0.476508, -0.313413, -0.0361821, -0.173037, -0.235731, -0.163113, 0.349008,
0.248674, -0.0295413, 0.291437, -0.165005};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// 4 args (ih is used)
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
1},
seq,
w,
r,
und,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.248571, 0.0982155, 0.00808877, 0.0986508, 0.0969705, 0.434692, -0.141696,
-0.164271, -0.121157, 0.863222, -0.0718357, 0.137711, 0.109221, -0.00207995,
0.0331223, 0.262705, 0.346587, 0.457158, 0.240744, 0.404261, 0.222779,
0.179757, -0.0845316, 0.0690347, 0.10204, 0.100155, -0.190286, -0.122062,
-0.274379, 0.547281, -0.226753, -0.0397069, 0.120404, 0.171299, 0.259989,
0.0864604, 0.111322, 0.331784, 0.604653, 0.181017, 0.237426, 0.0911999,
0.233106, 0.32996, -0.17175, 0.0190231, -0.154805, -0.205631, -0.405354,
0.519054, -0.380409, -0.0350301, -0.00633752, 0.403791, 0.181883, -0.0977917,
-0.0339407, 0.413089, 0.721238, 0.431879};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
}
TEST_CASE(gru_bidirectional_actv_funcs)
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}};
std::vector<float> w_data{
0.3809, 0.4283, 0.2294, -0.1018, -0.1226, -0.0037, 0.2449, -0.2712, -0.1418,
0.1363, -0.3453, -0.0693, -0.2281, 0.2699, -0.2024, -0.3085, -0.3338, 0.4109,
0.2605, -0.1019, -0.2813, 0.3323, -0.1590, 0.0788, -0.3535, 0.0397, 0.2732,
0.2906, 0.0519, 0.3617, -0.2664, 0.1441, 0.0464, -0.1057, 0.2204, -0.3294,
0.3670, 0.1411, 0.3852, 0.3572, 0.3918, 0.0483, -0.3906, -0.2841, -0.2778,
-0.4272, 0.2335, -0.1811, -0.3885, -0.1279, 0.1000, 0.0206, -0.3284, -0.0353,
0.1197, 0.1190, 0.3862, 0.0965, -0.0492, 0.2657, -0.1430, 0.0597, 0.1408,
-0.0315, 0.1248, 0.0751, 0.3838, 0.3020, 0.0515, 0.2375, -0.4255, 0.1714,
-0.0432, 0.3447, -0.2441, -0.3989, -0.3428, -0.4204, -0.4080, -0.2683, -0.0996,
-0.1685, -0.0532, -0.1258, 0.1663, -0.3526, -0.3915, -0.1721, 0.1292, -0.2279};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}};
std::vector<float> r_data{
-0.2683, 0.0699, -0.4021, -0.1379, 0.0042, -0.2447, 0.4006, 0.0270, -0.0446, 0.1063,
0.1381, 0.1310, -0.3596, 0.3869, 0.3929, 0.2750, 0.0890, 0.3069, -0.1691, -0.2194,
-0.1066, 0.3187, -0.4369, -0.0603, -0.0834, -0.1182, -0.2047, 0.3253, -0.2931, 0.2082,
0.0424, 0.1111, -0.2773, -0.0279, -0.0869, 0.1413, -0.4227, -0.3672, 0.4137, 0.0609,
0.4223, -0.4032, 0.2945, 0.3600, 0.3345, -0.3880, -0.0192, -0.0090, -0.2648, 0.4339,
-0.0155, 0.4437, -0.1766, 0.1957, 0.2475, 0.3773, -0.2710, 0.3289, -0.2077, -0.2534,
-0.0832, -0.1632, 0.0728, 0.2520, 0.4153, 0.1659, -0.4342, 0.0541, 0.1812, -0.2305,
0.4440, 0.0946, 0.0410, -0.4381, -0.3161, 0.3906, -0.3958, -0.4238, 0.1975, 0.3440,
0.1437, -0.0568, 0.1492, -0.4248, -0.3304, 0.2786, -0.1328, -0.3740, -0.3566, 0.3074,
0.0924, 0.2684, -0.1527, 0.1826, 0.2424, 0.2002, 0.3479, -0.1089, 0.3472, -0.3677,
-0.4231, -0.0798, -0.3709, 0.3924, 0.2774, -0.3690, -0.0233, 0.2845, 0.1969, 0.1618,
-0.3742, -0.3619, 0.2925, -0.1838, -0.1495, -0.3747, 0.0341, -0.4243, -0.0732, -0.3997,
0.2139, 0.2425, 0.4171, -0.3358, 0.3534, 0.0938, -0.0582, -0.2681, -0.4293, 0.1027,
0.4101, 0.2641, -0.4110, -0.1681, 0.3582, -0.2089, 0.0852, 0.0963, 0.3866, 0.1955,
-0.2174, 0.1996, -0.2252, 0.1748, 0.1833, -0.3155, 0.2567, -0.4387, 0.3402, 0.0599};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
std::vector<float> bias_data{
-0.1582, -0.0826, 0.4008, 0.0118, 0.2511, 0.1900, -0.2838, 0.2549, -0.2484, 0.2363,
-0.4083, -0.0295, -0.1161, 0.1211, 0.2509, -0.1414, -0.2628, -0.2992, 0.1517, 0.1817,
-0.2783, 0.3183, -0.1629, -0.3108, -0.3418, 0.0411, 0.2203, 0.2187, -0.2990, -0.0416,
0.0209, -0.1024, 0.4443, -0.4420, -0.0330, -0.3591, -0.2990, 0.2167, 0.1395, 0.2317,
0.1318, 0.1909, -0.3615, 0.1953, -0.2582, -0.2217, 0.3723, 0.1458, 0.2630, -0.0377,
0.1754, 0.0800, -0.3964, -0.3247, 0.4219, -0.0900, 0.3553, 0.2614, -0.1298, -0.1124};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input{-0.8432,
-0.9887,
1.3041,
-2.6430,
-0.3306,
-0.8504,
-0.3933,
0.5151,
-0.2951,
0.0093,
-1.1948,
-0.1239,
0.0373,
1.3211,
0.7854,
-0.4838,
-1.0536,
-0.2529};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
std::vector<float> ih_data{-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348,
0.4178, 1.0175, 0.9212, -0.0468, 0.5691, -0.0882, 0.8340,
0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212};
float clip = 0.0f;
// no activation function specified, so default is used.
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto concat_hs = p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip, 1},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.0959787, 0.0794681, 0.241526, 0.321104, 0.00693533,
-0.311839, -0.12802, -0.16643, -0.393849, 0.648851,
0.0248217, 0.435231, -0.144448, 0.101531, -0.111305,
0.381317, 0.468983, 0.230557, 0.348021, 0.180229};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// 1 activation function (sigmoid) specified
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.325495, 0.469214, 0.164517, 0.585327, 0.328398, 0.457928, 0.065011, 0.35986,
0.545029, 0.859425, 0.427923, 0.667133, 0.41591, 0.540971, 0.365475, 0.482058,
0.565495, 0.556993, 0.607649, 0.543627, 0.428915, 0.537405, 0.306046, 0.518399,
0.403561, 0.410694, 0.301163, 0.407397, 0.471334, 0.726446, 0.309389, 0.612072,
0.360619, 0.590861, 0.366545, 0.367001, 0.433829, 0.501275, 0.72481, 0.512745,
0.463795, 0.539649, 0.487682, 0.554471, 0.395916, 0.430744, 0.415923, 0.424275,
0.409655, 0.698256, 0.126883, 0.554374, 0.216137, 0.671491, 0.263833, 0.0678646,
0.132732, 0.477083, 0.802206, 0.626802};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// 1 activation function (tanh) specified
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
1},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.0919632, -0.398302, -0.0267752, -0.326771, 0.401983, 0.949841, 0.557779,
-0.745259, -1.52726, 0.946066, 0.330446, 0.301982, -0.443763, -0.0655817,
-0.326473, 0.861394, 0.560799, -0.101768, 0.145142, 0.128956, -0.329758,
0.458253, -0.339208, 0.289109, 0.36728, -1.09574, -0.181394, -0.575781,
-0.823083, 0.804262, -0.0965933, 0.20405, -0.430215, 0.00884668, 0.0716857,
0.844222, 0.516472, -0.191571, 0.596968, -0.545405, -0.336693, -0.0280516,
0.339058, 1.00367, 0.12655, -0.0984504, -0.174945, -0.5365, 0.183188,
0.66716, -0.704461, -0.393346, -0.627123, 0.210395, 0.0563026, 0.31419,
0.759629, 0.000258222, 0.350835, -0.682684};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// 3 activation functions specified
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto concat_hs = p.add_instruction(
migraphx::op::gru{hidden_size,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
1},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{0.351019, 0.474363, 0.570719, 0.717703, 0.468843,
1.15142, 0.457633, 0.300962, 0.361245, 0.666199,
0.330446, 0.301982, -0.443763, -0.0655817, -0.326473,
0.861394, 0.560799, -0.101768, 0.145142, 0.128956};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// 4 activation functions all specified
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{},
migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
1},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.0352243, 0.0146756, 0.00570925, 0.152446, 0.208683, 0.214342, -0.0454273,
-0.135177, -0.0800739, 0.903659, 0.0248217, 0.435231, -0.144448, 0.101531,
-0.111305, 0.381317, 0.468983, 0.230557, 0.348021, 0.180229, -0.0930435,
0.174108, -0.063834, 0.0909285, 0.22759, -0.221983, -0.139656, -0.0938906,
-0.247681, 0.69647, -0.159396, 0.299061, -0.116652, 0.238649, 0.109945,
0.192866, 0.307073, 0.191113, 0.658287, -0.0340374, -0.0959787, 0.0794681,
0.241526, 0.321104, 0.00693533, -0.311839, -0.12802, -0.16643, -0.393849,
0.648851, -0.395918, 0.231694, -0.160503, 0.383289, 0.0879262, -0.0254665,
0.079043, 0.322652, 0.752701, 0.243775};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// seq length of 1
{
migraphx::program p;
seq_len = 1;
migraphx::shape in_shape_one{migraphx::shape::float_type,
{seq_len, batch_size, input_size}};
std::vector<float> input_one{-0.8432, -0.9887, 1.3041, -2.6430, -0.3306, -0.8504};
auto seq = p.add_literal(migraphx::literal{in_shape_one, input_one});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
1},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{0.0352243, 0.0146756, 0.00570925, 0.152446, 0.208683,
0.214342, -0.0454273, -0.135177, -0.0800739, 0.903659,
-0.0271321, 0.624762, -0.117084, 0.509115, -0.0175078,
-0.144492, -0.0115366, 0.409153, 0.487015, 0.550755};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -1163,7 +1163,7 @@ struct test_rnn_forward
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -1205,7 +1205,7 @@ struct test_rnn_forward10
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -1246,7 +1246,7 @@ struct test_rnn_reverse
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -1286,7 +1286,7 @@ struct test_rnn_reverse2
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -1321,7 +1321,7 @@ struct test_rnn_3args
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -1355,7 +1355,7 @@ struct test_rnn_4args
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -1392,7 +1392,7 @@ struct test_rnn_5args
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -1433,7 +1433,7 @@ struct test_rnn_bidirectional
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -1474,7 +1474,7 @@ struct test_rnn_bidirectional10
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -1512,7 +1512,7 @@ struct test_rnn_bi_3args
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -1523,6 +1523,597 @@ struct test_rnn_bi_3args
}
};
struct test_gru_forward_last
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p;
}
};
struct test_gru_forward_hs
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
und,
ih);
return p;
}
};
struct test_gru_forward_3args_und
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
und,
und,
und);
return p;
}
};
struct test_gru_forward_3args
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r);
return p;
}
};
struct test_gru_forward_seq1
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 1;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r);
return p;
}
};
struct test_gru_forward_default_actv
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 1;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r);
return p;
}
};
struct test_gru_forward_default_actv1
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::gru{
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r,
bias,
und,
ih);
return p;
}
};
struct test_gru_reverse_last
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p;
}
};
struct test_gru_reverse_3args
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r);
return p;
}
};
struct test_gru_bidirct_last
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p;
}
};
struct test_gru_bidirct_hs
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
und,
ih);
return p;
}
};
struct test_gru_bidirct_3args_und
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
und,
und,
und);
return p;
}
};
struct test_gru_bidirct_3args
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r);
return p;
}
};
struct test_gru_bidirct_seq1
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 1;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r);
return p;
}
};
struct test_gru_bidirct_default_actv
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 1;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r);
return p;
}
};
struct test_gru_bidirct_default_actv1
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 8;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
und,
ih);
return p;
}
};
int main()
{
verify_program<test_relu_lrn>();
......@@ -1597,4 +2188,20 @@ int main()
verify_program<test_rnn_bidirectional>();
verify_program<test_rnn_bidirectional10>();
verify_program<test_rnn_bi_3args>();
verify_program<test_gru_forward_last>();
verify_program<test_gru_forward_hs>();
verify_program<test_gru_forward_3args_und>();
verify_program<test_gru_forward_3args>();
verify_program<test_gru_forward_seq1>();
verify_program<test_gru_forward_default_actv>();
verify_program<test_gru_forward_default_actv1>();
verify_program<test_gru_reverse_last>();
verify_program<test_gru_reverse_3args>();
verify_program<test_gru_bidirct_last>();
verify_program<test_gru_bidirct_hs>();
verify_program<test_gru_bidirct_3args_und>();
verify_program<test_gru_bidirct_3args>();
verify_program<test_gru_bidirct_seq1>();
verify_program<test_gru_bidirct_default_actv>();
verify_program<test_gru_bidirct_default_actv1>();
}
......@@ -491,7 +491,7 @@ TEST_CASE(rnn_test)
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -523,7 +523,7 @@ TEST_CASE(rnn_test)
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -555,7 +555,7 @@ TEST_CASE(rnn_test)
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -583,7 +583,7 @@ TEST_CASE(rnn_test)
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -615,7 +615,7 @@ TEST_CASE(rnn_test)
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -630,6 +630,429 @@ TEST_CASE(rnn_test)
}
}
TEST_CASE(gru_test)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// forward
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip,
1},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward.onnx");
EXPECT(p == prog);
}
// reverse
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse.onnx");
EXPECT(p == prog);
}
// bidirectional
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{},
migraphx::op::sigmoid{},
migraphx::op::relu{},
migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(gru_test_args)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// 3 arguments
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_3arg.onnx");
EXPECT(p == prog);
}
// 4 arguments
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
r,
bias,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_4arg.onnx");
EXPECT(p == prog);
}
// 5 arguments
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_5arg.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(gru_test_actv_funcs)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// bidirection, 0 actv function
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx");
EXPECT(p == prog);
}
// bidirection, 1 actv function
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{
hs, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx");
EXPECT(p == prog);
}
// bidirection, 2 actv functions
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_2.onnx");
EXPECT(p == prog);
}
// bidirection, 3 actv functions
{
nd = 2;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx");
EXPECT(p == prog);
}
// forward, 0 actv function
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward_0.onnx");
EXPECT(p == prog);
}
// reverse, 1 actv function
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
auto r =
p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{
hs, {migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(flatten_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment