"...resnet50_tensorflow.git" did not exist on "219f6f06e488f242f5c20357922cfe9e1fb1a6ae"
Commit 6d0742b6 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

save implementation of gru operator

parent 67491293
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_GRU_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_GRU_HPP
#include <string>
#include <vector>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
/**
* Rewrite rnn to gemm and add.
*/
struct rewrite_gru
{
std::string name() const { return "rewrite_gru"; }
void apply(program& prog) const;
private:
std::vector<instruction_ref> rnn_gru(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref wx,
instruction_ref wh,
instruction_ref ih,
instruction_ref bias,
operation& actv_func) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -86,6 +86,7 @@ struct onnx_parser ...@@ -86,6 +86,7 @@ struct onnx_parser
add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill); add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
add_mem_op("Transpose", &onnx_parser::parse_transpose); add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("RNN", &onnx_parser::parse_rnn); add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("GRU", &onnx_parser::parse_gru);
// init the activation function map // init the activation function map
init_actv_func(); init_actv_func();
...@@ -651,8 +652,7 @@ struct onnx_parser ...@@ -651,8 +652,7 @@ struct onnx_parser
parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
migraphx::shape input_shape = args[0]->get_shape(); migraphx::shape input_shape = args[0]->get_shape();
migraphx::shape w_shape = args[1]->get_shape(); std::size_t hidden_size = args[1]->get_shape().lens()[1];
std::size_t hidden_size = w_shape.lens()[1];
if(contains(attributes, "hidden_size")) if(contains(attributes, "hidden_size"))
{ {
...@@ -702,6 +702,77 @@ struct onnx_parser ...@@ -702,6 +702,77 @@ struct onnx_parser
std::move(args)); std::move(args));
} }
instruction_ref
parse_gru(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2];
if(contains(attributes, "hidden_size"))
{
hidden_size = parse_value(attributes.at("hidden_size")).at<int>();
}
else
{
MIGRAPHX_THROW("GRU: hidden size attribute missing");
}
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(attributes, "direction"))
{
direction = attributes.at("direction").s();
}
op::gru::gru_direction_t dirct = op::gru::forward;
if(direction == "bidirectional")
{
dirct = op::gru::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::gru::reverse;
}
std::vector<std::string> act_funcs = {"sigmoid", "tanh"};
if(contains(attributes, "activations"))
{
act_funcs[0] = attributes.at("activations").strings(0);
act_funcs[1] = attributes.at("activations").strings(1);
}
if (act_funcs.size() != 2)
{
MIGRAPHX_THROW("GRU: wrong activation function attribute");
}
for (std::size_t i = 0; i < act_funcs.size(); ++i)
{
if(actv_funcs.count(act_funcs.at(i)) == 0)
{
MIGRAPHX_THROW("GRU: activation function " + act_funcs.at(i) + " not supported");
}
}
// To be added later
float clip = 0.0;
if(contains(attributes, "clip"))
{
clip = parse_value(attributes.at("clip")).at<float>();
}
int linear_before_reset = 0;
if (contains(attributes, "linear_before_reset"))
{
linear_before_reset = parse_value(attributes.at("linear_before_reset")).at<int>();
}
return prog.add_instruction(op::gru{hidden_size,
{actv_funcs[act_funcs.at(0)], actv_funcs[act_funcs.at(1)]},
dirct, clip, linear_before_reset},
std::move(args));
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
onnx::ModelProto model; onnx::ModelProto model;
......
#include <migraphx/rewrite_gru.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_gru::apply(program& prog) const
{
for(auto ins : iterator_for(prog))
{
if(ins->name() != "gru")
{
continue;
}
// could be 3 to 5 inputs (though onnx::rnn has 6 inputs,
// the 5th one is undefined and ignored by protobuf. so
// we need to process up to 5 inputs
auto args = ins->inputs();
shape seq_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2];
std::size_t batchs = seq_shape.lens()[1];
shape::type_t type = seq_shape.type();
migraphx::shape ih_shape{type, {batchs, hidden_size}};
std::vector<char> data(ih_shape.bytes(), 0);
auto gru_op = any_cast<op::gru>(ins->get_operator());
op::gru::gru_direction_t dicrt = gru_op.direction;
if(dicrt == op::gru::bidirectional)
{
long hs = static_cast<long>(hidden_size);
// forward weight
auto uw_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto w_forward = prog.insert_instruction(ins, op::squeeze{{0}, uw_forward});
auto ur_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto r_forward = prog.insert_instruction(ins, op::squeeze{{0}}, ur_forward);
// reverse weight
auto uw_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
auto w_reverse = prog.insert_instruction(ins, op::squeeze{{0}, uw_reverse});
auto ur_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
auto r_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, ur_reverse);
// process bias
instruction_ref bias_forward, bias_reverse;
bias_forward = bias_reverse = prog.end();
if(args.size() >= 4)
{
// forward bias
auto uwb_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_forward = prog.insert_instruction(ins, op::squeeze{{0}}, uwb_forward);
// backward bias
auto uwb_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, uwb_reverse);
}
// intial hidden state
instruction_ref ih_forward, ih_reverse;
if(args.size() >= 5)
{
// forward
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[4]);
ih_forward = prog.insert_instruction(ins, op::squeeze{{0}}, ih_forward);
// reverse
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[4]);
ih_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, ih_reverse);
}
else
{
ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret_forward = gru_oper(true,
prog,
ins,
args[0],
w_forward,
r_forward,
ih_forward,
bias_forward,
gru_op.linear_before_reset,
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1));
auto ret_reverse = rnn_oper(false,
prog,
ins,
args[0],
w_reverse,
r_reverse,
ih_reverse,
bias_reverse,
gru_op.linear_before_reset,
gru_op.actv_funcs.at(2),
gru_op.actv_funcs.at(3));
// auto final_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);
// add the dimension of num_direction
ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]);
ret_reverse[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_reverse[0]);
// concat the forward and reverse output
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
else
{
bool is_forward = (dicrt == op::gru::forward) ? true : false;
// weight matrix
auto w = prog.insert_instruction(ins, op::squeeze{{0}}, args[1]);
auto r = prog.insert_instruction(ins, op::squeeze{{0}}, args[2]);
// bias
instruction_ref bias = prog.end();
if(args.size() >= 4)
{
bias = prog.insert_instruction(ins, op::squeeze{{0}}, args[3]);
}
// intial hidden state
instruction_ref ih;
if(args.size() >= 5)
{
ih = prog.insert_instruction(ins, op::squeeze{{0}}, args[4]);
}
else
{
ih = prog.add_literal(migraphx::literal{s, data});
}
auto ret = gru_oper(
is_forward, prog, ins, args[0], w, r, ih, bias, gru_op.linear_before_reset, gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(1));
// add the dimension of num_direction
prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
}
}
}
std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref w,
instruction_ref r,
instruction_ref ih,
instruction_ref bias,
int linear_before_reset,
operation& actv_func1,
operation& actv_func2) const
{
instruction_ref hidden_out, final_out;
long seq_len = static_cast<long>(input->get_shape().lens()[0]);
long hs = static_cast<long>(r->get_shape().lens()[1]);
long seq_index = is_forward ? 0 : seq_len - 1;
migraphx::shape s(input->get_shape().type(), {1});
auto l1 = prog.add_literal(migraphx::leteral{s, {1}});
// weight matrix
std::vector<int64_t> perm{1, 0};
auto wz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, w);
auto twz = prog.insert_instruction(ins, op::transpose{perm}, wz);
auto wr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, w);
auto twr = prog.insert_instruction(ins, op::transpose{perm}, wr);
auto wh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, w);
auto twh = prog.insert_instruction(ins, op::transpose{perm}, wh);
auto rz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, r);
auto trz = prog.insert_instruction(ins, op::transpose{perm}, rz);
auto rr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, r);
auto trr = prog.insert_instruction(ins, op::transpose{perm}, rr);
auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, r);
auto trh = prog.insert_instruction(ins, op::transpose{perm}, rh);
// bias
instruction_ref br_bz, br_br, br_wbh, br_rbh, br_bh;
if (bias != prog.end())
{
auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, bias);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2*hs}}, bias);
wbh = prog.insert_instruction(ins, op::slice{{0}, {2*hs}, {3*hs}}, bias);
br_wbh = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, wbh);
auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3*hs}, {4*hs}}, bias);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4*hs}, {5*hs}}, bias);
rbh = prog.insert_instruction(ins, op::slice{{0}, {5*hs}, {6*hs}}, bias);
br_rbh = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, rbh);
auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz);
br_bz = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, bz);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
br_br = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, br);
br_bh = prog.insert_instruction(ins, op::add{}, br_wbh, br_rbh);
}
for(long i = 0; i < seq_len; i++)
{
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto xwzt = prog.insert_instruction(ins, op::dot{}, xt, twz);
auto hrzt = prog.insert_instruction(ins, op::dot{}, ih, trz);
auto xwhr_zt = prog.insert_instruction(ins, op::add{}, xwzt, hrzt);
if (bias != prog.end())
{
xwhr_zt = prog.insert_instruction(ins, op::add{}, xwhr_zt, br_bz);
}
auto zt = prog.insert_instruction(ins, actv_func1, xwhr_zt);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto xwrt = prog.insert_instruction(ins, op::dot{}, xt, twr);
auto hrrt = prog.insert_instruction(ins, op::dot{}, xt, trr);
auto xwhr_rt = prog.insert_instruction(ins, op::add{}, xwrt, hrrt);
if (bias != prog.end())
{
xwhr_rt = prog.insert_instruction(ins, op::add{}, xwhr_rt, br_br);
}
auto rt = prog.insert_instruction(ins, actv_func1, xwhr_rt);
instruction_ref xwhh_rt;
if (linear_before_reset == 0)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto xwht = prog.insert_instruction(ins, op::dot{}, xt, twh);
auto rt_ht = prog.insert_instruction(ins, op::mul{}, rt, ih);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht, trh);
xwhh_rt = prog.insert_instruction(ins, op::add{}, xwht, rt_rt);
if (bias != prog.end())
{
xwhh_rt = prog.insert_instruction(ins, op::add{}, xwhh_rt, br_bh);
}
}
else
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto xwht = prog.insert_instruction(ins, op::dot{}, xt, twh);
auto ih_rht = prog.insert_instruction(ins, op::dot{}, ih, twh);
if (bias != prog.end())
{
ih_rht = prog.insert_instruction(ins, op::add{}, ih_rht, br_rbh);
}
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ih_rht);
xwhh_rt = prog.insert_instruction(ins, op::add{}, xwht, rt_rh);
if (bias != prog.end())
{
xwhh_rt = prog.insert_instruction(ins, op::add{}, xwhh_rt, br_wbh);
}
}
ht = prog.insert_instruction(ins, actv_func2, xwhh_rt);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto 1zt = prog.insert_instruction(ins, op::sub{}, l1, zt);
auto 1ztht = prog.insert_instruction(ins, op::mul{}, 1zt, ht);
auto ztht1 = prog.insert_instruction(ins, op::mul{}, zt, ih);
ih = prog.insert_instruction(ins, op::add{}, 1ztht ztht1);
final_out = ih;
if(is_forward)
{
hidden_out = (seq_index == 0)
? ih
: prog.insert_instruction(ins, op::concat{0}, hidden_out, ih);
}
else
{
hidden_out = (seq_index == seq_len - 1)
? ih
: prog.insert_instruction(ins, op::concat{0}, ih, hidden_out);
}
seq_index = is_forward ? (seq_index + 1) : (seq_index - 1);
}
std::vector<instruction_ref> out_args;
out_args.push_back(hidden_out);
out_args.push_back(final_out);
return out_args;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment