Unverified Commit c180f601 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge pull request #163 from ROCmSoftwarePlatform/rnn_operator

Rnn operator
parents 15d3cf62 d3a09f1a
...@@ -11,6 +11,7 @@ add_library(migraphx ...@@ -11,6 +11,7 @@ add_library(migraphx
eliminate_contiguous.cpp eliminate_contiguous.cpp
eliminate_concat.cpp eliminate_concat.cpp
fwd_conv_batchnorm_rewrite.cpp fwd_conv_batchnorm_rewrite.cpp
rewrite_rnn.cpp
env.cpp env.cpp
generate.cpp generate.cpp
instruction.cpp instruction.cpp
......
...@@ -12,7 +12,7 @@ void auto_contiguous::apply(program& p) const ...@@ -12,7 +12,7 @@ void auto_contiguous::apply(program& p) const
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
shape s = ins->get_shape(); shape s = ins->get_shape();
if(not s.standard()) if(not s.standard() and s.elements() != 0)
{ {
auto c = p.insert_instruction(std::next(ins), op::contiguous{}, ins); auto c = p.insert_instruction(std::next(ins), op::contiguous{}, ins);
p.replace_instruction(ins, c); p.replace_instruction(ins, c);
......
...@@ -430,7 +430,6 @@ struct concat ...@@ -430,7 +430,6 @@ struct concat
} }
return result; return result;
} }
int output_alias(const std::vector<shape>&) const { return 0; }
}; };
struct slice struct slice
...@@ -616,11 +615,16 @@ struct reshape ...@@ -616,11 +615,16 @@ struct reshape
{ {
if(dims[i] == 0) if(dims[i] == 0)
rdims[i] = idims[i]; rdims[i] = idims[i];
// since rdims using size_t type, -1 is the max value
// is size_t that cause later compuation incorrect
if(dims[i] == -1)
rdims[i] = 1;
} }
if(n_neg_dims > 0) if(n_neg_dims > 0)
{ {
size_t missing_dim = size_t missing_dim =
-inputs.front().elements() / inputs.front().elements() /
std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies<int64_t>()); std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies<int64_t>());
for(std::size_t i = 0; i < rdims.size(); i++) for(std::size_t i = 0; i < rdims.size(); i++)
{ {
...@@ -628,11 +632,7 @@ struct reshape ...@@ -628,11 +632,7 @@ struct reshape
rdims[i] = missing_dim; rdims[i] = missing_dim;
} }
} }
if(dims.back() == -1)
{
rdims.pop_back();
std::copy(idims.begin() + rdims.size(), idims.end(), std::back_inserter(rdims));
}
shape s{inputs.front().type(), rdims}; shape s{inputs.front().type(), rdims};
if(s.elements() != inputs.front().elements()) if(s.elements() != inputs.front().elements())
MIGRAPHX_THROW("Wrong number of elements for reshape"); MIGRAPHX_THROW("Wrong number of elements for reshape");
...@@ -764,8 +764,6 @@ struct gather ...@@ -764,8 +764,6 @@ struct gather
return result; return result;
} }
int output_alias(const std::vector<shape>&) const { return 0; }
}; };
struct dot struct dot
...@@ -1131,6 +1129,76 @@ struct outline ...@@ -1131,6 +1129,76 @@ struct outline
argument compute(const shape&, const std::vector<argument>&) const { return {s, nullptr}; } argument compute(const shape&, const std::vector<argument>&) const { return {s, nullptr}; }
}; };
struct rnn
{
enum rnn_direction_t
{
forward,
reverse,
bidirectional,
};
std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{tanh{}, tanh{}};
rnn_direction_t direction = forward;
float clip = 0.0f;
std::string name() const { return "rnn"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto in_dims = inputs[0].lens();
auto hidden_dims = inputs[2].lens();
if(hidden_size != hidden_dims[2])
{
MIGRAPHX_THROW("RNN: hidden size mismatch in attribute and input");
}
std::size_t num_directions = 1;
if(direction == bidirectional)
{
num_directions = 2;
}
if(num_directions != hidden_dims[0])
{
MIGRAPHX_THROW("RNN: num_direction mismatch in attribute and input");
}
std::vector<std::size_t> out_dims(in_dims);
out_dims.insert(out_dims.begin() + 1, num_directions);
out_dims.back() = hidden_size;
return {inputs[0].type(), out_dims};
}
};
struct rnn_last_output
{
std::string name() const { return "rnn_last_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
dims.erase(dims.begin());
return {inputs[0].type(), dims};
}
};
struct undefined
{
std::string name() const { return "undefined"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(0);
return {};
}
argument compute(const shape&, const std::vector<argument>&) const { return {{}, nullptr}; }
};
} // namespace op } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_RNN_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_RNN_HPP
#include <string>
#include <vector>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
/**
* Rewrite rnn to gemm and add.
*/
struct rewrite_rnn
{
std::string name() const { return "rewrite_rnn"; }
void apply(program& prog) const;
private:
std::vector<instruction_ref> rnn_cell(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref w,
instruction_ref r,
instruction_ref bias,
instruction_ref ih,
operation& actv_func) const;
std::vector<operation> compute_actv_funcs(instruction_ref ins) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -32,6 +32,7 @@ struct onnx_parser ...@@ -32,6 +32,7 @@ struct onnx_parser
bool is_pytorch = false; bool is_pytorch = false;
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
std::unordered_map<std::string, operation> map_actv_funcs;
onnx_parser() onnx_parser()
{ {
...@@ -85,7 +86,20 @@ struct onnx_parser ...@@ -85,7 +86,20 @@ struct onnx_parser
add_mem_op("Shape", &onnx_parser::parse_shape); add_mem_op("Shape", &onnx_parser::parse_shape);
add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill); add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
add_mem_op("Transpose", &onnx_parser::parse_transpose); add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("Pad", &onnx_parser::parse_pad); add_mem_op("Pad", &onnx_parser::parse_pad);
// init the activation function map
init_actv_func();
}
void init_actv_func()
{
map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
map_actv_funcs.insert(std::make_pair("elu", op::elu{}));
} }
template <class F> template <class F>
...@@ -677,6 +691,96 @@ struct onnx_parser ...@@ -677,6 +691,96 @@ struct onnx_parser
} }
} }
std::vector<instruction_ref>
parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
migraphx::shape input_shape = args[0]->get_shape();
migraphx::shape w_shape = args[1]->get_shape();
std::size_t hidden_size = w_shape.lens()[1];
if(contains(attributes, "hidden_size"))
{
std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att)
{
MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
}
}
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(attributes, "direction"))
{
direction = attributes.at("direction").s();
}
op::rnn::rnn_direction_t dirct = op::rnn::forward;
if(direction == "bidirectional")
{
dirct = op::rnn::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::rnn::reverse;
}
std::vector<std::string> vec_names{"tanh"};
if(contains(attributes, "activations"))
{
auto names = attributes.at("activations").strings();
vec_names.clear();
for_each(names.begin(), names.end(), [&](auto& fn) { vec_names.push_back(fn); });
}
for_each(vec_names.begin(), vec_names.end(), [&](auto& fn) {
if(map_actv_funcs.count(fn) == 0)
{
MIGRAPHX_THROW("RNN: activation function " + std::string(fn) + " not supported");
}
});
// bidirectional case should have two activation functions.
// one is for forward, and the other is for reverse.
// if only one actv function is provided, we use it in both
// forward and reverse direction
if(dirct == op::rnn::bidirectional)
{
if(vec_names.size() == 1)
{
vec_names.push_back(vec_names.at(0));
}
}
std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& fn) {
return map_actv_funcs[fn];
});
// To be added later
float clip = 0.0;
if(contains(attributes, "clip"))
{
clip = parse_value(attributes.at("clip")).at<float>();
}
// if the number of arguments is less than 6, append
// undefined operator to have 6 arguments
if(args.size() < 6)
{
auto ins = prog.add_instruction(op::undefined{});
args.insert(args.end(), (6 - args.size()), ins);
}
// first output for the concatenation of hidden states
auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
std::move(args));
// second output for the last hidden state
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
return {hidden_states, last_output};
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
onnx::ModelProto model; onnx::ModelProto model;
...@@ -723,6 +827,12 @@ struct onnx_parser ...@@ -723,6 +827,12 @@ struct onnx_parser
} }
} }
void parse_undefined(const std::string& name)
{
auto ins = prog.add_instruction(op::undefined{});
instructions[name] = ins;
}
void parse_node(const std::string& name) void parse_node(const std::string& name)
{ {
if(name.empty()) if(name.empty())
...@@ -737,12 +847,12 @@ struct onnx_parser ...@@ -737,12 +847,12 @@ struct onnx_parser
{ {
assert(name != input); assert(name != input);
this->parse_node(input); this->parse_node(input);
args.push_back(instructions.at(input));
} }
else else if(input.empty())
{ {
args.push_back(instructions.at(input)); this->parse_undefined(input);
} }
args.push_back(instructions.at(input));
} }
std::vector<instruction_ref> result; std::vector<instruction_ref> result;
if(ops.count(node.op_type()) == 0) if(ops.count(node.op_type()) == 0)
......
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/time.hpp> #include <migraphx/time.hpp>
...@@ -134,6 +135,12 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re ...@@ -134,6 +135,12 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
assert(has_instruction(ins)); assert(has_instruction(ins));
assert(has_instruction(rep)); assert(has_instruction(rep));
assert(ins != rep); assert(ins != rep);
if(ins == std::prev(this->end()))
{
return replace_instruction(ins, op::identity{}, rep);
}
// TODO: Should it be an error if the output is empty? // TODO: Should it be an error if the output is empty?
if(ins->outputs().empty()) if(ins->outputs().empty())
{ {
......
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_rnn::apply(program& prog) const
{
std::unordered_map<instruction_ref, instruction_ref> map_last_output;
for(auto ins : iterator_for(prog))
{
// rewrite rnn operator
if(ins->name() == "rnn")
{
// could be 3 to 6 inputs, but the parse_rnn function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have only 3 arguments
// when writing their program.
auto args = ins->inputs();
shape seq_shape = args[0]->get_shape();
std::size_t hidden_size = args[1]->get_shape().lens()[1];
std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type();
migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<float> data(ih_shape.elements(), 0);
auto actv_funcs = compute_actv_funcs(ins);
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn::rnn_direction_t dicrt = rnn_op.direction;
if(dicrt == op::rnn::bidirectional)
{
// input weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
// hidden state weight matrix
auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
// process bias
instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
{
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
}
// process intial hidden state, it could be the 6th argument
// or the 5th one (if the sequence len argument is ignored)
instruction_ref ih_forward{};
instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
{
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
}
else
{
ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret_forward = rnn_cell(true,
prog,
ins,
args[0],
w_forward,
r_forward,
bias_forward,
ih_forward,
actv_funcs.at(0));
auto ret_reverse = rnn_cell(false,
prog,
ins,
args[0],
w_reverse,
r_reverse,
bias_reverse,
ih_reverse,
actv_funcs.at(1));
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction
// sequence len is 1
instruction_ref hidden_output{};
if(ret_forward[0] == prog.end())
{
hidden_output = prog.replace_instruction(
ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
{
ret_forward[0] =
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
hidden_output = prog.replace_instruction(
ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
map_last_output[hidden_output] = last_output;
}
else
{
bool is_forward = (dicrt == op::rnn::forward);
// input weight matrix
auto w = args[1];
// hidden state weight matrix
auto r = args[2];
// process bias and initial hidden state
instruction_ref bias = prog.end();
if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
{
bias = args[3];
}
// process intial hidden state
instruction_ref ih;
if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
{
ih = args[5];
}
else
{
ih = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret =
rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// following logic is to ensure the last instruction is a
// concat instruction
// sequence len is 1
instruction_ref hidden_output{};
if(ret[0] == prog.end())
{
hidden_output = prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
hidden_output =
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
map_last_output[hidden_output] = last_output;
}
}
// rewrite the rnn_last_output operator that right after the rnn
// operator. Intuitively, we can do a slice on the input to get
// the last output, but it is already existed in the rnn operator,
// so we can just use it as the output here
if(ins->name() == "rnn_last_output")
{
auto inputs = ins->inputs();
assert(inputs.size() == 1);
auto arg = inputs[0];
if(map_last_output.count(arg) == 0)
{
MIGRAPHX_THROW("RNN_LAST_OUTPUT: no related rnn operator as its input");
}
prog.replace_instruction(ins, map_last_output[arg]);
}
}
}
std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref w,
instruction_ref r,
instruction_ref bias,
instruction_ref ih,
operation& actv_func) const
{
// squeeze and transpose w
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto tran_sw = prog.insert_instruction(ins, op::transpose{perm}, sw);
// squeeze and transpose r
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);
// initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
// bias
if(bias != prog.end())
{
long hs = r->get_shape().lens()[2];
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto b = prog.insert_instruction(ins, op::add{}, wb, rb);
bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b);
}
instruction_ref hidden_out = prog.end();
instruction_ref last_out{};
last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
std::size_t seq_len = input->get_shape().lens()[0];
for(std::size_t i = 0; i < seq_len; i++)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
instruction_ref ht;
if(bias != prog.end())
{
ht = prog.insert_instruction(ins, op::add{}, xt_ht, bias);
}
else
{
ht = xt_ht;
}
// apply activation function
ht = prog.insert_instruction(ins, actv_func, ht);
sih = ht;
// add the dimensions of sequence length (axis 0 for sequence length,
// axis 1 for num_directions
last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
// concatenation for the last last_out is performed in the apply()
// function to ensure the last instruction is concat, then we have
// output inserted
if(i < seq_len - 1)
{
if(is_forward)
{
hidden_out =
(seq_index == 0)
? last_out
: prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
}
else
{
hidden_out =
(seq_index == seq_len - 1)
? last_out
: prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
}
}
}
return {hidden_out, last_out};
}
std::vector<operation> rewrite_rnn::compute_actv_funcs(instruction_ref ins) const
{
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
// before rewrite the rnn operator, need to ensure
// we have 2 actv funcs. If less than 2, use the
// algorithm in parse_rnn to make 2 actv functions
if(rnn_op.direction == op::rnn::bidirectional)
{
if(rnn_op.actv_funcs.empty())
{
// default is tanh
return {op::tanh{}, op::tanh{}};
}
else if(rnn_op.actv_funcs.size() == 1)
{
return {rnn_op.actv_funcs.at(0), rnn_op.actv_funcs.at(0)};
}
else
{
return rnn_op.actv_funcs;
}
}
else
{
if(rnn_op.actv_funcs.empty())
{
// default is tanh
return {op::tanh{}};
}
else
{
return rnn_op.actv_funcs;
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -14,7 +14,8 @@ bool is_nonstandard_reshaper(instruction_ref ins) ...@@ -14,7 +14,8 @@ bool is_nonstandard_reshaper(instruction_ref ins)
{ {
// clang-format off // clang-format off
static const std::unordered_set<std::string> names = { static const std::unordered_set<std::string> names = {
"reshape" "reshape",
"contiguous"
}; };
// clang-format on // clang-format on
return contains(names, ins->name()) and ins->inputs().front()->name() == "contiguous"; return contains(names, ins->name()) and ins->inputs().front()->name() == "contiguous";
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/cpu/lowering.hpp> #include <migraphx/cpu/lowering.hpp>
#include <migraphx/auto_contiguous.hpp> #include <migraphx/auto_contiguous.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/dead_code_elimination.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -11,7 +13,11 @@ std::string target::name() const { return "cpu"; } ...@@ -11,7 +13,11 @@ std::string target::name() const { return "cpu"; }
std::vector<pass> target::get_passes(migraphx::context&) const std::vector<pass> target::get_passes(migraphx::context&) const
{ {
return {auto_contiguous{}, lowering{}}; return {auto_contiguous{},
rewrite_rnn{},
dead_code_elimination{},
lowering{},
dead_code_elimination{}};
} }
} // namespace cpu } // namespace cpu
......
...@@ -107,6 +107,7 @@ argument miopen_gemm::compute(context& ctx, ...@@ -107,6 +107,7 @@ argument miopen_gemm::compute(context& ctx,
ldc); ldc);
}); });
return args[2]; return args[2];
} }
......
...@@ -56,6 +56,7 @@ struct miopen_apply ...@@ -56,6 +56,7 @@ struct miopen_apply
program* prog = nullptr; program* prog = nullptr;
context ctx{}; context ctx{};
std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{}; std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{};
instruction_ref last{};
void check_shape(shape x, instruction_ref i) void check_shape(shape x, instruction_ref i)
{ {
...@@ -66,6 +67,7 @@ struct miopen_apply ...@@ -66,6 +67,7 @@ struct miopen_apply
void init() void init()
{ {
this->last = instruction::get_output_alias(std::prev(prog->end()));
add_miopen_simple_op<miopen_relu>("relu", make_relu); add_miopen_simple_op<miopen_relu>("relu", make_relu);
add_miopen_simple_op<miopen_sigmoid>("sigmoid", make_sigmoid); add_miopen_simple_op<miopen_sigmoid>("sigmoid", make_sigmoid);
add_miopen_simple_op<miopen_abs>("abs", make_abs); add_miopen_simple_op<miopen_abs>("abs", make_abs);
...@@ -117,7 +119,7 @@ struct miopen_apply ...@@ -117,7 +119,7 @@ struct miopen_apply
instruction_ref insert_allocation(instruction_ref ins, const shape& s, std::string tag = "") instruction_ref insert_allocation(instruction_ref ins, const shape& s, std::string tag = "")
{ {
if(ins == --prog->end() and tag.empty()) if(ins == last and tag.empty())
{ {
return prog->add_parameter("output", s); return prog->add_parameter("output", s);
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <migraphx/eliminate_contiguous.hpp> #include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/common_subexpression_elimination.hpp> #include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/eliminate_concat.hpp> #include <migraphx/eliminate_concat.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
...@@ -31,14 +32,16 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -31,14 +32,16 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
dead_code_elimination{}, dead_code_elimination{},
fwd_conv_batchnorm_rewrite{}, fwd_conv_batchnorm_rewrite{},
dead_code_elimination{}, dead_code_elimination{},
common_subexpression_elimination{}, rewrite_rnn{},
dead_code_elimination{}, dead_code_elimination{},
//common_subexpression_elimination{},
//dead_code_elimination{},
simplify_algebra{}, simplify_algebra{},
dead_code_elimination{}, dead_code_elimination{},
constant_propagate{}, constant_propagate{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
simplify_reshapes{}, //simplify_reshapes{},
dead_code_elimination{}, dead_code_elimination{},
lowering{ctx}, lowering{ctx},
eliminate_concat{concat_gpu_optimization{}}, eliminate_concat{concat_gpu_optimization{}},
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
float sigmoid(float x) { return 1 / (1 + expf(-x)); } float sigmoid(float x) { return 1 / (1 + expf(-x)); }
...@@ -1346,6 +1347,521 @@ TEST_CASE(min_test) ...@@ -1346,6 +1347,521 @@ TEST_CASE(min_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(rnn_forward)
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> wf_data{0.4691,
0.3185,
-0.2227,
0.4423,
-0.0609,
-0.2803,
0.1744,
0.3146,
0.4049,
-0.3973,
-0.0890,
-0.1636};
std::vector<float> rf_data{-0.0456,
0.1061,
0.1574,
-0.4928,
-0.4300,
-0.1909,
-0.0225,
-0.2668,
0.1840,
-0.4453,
-0.4896,
0.1302,
-0.0929,
0.3545,
-0.4981,
0.0616};
std::vector<float> biasf_data{
-0.4938, 0.4355, -0.3186, 0.2094, 0.1037, -0.1071, 0.4504, -0.3990};
std::vector<float> input(seq_len * batch_size * input_size, 0);
input[0] = input[1] = 1.0;
float clip = 0.0f;
{
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
auto seq = p.add_literal(migraphx::literal{in_shape, input});
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
auto w = p.add_literal(migraphx::literal{w_shape, wf_data});
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto r = p.add_literal(migraphx::literal{r_shape, rf_data});
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, biasf_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward,
clip},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{0.37780784,
0.61055139,
0.55168478,
-0.5888475,
-0.37144644,
0.31708236,
0.13104209,
-0.18736027,
0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
{
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
auto seq = p.add_literal(migraphx::literal{in_shape, input});
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
auto w = p.add_literal(migraphx::literal{w_shape, wf_data});
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto r = p.add_literal(migraphx::literal{r_shape, rf_data});
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, biasf_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn::forward, clip},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({});
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
}
TEST_CASE(rnn_reverse)
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> wr_data{-0.0296,
-0.1341,
0.1761,
-0.2325,
-0.0717,
0.1852,
0.2720,
0.1471,
-0.1097,
0.3363,
-0.0587,
-0.2302};
std::vector<float> rr_data{0.2528,
-0.2333,
0.3973,
0.1593,
-0.0388,
0.1702,
0.3829,
-0.0712,
-0.1668,
0.3074,
-0.2854,
0.4049,
-0.3737,
-0.1051,
0.4482,
-0.2841};
std::vector<float> biasr_data{-0.3188, 0.1341, -0.4446, 0.1389, 0.3117, 0.3664, 0.2352, 0.2552};
std::vector<float> input(seq_len * batch_size * input_size, 0);
input[0] = input[1] = 1.0;
float clip = 0.0f;
{
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
auto seq = p.add_literal(migraphx::literal{in_shape, input});
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
auto w = p.add_literal(migraphx::literal{w_shape, wr_data});
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto r = p.add_literal(migraphx::literal{r_shape, rr_data});
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, biasr_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse,
clip},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.29385301,
0.16796815,
0.51075965,
0.40258689,
-0.13818839,
0.44124447,
0.14365635,
0.14803654,
-0.0070999,
0.46251031,
-0.20639211,
0.37488942,
-0.0070999,
0.46251031,
-0.20639211,
0.37488942};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
{
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
auto seq = p.add_literal(migraphx::literal{in_shape, input});
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
auto w = p.add_literal(migraphx::literal{w_shape, wr_data});
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto r = p.add_literal(migraphx::literal{r_shape, rr_data});
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, biasr_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn::reverse, clip},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({});
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{-0.29385301,
0.16796815,
0.51075965,
0.40258689,
-0.13818839,
0.44124447,
0.14365635,
0.14803654};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
}
TEST_CASE(rnn_bidirectional)
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
std::vector<float> wf_data{0.4691,
0.3185,
-0.2227,
0.4423,
-0.0609,
-0.2803,
0.1744,
0.3146,
0.4049,
-0.3973,
-0.0890,
-0.1636};
std::vector<float> wr_data{-0.0296,
-0.1341,
0.1761,
-0.2325,
-0.0717,
0.1852,
0.2720,
0.1471,
-0.1097,
0.3363,
-0.0587,
-0.2302};
std::vector<float> rf_data{-0.0456,
0.1061,
0.1574,
-0.4928,
-0.4300,
-0.1909,
-0.0225,
-0.2668,
0.1840,
-0.4453,
-0.4896,
0.1302,
-0.0929,
0.3545,
-0.4981,
0.0616};
std::vector<float> rr_data{0.2528,
-0.2333,
0.3973,
0.1593,
-0.0388,
0.1702,
0.3829,
-0.0712,
-0.1668,
0.3074,
-0.2854,
0.4049,
-0.3737,
-0.1051,
0.4482,
-0.2841};
std::vector<float> biasf_data{
-0.4938, 0.4355, -0.3186, 0.2094, 0.1037, -0.1071, 0.4504, -0.3990};
std::vector<float> biasr_data{-0.3188, 0.1341, -0.4446, 0.1389, 0.3117, 0.3664, 0.2352, 0.2552};
std::vector<float> input(seq_len * batch_size * input_size, 0);
input[0] = input[1] = 1.0;
float clip = 0.0f;
{
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
auto seq = p.add_literal(migraphx::literal{in_shape, input});
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w_data = wf_data;
w_data.insert(w_data.end(), wr_data.begin(), wr_data.end());
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r_data = rf_data;
r_data.insert(r_data.end(), rr_data.begin(), rr_data.end());
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias_data = biasf_data;
bias_data.insert(bias_data.end(), biasr_data.begin(), biasr_data.end());
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(
migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn::bidirectional, clip},
seq,
w,
r,
bias,
und,
ih);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.37780784, 0.61055139, 0.55168478, -0.5888475, -0.37144644, 0.31708236,
0.13104209, -0.18736027, -0.29385301, 0.16796815, 0.51075965, 0.40258689,
-0.13818839, 0.44124447, 0.14365635, 0.14803654, 0.03445704, 0.19167931,
-0.3946827, -0.30889652, -0.22276389, 0.44193283, -0.16477929, -0.11893477,
-0.0070999, 0.46251031, -0.20639211, 0.37488942, -0.0070999, 0.46251031,
-0.20639211, 0.37488942};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
{
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
auto seq = p.add_literal(migraphx::literal{in_shape, input});
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w_data = wf_data;
w_data.insert(w_data.end(), wr_data.begin(), wr_data.end());
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r_data = rf_data;
r_data.insert(r_data.end(), rr_data.begin(), rr_data.end());
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias_data = biasf_data;
bias_data.insert(bias_data.end(), biasr_data.begin(), biasr_data.end());
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction(
migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::bidirectional, clip},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({});
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477,
-0.29385301,
0.16796815,
0.51075965,
0.40258689,
-0.13818839,
0.44124447,
0.14365635,
0.14803654};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
{
std::vector<float> ih_data(num_dirct * batch_size * hidden_size, 0);
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w_data = wf_data;
w_data.insert(w_data.end(), wr_data.begin(), wr_data.end());
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r_data = rf_data;
r_data.insert(r_data.end(), rr_data.begin(), rr_data.end());
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias_data = biasf_data;
bias_data.insert(bias_data.end(), biasr_data.begin(), biasr_data.end());
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional,
clip},
seq,
w,
r,
bias);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({});
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477,
-0.29385301,
0.16796815,
0.51075965,
0.40258689,
-0.13818839,
0.44124447,
0.14365635,
0.14803654};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
}
TEST_CASE(pad_test) TEST_CASE(pad_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -1116,6 +1116,394 @@ struct test_conv_bn_relu_pooling2 ...@@ -1116,6 +1116,394 @@ struct test_conv_bn_relu_pooling2
} }
}; };
struct test_rnn_forward
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 1;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward,
clip},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p;
}
};
struct test_rnn_forward10
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 10;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward,
clip},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p;
}
};
struct test_rnn_reverse
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 1;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse,
clip},
seq,
w,
r,
bias,
und,
ih);
return p;
}
};
struct test_rnn_reverse2
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse,
clip},
seq,
w,
r,
bias,
und,
ih);
return p;
}
};
struct test_rnn_3args
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 1;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse,
clip},
seq,
w,
r);
return p;
}
};
struct test_rnn_4args
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 5;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse,
clip},
seq,
w,
r,
bias);
return p;
}
};
struct test_rnn_5args
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 10;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward,
clip},
seq,
w,
r,
bias,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p;
}
};
struct test_rnn_bidirectional
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 1;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional,
clip},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p;
}
};
struct test_rnn_bidirectional10
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 10;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto bias = p.add_parameter("bias", b_shape);
auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{});
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional,
clip},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p;
}
};
struct test_rnn_bi_3args
{
migraphx::program create_program() const
{
std::size_t batch_size = 2;
std::size_t seq_len = 10;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
auto seq = p.add_parameter("seq", in_shape);
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional,
clip},
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p;
}
};
int main() int main()
{ {
verify_program<test_pooling_autopad>(); verify_program<test_pooling_autopad>();
...@@ -1179,4 +1567,14 @@ int main() ...@@ -1179,4 +1567,14 @@ int main()
verify_program<test_slice>(); verify_program<test_slice>();
verify_program<test_gather>(); verify_program<test_gather>();
verify_program<test_gather_neg_axis>(); verify_program<test_gather_neg_axis>();
verify_program<test_rnn_forward>();
verify_program<test_rnn_forward10>();
verify_program<test_rnn_reverse>();
verify_program<test_rnn_reverse2>();
verify_program<test_rnn_3args>();
verify_program<test_rnn_4args>();
verify_program<test_rnn_5args>();
verify_program<test_rnn_bidirectional>();
verify_program<test_rnn_bidirectional10>();
verify_program<test_rnn_bi_3args>();
} }
...@@ -466,6 +466,170 @@ TEST_CASE(shape_gather_test) ...@@ -466,6 +466,170 @@ TEST_CASE(shape_gather_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(rnn_test)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// bidirectional
{
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::bidirectional,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_bi.onnx");
EXPECT(p == prog);
}
// forward
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::forward,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_forward.onnx");
EXPECT(p == prog);
}
// reverse
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse,
clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_reverse.onnx");
EXPECT(p == prog);
}
// 3 argumments
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse,
clip},
seq,
w,
r,
und,
und,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_3args.onnx");
EXPECT(p == prog);
}
// 5 argumments
{
nd = 1;
migraphx::program p;
auto seq =
p.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
auto w = p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, hs, is}});
auto r = p.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, hs, hs}});
auto bias =
p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 2 * hs}});
auto seq_len =
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse,
clip},
seq,
w,
r,
bias,
seq_len,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_rnn_5args.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(flatten_test) TEST_CASE(flatten_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment