Commit 0da78d29 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge gru_operator changes

parents d53a69f6 ce7b4b17
...@@ -60,6 +60,30 @@ struct batch_norm_inference ...@@ -60,6 +60,30 @@ struct batch_norm_inference
} }
}; };
struct lrn
{
float alpha = 0.0001;
float beta = 0.75;
float bias = 1.0;
int size = 1;
std::string name() const { return "lrn"; }
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.alpha, "alpha"),
f(self.beta, "beta"),
f(self.bias, "bias"),
f(self.size, "size"));
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return inputs.front();
}
};
struct convolution struct convolution
{ {
std::array<std::size_t, 2> padding = {{0, 0}}; std::array<std::size_t, 2> padding = {{0, 0}};
...@@ -1140,20 +1164,20 @@ struct outline ...@@ -1140,20 +1164,20 @@ struct outline
argument compute(const shape&, const std::vector<argument>&) const { return {s, nullptr}; } argument compute(const shape&, const std::vector<argument>&) const { return {s, nullptr}; }
}; };
struct rnn // indicate rnn computation direction
enum class rnn_direction
{ {
forward,
reverse,
bidirectional,
};
enum rnn_direction_t struct rnn
{ {
forward,
reverse,
bidirectional,
};
std::size_t hidden_size = 1; std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{tanh{}, tanh{}}; std::vector<operation> actv_funcs{tanh{}, tanh{}};
rnn_direction_t direction = forward; rnn_direction direction = rnn_direction::forward;
float clip = 0.0f; float clip = 0.0f;
std::string name() const { return "rnn"; } std::string name() const { return "rnn"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
...@@ -1166,7 +1190,7 @@ struct rnn ...@@ -1166,7 +1190,7 @@ struct rnn
} }
std::size_t num_directions = 1; std::size_t num_directions = 1;
if(direction == bidirectional) if(direction == rnn_direction::bidirectional)
{ {
num_directions = 2; num_directions = 2;
} }
...@@ -1200,18 +1224,11 @@ struct rnn_last_output ...@@ -1200,18 +1224,11 @@ struct rnn_last_output
struct gru struct gru
{ {
enum gru_direction_t
{
forward,
reverse,
bidirectional,
};
std::size_t hidden_size = 1; std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{sigmoid{}, tanh{}}; std::vector<operation> actv_funcs{sigmoid{}, tanh{}};
gru_direction_t direction = forward; rnn_direction direction = rnn_direction::forward;
float clip = 0.0f; float clip = 0.0f;
int linear_before_reset = 0; int linear_before_reset = 0;
std::string name() const { return "gru"; } std::string name() const { return "gru"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
...@@ -1224,7 +1241,7 @@ struct gru ...@@ -1224,7 +1241,7 @@ struct gru
} }
std::size_t num_directions = 1; std::size_t num_directions = 1;
if(direction == bidirectional) if(direction == rnn_direction::bidirectional)
{ {
num_directions = 2; num_directions = 2;
} }
...@@ -1242,32 +1259,11 @@ struct gru ...@@ -1242,32 +1259,11 @@ struct gru
} }
}; };
struct gru_last_output
{
std::string name() const { return "gru_last_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
dims.erase(dims.begin());
return {inputs[0].type(), dims};
}
};
struct lstm struct lstm
{ {
enum lstm_direction_t
{
forward,
reverse,
bidirectional,
};
std::size_t hidden_size = 1; std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{sigmoid{}, tanh{}, tanh{}}; std::vector<operation> actv_funcs{sigmoid{}, tanh{}, tanh{}};
lstm_direction_t direction = forward; rnn_direction direction = rnn_direction::forward;
float clip = 0.0f; float clip = 0.0f;
int input_forget = 0; int input_forget = 0;
...@@ -1282,7 +1278,7 @@ struct lstm ...@@ -1282,7 +1278,7 @@ struct lstm
} }
std::size_t num_directions = 1; std::size_t num_directions = 1;
if(direction == bidirectional) if(direction == rnn_direction::bidirectional)
{ {
num_directions = 2; num_directions = 2;
} }
...@@ -1300,20 +1296,6 @@ struct lstm ...@@ -1300,20 +1296,6 @@ struct lstm
} }
}; };
struct lstm_last_output
{
std::string name() const { return "lstm_last_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
dims.erase(dims.begin());
return {inputs[0].type(), dims};
}
};
struct lstm_last_cell_output struct lstm_last_cell_output
{ {
std::string name() const { return "lstm_last_cell_output"; } std::string name() const { return "lstm_last_cell_output"; }
......
...@@ -64,6 +64,7 @@ struct onnx_parser ...@@ -64,6 +64,7 @@ struct onnx_parser
add_variadic_op("Max", op::max{}); add_variadic_op("Max", op::max{});
add_variadic_op("Min", op::min{}); add_variadic_op("Min", op::min{});
add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler); add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu); add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
add_mem_op("Elu", &onnx_parser::parse_elu); add_mem_op("Elu", &onnx_parser::parse_elu);
...@@ -88,6 +89,7 @@ struct onnx_parser ...@@ -88,6 +89,7 @@ struct onnx_parser
add_mem_op("Transpose", &onnx_parser::parse_transpose); add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("RNN", &onnx_parser::parse_rnn); add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("GRU", &onnx_parser::parse_gru); add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("Pad", &onnx_parser::parse_pad); add_mem_op("Pad", &onnx_parser::parse_pad);
// init the activation function map // init the activation function map
...@@ -537,6 +539,25 @@ struct onnx_parser ...@@ -537,6 +539,25 @@ struct onnx_parser
return prog.add_instruction(op, args.front()); return prog.add_instruction(op, args.front());
} }
instruction_ref
parse_lrn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
float alpha = 0.0001;
float beta = 0.75;
float bias = 1.0;
int size = 1;
if(contains(attributes, "alpha"))
alpha = parse_value(attributes.at("alpha")).at<float>();
if(contains(attributes, "beta"))
beta = parse_value(attributes.at("beta")).at<float>();
if(contains(attributes, "bias"))
bias = parse_value(attributes.at("bias")).at<float>();
if(contains(attributes, "size"))
size = parse_value(attributes.at("size")).at<int>();
op::lrn op{alpha, beta, bias, size};
return prog.add_instruction(op, args.front());
}
instruction_ref parse_imagescaler(const std::string&, instruction_ref parse_imagescaler(const std::string&,
attribute_map attributes, attribute_map attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
...@@ -714,14 +735,14 @@ struct onnx_parser ...@@ -714,14 +735,14 @@ struct onnx_parser
direction = attributes.at("direction").s(); direction = attributes.at("direction").s();
} }
op::rnn::rnn_direction_t dirct = op::rnn::forward; op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional") if(direction == "bidirectional")
{ {
dirct = op::rnn::bidirectional; dirct = op::rnn_direction::bidirectional;
} }
else if(direction == "reverse") else if(direction == "reverse")
{ {
dirct = op::rnn::reverse; dirct = op::rnn_direction::reverse;
} }
std::vector<std::string> vec_names{"tanh"}; std::vector<std::string> vec_names{"tanh"};
...@@ -743,7 +764,7 @@ struct onnx_parser ...@@ -743,7 +764,7 @@ struct onnx_parser
// one is for forward, and the other is for reverse. // one is for forward, and the other is for reverse.
// if only one actv function is provided, we use it in both // if only one actv function is provided, we use it in both
// forward and reverse direction // forward and reverse direction
if(dirct == op::rnn::bidirectional) if(dirct == op::rnn_direction::bidirectional)
{ {
if(vec_names.size() == 1) if(vec_names.size() == 1)
{ {
...@@ -803,14 +824,14 @@ struct onnx_parser ...@@ -803,14 +824,14 @@ struct onnx_parser
direction = attributes.at("direction").s(); direction = attributes.at("direction").s();
} }
op::gru::gru_direction_t dirct = op::gru::forward; op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional") if(direction == "bidirectional")
{ {
dirct = op::gru::bidirectional; dirct = op::rnn_direction::bidirectional;
} }
else if(direction == "reverse") else if(direction == "reverse")
{ {
dirct = op::gru::reverse; dirct = op::rnn_direction::reverse;
} }
std::vector<std::string> vec_names = {"sigmoid", "tanh"}; std::vector<std::string> vec_names = {"sigmoid", "tanh"};
...@@ -824,7 +845,7 @@ struct onnx_parser ...@@ -824,7 +845,7 @@ struct onnx_parser
} }
// need 4 activation functions // need 4 activation functions
if(dirct == op::gru::bidirectional) if(dirct == op::rnn_direction::bidirectional)
{ {
// 4 activation functions are used in the bidirectional // 4 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we // scenario. No spec is provided in onnx::operator. we
...@@ -895,7 +916,7 @@ struct onnx_parser ...@@ -895,7 +916,7 @@ struct onnx_parser
std::move(args)); std::move(args));
// second output for last gru output // second output for last gru output
auto last_output = prog.add_instruction(op::gru_last_output{}, hidden_states); auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
return {hidden_states, last_output}; return {hidden_states, last_output};
} }
...@@ -922,18 +943,18 @@ struct onnx_parser ...@@ -922,18 +943,18 @@ struct onnx_parser
direction = attributes.at("direction").s(); direction = attributes.at("direction").s();
} }
op::lstm::lstm_direction_t dirct = op::lstm::forward; op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional") if(direction == "bidirectional")
{ {
dirct = op::lstm::bidirectional; dirct = op::rnn_direction::bidirectional;
} }
else if(direction == "reverse") else if(direction == "reverse")
{ {
dirct = op::lstm::reverse; dirct = op::rnn_direction::reverse;
} }
else if(direction == "forward") else if(direction == "forward")
{ {
dirct = op::lstm::forward; dirct = op::rnn_direction::forward;
} }
else else
{ {
...@@ -951,7 +972,7 @@ struct onnx_parser ...@@ -951,7 +972,7 @@ struct onnx_parser
} }
// need 6 activation functions for bidirectional directions // need 6 activation functions for bidirectional directions
if(dirct == op::lstm::bidirectional) if(dirct == op::rnn_direction::bidirectional)
{ {
// 6 activation functions are used in the bidirectional // 6 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we // scenario. No spec is provided in onnx::operator. we
...@@ -1034,7 +1055,7 @@ struct onnx_parser ...@@ -1034,7 +1055,7 @@ struct onnx_parser
op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args)); op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
// second output for last lstm output // second output for last lstm output
auto last_output = prog.add_instruction(op::lstm_last_output{}, hidden_states); auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
// third output for last cell output // third output for last cell output
auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states); auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);
......
...@@ -43,11 +43,11 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const ...@@ -43,11 +43,11 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
migraphx::shape ih_shape{type, {1, batch_size, hidden_size}}; migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<float> data(ih_shape.elements(), 0); std::vector<float> data(ih_shape.elements(), 0);
auto actv_funcs = vanilla_rnn_actv_funcs(ins); auto actv_funcs = vanilla_rnn_actv_funcs(ins);
auto rnn_op = any_cast<op::rnn>(ins->get_operator()); auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn::rnn_direction_t dicrt = rnn_op.direction; op::rnn_direction dicrt = rnn_op.direction;
instruction_ref last_output{}; instruction_ref last_output{};
if(dicrt == op::rnn::bidirectional) if(dicrt == op::rnn_direction::bidirectional)
{ {
// input weight matrix // input weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]); auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
...@@ -107,11 +107,9 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const ...@@ -107,11 +107,9 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
// The following logic is to ensure the last instruction rewritten from // The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction // rnn operator is a concat instruction
// sequence len is 1 // sequence len is 1
instruction_ref hidden_output{};
if(ret_forward[0] == prog.end()) if(ret_forward[0] == prog.end())
{ {
hidden_output = prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
} }
else else
{ {
...@@ -119,13 +117,12 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const ...@@ -119,13 +117,12 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]); prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] = ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]); prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
hidden_output = prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
} }
} }
else else
{ {
bool is_forward = (dicrt == op::rnn::forward); bool is_forward = (dicrt == op::rnn_direction::forward);
// input weight matrix // input weight matrix
auto w = args[1]; auto w = args[1];
...@@ -157,16 +154,15 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const ...@@ -157,16 +154,15 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
// following logic is to ensure the last instruction is a // following logic is to ensure the last instruction is a
// concat instruction // concat instruction
// sequence len is 1 // sequence len is 1
instruction_ref hidden_output{};
if(ret[0] == prog.end()) if(ret[0] == prog.end())
{ {
hidden_output = prog.replace_instruction(ins, op::concat{0}, ret[1]); prog.replace_instruction(ins, op::concat{0}, ret[1]);
} }
else else
{ {
auto concat_arg0 = is_forward ? ret[0] : ret[1]; auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0]; auto concat_arg1 = is_forward ? ret[1] : ret[0];
hidden_output = prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1); prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
} }
} }
...@@ -282,7 +278,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) ...@@ -282,7 +278,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins)
// append undefined operators to make 6 arguments when parsing // append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have any num of arguments // an onnx file. Another case is user can have any num of arguments
// when writing their program. // when writing their program.
if(rnn_op.direction == op::rnn::bidirectional) if(rnn_op.direction == op::rnn_direction::bidirectional)
{ {
if(rnn_op.actv_funcs.empty()) if(rnn_op.actv_funcs.empty())
{ {
...@@ -329,10 +325,10 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -329,10 +325,10 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
migraphx::shape ih_shape{type, {1, batch_size, hidden_size}}; migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<float> data(ih_shape.elements(), 0.0); std::vector<float> data(ih_shape.elements(), 0.0);
auto gru_op = any_cast<op::gru>(ins->get_operator()); auto gru_op = any_cast<op::gru>(ins->get_operator());
op::gru::gru_direction_t dicrt = gru_op.direction; op::rnn_direction dicrt = gru_op.direction;
instruction_ref last_output{}; instruction_ref last_output{};
if(dicrt == op::gru::bidirectional) if(dicrt == op::rnn_direction::bidirectional)
{ {
// w weight matrix // w weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]); auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
...@@ -402,7 +398,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -402,7 +398,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
} }
else else
{ {
bool is_forward = (dicrt == op::gru::forward); bool is_forward = (dicrt == op::rnn_direction::forward);
// weight matrix // weight matrix
auto w = args[1]; auto w = args[1];
auto r = args[2]; auto r = args[2];
...@@ -447,14 +443,14 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -447,14 +443,14 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
} }
} }
// replace the corresponding gru_last_output instruction // replace the corresponding rnn_last_output instruction
// with the last_output, if gru_last_output exists // with the last_output, if rnn_last_output exists
// while loop to handle case of multiple gru_last_output operators // while loop to handle case of multiple rnn_last_output operators
auto last_output_it = ins->outputs().begin(); auto last_output_it = ins->outputs().begin();
while(last_output_it != ins->outputs().end()) while(last_output_it != ins->outputs().end())
{ {
last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) { last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "gru_last_output"; return i->name() == "rnn_last_output";
}); });
if(last_output_it != ins->outputs().end()) if(last_output_it != ins->outputs().end())
...@@ -638,7 +634,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const ...@@ -638,7 +634,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
// we have 4 actv funcs, even though a user does not // we have 4 actv funcs, even though a user does not
// specifiy any actv func. If less than 4, use the // specifiy any actv func. If less than 4, use the
// algorithm in parse_gru to make 4 actv functions // algorithm in parse_gru to make 4 actv functions
if(gru_op.direction == op::gru::bidirectional) if(gru_op.direction == op::rnn_direction::bidirectional)
{ {
if(gru_op.actv_funcs.empty()) if(gru_op.actv_funcs.empty())
return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}}; return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}};
...@@ -689,11 +685,11 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -689,11 +685,11 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
auto actv_funcs = lstm_actv_funcs(ins); auto actv_funcs = lstm_actv_funcs(ins);
auto lstm_op = any_cast<op::lstm>(ins->get_operator()); auto lstm_op = any_cast<op::lstm>(ins->get_operator());
op::lstm::lstm_direction_t dirct = lstm_op.direction; op::rnn_direction dirct = lstm_op.direction;
instruction_ref last_output{}; instruction_ref last_output{};
instruction_ref last_cell_output{}; instruction_ref last_cell_output{};
if(dirct == op::lstm::bidirectional) if(dirct == op::rnn_direction::bidirectional)
{ {
// input weight matrix // input weight matrix
// input weight matrix // input weight matrix
...@@ -799,7 +795,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -799,7 +795,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
} }
else else
{ {
bool is_forward = (dirct == op::lstm::forward); bool is_forward = (dirct == op::rnn_direction::forward);
// weight matrices // weight matrices
auto w = args[1]; auto w = args[1];
auto r = args[2]; auto r = args[2];
...@@ -1100,7 +1096,7 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const ...@@ -1100,7 +1096,7 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
// algorithm in parse_lstm to make 6 actv functions // algorithm in parse_lstm to make 6 actv functions
const auto& actv_funcs = lstm_op.actv_funcs; const auto& actv_funcs = lstm_op.actv_funcs;
std::size_t num_actv_funcs = actv_funcs.size(); std::size_t num_actv_funcs = actv_funcs.size();
if(lstm_op.direction == op::lstm::bidirectional) if(lstm_op.direction == op::rnn_direction::bidirectional)
{ {
switch(num_actv_funcs) switch(num_actv_funcs)
{ {
......
...@@ -103,6 +103,43 @@ struct cpu_batch_norm_inference ...@@ -103,6 +103,43 @@ struct cpu_batch_norm_inference
} }
}; };
struct cpu_lrn
{
op::lrn op;
std::string name() const { return "cpu::lrn"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
int n_batch = output_shape.lens()[0];
int channels = output_shape.lens()[1];
int height = output_shape.lens()[2];
int width = output_shape.lens()[3];
float alphaoverarea = op.alpha / op.size;
int radius = (op.size - 1) / 2;
par_dfor(n_batch, height, width)([&](int b, int h, int w) {
float scale = 0;
dfor(channels)([&](int c) {
auto start = (c - radius) < 0 ? 0 : (c - radius);
auto end = (c + radius) > channels ? channels : (c + radius);
for(auto k = start; k < end; ++k)
{
scale += std::pow(input(b, k, h, w), 2);
}
scale *= alphaoverarea;
scale += op.bias;
scale = std::pow(scale, -op.beta);
output(b, c, h, w) = input(b, c, h, w) * scale;
});
});
});
return result;
}
};
struct cpu_convolution struct cpu_convolution
{ {
op::convolution op; op::convolution op;
...@@ -681,6 +718,7 @@ struct cpu_apply ...@@ -681,6 +718,7 @@ struct cpu_apply
apply_map["dot"] = extend_op<cpu_gemm, op::dot>(); apply_map["dot"] = extend_op<cpu_gemm, op::dot>();
apply_map["batch_norm_inference"] = apply_map["batch_norm_inference"] =
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>(); extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["lrn"] = extend_op<cpu_lrn, op::lrn>();
apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>(); apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>();
apply_map["pad"] = extend_op<cpu_pad, op::pad>(); apply_map["pad"] = extend_op<cpu_pad, op::pad>();
apply_map["concat"] = extend_op<cpu_concat, op::concat>(); apply_map["concat"] = extend_op<cpu_concat, op::concat>();
......
...@@ -61,6 +61,7 @@ add_library(migraphx_gpu ...@@ -61,6 +61,7 @@ add_library(migraphx_gpu
elu.cpp elu.cpp
pad.cpp pad.cpp
gather.cpp gather.cpp
lrn.cpp
) )
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu) set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
rocm_clang_tidy_check(migraphx_gpu) rocm_clang_tidy_check(migraphx_gpu)
......
#ifndef MIGRAPHX_GUARD_RTGLIB_LRN_HPP
#define MIGRAPHX_GUARD_RTGLIB_LRN_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct miopen_lrn
{
shared<lrn_descriptor> ldesc;
std::string name() const { return "gpu::lrn"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -23,6 +23,8 @@ using fusion_plan_descriptor = MIGRAPHX_MANAGE_PTR(miopenFusionPlanDescriptor_t, ...@@ -23,6 +23,8 @@ using fusion_plan_descriptor = MIGRAPHX_MANAGE_PTR(miopenFusionPlanDescriptor_t,
miopenDestroyFusionPlan); miopenDestroyFusionPlan);
using fused_operator_args = MIGRAPHX_MANAGE_PTR(miopenOperatorArgs_t, miopenDestroyOperatorArgs); using fused_operator_args = MIGRAPHX_MANAGE_PTR(miopenOperatorArgs_t, miopenDestroyOperatorArgs);
using lrn_descriptor = MIGRAPHX_MANAGE_PTR(miopenLRNDescriptor_t, miopenDestroyLRNDescriptor);
template <class Result, class F, class... Ts> template <class Result, class F, class... Ts>
Result make_obj(F f, Ts... xs) Result make_obj(F f, Ts... xs)
{ {
...@@ -89,6 +91,13 @@ inline pooling_descriptor make_pooling(const migraphx::op::pooling& op) ...@@ -89,6 +91,13 @@ inline pooling_descriptor make_pooling(const migraphx::op::pooling& op)
return p; return p;
} }
inline lrn_descriptor make_lrn(const migraphx::op::lrn& op)
{
auto ldesc = make_obj<lrn_descriptor>(&miopenCreateLRNDescriptor);
miopenSetLRNDescriptor(ldesc.get(), miopenLRNCrossChannel, op.size, op.alpha, op.beta, op.bias);
return ldesc;
}
inline activation_descriptor make_relu() inline activation_descriptor make_relu()
{ {
auto ad = make_obj<activation_descriptor>(&miopenCreateActivationDescriptor); auto ad = make_obj<activation_descriptor>(&miopenCreateActivationDescriptor);
......
...@@ -43,6 +43,7 @@ ...@@ -43,6 +43,7 @@
#include <migraphx/gpu/concat.hpp> #include <migraphx/gpu/concat.hpp>
#include <migraphx/gpu/pad.hpp> #include <migraphx/gpu/pad.hpp>
#include <migraphx/gpu/gather.hpp> #include <migraphx/gpu/gather.hpp>
#include <migraphx/gpu/lrn.hpp>
#include <utility> #include <utility>
#include <functional> #include <functional>
#include <algorithm> #include <algorithm>
...@@ -99,6 +100,7 @@ struct miopen_apply ...@@ -99,6 +100,7 @@ struct miopen_apply
add_extend_op<hip_gather, op::gather>("gather"); add_extend_op<hip_gather, op::gather>("gather");
add_extend_op<hip_pad, op::pad>("pad"); add_extend_op<hip_pad, op::pad>("pad");
add_lrn_op();
add_convolution_op(); add_convolution_op();
add_pooling_op(); add_pooling_op();
add_batch_norm_inference_op(); add_batch_norm_inference_op();
...@@ -159,6 +161,17 @@ struct miopen_apply ...@@ -159,6 +161,17 @@ struct miopen_apply
}); });
} }
void add_lrn_op()
{
apply_map.emplace("lrn", [=](instruction_ref ins) {
auto&& op = any_cast<op::lrn>(ins->get_operator());
auto ldesc = make_lrn(op);
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(
ins, miopen_lrn{std::move(ldesc)}, ins->inputs().at(0), output);
});
}
template <class T> template <class T>
void add_generic_op(std::string name) void add_generic_op(std::string name)
{ {
......
#include <migraphx/gpu/lrn.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape miopen_lrn::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).not_broadcasted();
return inputs.at(1);
}
argument miopen_lrn::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
float alpha = 1;
float beta = 0;
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
miopenLRNForward(ctx.get_stream().get_miopen(),
ldesc.get(),
&alpha,
x_desc.get(),
args[0].implicit(),
&beta,
y_desc.get(),
args[1].implicit(),
false,
nullptr);
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -732,6 +732,20 @@ TEST_CASE(leaky_relu_test) ...@@ -732,6 +732,20 @@ TEST_CASE(leaky_relu_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(lrn_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1, 5, 1, 1}};
auto l = p.add_literal(migraphx::literal{s, {-2.0f, 1.0f, 0.f, 1.0f, 2.0f}});
p.add_instruction(migraphx::op::lrn{0.0001, 0.75, 1, 5}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(5);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2 / 1.000075, 1 / 1.00009, 0 / 1.000145, 1 / 1.00009, 2 / 1.000075};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(imagescaler_test) TEST_CASE(imagescaler_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -69,7 +69,7 @@ TEST_CASE(rnn_forward) ...@@ -69,7 +69,7 @@ TEST_CASE(rnn_forward)
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -111,14 +111,14 @@ TEST_CASE(rnn_forward) ...@@ -111,14 +111,14 @@ TEST_CASE(rnn_forward)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn::forward, clip}, migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq, seq,
w, w,
r, r,
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
...@@ -148,14 +148,14 @@ TEST_CASE(rnn_forward) ...@@ -148,14 +148,14 @@ TEST_CASE(rnn_forward)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn::forward, clip}, migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq, seq,
w, w,
r, r,
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
...@@ -183,7 +183,10 @@ TEST_CASE(rnn_forward) ...@@ -183,7 +183,10 @@ TEST_CASE(rnn_forward)
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn::forward, clip}, seq, w, r); migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
...@@ -195,6 +198,47 @@ TEST_CASE(rnn_forward) ...@@ -195,6 +198,47 @@ TEST_CASE(rnn_forward)
0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.}; 0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
} }
// seq_len = 1
{
seq_len = 1;
std::vector<float> input_1(seq_len * batch_size * input_size, 0);
input_1[0] = input_1[1] = 1.0;
migraphx::shape in_shape_1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape_1, input_1});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{0.37780784,
0.61055139,
0.55168478,
-0.5888475,
-0.37144644,
0.31708236,
0.13104209,
-0.18736027};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
} }
TEST_CASE(rnn_reverse) TEST_CASE(rnn_reverse)
...@@ -253,13 +297,14 @@ TEST_CASE(rnn_reverse) ...@@ -253,13 +297,14 @@ TEST_CASE(rnn_reverse)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn::reverse, clip}, p.add_instruction(
seq, migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip},
w, seq,
r, w,
bias, r,
und, bias,
ih); und,
ih);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
...@@ -294,14 +339,14 @@ TEST_CASE(rnn_reverse) ...@@ -294,14 +339,14 @@ TEST_CASE(rnn_reverse)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn::reverse, clip}, migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip},
seq, seq,
w, w,
r, r,
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
...@@ -378,7 +423,7 @@ TEST_CASE(rnn_bidirectional) ...@@ -378,7 +423,7 @@ TEST_CASE(rnn_bidirectional)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction( p.add_instruction(
migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn::bidirectional, clip}, migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq, seq,
w, w,
r, r,
...@@ -399,6 +444,7 @@ TEST_CASE(rnn_bidirectional) ...@@ -399,6 +444,7 @@ TEST_CASE(rnn_bidirectional)
-0.20639211, 0.37488942}; -0.20639211, 0.37488942};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
} }
// last rnn output for program output // last rnn output for program output
{ {
migraphx::program p; migraphx::program p;
...@@ -409,15 +455,17 @@ TEST_CASE(rnn_bidirectional) ...@@ -409,15 +455,17 @@ TEST_CASE(rnn_bidirectional)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::rnn{ p.add_instruction(migraphx::op::rnn{hidden_size,
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::bidirectional, clip}, {migraphx::op::tanh{}},
seq, migraphx::op::rnn_direction::bidirectional,
w, clip},
r, seq,
bias, w,
und, r,
ih); bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
...@@ -457,7 +505,7 @@ TEST_CASE(rnn_bidirectional) ...@@ -457,7 +505,7 @@ TEST_CASE(rnn_bidirectional)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -500,7 +548,7 @@ TEST_CASE(rnn_bidirectional) ...@@ -500,7 +548,7 @@ TEST_CASE(rnn_bidirectional)
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -519,6 +567,53 @@ TEST_CASE(rnn_bidirectional) ...@@ -519,6 +567,53 @@ TEST_CASE(rnn_bidirectional)
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
} }
// concatenation of hidden state for program output
{
seq_len = 1;
std::vector<float> input_1(seq_len * batch_size * input_size, 0);
input_1[0] = input_1[1] = 1.0;
migraphx::shape in_shape_1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape_1, input_1});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{0.37780784,
0.61055139,
0.55168478,
-0.5888475,
-0.37144644,
0.31708236,
0.13104209,
-0.18736027,
-0.16915828,
0.1938169,
0.20667936,
0.58609703,
-0.0070999,
0.46251031,
-0.20639211,
0.37488942};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
} }
TEST_CASE(gru_forward) TEST_CASE(gru_forward)
...@@ -588,7 +683,7 @@ TEST_CASE(gru_forward) ...@@ -588,7 +683,7 @@ TEST_CASE(gru_forward)
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip, clip,
1}, 1},
seq, seq,
...@@ -625,7 +720,7 @@ TEST_CASE(gru_forward) ...@@ -625,7 +720,7 @@ TEST_CASE(gru_forward)
auto concat_hs = auto concat_hs =
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip, clip,
1}, 1},
seq, seq,
...@@ -634,7 +729,7 @@ TEST_CASE(gru_forward) ...@@ -634,7 +729,7 @@ TEST_CASE(gru_forward)
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, concat_hs); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
...@@ -654,7 +749,7 @@ TEST_CASE(gru_forward) ...@@ -654,7 +749,7 @@ TEST_CASE(gru_forward)
EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
} }
// two gru_last_output operators after gru // two rnn_last_output operators after gru
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto seq = p.add_literal(migraphx::literal{in_shape, input});
...@@ -666,7 +761,7 @@ TEST_CASE(gru_forward) ...@@ -666,7 +761,7 @@ TEST_CASE(gru_forward)
auto concat_hs = auto concat_hs =
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip, clip,
1}, 1},
seq, seq,
...@@ -675,8 +770,8 @@ TEST_CASE(gru_forward) ...@@ -675,8 +770,8 @@ TEST_CASE(gru_forward)
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, concat_hs); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.add_instruction(migraphx::op::gru_last_output{}, concat_hs); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
...@@ -708,7 +803,7 @@ TEST_CASE(gru_forward) ...@@ -708,7 +803,7 @@ TEST_CASE(gru_forward)
auto concat_hs = auto concat_hs =
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip, clip,
0}, 0},
seq, seq,
...@@ -717,7 +812,7 @@ TEST_CASE(gru_forward) ...@@ -717,7 +812,7 @@ TEST_CASE(gru_forward)
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, concat_hs); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
...@@ -736,6 +831,64 @@ TEST_CASE(gru_forward) ...@@ -736,6 +831,64 @@ TEST_CASE(gru_forward)
EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
} }
}
TEST_CASE(gru_forward_args)
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}};
std::vector<float> w_data{
0.3485, -0.0378, -0.1782, 0.1416, -0.3096, -0.2212, -0.3883, 0.1983, -0.2418,
0.1480, -0.3255, 0.1359, -0.3551, -0.3605, -0.3482, -0.1424, -0.0495, -0.1640,
-0.1979, -0.2577, -0.4097, -0.1211, -0.0412, 0.1801, 0.1721, -0.4327, -0.0498,
0.2628, -0.1573, -0.1577, 0.2759, -0.2023, -0.1185, -0.2136, 0.1294, -0.2331,
0.0701, 0.4316, 0.0480, 0.0247, -0.0166, -0.2729, 0.1712, -0.3984, -0.3905};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}};
std::vector<float> r_data{
0.2848, -0.2851, -0.3466, -0.1718, -0.1492, -0.0082, 0.2452, -0.0401, 0.3399, 0.2529,
-0.0953, -0.0903, -0.1518, -0.1373, 0.3848, -0.0130, -0.4339, 0.0406, -0.1926, -0.1131,
0.4285, -0.0013, 0.2243, 0.2752, 0.1776, -0.1720, 0.0822, -0.0295, 0.1062, -0.2721,
-0.2736, -0.1826, 0.3541, -0.4259, 0.2188, 0.0706, 0.3650, 0.3947, 0.2522, 0.2179,
-0.0744, 0.2122, -0.4346, 0.2760, 0.4076, 0.1183, -0.1500, -0.1704, 0.3090, -0.0706,
-0.2442, 0.3021, 0.1680, 0.0783, -0.3754, -0.3469, -0.2972, -0.0170, 0.4143, 0.3801,
0.3852, -0.1170, -0.2937, 0.2979, -0.1357, 0.4257, 0.3884, -0.2916, 0.1071, 0.0934,
0.3645, -0.4310, -0.3480, 0.0702, -0.1558};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
std::vector<float> bias_data{
0.0560, 0.0310, -0.1669, -0.0781, 0.1793, -0.1758, 0.3173, -0.1650, -0.3732, 0.2946,
-0.0912, 0.3118, 0.1391, 0.2755, 0.2695, -0.1059, -0.2357, 0.3629, -0.2534, -0.0494,
0.0556, 0.0881, -0.2592, -0.2213, 0.2310, -0.4044, 0.1801, 0.1438, 0.3108, -0.3607};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input{-0.8432,
-0.9887,
1.3041,
-2.6430,
-0.3306,
-0.8504,
-0.3933,
0.5151,
-0.2951,
0.0093,
-1.1948,
-0.1239,
0.0373,
1.3211,
0.7854,
-0.4838,
-1.0536,
-0.2529};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
std::vector<float> ih_data{
-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212};
float clip = 0.0f;
// 3 args // 3 args
{ {
...@@ -745,7 +898,7 @@ TEST_CASE(gru_forward) ...@@ -745,7 +898,7 @@ TEST_CASE(gru_forward)
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip, clip,
1}, 1},
seq, seq,
...@@ -776,7 +929,7 @@ TEST_CASE(gru_forward) ...@@ -776,7 +929,7 @@ TEST_CASE(gru_forward)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip, clip,
1}, 1},
seq, seq,
...@@ -809,7 +962,7 @@ TEST_CASE(gru_forward) ...@@ -809,7 +962,7 @@ TEST_CASE(gru_forward)
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip, clip,
1}, 1},
seq, seq,
...@@ -833,6 +986,64 @@ TEST_CASE(gru_forward) ...@@ -833,6 +986,64 @@ TEST_CASE(gru_forward)
EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
} }
}
TEST_CASE(gru_forward_actv_funcs)
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}};
std::vector<float> w_data{
0.3485, -0.0378, -0.1782, 0.1416, -0.3096, -0.2212, -0.3883, 0.1983, -0.2418,
0.1480, -0.3255, 0.1359, -0.3551, -0.3605, -0.3482, -0.1424, -0.0495, -0.1640,
-0.1979, -0.2577, -0.4097, -0.1211, -0.0412, 0.1801, 0.1721, -0.4327, -0.0498,
0.2628, -0.1573, -0.1577, 0.2759, -0.2023, -0.1185, -0.2136, 0.1294, -0.2331,
0.0701, 0.4316, 0.0480, 0.0247, -0.0166, -0.2729, 0.1712, -0.3984, -0.3905};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}};
std::vector<float> r_data{
0.2848, -0.2851, -0.3466, -0.1718, -0.1492, -0.0082, 0.2452, -0.0401, 0.3399, 0.2529,
-0.0953, -0.0903, -0.1518, -0.1373, 0.3848, -0.0130, -0.4339, 0.0406, -0.1926, -0.1131,
0.4285, -0.0013, 0.2243, 0.2752, 0.1776, -0.1720, 0.0822, -0.0295, 0.1062, -0.2721,
-0.2736, -0.1826, 0.3541, -0.4259, 0.2188, 0.0706, 0.3650, 0.3947, 0.2522, 0.2179,
-0.0744, 0.2122, -0.4346, 0.2760, 0.4076, 0.1183, -0.1500, -0.1704, 0.3090, -0.0706,
-0.2442, 0.3021, 0.1680, 0.0783, -0.3754, -0.3469, -0.2972, -0.0170, 0.4143, 0.3801,
0.3852, -0.1170, -0.2937, 0.2979, -0.1357, 0.4257, 0.3884, -0.2916, 0.1071, 0.0934,
0.3645, -0.4310, -0.3480, 0.0702, -0.1558};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
std::vector<float> bias_data{
0.0560, 0.0310, -0.1669, -0.0781, 0.1793, -0.1758, 0.3173, -0.1650, -0.3732, 0.2946,
-0.0912, 0.3118, 0.1391, 0.2755, 0.2695, -0.1059, -0.2357, 0.3629, -0.2534, -0.0494,
0.0556, 0.0881, -0.2592, -0.2213, 0.2310, -0.4044, 0.1801, 0.1438, 0.3108, -0.3607};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input{-0.8432,
-0.9887,
1.3041,
-2.6430,
-0.3306,
-0.8504,
-0.3933,
0.5151,
-0.2951,
0.0093,
-1.1948,
-0.1239,
0.0373,
1.3211,
0.7854,
-0.4838,
-1.0536,
-0.2529};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
std::vector<float> ih_data{
-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212};
float clip = 0.0f;
// no activation function specified, so default is used. // no activation function specified, so default is used.
{ {
...@@ -844,14 +1055,14 @@ TEST_CASE(gru_forward) ...@@ -844,14 +1055,14 @@ TEST_CASE(gru_forward)
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto concat_hs = p.add_instruction( auto concat_hs = p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::gru::forward, clip, 1}, migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::forward, clip, 1},
seq, seq,
w, w,
r, r,
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, concat_hs); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
...@@ -880,15 +1091,17 @@ TEST_CASE(gru_forward) ...@@ -880,15 +1091,17 @@ TEST_CASE(gru_forward)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction( p.add_instruction(migraphx::op::gru{hidden_size,
migraphx::op::gru{ {migraphx::op::sigmoid{}},
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::gru::forward, clip, 1}, migraphx::op::rnn_direction::forward,
seq, clip,
w, 1},
r, seq,
bias, w,
und, r,
ih); bias,
und,
ih);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
...@@ -915,14 +1128,14 @@ TEST_CASE(gru_forward) ...@@ -915,14 +1128,14 @@ TEST_CASE(gru_forward)
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto concat_hs = p.add_instruction( auto concat_hs = p.add_instruction(
migraphx::op::gru{ migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::forward, clip, 1}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, 1},
seq, seq,
w, w,
r, r,
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, concat_hs); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
...@@ -957,7 +1170,7 @@ TEST_CASE(gru_forward) ...@@ -957,7 +1170,7 @@ TEST_CASE(gru_forward)
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip, clip,
1}, 1},
seq, seq,
...@@ -1055,7 +1268,7 @@ TEST_CASE(gru_reverse) ...@@ -1055,7 +1268,7 @@ TEST_CASE(gru_reverse)
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::reverse, migraphx::op::rnn_direction::reverse,
clip, clip,
1}, 1},
seq, seq,
...@@ -1092,7 +1305,7 @@ TEST_CASE(gru_reverse) ...@@ -1092,7 +1305,7 @@ TEST_CASE(gru_reverse)
auto concat_hs = auto concat_hs =
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::reverse, migraphx::op::rnn_direction::reverse,
clip, clip,
1}, 1},
seq, seq,
...@@ -1101,7 +1314,7 @@ TEST_CASE(gru_reverse) ...@@ -1101,7 +1314,7 @@ TEST_CASE(gru_reverse)
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, concat_hs); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
...@@ -1133,7 +1346,7 @@ TEST_CASE(gru_reverse) ...@@ -1133,7 +1346,7 @@ TEST_CASE(gru_reverse)
auto concat_hs = auto concat_hs =
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::reverse, migraphx::op::rnn_direction::reverse,
clip, clip,
0}, 0},
seq, seq,
...@@ -1142,7 +1355,7 @@ TEST_CASE(gru_reverse) ...@@ -1142,7 +1355,7 @@ TEST_CASE(gru_reverse)
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, concat_hs); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
...@@ -1171,13 +1384,14 @@ TEST_CASE(gru_reverse) ...@@ -1171,13 +1384,14 @@ TEST_CASE(gru_reverse)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size, {}, migraphx::op::gru::reverse, clip, 1}, p.add_instruction(
seq, migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip, 1},
w, seq,
r, w,
bias, r,
und, bias,
ih); und,
ih);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
...@@ -1208,7 +1422,7 @@ TEST_CASE(gru_reverse) ...@@ -1208,7 +1422,7 @@ TEST_CASE(gru_reverse)
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::reverse, migraphx::op::rnn_direction::reverse,
clip, clip,
1}, 1},
seq, seq,
...@@ -1324,7 +1538,7 @@ TEST_CASE(gru_bidirectional) ...@@ -1324,7 +1538,7 @@ TEST_CASE(gru_bidirectional)
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
1}, 1},
seq, seq,
...@@ -1365,7 +1579,7 @@ TEST_CASE(gru_bidirectional) ...@@ -1365,7 +1579,7 @@ TEST_CASE(gru_bidirectional)
auto concat_hs = auto concat_hs =
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
1}, 1},
seq, seq,
...@@ -1374,7 +1588,7 @@ TEST_CASE(gru_bidirectional) ...@@ -1374,7 +1588,7 @@ TEST_CASE(gru_bidirectional)
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, concat_hs); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
...@@ -1400,7 +1614,7 @@ TEST_CASE(gru_bidirectional) ...@@ -1400,7 +1614,7 @@ TEST_CASE(gru_bidirectional)
auto concat_hs = auto concat_hs =
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
0}, 0},
seq, seq,
...@@ -1409,7 +1623,7 @@ TEST_CASE(gru_bidirectional) ...@@ -1409,7 +1623,7 @@ TEST_CASE(gru_bidirectional)
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, concat_hs); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
...@@ -1422,6 +1636,82 @@ TEST_CASE(gru_bidirectional) ...@@ -1422,6 +1636,82 @@ TEST_CASE(gru_bidirectional)
EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
} }
}
TEST_CASE(gru_bidirectional_args)
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}};
std::vector<float> w_data{
0.3809, 0.4283, 0.2294, -0.1018, -0.1226, -0.0037, 0.2449, -0.2712, -0.1418,
0.1363, -0.3453, -0.0693, -0.2281, 0.2699, -0.2024, -0.3085, -0.3338, 0.4109,
0.2605, -0.1019, -0.2813, 0.3323, -0.1590, 0.0788, -0.3535, 0.0397, 0.2732,
0.2906, 0.0519, 0.3617, -0.2664, 0.1441, 0.0464, -0.1057, 0.2204, -0.3294,
0.3670, 0.1411, 0.3852, 0.3572, 0.3918, 0.0483, -0.3906, -0.2841, -0.2778,
-0.4272, 0.2335, -0.1811, -0.3885, -0.1279, 0.1000, 0.0206, -0.3284, -0.0353,
0.1197, 0.1190, 0.3862, 0.0965, -0.0492, 0.2657, -0.1430, 0.0597, 0.1408,
-0.0315, 0.1248, 0.0751, 0.3838, 0.3020, 0.0515, 0.2375, -0.4255, 0.1714,
-0.0432, 0.3447, -0.2441, -0.3989, -0.3428, -0.4204, -0.4080, -0.2683, -0.0996,
-0.1685, -0.0532, -0.1258, 0.1663, -0.3526, -0.3915, -0.1721, 0.1292, -0.2279};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}};
std::vector<float> r_data{
-0.2683, 0.0699, -0.4021, -0.1379, 0.0042, -0.2447, 0.4006, 0.0270, -0.0446, 0.1063,
0.1381, 0.1310, -0.3596, 0.3869, 0.3929, 0.2750, 0.0890, 0.3069, -0.1691, -0.2194,
-0.1066, 0.3187, -0.4369, -0.0603, -0.0834, -0.1182, -0.2047, 0.3253, -0.2931, 0.2082,
0.0424, 0.1111, -0.2773, -0.0279, -0.0869, 0.1413, -0.4227, -0.3672, 0.4137, 0.0609,
0.4223, -0.4032, 0.2945, 0.3600, 0.3345, -0.3880, -0.0192, -0.0090, -0.2648, 0.4339,
-0.0155, 0.4437, -0.1766, 0.1957, 0.2475, 0.3773, -0.2710, 0.3289, -0.2077, -0.2534,
-0.0832, -0.1632, 0.0728, 0.2520, 0.4153, 0.1659, -0.4342, 0.0541, 0.1812, -0.2305,
0.4440, 0.0946, 0.0410, -0.4381, -0.3161, 0.3906, -0.3958, -0.4238, 0.1975, 0.3440,
0.1437, -0.0568, 0.1492, -0.4248, -0.3304, 0.2786, -0.1328, -0.3740, -0.3566, 0.3074,
0.0924, 0.2684, -0.1527, 0.1826, 0.2424, 0.2002, 0.3479, -0.1089, 0.3472, -0.3677,
-0.4231, -0.0798, -0.3709, 0.3924, 0.2774, -0.3690, -0.0233, 0.2845, 0.1969, 0.1618,
-0.3742, -0.3619, 0.2925, -0.1838, -0.1495, -0.3747, 0.0341, -0.4243, -0.0732, -0.3997,
0.2139, 0.2425, 0.4171, -0.3358, 0.3534, 0.0938, -0.0582, -0.2681, -0.4293, 0.1027,
0.4101, 0.2641, -0.4110, -0.1681, 0.3582, -0.2089, 0.0852, 0.0963, 0.3866, 0.1955,
-0.2174, 0.1996, -0.2252, 0.1748, 0.1833, -0.3155, 0.2567, -0.4387, 0.3402, 0.0599};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
std::vector<float> bias_data{
-0.1582, -0.0826, 0.4008, 0.0118, 0.2511, 0.1900, -0.2838, 0.2549, -0.2484, 0.2363,
-0.4083, -0.0295, -0.1161, 0.1211, 0.2509, -0.1414, -0.2628, -0.2992, 0.1517, 0.1817,
-0.2783, 0.3183, -0.1629, -0.3108, -0.3418, 0.0411, 0.2203, 0.2187, -0.2990, -0.0416,
0.0209, -0.1024, 0.4443, -0.4420, -0.0330, -0.3591, -0.2990, 0.2167, 0.1395, 0.2317,
0.1318, 0.1909, -0.3615, 0.1953, -0.2582, -0.2217, 0.3723, 0.1458, 0.2630, -0.0377,
0.1754, 0.0800, -0.3964, -0.3247, 0.4219, -0.0900, 0.3553, 0.2614, -0.1298, -0.1124};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input{-0.8432,
-0.9887,
1.3041,
-2.6430,
-0.3306,
-0.8504,
-0.3933,
0.5151,
-0.2951,
0.0093,
-1.1948,
-0.1239,
0.0373,
1.3211,
0.7854,
-0.4838,
-1.0536,
-0.2529};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
std::vector<float> ih_data{-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348,
0.4178, 1.0175, 0.9212, -0.0468, 0.5691, -0.0882, 0.8340,
0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212};
float clip = 0.0f;
// 3 args // 3 args
{ {
...@@ -1431,7 +1721,7 @@ TEST_CASE(gru_bidirectional) ...@@ -1431,7 +1721,7 @@ TEST_CASE(gru_bidirectional)
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
0}, 0},
seq, seq,
...@@ -1466,7 +1756,7 @@ TEST_CASE(gru_bidirectional) ...@@ -1466,7 +1756,7 @@ TEST_CASE(gru_bidirectional)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
1}, 1},
seq, seq,
...@@ -1503,7 +1793,7 @@ TEST_CASE(gru_bidirectional) ...@@ -1503,7 +1793,7 @@ TEST_CASE(gru_bidirectional)
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
1}, 1},
seq, seq,
...@@ -1530,6 +1820,82 @@ TEST_CASE(gru_bidirectional) ...@@ -1530,6 +1820,82 @@ TEST_CASE(gru_bidirectional)
-0.0339407, 0.413089, 0.721238, 0.431879}; -0.0339407, 0.413089, 0.721238, 0.431879};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
} }
}
TEST_CASE(gru_bidirectional_actv_funcs)
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}};
std::vector<float> w_data{
0.3809, 0.4283, 0.2294, -0.1018, -0.1226, -0.0037, 0.2449, -0.2712, -0.1418,
0.1363, -0.3453, -0.0693, -0.2281, 0.2699, -0.2024, -0.3085, -0.3338, 0.4109,
0.2605, -0.1019, -0.2813, 0.3323, -0.1590, 0.0788, -0.3535, 0.0397, 0.2732,
0.2906, 0.0519, 0.3617, -0.2664, 0.1441, 0.0464, -0.1057, 0.2204, -0.3294,
0.3670, 0.1411, 0.3852, 0.3572, 0.3918, 0.0483, -0.3906, -0.2841, -0.2778,
-0.4272, 0.2335, -0.1811, -0.3885, -0.1279, 0.1000, 0.0206, -0.3284, -0.0353,
0.1197, 0.1190, 0.3862, 0.0965, -0.0492, 0.2657, -0.1430, 0.0597, 0.1408,
-0.0315, 0.1248, 0.0751, 0.3838, 0.3020, 0.0515, 0.2375, -0.4255, 0.1714,
-0.0432, 0.3447, -0.2441, -0.3989, -0.3428, -0.4204, -0.4080, -0.2683, -0.0996,
-0.1685, -0.0532, -0.1258, 0.1663, -0.3526, -0.3915, -0.1721, 0.1292, -0.2279};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}};
std::vector<float> r_data{
-0.2683, 0.0699, -0.4021, -0.1379, 0.0042, -0.2447, 0.4006, 0.0270, -0.0446, 0.1063,
0.1381, 0.1310, -0.3596, 0.3869, 0.3929, 0.2750, 0.0890, 0.3069, -0.1691, -0.2194,
-0.1066, 0.3187, -0.4369, -0.0603, -0.0834, -0.1182, -0.2047, 0.3253, -0.2931, 0.2082,
0.0424, 0.1111, -0.2773, -0.0279, -0.0869, 0.1413, -0.4227, -0.3672, 0.4137, 0.0609,
0.4223, -0.4032, 0.2945, 0.3600, 0.3345, -0.3880, -0.0192, -0.0090, -0.2648, 0.4339,
-0.0155, 0.4437, -0.1766, 0.1957, 0.2475, 0.3773, -0.2710, 0.3289, -0.2077, -0.2534,
-0.0832, -0.1632, 0.0728, 0.2520, 0.4153, 0.1659, -0.4342, 0.0541, 0.1812, -0.2305,
0.4440, 0.0946, 0.0410, -0.4381, -0.3161, 0.3906, -0.3958, -0.4238, 0.1975, 0.3440,
0.1437, -0.0568, 0.1492, -0.4248, -0.3304, 0.2786, -0.1328, -0.3740, -0.3566, 0.3074,
0.0924, 0.2684, -0.1527, 0.1826, 0.2424, 0.2002, 0.3479, -0.1089, 0.3472, -0.3677,
-0.4231, -0.0798, -0.3709, 0.3924, 0.2774, -0.3690, -0.0233, 0.2845, 0.1969, 0.1618,
-0.3742, -0.3619, 0.2925, -0.1838, -0.1495, -0.3747, 0.0341, -0.4243, -0.0732, -0.3997,
0.2139, 0.2425, 0.4171, -0.3358, 0.3534, 0.0938, -0.0582, -0.2681, -0.4293, 0.1027,
0.4101, 0.2641, -0.4110, -0.1681, 0.3582, -0.2089, 0.0852, 0.0963, 0.3866, 0.1955,
-0.2174, 0.1996, -0.2252, 0.1748, 0.1833, -0.3155, 0.2567, -0.4387, 0.3402, 0.0599};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
std::vector<float> bias_data{
-0.1582, -0.0826, 0.4008, 0.0118, 0.2511, 0.1900, -0.2838, 0.2549, -0.2484, 0.2363,
-0.4083, -0.0295, -0.1161, 0.1211, 0.2509, -0.1414, -0.2628, -0.2992, 0.1517, 0.1817,
-0.2783, 0.3183, -0.1629, -0.3108, -0.3418, 0.0411, 0.2203, 0.2187, -0.2990, -0.0416,
0.0209, -0.1024, 0.4443, -0.4420, -0.0330, -0.3591, -0.2990, 0.2167, 0.1395, 0.2317,
0.1318, 0.1909, -0.3615, 0.1953, -0.2582, -0.2217, 0.3723, 0.1458, 0.2630, -0.0377,
0.1754, 0.0800, -0.3964, -0.3247, 0.4219, -0.0900, 0.3553, 0.2614, -0.1298, -0.1124};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input{-0.8432,
-0.9887,
1.3041,
-2.6430,
-0.3306,
-0.8504,
-0.3933,
0.5151,
-0.2951,
0.0093,
-1.1948,
-0.1239,
0.0373,
1.3211,
0.7854,
-0.4838,
-1.0536,
-0.2529};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
std::vector<float> ih_data{-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348,
0.4178, 1.0175, 0.9212, -0.0468, 0.5691, -0.0882, 0.8340,
0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212};
float clip = 0.0f;
// no activation function specified, so default is used. // no activation function specified, so default is used.
{ {
...@@ -1541,14 +1907,14 @@ TEST_CASE(gru_bidirectional) ...@@ -1541,14 +1907,14 @@ TEST_CASE(gru_bidirectional)
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto concat_hs = p.add_instruction( auto concat_hs = p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::gru::bidirectional, clip, 1}, migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip, 1},
seq, seq,
w, w,
r, r,
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, concat_hs); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
...@@ -1571,15 +1937,17 @@ TEST_CASE(gru_bidirectional) ...@@ -1571,15 +1937,17 @@ TEST_CASE(gru_bidirectional)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction( p.add_instruction(migraphx::op::gru{hidden_size,
migraphx::op::gru{ {migraphx::op::sigmoid{}},
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::gru::bidirectional, clip, 0}, migraphx::op::rnn_direction::bidirectional,
seq, clip,
w, 0},
r, seq,
bias, w,
und, r,
ih); bias,
und,
ih);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
...@@ -1606,15 +1974,17 @@ TEST_CASE(gru_bidirectional) ...@@ -1606,15 +1974,17 @@ TEST_CASE(gru_bidirectional)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction( p.add_instruction(migraphx::op::gru{hidden_size,
migraphx::op::gru{ {migraphx::op::tanh{}},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::bidirectional, clip, 1}, migraphx::op::rnn_direction::bidirectional,
seq, clip,
w, 1},
r, seq,
bias, w,
und, r,
ih); bias,
und,
ih);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
...@@ -1646,7 +2016,7 @@ TEST_CASE(gru_bidirectional) ...@@ -1646,7 +2016,7 @@ TEST_CASE(gru_bidirectional)
auto concat_hs = p.add_instruction( auto concat_hs = p.add_instruction(
migraphx::op::gru{hidden_size, migraphx::op::gru{hidden_size,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
1}, 1},
seq, seq,
...@@ -1655,7 +2025,7 @@ TEST_CASE(gru_bidirectional) ...@@ -1655,7 +2025,7 @@ TEST_CASE(gru_bidirectional)
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, concat_hs); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
...@@ -1682,7 +2052,7 @@ TEST_CASE(gru_bidirectional) ...@@ -1682,7 +2052,7 @@ TEST_CASE(gru_bidirectional)
migraphx::op::tanh{}, migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::sigmoid{},
migraphx::op::tanh{}}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
1}, 1},
seq, seq,
...@@ -1726,7 +2096,7 @@ TEST_CASE(gru_bidirectional) ...@@ -1726,7 +2096,7 @@ TEST_CASE(gru_bidirectional)
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip, clip,
1}, 1},
seq, seq,
......
...@@ -669,6 +669,18 @@ struct test_elu ...@@ -669,6 +669,18 @@ struct test_elu
} }
}; };
struct test_relu_lrn
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 5, 2, 2}});
auto y = p.add_instruction(migraphx::op::relu{}, x);
p.add_instruction(migraphx::op::lrn{0.0001, 0.75, 1.0, 5}, y);
return p;
}
};
struct test_conv_pooling struct test_conv_pooling
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -1144,7 +1156,7 @@ struct test_rnn_forward ...@@ -1144,7 +1156,7 @@ struct test_rnn_forward
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1186,7 +1198,7 @@ struct test_rnn_forward10 ...@@ -1186,7 +1198,7 @@ struct test_rnn_forward10
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1227,7 +1239,7 @@ struct test_rnn_reverse ...@@ -1227,7 +1239,7 @@ struct test_rnn_reverse
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -1267,7 +1279,7 @@ struct test_rnn_reverse2 ...@@ -1267,7 +1279,7 @@ struct test_rnn_reverse2
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -1302,7 +1314,7 @@ struct test_rnn_3args ...@@ -1302,7 +1314,7 @@ struct test_rnn_3args
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -1336,7 +1348,7 @@ struct test_rnn_4args ...@@ -1336,7 +1348,7 @@ struct test_rnn_4args
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -1373,7 +1385,7 @@ struct test_rnn_5args ...@@ -1373,7 +1385,7 @@ struct test_rnn_5args
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1414,7 +1426,7 @@ struct test_rnn_bidirectional ...@@ -1414,7 +1426,7 @@ struct test_rnn_bidirectional
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -1455,7 +1467,7 @@ struct test_rnn_bidirectional10 ...@@ -1455,7 +1467,7 @@ struct test_rnn_bidirectional10
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -1493,7 +1505,7 @@ struct test_rnn_bi_3args ...@@ -1493,7 +1505,7 @@ struct test_rnn_bi_3args
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -1534,7 +1546,7 @@ struct test_gru_forward_last ...@@ -1534,7 +1546,7 @@ struct test_gru_forward_last
auto output = auto output =
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1542,7 +1554,7 @@ struct test_gru_forward_last ...@@ -1542,7 +1554,7 @@ struct test_gru_forward_last
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, output); p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p; return p;
} }
...@@ -1577,7 +1589,7 @@ struct test_gru_forward_hs ...@@ -1577,7 +1589,7 @@ struct test_gru_forward_hs
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1613,7 +1625,7 @@ struct test_gru_forward_3args_und ...@@ -1613,7 +1625,7 @@ struct test_gru_forward_3args_und
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1648,7 +1660,7 @@ struct test_gru_forward_3args ...@@ -1648,7 +1660,7 @@ struct test_gru_forward_3args
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1680,7 +1692,7 @@ struct test_gru_forward_seq1 ...@@ -1680,7 +1692,7 @@ struct test_gru_forward_seq1
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1711,7 +1723,10 @@ struct test_gru_forward_default_actv ...@@ -1711,7 +1723,10 @@ struct test_gru_forward_default_actv
auto w = p.add_parameter("w", w_shape); auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction( p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::gru::forward, clip}, seq, w, r); migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r);
return p; return p;
} }
...@@ -1746,7 +1761,7 @@ struct test_gru_forward_default_actv1 ...@@ -1746,7 +1761,7 @@ struct test_gru_forward_default_actv1
p.add_instruction( p.add_instruction(
migraphx::op::gru{ migraphx::op::gru{
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::gru::forward, clip}, hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::forward, clip},
seq, seq,
w, w,
r, r,
...@@ -1788,7 +1803,7 @@ struct test_gru_reverse_last ...@@ -1788,7 +1803,7 @@ struct test_gru_reverse_last
auto output = auto output =
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -1796,7 +1811,7 @@ struct test_gru_reverse_last ...@@ -1796,7 +1811,7 @@ struct test_gru_reverse_last
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, output); p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p; return p;
} }
...@@ -1824,7 +1839,7 @@ struct test_gru_reverse_3args ...@@ -1824,7 +1839,7 @@ struct test_gru_reverse_3args
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -1864,7 +1879,7 @@ struct test_gru_bidirct_last ...@@ -1864,7 +1879,7 @@ struct test_gru_bidirct_last
auto output = auto output =
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -1872,7 +1887,7 @@ struct test_gru_bidirct_last ...@@ -1872,7 +1887,7 @@ struct test_gru_bidirct_last
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, output); p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p; return p;
} }
...@@ -1907,7 +1922,7 @@ struct test_gru_bidirct_hs ...@@ -1907,7 +1922,7 @@ struct test_gru_bidirct_hs
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -1943,7 +1958,7 @@ struct test_gru_bidirct_3args_und ...@@ -1943,7 +1958,7 @@ struct test_gru_bidirct_3args_und
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -1978,7 +1993,7 @@ struct test_gru_bidirct_3args ...@@ -1978,7 +1993,7 @@ struct test_gru_bidirct_3args
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -2010,7 +2025,7 @@ struct test_gru_bidirct_seq1 ...@@ -2010,7 +2025,7 @@ struct test_gru_bidirct_seq1
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -2041,7 +2056,10 @@ struct test_gru_bidirct_default_actv ...@@ -2041,7 +2056,10 @@ struct test_gru_bidirct_default_actv
auto w = p.add_parameter("w", w_shape); auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction( p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::gru::bidirectional, clip}, seq, w, r); migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r);
return p; return p;
} }
...@@ -2074,15 +2092,16 @@ struct test_gru_bidirct_default_actv1 ...@@ -2074,15 +2092,16 @@ struct test_gru_bidirct_default_actv1
auto ih = p.add_parameter("ih", ih_shape); auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction( p.add_instruction(migraphx::op::gru{hidden_size,
migraphx::op::gru{ {migraphx::op::sigmoid{}},
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::gru::bidirectional, clip}, migraphx::op::rnn_direction::bidirectional,
seq, clip},
w, seq,
r, w,
bias, r,
und, bias,
ih); und,
ih);
return p; return p;
} }
...@@ -2090,6 +2109,7 @@ struct test_gru_bidirct_default_actv1 ...@@ -2090,6 +2109,7 @@ struct test_gru_bidirct_default_actv1
int main() int main()
{ {
verify_program<test_relu_lrn>();
verify_program<test_pooling_autopad>(); verify_program<test_pooling_autopad>();
verify_program<test_abs>(); verify_program<test_abs>();
verify_program<test_concat>(); verify_program<test_concat>();
......
...@@ -491,7 +491,7 @@ TEST_CASE(rnn_test) ...@@ -491,7 +491,7 @@ TEST_CASE(rnn_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -523,7 +523,7 @@ TEST_CASE(rnn_test) ...@@ -523,7 +523,7 @@ TEST_CASE(rnn_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -555,7 +555,7 @@ TEST_CASE(rnn_test) ...@@ -555,7 +555,7 @@ TEST_CASE(rnn_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -583,7 +583,7 @@ TEST_CASE(rnn_test) ...@@ -583,7 +583,7 @@ TEST_CASE(rnn_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -615,7 +615,7 @@ TEST_CASE(rnn_test) ...@@ -615,7 +615,7 @@ TEST_CASE(rnn_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -658,15 +658,16 @@ TEST_CASE(gru_test) ...@@ -658,15 +658,16 @@ TEST_CASE(gru_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip}, clip,
1},
seq, seq,
w, w,
r, r,
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_forward.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -692,7 +693,7 @@ TEST_CASE(gru_test) ...@@ -692,7 +693,7 @@ TEST_CASE(gru_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -700,7 +701,7 @@ TEST_CASE(gru_test) ...@@ -700,7 +701,7 @@ TEST_CASE(gru_test)
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_reverse.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -723,24 +724,35 @@ TEST_CASE(gru_test) ...@@ -723,24 +724,35 @@ TEST_CASE(gru_test)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(migraphx::op::gru{hs, auto out_hs =
{migraphx::op::tanh{}, p.add_instruction(migraphx::op::gru{hs,
migraphx::op::sigmoid{}, {migraphx::op::tanh{},
migraphx::op::relu{}, migraphx::op::sigmoid{},
migraphx::op::tanh{}}, migraphx::op::relu{},
migraphx::op::gru::bidirectional, migraphx::op::tanh{}},
clip}, migraphx::op::rnn_direction::bidirectional,
seq, clip},
w, seq,
r, w,
bias, r,
seq_len, bias,
ih); seq_len,
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
}
TEST_CASE(gru_test_args)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// 3 arguments // 3 arguments
{ {
...@@ -757,7 +769,7 @@ TEST_CASE(gru_test) ...@@ -757,7 +769,7 @@ TEST_CASE(gru_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -765,7 +777,7 @@ TEST_CASE(gru_test) ...@@ -765,7 +777,7 @@ TEST_CASE(gru_test)
und, und,
und, und,
und); und);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_3arg.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_3arg.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -789,7 +801,7 @@ TEST_CASE(gru_test) ...@@ -789,7 +801,7 @@ TEST_CASE(gru_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -797,7 +809,7 @@ TEST_CASE(gru_test) ...@@ -797,7 +809,7 @@ TEST_CASE(gru_test)
bias, bias,
und, und,
und); und);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_4arg.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_4arg.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -823,7 +835,7 @@ TEST_CASE(gru_test) ...@@ -823,7 +835,7 @@ TEST_CASE(gru_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -831,12 +843,21 @@ TEST_CASE(gru_test) ...@@ -831,12 +843,21 @@ TEST_CASE(gru_test)
bias, bias,
seq_len, seq_len,
und); und);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_5arg.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_5arg.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
}
TEST_CASE(gru_test_actv_funcs)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// bidirection, 0 actv function // bidirection, 0 actv function
{ {
nd = 2; nd = 2;
...@@ -854,15 +875,15 @@ TEST_CASE(gru_test) ...@@ -854,15 +875,15 @@ TEST_CASE(gru_test)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::gru::bidirectional, clip}, migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq, seq,
w, w,
r, r,
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -886,14 +907,15 @@ TEST_CASE(gru_test) ...@@ -886,14 +907,15 @@ TEST_CASE(gru_test)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {migraphx::op::tanh{}}, migraphx::op::gru::bidirectional, clip}, migraphx::op::gru{
hs, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
seq, seq,
w, w,
r, r,
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -919,7 +941,7 @@ TEST_CASE(gru_test) ...@@ -919,7 +941,7 @@ TEST_CASE(gru_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -927,7 +949,7 @@ TEST_CASE(gru_test) ...@@ -927,7 +949,7 @@ TEST_CASE(gru_test)
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_2.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi_2.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -953,7 +975,7 @@ TEST_CASE(gru_test) ...@@ -953,7 +975,7 @@ TEST_CASE(gru_test)
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::gru{hs, migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -961,7 +983,7 @@ TEST_CASE(gru_test) ...@@ -961,7 +983,7 @@ TEST_CASE(gru_test)
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -984,14 +1006,15 @@ TEST_CASE(gru_test) ...@@ -984,14 +1006,15 @@ TEST_CASE(gru_test)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::gru::forward, clip}, auto out_hs =
seq, p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::forward, clip},
w, seq,
r, w,
bias, r,
seq_len, bias,
ih); seq_len,
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward_0.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_forward_0.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -1015,14 +1038,15 @@ TEST_CASE(gru_test) ...@@ -1015,14 +1038,15 @@ TEST_CASE(gru_test)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {migraphx::op::relu{}}, migraphx::op::gru::reverse, clip}, migraphx::op::gru{
hs, {migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip},
seq, seq,
w, w,
r, r,
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -1173,4 +1197,18 @@ TEST_CASE(pad_test) ...@@ -1173,4 +1197,18 @@ TEST_CASE(pad_test)
migraphx::parse_onnx("pad_test.onnx"); migraphx::parse_onnx("pad_test.onnx");
} }
TEST_CASE(lrn_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 28, 24, 24}});
migraphx::op::lrn op;
op.size = 5;
op.alpha = 0.0001;
op.beta = 0.75;
op.bias = 1.0;
p.add_instruction(op, l0);
migraphx::parse_onnx("lrn_test.onnx");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -282,15 +282,16 @@ TEST_CASE(rnn) ...@@ -282,15 +282,16 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type, expect_shape(
{seq_len, num_dirct, batch_size, hidden_size}}, migraphx::shape{migraphx::shape::float_type,
migraphx::op::rnn{ {seq_len, num_dirct, batch_size, hidden_size}},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::forward, clip}, migraphx::op::rnn{
in_shape, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
w_shape, in_shape,
r_shape, w_shape,
b_shape, r_shape,
ih_shape); b_shape,
ih_shape);
} }
{ {
...@@ -307,15 +308,16 @@ TEST_CASE(rnn) ...@@ -307,15 +308,16 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type, expect_shape(
{seq_len, num_dirct, batch_size, hidden_size}}, migraphx::shape{migraphx::shape::float_type,
migraphx::op::rnn{ {seq_len, num_dirct, batch_size, hidden_size}},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::reverse, clip}, migraphx::op::rnn{
in_shape, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip},
w_shape, in_shape,
r_shape, w_shape,
b_shape, r_shape,
ih_shape); b_shape,
ih_shape);
} }
{ {
...@@ -332,16 +334,17 @@ TEST_CASE(rnn) ...@@ -332,16 +334,17 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape( expect_shape(migraphx::shape{migraphx::shape::float_type,
migraphx::shape{migraphx::shape::float_type, {seq_len, num_dirct, batch_size, hidden_size}},
{seq_len, num_dirct, batch_size, hidden_size}}, migraphx::op::rnn{hidden_size,
migraphx::op::rnn{ {migraphx::op::tanh{}},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::bidirectional, clip}, migraphx::op::rnn_direction::bidirectional,
in_shape, clip},
w_shape, in_shape,
r_shape, w_shape,
b_shape, r_shape,
ih_shape); b_shape,
ih_shape);
} }
{ {
...@@ -358,14 +361,15 @@ TEST_CASE(rnn) ...@@ -358,14 +361,15 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape( throws_shape(migraphx::op::rnn{hidden_size + 1,
migraphx::op::rnn{ {migraphx::op::tanh{}},
hidden_size + 1, {migraphx::op::tanh{}}, migraphx::op::rnn::forward, clip}, migraphx::op::rnn_direction::forward,
in_shape, clip},
w_shape, in_shape,
r_shape, w_shape,
b_shape, r_shape,
ih_shape); b_shape,
ih_shape);
} }
{ {
...@@ -382,14 +386,15 @@ TEST_CASE(rnn) ...@@ -382,14 +386,15 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape( throws_shape(migraphx::op::rnn{hidden_size,
migraphx::op::rnn{ {migraphx::op::tanh{}},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::bidirectional, clip}, migraphx::op::rnn_direction::bidirectional,
in_shape, clip},
w_shape, in_shape,
r_shape, w_shape,
b_shape, r_shape,
ih_shape); b_shape,
ih_shape);
} }
{ {
...@@ -408,7 +413,7 @@ TEST_CASE(rnn) ...@@ -408,7 +413,7 @@ TEST_CASE(rnn)
throws_shape( throws_shape(
migraphx::op::rnn{ migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::forward, clip}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -435,15 +440,16 @@ TEST_CASE(gru) ...@@ -435,15 +440,16 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type, expect_shape(
{seq_len, num_dirct, batch_size, hidden_size}}, migraphx::shape{migraphx::shape::float_type,
migraphx::op::gru{ {seq_len, num_dirct, batch_size, hidden_size}},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::forward, clip}, migraphx::op::gru{
in_shape, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
w_shape, in_shape,
r_shape, w_shape,
b_shape, r_shape,
ih_shape); b_shape,
ih_shape);
} }
{ {
...@@ -462,15 +468,16 @@ TEST_CASE(gru) ...@@ -462,15 +468,16 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type, expect_shape(
{seq_len, num_dirct, batch_size, hidden_size}}, migraphx::shape{migraphx::shape::float_type,
migraphx::op::gru{ {seq_len, num_dirct, batch_size, hidden_size}},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::reverse, clip}, migraphx::op::gru{
in_shape, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip},
w_shape, in_shape,
r_shape, w_shape,
b_shape, r_shape,
ih_shape); b_shape,
ih_shape);
} }
{ {
...@@ -489,16 +496,17 @@ TEST_CASE(gru) ...@@ -489,16 +496,17 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape( expect_shape(migraphx::shape{migraphx::shape::float_type,
migraphx::shape{migraphx::shape::float_type, {seq_len, num_dirct, batch_size, hidden_size}},
{seq_len, num_dirct, batch_size, hidden_size}}, migraphx::op::gru{hidden_size,
migraphx::op::gru{ {migraphx::op::tanh{}},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::bidirectional, clip}, migraphx::op::rnn_direction::bidirectional,
in_shape, clip},
w_shape, in_shape,
r_shape, w_shape,
b_shape, r_shape,
ih_shape); b_shape,
ih_shape);
} }
{ {
...@@ -517,14 +525,15 @@ TEST_CASE(gru) ...@@ -517,14 +525,15 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape( throws_shape(migraphx::op::gru{hidden_size + 1,
migraphx::op::gru{ {migraphx::op::tanh{}},
hidden_size + 1, {migraphx::op::tanh{}}, migraphx::op::gru::forward, clip}, migraphx::op::rnn_direction::forward,
in_shape, clip},
w_shape, in_shape,
r_shape, w_shape,
b_shape, r_shape,
ih_shape); b_shape,
ih_shape);
} }
{ {
...@@ -543,14 +552,15 @@ TEST_CASE(gru) ...@@ -543,14 +552,15 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape( throws_shape(migraphx::op::gru{hidden_size,
migraphx::op::gru{ {migraphx::op::tanh{}},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::bidirectional, clip}, migraphx::op::rnn_direction::bidirectional,
in_shape, clip},
w_shape, in_shape,
r_shape, w_shape,
b_shape, r_shape,
ih_shape); b_shape,
ih_shape);
} }
{ {
...@@ -571,7 +581,7 @@ TEST_CASE(gru) ...@@ -571,7 +581,7 @@ TEST_CASE(gru)
throws_shape( throws_shape(
migraphx::op::gru{ migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::forward, clip}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment