Commit 0da78d29 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge gru_operator changes

parents d53a69f6 ce7b4b17
...@@ -60,6 +60,30 @@ struct batch_norm_inference ...@@ -60,6 +60,30 @@ struct batch_norm_inference
} }
}; };
struct lrn
{
float alpha = 0.0001;
float beta = 0.75;
float bias = 1.0;
int size = 1;
std::string name() const { return "lrn"; }
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.alpha, "alpha"),
f(self.beta, "beta"),
f(self.bias, "bias"),
f(self.size, "size"));
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return inputs.front();
}
};
struct convolution struct convolution
{ {
std::array<std::size_t, 2> padding = {{0, 0}}; std::array<std::size_t, 2> padding = {{0, 0}};
...@@ -1140,19 +1164,19 @@ struct outline ...@@ -1140,19 +1164,19 @@ struct outline
argument compute(const shape&, const std::vector<argument>&) const { return {s, nullptr}; } argument compute(const shape&, const std::vector<argument>&) const { return {s, nullptr}; }
}; };
struct rnn // indicate rnn computation direction
enum class rnn_direction
{ {
enum rnn_direction_t
{
forward, forward,
reverse, reverse,
bidirectional, bidirectional,
}; };
struct rnn
{
std::size_t hidden_size = 1; std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{tanh{}, tanh{}}; std::vector<operation> actv_funcs{tanh{}, tanh{}};
rnn_direction_t direction = forward; rnn_direction direction = rnn_direction::forward;
float clip = 0.0f; float clip = 0.0f;
std::string name() const { return "rnn"; } std::string name() const { return "rnn"; }
...@@ -1166,7 +1190,7 @@ struct rnn ...@@ -1166,7 +1190,7 @@ struct rnn
} }
std::size_t num_directions = 1; std::size_t num_directions = 1;
if(direction == bidirectional) if(direction == rnn_direction::bidirectional)
{ {
num_directions = 2; num_directions = 2;
} }
...@@ -1200,16 +1224,9 @@ struct rnn_last_output ...@@ -1200,16 +1224,9 @@ struct rnn_last_output
struct gru struct gru
{ {
enum gru_direction_t
{
forward,
reverse,
bidirectional,
};
std::size_t hidden_size = 1; std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{sigmoid{}, tanh{}}; std::vector<operation> actv_funcs{sigmoid{}, tanh{}};
gru_direction_t direction = forward; rnn_direction direction = rnn_direction::forward;
float clip = 0.0f; float clip = 0.0f;
int linear_before_reset = 0; int linear_before_reset = 0;
...@@ -1224,7 +1241,7 @@ struct gru ...@@ -1224,7 +1241,7 @@ struct gru
} }
std::size_t num_directions = 1; std::size_t num_directions = 1;
if(direction == bidirectional) if(direction == rnn_direction::bidirectional)
{ {
num_directions = 2; num_directions = 2;
} }
...@@ -1242,32 +1259,11 @@ struct gru ...@@ -1242,32 +1259,11 @@ struct gru
} }
}; };
struct gru_last_output
{
std::string name() const { return "gru_last_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
dims.erase(dims.begin());
return {inputs[0].type(), dims};
}
};
struct lstm struct lstm
{ {
enum lstm_direction_t
{
forward,
reverse,
bidirectional,
};
std::size_t hidden_size = 1; std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{sigmoid{}, tanh{}, tanh{}}; std::vector<operation> actv_funcs{sigmoid{}, tanh{}, tanh{}};
lstm_direction_t direction = forward; rnn_direction direction = rnn_direction::forward;
float clip = 0.0f; float clip = 0.0f;
int input_forget = 0; int input_forget = 0;
...@@ -1282,7 +1278,7 @@ struct lstm ...@@ -1282,7 +1278,7 @@ struct lstm
} }
std::size_t num_directions = 1; std::size_t num_directions = 1;
if(direction == bidirectional) if(direction == rnn_direction::bidirectional)
{ {
num_directions = 2; num_directions = 2;
} }
...@@ -1300,20 +1296,6 @@ struct lstm ...@@ -1300,20 +1296,6 @@ struct lstm
} }
}; };
struct lstm_last_output
{
std::string name() const { return "lstm_last_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
dims.erase(dims.begin());
return {inputs[0].type(), dims};
}
};
struct lstm_last_cell_output struct lstm_last_cell_output
{ {
std::string name() const { return "lstm_last_cell_output"; } std::string name() const { return "lstm_last_cell_output"; }
......
...@@ -64,6 +64,7 @@ struct onnx_parser ...@@ -64,6 +64,7 @@ struct onnx_parser
add_variadic_op("Max", op::max{}); add_variadic_op("Max", op::max{});
add_variadic_op("Min", op::min{}); add_variadic_op("Min", op::min{});
add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler); add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu); add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
add_mem_op("Elu", &onnx_parser::parse_elu); add_mem_op("Elu", &onnx_parser::parse_elu);
...@@ -88,6 +89,7 @@ struct onnx_parser ...@@ -88,6 +89,7 @@ struct onnx_parser
add_mem_op("Transpose", &onnx_parser::parse_transpose); add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("RNN", &onnx_parser::parse_rnn); add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("GRU", &onnx_parser::parse_gru); add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("Pad", &onnx_parser::parse_pad); add_mem_op("Pad", &onnx_parser::parse_pad);
// init the activation function map // init the activation function map
...@@ -537,6 +539,25 @@ struct onnx_parser ...@@ -537,6 +539,25 @@ struct onnx_parser
return prog.add_instruction(op, args.front()); return prog.add_instruction(op, args.front());
} }
instruction_ref
parse_lrn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
float alpha = 0.0001;
float beta = 0.75;
float bias = 1.0;
int size = 1;
if(contains(attributes, "alpha"))
alpha = parse_value(attributes.at("alpha")).at<float>();
if(contains(attributes, "beta"))
beta = parse_value(attributes.at("beta")).at<float>();
if(contains(attributes, "bias"))
bias = parse_value(attributes.at("bias")).at<float>();
if(contains(attributes, "size"))
size = parse_value(attributes.at("size")).at<int>();
op::lrn op{alpha, beta, bias, size};
return prog.add_instruction(op, args.front());
}
instruction_ref parse_imagescaler(const std::string&, instruction_ref parse_imagescaler(const std::string&,
attribute_map attributes, attribute_map attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
...@@ -714,14 +735,14 @@ struct onnx_parser ...@@ -714,14 +735,14 @@ struct onnx_parser
direction = attributes.at("direction").s(); direction = attributes.at("direction").s();
} }
op::rnn::rnn_direction_t dirct = op::rnn::forward; op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional") if(direction == "bidirectional")
{ {
dirct = op::rnn::bidirectional; dirct = op::rnn_direction::bidirectional;
} }
else if(direction == "reverse") else if(direction == "reverse")
{ {
dirct = op::rnn::reverse; dirct = op::rnn_direction::reverse;
} }
std::vector<std::string> vec_names{"tanh"}; std::vector<std::string> vec_names{"tanh"};
...@@ -743,7 +764,7 @@ struct onnx_parser ...@@ -743,7 +764,7 @@ struct onnx_parser
// one is for forward, and the other is for reverse. // one is for forward, and the other is for reverse.
// if only one actv function is provided, we use it in both // if only one actv function is provided, we use it in both
// forward and reverse direction // forward and reverse direction
if(dirct == op::rnn::bidirectional) if(dirct == op::rnn_direction::bidirectional)
{ {
if(vec_names.size() == 1) if(vec_names.size() == 1)
{ {
...@@ -803,14 +824,14 @@ struct onnx_parser ...@@ -803,14 +824,14 @@ struct onnx_parser
direction = attributes.at("direction").s(); direction = attributes.at("direction").s();
} }
op::gru::gru_direction_t dirct = op::gru::forward; op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional") if(direction == "bidirectional")
{ {
dirct = op::gru::bidirectional; dirct = op::rnn_direction::bidirectional;
} }
else if(direction == "reverse") else if(direction == "reverse")
{ {
dirct = op::gru::reverse; dirct = op::rnn_direction::reverse;
} }
std::vector<std::string> vec_names = {"sigmoid", "tanh"}; std::vector<std::string> vec_names = {"sigmoid", "tanh"};
...@@ -824,7 +845,7 @@ struct onnx_parser ...@@ -824,7 +845,7 @@ struct onnx_parser
} }
// need 4 activation functions // need 4 activation functions
if(dirct == op::gru::bidirectional) if(dirct == op::rnn_direction::bidirectional)
{ {
// 4 activation functions are used in the bidirectional // 4 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we // scenario. No spec is provided in onnx::operator. we
...@@ -895,7 +916,7 @@ struct onnx_parser ...@@ -895,7 +916,7 @@ struct onnx_parser
std::move(args)); std::move(args));
// second output for last gru output // second output for last gru output
auto last_output = prog.add_instruction(op::gru_last_output{}, hidden_states); auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
return {hidden_states, last_output}; return {hidden_states, last_output};
} }
...@@ -922,18 +943,18 @@ struct onnx_parser ...@@ -922,18 +943,18 @@ struct onnx_parser
direction = attributes.at("direction").s(); direction = attributes.at("direction").s();
} }
op::lstm::lstm_direction_t dirct = op::lstm::forward; op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional") if(direction == "bidirectional")
{ {
dirct = op::lstm::bidirectional; dirct = op::rnn_direction::bidirectional;
} }
else if(direction == "reverse") else if(direction == "reverse")
{ {
dirct = op::lstm::reverse; dirct = op::rnn_direction::reverse;
} }
else if(direction == "forward") else if(direction == "forward")
{ {
dirct = op::lstm::forward; dirct = op::rnn_direction::forward;
} }
else else
{ {
...@@ -951,7 +972,7 @@ struct onnx_parser ...@@ -951,7 +972,7 @@ struct onnx_parser
} }
// need 6 activation functions for bidirectional directions // need 6 activation functions for bidirectional directions
if(dirct == op::lstm::bidirectional) if(dirct == op::rnn_direction::bidirectional)
{ {
// 6 activation functions are used in the bidirectional // 6 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we // scenario. No spec is provided in onnx::operator. we
...@@ -1034,7 +1055,7 @@ struct onnx_parser ...@@ -1034,7 +1055,7 @@ struct onnx_parser
op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args)); op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
// second output for last lstm output // second output for last lstm output
auto last_output = prog.add_instruction(op::lstm_last_output{}, hidden_states); auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
// third output for last cell output // third output for last cell output
auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states); auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);
......
...@@ -45,9 +45,9 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const ...@@ -45,9 +45,9 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
auto actv_funcs = vanilla_rnn_actv_funcs(ins); auto actv_funcs = vanilla_rnn_actv_funcs(ins);
auto rnn_op = any_cast<op::rnn>(ins->get_operator()); auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn::rnn_direction_t dicrt = rnn_op.direction; op::rnn_direction dicrt = rnn_op.direction;
instruction_ref last_output{}; instruction_ref last_output{};
if(dicrt == op::rnn::bidirectional) if(dicrt == op::rnn_direction::bidirectional)
{ {
// input weight matrix // input weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]); auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
...@@ -107,10 +107,8 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const ...@@ -107,10 +107,8 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
// The following logic is to ensure the last instruction rewritten from // The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction // rnn operator is a concat instruction
// sequence len is 1 // sequence len is 1
instruction_ref hidden_output{};
if(ret_forward[0] == prog.end()) if(ret_forward[0] == prog.end())
{ {
hidden_output =
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
} }
else else
...@@ -119,13 +117,12 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const ...@@ -119,13 +117,12 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]); prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] = ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]); prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
hidden_output =
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]}); prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
} }
} }
else else
{ {
bool is_forward = (dicrt == op::rnn::forward); bool is_forward = (dicrt == op::rnn_direction::forward);
// input weight matrix // input weight matrix
auto w = args[1]; auto w = args[1];
...@@ -157,16 +154,15 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const ...@@ -157,16 +154,15 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
// following logic is to ensure the last instruction is a // following logic is to ensure the last instruction is a
// concat instruction // concat instruction
// sequence len is 1 // sequence len is 1
instruction_ref hidden_output{};
if(ret[0] == prog.end()) if(ret[0] == prog.end())
{ {
hidden_output = prog.replace_instruction(ins, op::concat{0}, ret[1]); prog.replace_instruction(ins, op::concat{0}, ret[1]);
} }
else else
{ {
auto concat_arg0 = is_forward ? ret[0] : ret[1]; auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0]; auto concat_arg1 = is_forward ? ret[1] : ret[0];
hidden_output = prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1); prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
} }
} }
...@@ -282,7 +278,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) ...@@ -282,7 +278,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins)
// append undefined operators to make 6 arguments when parsing // append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have any num of arguments // an onnx file. Another case is user can have any num of arguments
// when writing their program. // when writing their program.
if(rnn_op.direction == op::rnn::bidirectional) if(rnn_op.direction == op::rnn_direction::bidirectional)
{ {
if(rnn_op.actv_funcs.empty()) if(rnn_op.actv_funcs.empty())
{ {
...@@ -330,9 +326,9 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -330,9 +326,9 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
std::vector<float> data(ih_shape.elements(), 0.0); std::vector<float> data(ih_shape.elements(), 0.0);
auto gru_op = any_cast<op::gru>(ins->get_operator()); auto gru_op = any_cast<op::gru>(ins->get_operator());
op::gru::gru_direction_t dicrt = gru_op.direction; op::rnn_direction dicrt = gru_op.direction;
instruction_ref last_output{}; instruction_ref last_output{};
if(dicrt == op::gru::bidirectional) if(dicrt == op::rnn_direction::bidirectional)
{ {
// w weight matrix // w weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]); auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
...@@ -402,7 +398,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -402,7 +398,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
} }
else else
{ {
bool is_forward = (dicrt == op::gru::forward); bool is_forward = (dicrt == op::rnn_direction::forward);
// weight matrix // weight matrix
auto w = args[1]; auto w = args[1];
auto r = args[2]; auto r = args[2];
...@@ -447,14 +443,14 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -447,14 +443,14 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
} }
} }
// replace the corresponding gru_last_output instruction // replace the corresponding rnn_last_output instruction
// with the last_output, if gru_last_output exists // with the last_output, if rnn_last_output exists
// while loop to handle case of multiple gru_last_output operators // while loop to handle case of multiple rnn_last_output operators
auto last_output_it = ins->outputs().begin(); auto last_output_it = ins->outputs().begin();
while(last_output_it != ins->outputs().end()) while(last_output_it != ins->outputs().end())
{ {
last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) { last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "gru_last_output"; return i->name() == "rnn_last_output";
}); });
if(last_output_it != ins->outputs().end()) if(last_output_it != ins->outputs().end())
...@@ -638,7 +634,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const ...@@ -638,7 +634,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
// we have 4 actv funcs, even though a user does not // we have 4 actv funcs, even though a user does not
// specifiy any actv func. If less than 4, use the // specifiy any actv func. If less than 4, use the
// algorithm in parse_gru to make 4 actv functions // algorithm in parse_gru to make 4 actv functions
if(gru_op.direction == op::gru::bidirectional) if(gru_op.direction == op::rnn_direction::bidirectional)
{ {
if(gru_op.actv_funcs.empty()) if(gru_op.actv_funcs.empty())
return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}}; return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}};
...@@ -689,11 +685,11 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -689,11 +685,11 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
auto actv_funcs = lstm_actv_funcs(ins); auto actv_funcs = lstm_actv_funcs(ins);
auto lstm_op = any_cast<op::lstm>(ins->get_operator()); auto lstm_op = any_cast<op::lstm>(ins->get_operator());
op::lstm::lstm_direction_t dirct = lstm_op.direction; op::rnn_direction dirct = lstm_op.direction;
instruction_ref last_output{}; instruction_ref last_output{};
instruction_ref last_cell_output{}; instruction_ref last_cell_output{};
if(dirct == op::lstm::bidirectional) if(dirct == op::rnn_direction::bidirectional)
{ {
// input weight matrix // input weight matrix
// input weight matrix // input weight matrix
...@@ -799,7 +795,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -799,7 +795,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
} }
else else
{ {
bool is_forward = (dirct == op::lstm::forward); bool is_forward = (dirct == op::rnn_direction::forward);
// weight matrices // weight matrices
auto w = args[1]; auto w = args[1];
auto r = args[2]; auto r = args[2];
...@@ -1100,7 +1096,7 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const ...@@ -1100,7 +1096,7 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
// algorithm in parse_lstm to make 6 actv functions // algorithm in parse_lstm to make 6 actv functions
const auto& actv_funcs = lstm_op.actv_funcs; const auto& actv_funcs = lstm_op.actv_funcs;
std::size_t num_actv_funcs = actv_funcs.size(); std::size_t num_actv_funcs = actv_funcs.size();
if(lstm_op.direction == op::lstm::bidirectional) if(lstm_op.direction == op::rnn_direction::bidirectional)
{ {
switch(num_actv_funcs) switch(num_actv_funcs)
{ {
......
...@@ -103,6 +103,43 @@ struct cpu_batch_norm_inference ...@@ -103,6 +103,43 @@ struct cpu_batch_norm_inference
} }
}; };
struct cpu_lrn
{
op::lrn op;
std::string name() const { return "cpu::lrn"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
int n_batch = output_shape.lens()[0];
int channels = output_shape.lens()[1];
int height = output_shape.lens()[2];
int width = output_shape.lens()[3];
float alphaoverarea = op.alpha / op.size;
int radius = (op.size - 1) / 2;
par_dfor(n_batch, height, width)([&](int b, int h, int w) {
float scale = 0;
dfor(channels)([&](int c) {
auto start = (c - radius) < 0 ? 0 : (c - radius);
auto end = (c + radius) > channels ? channels : (c + radius);
for(auto k = start; k < end; ++k)
{
scale += std::pow(input(b, k, h, w), 2);
}
scale *= alphaoverarea;
scale += op.bias;
scale = std::pow(scale, -op.beta);
output(b, c, h, w) = input(b, c, h, w) * scale;
});
});
});
return result;
}
};
struct cpu_convolution struct cpu_convolution
{ {
op::convolution op; op::convolution op;
...@@ -681,6 +718,7 @@ struct cpu_apply ...@@ -681,6 +718,7 @@ struct cpu_apply
apply_map["dot"] = extend_op<cpu_gemm, op::dot>(); apply_map["dot"] = extend_op<cpu_gemm, op::dot>();
apply_map["batch_norm_inference"] = apply_map["batch_norm_inference"] =
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>(); extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["lrn"] = extend_op<cpu_lrn, op::lrn>();
apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>(); apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>();
apply_map["pad"] = extend_op<cpu_pad, op::pad>(); apply_map["pad"] = extend_op<cpu_pad, op::pad>();
apply_map["concat"] = extend_op<cpu_concat, op::concat>(); apply_map["concat"] = extend_op<cpu_concat, op::concat>();
......
...@@ -61,6 +61,7 @@ add_library(migraphx_gpu ...@@ -61,6 +61,7 @@ add_library(migraphx_gpu
elu.cpp elu.cpp
pad.cpp pad.cpp
gather.cpp gather.cpp
lrn.cpp
) )
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu) set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
rocm_clang_tidy_check(migraphx_gpu) rocm_clang_tidy_check(migraphx_gpu)
......
#ifndef MIGRAPHX_GUARD_RTGLIB_LRN_HPP
#define MIGRAPHX_GUARD_RTGLIB_LRN_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct miopen_lrn
{
shared<lrn_descriptor> ldesc;
std::string name() const { return "gpu::lrn"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -23,6 +23,8 @@ using fusion_plan_descriptor = MIGRAPHX_MANAGE_PTR(miopenFusionPlanDescriptor_t, ...@@ -23,6 +23,8 @@ using fusion_plan_descriptor = MIGRAPHX_MANAGE_PTR(miopenFusionPlanDescriptor_t,
miopenDestroyFusionPlan); miopenDestroyFusionPlan);
using fused_operator_args = MIGRAPHX_MANAGE_PTR(miopenOperatorArgs_t, miopenDestroyOperatorArgs); using fused_operator_args = MIGRAPHX_MANAGE_PTR(miopenOperatorArgs_t, miopenDestroyOperatorArgs);
using lrn_descriptor = MIGRAPHX_MANAGE_PTR(miopenLRNDescriptor_t, miopenDestroyLRNDescriptor);
template <class Result, class F, class... Ts> template <class Result, class F, class... Ts>
Result make_obj(F f, Ts... xs) Result make_obj(F f, Ts... xs)
{ {
...@@ -89,6 +91,13 @@ inline pooling_descriptor make_pooling(const migraphx::op::pooling& op) ...@@ -89,6 +91,13 @@ inline pooling_descriptor make_pooling(const migraphx::op::pooling& op)
return p; return p;
} }
inline lrn_descriptor make_lrn(const migraphx::op::lrn& op)
{
auto ldesc = make_obj<lrn_descriptor>(&miopenCreateLRNDescriptor);
miopenSetLRNDescriptor(ldesc.get(), miopenLRNCrossChannel, op.size, op.alpha, op.beta, op.bias);
return ldesc;
}
inline activation_descriptor make_relu() inline activation_descriptor make_relu()
{ {
auto ad = make_obj<activation_descriptor>(&miopenCreateActivationDescriptor); auto ad = make_obj<activation_descriptor>(&miopenCreateActivationDescriptor);
......
...@@ -43,6 +43,7 @@ ...@@ -43,6 +43,7 @@
#include <migraphx/gpu/concat.hpp> #include <migraphx/gpu/concat.hpp>
#include <migraphx/gpu/pad.hpp> #include <migraphx/gpu/pad.hpp>
#include <migraphx/gpu/gather.hpp> #include <migraphx/gpu/gather.hpp>
#include <migraphx/gpu/lrn.hpp>
#include <utility> #include <utility>
#include <functional> #include <functional>
#include <algorithm> #include <algorithm>
...@@ -99,6 +100,7 @@ struct miopen_apply ...@@ -99,6 +100,7 @@ struct miopen_apply
add_extend_op<hip_gather, op::gather>("gather"); add_extend_op<hip_gather, op::gather>("gather");
add_extend_op<hip_pad, op::pad>("pad"); add_extend_op<hip_pad, op::pad>("pad");
add_lrn_op();
add_convolution_op(); add_convolution_op();
add_pooling_op(); add_pooling_op();
add_batch_norm_inference_op(); add_batch_norm_inference_op();
...@@ -159,6 +161,17 @@ struct miopen_apply ...@@ -159,6 +161,17 @@ struct miopen_apply
}); });
} }
void add_lrn_op()
{
apply_map.emplace("lrn", [=](instruction_ref ins) {
auto&& op = any_cast<op::lrn>(ins->get_operator());
auto ldesc = make_lrn(op);
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(
ins, miopen_lrn{std::move(ldesc)}, ins->inputs().at(0), output);
});
}
template <class T> template <class T>
void add_generic_op(std::string name) void add_generic_op(std::string name)
{ {
......
#include <migraphx/gpu/lrn.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape miopen_lrn::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).not_broadcasted();
return inputs.at(1);
}
argument miopen_lrn::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
float alpha = 1;
float beta = 0;
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
miopenLRNForward(ctx.get_stream().get_miopen(),
ldesc.get(),
&alpha,
x_desc.get(),
args[0].implicit(),
&beta,
y_desc.get(),
args[1].implicit(),
false,
nullptr);
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -732,6 +732,20 @@ TEST_CASE(leaky_relu_test) ...@@ -732,6 +732,20 @@ TEST_CASE(leaky_relu_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(lrn_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1, 5, 1, 1}};
auto l = p.add_literal(migraphx::literal{s, {-2.0f, 1.0f, 0.f, 1.0f, 2.0f}});
p.add_instruction(migraphx::op::lrn{0.0001, 0.75, 1, 5}, l);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(5);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2 / 1.000075, 1 / 1.00009, 0 / 1.000145, 1 / 1.00009, 2 / 1.000075};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(imagescaler_test) TEST_CASE(imagescaler_test)
{ {
migraphx::program p; migraphx::program p;
......
This diff is collapsed.
...@@ -669,6 +669,18 @@ struct test_elu ...@@ -669,6 +669,18 @@ struct test_elu
} }
}; };
struct test_relu_lrn
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 5, 2, 2}});
auto y = p.add_instruction(migraphx::op::relu{}, x);
p.add_instruction(migraphx::op::lrn{0.0001, 0.75, 1.0, 5}, y);
return p;
}
};
struct test_conv_pooling struct test_conv_pooling
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -1144,7 +1156,7 @@ struct test_rnn_forward ...@@ -1144,7 +1156,7 @@ struct test_rnn_forward
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1186,7 +1198,7 @@ struct test_rnn_forward10 ...@@ -1186,7 +1198,7 @@ struct test_rnn_forward10
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1227,7 +1239,7 @@ struct test_rnn_reverse ...@@ -1227,7 +1239,7 @@ struct test_rnn_reverse
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -1267,7 +1279,7 @@ struct test_rnn_reverse2 ...@@ -1267,7 +1279,7 @@ struct test_rnn_reverse2
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -1302,7 +1314,7 @@ struct test_rnn_3args ...@@ -1302,7 +1314,7 @@ struct test_rnn_3args
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -1336,7 +1348,7 @@ struct test_rnn_4args ...@@ -1336,7 +1348,7 @@ struct test_rnn_4args
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -1373,7 +1385,7 @@ struct test_rnn_5args ...@@ -1373,7 +1385,7 @@ struct test_rnn_5args
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1414,7 +1426,7 @@ struct test_rnn_bidirectional ...@@ -1414,7 +1426,7 @@ struct test_rnn_bidirectional
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -1455,7 +1467,7 @@ struct test_rnn_bidirectional10 ...@@ -1455,7 +1467,7 @@ struct test_rnn_bidirectional10
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -1493,7 +1505,7 @@ struct test_rnn_bi_3args ...@@ -1493,7 +1505,7 @@ struct test_rnn_bi_3args
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -1534,7 +1546,7 @@ struct test_gru_forward_last ...@@ -1534,7 +1546,7 @@ struct test_gru_forward_last
auto output = auto output =
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1542,7 +1554,7 @@ struct test_gru_forward_last ...@@ -1542,7 +1554,7 @@ struct test_gru_forward_last
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, output); p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p; return p;
} }
...@@ -1577,7 +1589,7 @@ struct test_gru_forward_hs ...@@ -1577,7 +1589,7 @@ struct test_gru_forward_hs
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1613,7 +1625,7 @@ struct test_gru_forward_3args_und ...@@ -1613,7 +1625,7 @@ struct test_gru_forward_3args_und
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1648,7 +1660,7 @@ struct test_gru_forward_3args ...@@ -1648,7 +1660,7 @@ struct test_gru_forward_3args
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1680,7 +1692,7 @@ struct test_gru_forward_seq1 ...@@ -1680,7 +1692,7 @@ struct test_gru_forward_seq1
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1711,7 +1723,10 @@ struct test_gru_forward_default_actv ...@@ -1711,7 +1723,10 @@ struct test_gru_forward_default_actv
auto w = p.add_parameter("w", w_shape); auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction( p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::gru::forward, clip}, seq, w, r); migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r);
return p; return p;
} }
...@@ -1746,7 +1761,7 @@ struct test_gru_forward_default_actv1 ...@@ -1746,7 +1761,7 @@ struct test_gru_forward_default_actv1
p.add_instruction( p.add_instruction(
migraphx::op::gru{ migraphx::op::gru{
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::gru::forward, clip}, hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::forward, clip},
seq, seq,
w, w,
r, r,
...@@ -1788,7 +1803,7 @@ struct test_gru_reverse_last ...@@ -1788,7 +1803,7 @@ struct test_gru_reverse_last
auto output = auto output =
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -1796,7 +1811,7 @@ struct test_gru_reverse_last ...@@ -1796,7 +1811,7 @@ struct test_gru_reverse_last
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, output); p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p; return p;
} }
...@@ -1824,7 +1839,7 @@ struct test_gru_reverse_3args ...@@ -1824,7 +1839,7 @@ struct test_gru_reverse_3args
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -1864,7 +1879,7 @@ struct test_gru_bidirct_last ...@@ -1864,7 +1879,7 @@ struct test_gru_bidirct_last
auto output = auto output =
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -1872,7 +1887,7 @@ struct test_gru_bidirct_last ...@@ -1872,7 +1887,7 @@ struct test_gru_bidirct_last
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, output); p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p; return p;
} }
...@@ -1907,7 +1922,7 @@ struct test_gru_bidirct_hs ...@@ -1907,7 +1922,7 @@ struct test_gru_bidirct_hs
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -1943,7 +1958,7 @@ struct test_gru_bidirct_3args_und ...@@ -1943,7 +1958,7 @@ struct test_gru_bidirct_3args_und
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -1978,7 +1993,7 @@ struct test_gru_bidirct_3args ...@@ -1978,7 +1993,7 @@ struct test_gru_bidirct_3args
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -2010,7 +2025,7 @@ struct test_gru_bidirct_seq1 ...@@ -2010,7 +2025,7 @@ struct test_gru_bidirct_seq1
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -2041,7 +2056,10 @@ struct test_gru_bidirct_default_actv ...@@ -2041,7 +2056,10 @@ struct test_gru_bidirct_default_actv
auto w = p.add_parameter("w", w_shape); auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction( p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::gru::bidirectional, clip}, seq, w, r); migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r);
return p; return p;
} }
...@@ -2074,9 +2092,10 @@ struct test_gru_bidirct_default_actv1 ...@@ -2074,9 +2092,10 @@ struct test_gru_bidirct_default_actv1
auto ih = p.add_parameter("ih", ih_shape); auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction( p.add_instruction(migraphx::op::gru{hidden_size,
migraphx::op::gru{ {migraphx::op::sigmoid{}},
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::gru::bidirectional, clip}, migraphx::op::rnn_direction::bidirectional,
clip},
seq, seq,
w, w,
r, r,
...@@ -2090,6 +2109,7 @@ struct test_gru_bidirct_default_actv1 ...@@ -2090,6 +2109,7 @@ struct test_gru_bidirct_default_actv1
int main() int main()
{ {
verify_program<test_relu_lrn>();
verify_program<test_pooling_autopad>(); verify_program<test_pooling_autopad>();
verify_program<test_abs>(); verify_program<test_abs>();
verify_program<test_concat>(); verify_program<test_concat>();
......
...@@ -491,7 +491,7 @@ TEST_CASE(rnn_test) ...@@ -491,7 +491,7 @@ TEST_CASE(rnn_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -523,7 +523,7 @@ TEST_CASE(rnn_test) ...@@ -523,7 +523,7 @@ TEST_CASE(rnn_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -555,7 +555,7 @@ TEST_CASE(rnn_test) ...@@ -555,7 +555,7 @@ TEST_CASE(rnn_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -583,7 +583,7 @@ TEST_CASE(rnn_test) ...@@ -583,7 +583,7 @@ TEST_CASE(rnn_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -615,7 +615,7 @@ TEST_CASE(rnn_test) ...@@ -615,7 +615,7 @@ TEST_CASE(rnn_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -658,15 +658,16 @@ TEST_CASE(gru_test) ...@@ -658,15 +658,16 @@ TEST_CASE(gru_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip}, clip,
1},
seq, seq,
w, w,
r, r,
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_forward.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -692,7 +693,7 @@ TEST_CASE(gru_test) ...@@ -692,7 +693,7 @@ TEST_CASE(gru_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -700,7 +701,7 @@ TEST_CASE(gru_test) ...@@ -700,7 +701,7 @@ TEST_CASE(gru_test)
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_reverse.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -723,12 +724,13 @@ TEST_CASE(gru_test) ...@@ -723,12 +724,13 @@ TEST_CASE(gru_test)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(migraphx::op::gru{hs, auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, {migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::sigmoid{},
migraphx::op::relu{}, migraphx::op::relu{},
migraphx::op::tanh{}}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -736,11 +738,21 @@ TEST_CASE(gru_test) ...@@ -736,11 +738,21 @@ TEST_CASE(gru_test)
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
}
TEST_CASE(gru_test_args)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// 3 arguments // 3 arguments
{ {
...@@ -757,7 +769,7 @@ TEST_CASE(gru_test) ...@@ -757,7 +769,7 @@ TEST_CASE(gru_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -765,7 +777,7 @@ TEST_CASE(gru_test) ...@@ -765,7 +777,7 @@ TEST_CASE(gru_test)
und, und,
und, und,
und); und);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_3arg.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_3arg.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -789,7 +801,7 @@ TEST_CASE(gru_test) ...@@ -789,7 +801,7 @@ TEST_CASE(gru_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -797,7 +809,7 @@ TEST_CASE(gru_test) ...@@ -797,7 +809,7 @@ TEST_CASE(gru_test)
bias, bias,
und, und,
und); und);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_4arg.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_4arg.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -823,7 +835,7 @@ TEST_CASE(gru_test) ...@@ -823,7 +835,7 @@ TEST_CASE(gru_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -831,12 +843,21 @@ TEST_CASE(gru_test) ...@@ -831,12 +843,21 @@ TEST_CASE(gru_test)
bias, bias,
seq_len, seq_len,
und); und);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_5arg.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_5arg.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
}
TEST_CASE(gru_test_actv_funcs)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// bidirection, 0 actv function // bidirection, 0 actv function
{ {
nd = 2; nd = 2;
...@@ -854,15 +875,15 @@ TEST_CASE(gru_test) ...@@ -854,15 +875,15 @@ TEST_CASE(gru_test)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::gru::bidirectional, clip}, migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq, seq,
w, w,
r, r,
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -886,14 +907,15 @@ TEST_CASE(gru_test) ...@@ -886,14 +907,15 @@ TEST_CASE(gru_test)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {migraphx::op::tanh{}}, migraphx::op::gru::bidirectional, clip}, migraphx::op::gru{
hs, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
seq, seq,
w, w,
r, r,
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -919,7 +941,7 @@ TEST_CASE(gru_test) ...@@ -919,7 +941,7 @@ TEST_CASE(gru_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -927,7 +949,7 @@ TEST_CASE(gru_test) ...@@ -927,7 +949,7 @@ TEST_CASE(gru_test)
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_2.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi_2.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -953,7 +975,7 @@ TEST_CASE(gru_test) ...@@ -953,7 +975,7 @@ TEST_CASE(gru_test)
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::gru{hs, migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -961,7 +983,7 @@ TEST_CASE(gru_test) ...@@ -961,7 +983,7 @@ TEST_CASE(gru_test)
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -984,14 +1006,15 @@ TEST_CASE(gru_test) ...@@ -984,14 +1006,15 @@ TEST_CASE(gru_test)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::gru::forward, clip}, auto out_hs =
p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::forward, clip},
seq, seq,
w, w,
r, r,
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward_0.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_forward_0.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -1015,14 +1038,15 @@ TEST_CASE(gru_test) ...@@ -1015,14 +1038,15 @@ TEST_CASE(gru_test)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {migraphx::op::relu{}}, migraphx::op::gru::reverse, clip}, migraphx::op::gru{
hs, {migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip},
seq, seq,
w, w,
r, r,
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -1173,4 +1197,18 @@ TEST_CASE(pad_test) ...@@ -1173,4 +1197,18 @@ TEST_CASE(pad_test)
migraphx::parse_onnx("pad_test.onnx"); migraphx::parse_onnx("pad_test.onnx");
} }
TEST_CASE(lrn_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 28, 24, 24}});
migraphx::op::lrn op;
op.size = 5;
op.alpha = 0.0001;
op.beta = 0.75;
op.bias = 1.0;
p.add_instruction(op, l0);
migraphx::parse_onnx("lrn_test.onnx");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -282,10 +282,11 @@ TEST_CASE(rnn) ...@@ -282,10 +282,11 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type, expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::rnn{ migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::forward, clip}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -307,10 +308,11 @@ TEST_CASE(rnn) ...@@ -307,10 +308,11 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type, expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::rnn{ migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::reverse, clip}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -332,11 +334,12 @@ TEST_CASE(rnn) ...@@ -332,11 +334,12 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape( expect_shape(migraphx::shape{migraphx::shape::float_type,
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::rnn{ migraphx::op::rnn{hidden_size,
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::bidirectional, clip}, {migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -358,9 +361,10 @@ TEST_CASE(rnn) ...@@ -358,9 +361,10 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape( throws_shape(migraphx::op::rnn{hidden_size + 1,
migraphx::op::rnn{ {migraphx::op::tanh{}},
hidden_size + 1, {migraphx::op::tanh{}}, migraphx::op::rnn::forward, clip}, migraphx::op::rnn_direction::forward,
clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -382,9 +386,10 @@ TEST_CASE(rnn) ...@@ -382,9 +386,10 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape( throws_shape(migraphx::op::rnn{hidden_size,
migraphx::op::rnn{ {migraphx::op::tanh{}},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::bidirectional, clip}, migraphx::op::rnn_direction::bidirectional,
clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -408,7 +413,7 @@ TEST_CASE(rnn) ...@@ -408,7 +413,7 @@ TEST_CASE(rnn)
throws_shape( throws_shape(
migraphx::op::rnn{ migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::forward, clip}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -435,10 +440,11 @@ TEST_CASE(gru) ...@@ -435,10 +440,11 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type, expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::gru{ migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::forward, clip}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -462,10 +468,11 @@ TEST_CASE(gru) ...@@ -462,10 +468,11 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type, expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::gru{ migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::reverse, clip}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -489,11 +496,12 @@ TEST_CASE(gru) ...@@ -489,11 +496,12 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape( expect_shape(migraphx::shape{migraphx::shape::float_type,
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::gru{ migraphx::op::gru{hidden_size,
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::bidirectional, clip}, {migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -517,9 +525,10 @@ TEST_CASE(gru) ...@@ -517,9 +525,10 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape( throws_shape(migraphx::op::gru{hidden_size + 1,
migraphx::op::gru{ {migraphx::op::tanh{}},
hidden_size + 1, {migraphx::op::tanh{}}, migraphx::op::gru::forward, clip}, migraphx::op::rnn_direction::forward,
clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -543,9 +552,10 @@ TEST_CASE(gru) ...@@ -543,9 +552,10 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape( throws_shape(migraphx::op::gru{hidden_size,
migraphx::op::gru{ {migraphx::op::tanh{}},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::bidirectional, clip}, migraphx::op::rnn_direction::bidirectional,
clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -571,7 +581,7 @@ TEST_CASE(gru) ...@@ -571,7 +581,7 @@ TEST_CASE(gru)
throws_shape( throws_shape(
migraphx::op::gru{ migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::forward, clip}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment