Commit a2ea4ecd authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add a pass for rnn operator.

parent 31b2c735
...@@ -11,6 +11,7 @@ add_library(migraphx ...@@ -11,6 +11,7 @@ add_library(migraphx
eliminate_contiguous.cpp eliminate_contiguous.cpp
eliminate_concat.cpp eliminate_concat.cpp
fwd_conv_batchnorm_rewrite.cpp fwd_conv_batchnorm_rewrite.cpp
rewrite_rnn.cpp
env.cpp env.cpp
generate.cpp generate.cpp
instruction.cpp instruction.cpp
......
...@@ -1052,6 +1052,51 @@ struct outline ...@@ -1052,6 +1052,51 @@ struct outline
argument compute(const shape&, const std::vector<argument>&) const { return {s, nullptr}; } argument compute(const shape&, const std::vector<argument>&) const { return {s, nullptr}; }
}; };
struct rnn
{
enum rnn_direction_t
{
forward,
reverse,
bidirectional,
};
std::size_t hidden_size = 1;
operation actv_func = tanh{};
rnn_direction_t direction = forward;
float clip = 0.0f;
std::string name() const { return "rnn"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto in_dims = inputs[0].lens();
auto hidden_dims = inputs[1].lens();
if (hidden_size != hidden_dims[1])
{
MIGRAPHX_THROW("RNN: hidden size mismatch in attribute and input");
}
std::size_t num_directions = 1;
if (direction == rnn_direction_t::bidirectional)
{
num_directions = 2;
}
if (num_directions != hidden_dims[0])
{
MIGRAPHX_THROW("RNN: num_direction does not match the direction attribute");
}
std::vector<std::size_t> out_dims(in_dims);
out_dims.insert(out_dims.begin() + 1, num_directions);
out_dims.back() = hidden_size;
return {inputs[0].type(), out_dims};
}
};
} // namespace op } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_RNN_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_RNN_HPP
#include <string>
#include <vector>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
/**
* Rewrite rnn to gemm and add.
*/
struct rewrite_rnn
{
std::string name() const { return "rewrite_rnn"; }
void apply(program& prog) const;
private:
std::vector<instruction_ref> rnn_oper(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref wx,
instruction_ref wh,
instruction_ref ih,
instruction_ref bias,
operation& actv_func) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -84,6 +84,7 @@ struct onnx_parser ...@@ -84,6 +84,7 @@ struct onnx_parser
add_mem_op("Shape", &onnx_parser::parse_shape); add_mem_op("Shape", &onnx_parser::parse_shape);
add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill); add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
add_mem_op("Transpose", &onnx_parser::parse_transpose); add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("RNN", &onnx_parser::parse_rnn);
} }
template <class F> template <class F>
...@@ -633,6 +634,65 @@ struct onnx_parser ...@@ -633,6 +634,65 @@ struct onnx_parser
} }
} }
instruction_ref
parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
migraphx::shape input_shape = args[0]->get_shape();
migraphx::shape w_shape = args[1]->get_shape();
std::size_t hidden_size = w_shape.lens()[1];
if(contains(attributes, "hidden_size"))
{
hidden_size = parse_value(attributes.at("hidden_size")).at<int>();
}
else
{
MIGRAPHX_THROW("RNN: hidden size attribute missing");
}
std::string activation_func = {"tanh"};
if(contains(attributes, "activations"))
{
activation_func = attributes.at("activations").strings(0);
}
std::unordered_map<std::string, operation> actv_func_map;
actv_func_map.insert(std::make_pair("tanh", op::tanh{}));
actv_func_map.insert(std::make_pair("relu", op::relu{}));
actv_func_map.insert(std::make_pair("sigmoid", op::sigmoid{}));
if (actv_func_map.count(activation_func) == 0)
{
MIGRAPHX_THROW("RNN: activation function " + activation_func + " not supported");
}
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(attributes, "direction"))
{
direction = attributes.at("direction").s();
}
op::rnn::rnn_direction_t dirct = op::rnn::forward;
if(direction == "bidirectional")
{
dirct = op::rnn::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::rnn::reverse;
}
// To be added later
float clip = 0.0;
if(contains(attributes, "clip"))
{
clip = parse_value(attributes.at("clip")).at<float>();
}
return prog.add_instruction(op::rnn{hidden_size, actv_func_map[activation_func], dirct, clip}, std::move(args));
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
onnx::ModelProto model; onnx::ModelProto model;
......
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_rnn::apply(program& prog) const
{
for(auto ins : iterator_for(prog))
{
if(ins->name() != "rnn")
{
continue;
}
// could be 3 to 5 inputs (though onnx::rnn has 6 inputs,
// the 5th one is undefined and ignored by protobuf. so
// we need to process up to 5 inputs
auto args = ins->inputs();
shape seq_shape = args[0]->get_shape();
shape wgt_shape = args[1]->get_shape();
std::size_t hidden_size = wgt_shape.lens()[1];
std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type();
migraphx::shape s{type, {batch_size, hidden_size}};
std::vector<char> data(s.bytes(), 0);
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn::rnn_direction_t dicrt = rnn_op.direction;
if(dicrt == op::rnn::rnn_direction_t::bidirectional)
{
std::vector<int64_t> perm{1, 0};
// process input weight matrix
// forward
auto xw_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto sxw_forward = prog.insert_instruction(ins, op::squeeze{{0}}, xw_forward);
auto trans_xw_forward = prog.insert_instruction(ins, op::transpose{perm}, sxw_forward);
// reverse
auto xw_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
auto sxw_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, xw_reverse);
auto trans_xw_reverse = prog.insert_instruction(ins, op::transpose{perm}, sxw_reverse);
// process hidden state weight matrix
auto hw_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto shw_forward = prog.insert_instruction(ins, op::squeeze{{0}}, hw_forward);
auto trans_hw_forward = prog.insert_instruction(ins, op::transpose{perm}, shw_forward);
auto hw_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
auto shw_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, hw_reverse);
auto trans_hw_reverse = prog.insert_instruction(ins, op::transpose{perm}, shw_reverse);
// process bias
instruction_ref bias_forward, bias_reverse;
bias_forward = bias_reverse = prog.end();
if(args.size() >= 4)
{
// forward
long h_size = static_cast<long>(hidden_size);
auto b_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
b_forward = prog.insert_instruction(ins, op::squeeze{{0}}, b_forward);
auto wbf = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, b_forward);
auto rbf = prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_forward);
auto bf = prog.insert_instruction(ins, op::add{}, wbf, rbf);
bias_forward = prog.insert_instruction(ins, op::broadcast{1, s}, bf);
// backward
auto b_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
b_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, b_reverse);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, b_reverse);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_reverse);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
bias_reverse = prog.insert_instruction(ins, op::broadcast{1, s}, br);
}
// process intial hidden state
instruction_ref ih_forward, ih_reverse;
if(args.size() >= 5)
{
// forward
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[4]);
ih_forward = prog.insert_instruction(ins, op::squeeze{{0}}, ih_forward);
// reverse
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[4]);
ih_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, ih_reverse);
}
else
{
ih_forward = prog.add_literal(migraphx::literal{s, data});
ih_reverse = prog.add_literal(migraphx::literal{s, data});
}
auto ret_forward = rnn_oper(true,
prog,
ins,
args[0],
trans_xw_forward,
trans_hw_forward,
ih_forward,
bias_forward,
rnn_op.actv_func);
auto ret_reverse = rnn_oper(false,
prog,
ins,
args[0],
trans_xw_reverse,
trans_hw_reverse,
ih_reverse,
bias_reverse,
rnn_op.actv_func);
// auto final_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1],
// add the dimension of num_direction
ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]);
ret_reverse[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_reverse[0]);
// concat the forward and reverse output
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
else
{
bool is_forward = (dicrt == op::rnn::rnn_direction_t::forward) ? true : false;
std::vector<int64_t> perm{1, 0};
// process input weight matrix
auto sxw = prog.insert_instruction(ins, op::squeeze{{0}}, args[1]);
auto trans_xw = prog.insert_instruction(ins, op::transpose{perm}, sxw);
// process hidden state weight matrix
auto shw = prog.insert_instruction(ins, op::squeeze{{0}}, args[2]);
auto trans_hw = prog.insert_instruction(ins, op::transpose{perm}, shw);
// process bias and initial hidden state
instruction_ref bias = prog.end();
if(args.size() >= 4)
{
long h_size = static_cast<long>(hidden_size);
auto bwr = prog.insert_instruction(ins, op::squeeze{{0}}, args[3]);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, bwr);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, bwr);
auto b = prog.insert_instruction(ins, op::add{}, wb, rb);
bias = prog.insert_instruction(ins, op::broadcast{1, s}, b);
}
// process intial hidden state
instruction_ref ih;
if(args.size() >= 5)
{
ih = prog.insert_instruction(ins, op::squeeze{{0}}, args[4]);
}
else
{
ih = prog.add_literal(migraphx::literal{s, data});
}
auto ret = rnn_oper(is_forward, prog, ins, args[0], trans_xw, trans_hw, ih, bias, rnn_op.actv_func);
// add the dimension of num_direction
prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
}
}
}
std::vector<instruction_ref> rewrite_rnn::rnn_oper(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref wx,
instruction_ref wh,
instruction_ref ih,
instruction_ref bias,
operation& actv_func) const
{
instruction_ref hidden_out, final_out;
migraphx::shape input_shape = input->get_shape();
std::size_t seq_len = input_shape.lens()[0];
long seq_index = is_forward ? 0 : seq_len - 1;
for(std::size_t i = 0; i < seq_len; i++)
{
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
auto x_w = prog.insert_instruction(ins, op::dot{}, xt, wx);
auto h_r = prog.insert_instruction(ins, op::dot{}, ih, wh);
auto x_h = prog.insert_instruction(ins, op::add{}, x_w, h_r);
instruction_ref before_actv;
if(bias != prog.end())
{
before_actv = prog.insert_instruction(ins, op::add{}, x_h, bias);
}
else
{
before_actv = x_h;
}
// apply activation function
ih = prog.insert_instruction(ins, actv_func, before_actv);
// add the dimension of sequence length
auto output = prog.insert_instruction(ins, op::unsqueeze{{0}}, ih);
final_out = output;
if(is_forward)
{
hidden_out = (seq_index == 0)
? output
: prog.insert_instruction(ins, op::concat{0}, hidden_out, output);
}
else
{
hidden_out = (seq_index == seq_len - 1)
? output
: prog.insert_instruction(ins, op::concat{0}, output, hidden_out);
}
seq_index = is_forward ? (seq_index + 1) : (seq_index - 1);
}
std::vector<instruction_ref> out_args;
out_args.push_back(hidden_out);
out_args.push_back(final_out);
return out_args;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/cpu/lowering.hpp> #include <migraphx/cpu/lowering.hpp>
#include <migraphx/auto_contiguous.hpp> #include <migraphx/auto_contiguous.hpp>
#include <migraphx/rewrite_rnn.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -11,7 +12,7 @@ std::string target::name() const { return "cpu"; } ...@@ -11,7 +12,7 @@ std::string target::name() const { return "cpu"; }
std::vector<pass> target::get_passes(migraphx::context&) const std::vector<pass> target::get_passes(migraphx::context&) const
{ {
return {auto_contiguous{}, lowering{}}; return {auto_contiguous{}, rewrite_rnn{}, lowering{}};
} }
} // namespace cpu } // namespace cpu
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <migraphx/eliminate_contiguous.hpp> #include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/common_subexpression_elimination.hpp> #include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/eliminate_concat.hpp> #include <migraphx/eliminate_concat.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
...@@ -33,6 +34,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -33,6 +34,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
dead_code_elimination{}, dead_code_elimination{},
common_subexpression_elimination{}, common_subexpression_elimination{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_rnn{},
dead_code_elimination{},
simplify_algebra{}, simplify_algebra{},
dead_code_elimination{}, dead_code_elimination{},
constant_propagate{}, constant_propagate{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment