Commit 2d7f3523 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

rewrite the gru operator to support two outputs.

parent 1fbe8c48
...@@ -1167,6 +1167,20 @@ struct rnn_last_output ...@@ -1167,6 +1167,20 @@ struct rnn_last_output
} }
}; };
struct gru_last_output
{
std::string name() const { return "gru_last_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
dims.erase(dims.begin());
return {inputs[0].type(), dims};
}
};
} // namespace op } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -13,7 +13,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -13,7 +13,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
/** /**
* Rewrite rnn to gemm and add. * Rewrite gru to gemm, mul, and add.
*/ */
struct rewrite_gru struct rewrite_gru
{ {
...@@ -21,14 +21,14 @@ struct rewrite_gru ...@@ -21,14 +21,14 @@ struct rewrite_gru
void apply(program& prog) const; void apply(program& prog) const;
private: private:
std::vector<instruction_ref> gru_oper(bool is_forward, std::vector<instruction_ref> gru_cell(bool is_forward,
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
instruction_ref input, instruction_ref input,
instruction_ref wx, instruction_ref wx,
instruction_ref wh, instruction_ref wh,
instruction_ref ih,
instruction_ref bias, instruction_ref bias,
instruction_ref ih,
int linear_before_reset, int linear_before_reset,
operation& actv_func1, operation& actv_func1,
operation& actv_func2) const; operation& actv_func2) const;
......
...@@ -732,14 +732,14 @@ struct onnx_parser ...@@ -732,14 +732,14 @@ struct onnx_parser
std::move(args)); std::move(args));
result.push_back(hidden_states); result.push_back(hidden_states);
// second out for the last hidden state // second output for the last hidden state
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states); auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
result.push_back(last_output); result.push_back(last_output);
return result; return result;
} }
instruction_ref std::vector<instruction_ref>
parse_gru(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_gru(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
migraphx::shape input_shape = args[0]->get_shape(); migraphx::shape input_shape = args[0]->get_shape();
...@@ -842,9 +842,18 @@ struct onnx_parser ...@@ -842,9 +842,18 @@ struct onnx_parser
linear_before_reset = parse_value(attributes.at("linear_before_reset")).at<int>(); linear_before_reset = parse_value(attributes.at("linear_before_reset")).at<int>();
} }
return prog.add_instruction( std::vector<instruction_ref> result;
// first output for concatenation of hidden states
auto hidden_states = prog.add_instruction(
op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset}, op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
std::move(args)); std::move(args));
result.push_back(hidden_states);
// second output for last gru output
auto last_output = prog.add_instruction(op::gru_last_output{}, hidden_states);
result.push_back(last_output);
return result;
} }
void parse_from(std::istream& is) void parse_from(std::istream& is)
......
...@@ -10,13 +10,11 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,13 +10,11 @@ inline namespace MIGRAPHX_INLINE_NS {
void rewrite_gru::apply(program& prog) const void rewrite_gru::apply(program& prog) const
{ {
instruction_ref last_output = prog.end();
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(prog))
{ {
if(ins->name() != "gru") if(ins->name() == "gru")
{ {
continue;
}
// could be 3 to 5 inputs (though onnx::rnn has 6 inputs, // could be 3 to 5 inputs (though onnx::rnn has 6 inputs,
// the 5th one is undefined and ignored by protobuf. so // the 5th one is undefined and ignored by protobuf. so
// we need to process up to 5 inputs // we need to process up to 5 inputs
...@@ -26,52 +24,37 @@ void rewrite_gru::apply(program& prog) const ...@@ -26,52 +24,37 @@ void rewrite_gru::apply(program& prog) const
std::size_t hidden_size = args[2]->get_shape().lens()[2]; std::size_t hidden_size = args[2]->get_shape().lens()[2];
std::size_t batchs = seq_shape.lens()[1]; std::size_t batchs = seq_shape.lens()[1];
shape::type_t type = seq_shape.type(); shape::type_t type = seq_shape.type();
migraphx::shape ih_shape{type, {batchs, hidden_size}}; migraphx::shape ih_shape{type, {1, batchs, hidden_size}};
std::vector<char> data(ih_shape.bytes(), 0); std::vector<char> data(ih_shape.bytes(), 0);
auto gru_op = any_cast<op::gru>(ins->get_operator()); auto gru_op = any_cast<op::gru>(ins->get_operator());
op::gru::gru_direction_t dicrt = gru_op.direction; op::gru::gru_direction_t dicrt = gru_op.direction;
if(dicrt == op::gru::bidirectional) if(dicrt == op::gru::bidirectional)
{ {
// forward weight // w weight matrix
auto uw_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]); auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto w_forward = prog.insert_instruction(ins, op::squeeze{{0}}, uw_forward); auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
auto ur_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto r_forward = prog.insert_instruction(ins, op::squeeze{{0}}, ur_forward);
// reverse weight // r weight matrix
auto uw_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]); auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto w_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, uw_reverse); auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
auto ur_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]); // bias
auto r_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, ur_reverse);
// process bias
instruction_ref bias_forward, bias_reverse; instruction_ref bias_forward, bias_reverse;
bias_forward = bias_reverse = prog.end(); bias_forward = bias_reverse = prog.end();
if(args.size() >= 4) if(args.size() >= 4)
{ {
// forward bias bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
auto uwb_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]); bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
bias_forward = prog.insert_instruction(ins, op::squeeze{{0}}, uwb_forward);
// backward bias
auto uwb_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, uwb_reverse);
} }
// intial hidden state // intial hidden state
instruction_ref ih_forward, ih_reverse; instruction_ref ih_forward, ih_reverse;
if(args.size() >= 5) if(args.size() == 6 || (args.size() == 5 && args[4]->get_shape().lens().size() == 3))
{ {
// forward auto arg_ih = (args.size() == 6) ? args[5] : args[4];
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[4]); ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, arg_ih);
ih_forward = prog.insert_instruction(ins, op::squeeze{{0}}, ih_forward); ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, arg_ih);
// reverse
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[4]);
ih_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, ih_reverse);
} }
else else
{ {
...@@ -79,32 +62,31 @@ void rewrite_gru::apply(program& prog) const ...@@ -79,32 +62,31 @@ void rewrite_gru::apply(program& prog) const
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data}); ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
} }
auto ret_forward = gru_oper(true, auto ret_forward = gru_cell(true,
prog, prog,
ins, ins,
args[0], args[0],
w_forward, w_forward,
r_forward, r_forward,
ih_forward,
bias_forward, bias_forward,
ih_forward,
gru_op.linear_before_reset, gru_op.linear_before_reset,
gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1)); gru_op.actv_funcs.at(1));
auto ret_reverse = gru_oper(false, auto ret_reverse = gru_cell(false,
prog, prog,
ins, ins,
args[0], args[0],
w_reverse, w_reverse,
r_reverse, r_reverse,
ih_reverse,
bias_reverse, bias_reverse,
ih_reverse,
gru_op.linear_before_reset, gru_op.linear_before_reset,
gru_op.actv_funcs.at(2), gru_op.actv_funcs.at(2),
gru_op.actv_funcs.at(3)); gru_op.actv_funcs.at(3));
// auto final_output = last_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);
// prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);
// add the dimension of num_direction // add the dimension of num_direction
ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]); ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]);
...@@ -117,61 +99,76 @@ void rewrite_gru::apply(program& prog) const ...@@ -117,61 +99,76 @@ void rewrite_gru::apply(program& prog) const
{ {
bool is_forward = (dicrt == op::gru::forward) ? true : false; bool is_forward = (dicrt == op::gru::forward) ? true : false;
// weight matrix // weight matrix
auto w = prog.insert_instruction(ins, op::squeeze{{0}}, args[1]); auto w = args[1];
auto r = prog.insert_instruction(ins, op::squeeze{{0}}, args[2]); auto r = args[2];
// bias // bias
instruction_ref bias = prog.end(); instruction_ref bias = prog.end();
if(args.size() >= 4) if(args.size() >= 4)
{ {
bias = prog.insert_instruction(ins, op::squeeze{{0}}, args[3]); bias = args[3];
} }
// intial hidden state // intial hidden state
instruction_ref ih; instruction_ref ih;
if(args.size() >= 5) if(args.size() == 6 || (args.size() == 5 && args[4]->get_shape().lens().size() == 3))
{ {
ih = prog.insert_instruction(ins, op::squeeze{{0}}, args[4]); ih = args.size() == 6 ? args[5]: args[4];
} }
else else
{ {
ih = prog.add_literal(migraphx::literal{ih_shape, data}); ih = prog.add_literal(migraphx::literal{ih_shape, data});
} }
auto ret = gru_oper(is_forward, auto ret = gru_cell(is_forward,
prog, prog,
ins, ins,
args[0], args[0],
w, w,
r, r,
ih,
bias, bias,
ih,
gru_op.linear_before_reset, gru_op.linear_before_reset,
gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1)); gru_op.actv_funcs.at(1));
last_output = ret[1];
// add the dimension of num_direction // add the dimension of num_direction
prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]); prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
} }
} }
// rewrite the gru_last_output operator that right after the gru
// operator. Intuitively, we can do a slice on its input to get
// the last output, but it is already existed in the rnn operator,
// so we can just use it as the output here
if (ins->name() == "gru_last_output")
{
if (last_output != prog.end())
{
prog.replace_instruction(ins, op::identity{}, last_output);
last_output = prog.end();
}
}
}
} }
std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward, std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
instruction_ref input, instruction_ref input,
instruction_ref w, instruction_ref w,
instruction_ref r, instruction_ref r,
instruction_ref ih,
instruction_ref bias, instruction_ref bias,
instruction_ref ih,
int linear_before_reset, int linear_before_reset,
operation& actv_func1, operation& actv_func1,
operation& actv_func2) const operation& actv_func2) const
{ {
instruction_ref hidden_out, final_out; instruction_ref hidden_out, last_out;
long seq_len = static_cast<long>(input->get_shape().lens()[0]); long seq_len = static_cast<long>(input->get_shape().lens()[0]);
long hs = static_cast<long>(r->get_shape().lens()[1]); long hs = static_cast<long>(r->get_shape().lens()[2]);
long seq_index = is_forward ? 0 : seq_len - 1;
migraphx::shape s(input->get_shape().type(), migraphx::shape s(input->get_shape().type(),
{input->get_shape().lens()[1], static_cast<std::size_t>(hs)}); {input->get_shape().lens()[1], static_cast<std::size_t>(hs)});
...@@ -180,122 +177,136 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward, ...@@ -180,122 +177,136 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
// weight matrix // weight matrix
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
auto wz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, w); auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto twz = prog.insert_instruction(ins, op::transpose{perm}, wz); auto wz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw);
auto wr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, w); auto tran_wz = prog.insert_instruction(ins, op::transpose{perm}, wz);
auto twr = prog.insert_instruction(ins, op::transpose{perm}, wr);
auto wh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, w); auto wr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw);
auto twh = prog.insert_instruction(ins, op::transpose{perm}, wh); auto tran_wr = prog.insert_instruction(ins, op::transpose{perm}, wr);
auto rz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, r); auto wh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw);
auto trz = prog.insert_instruction(ins, op::transpose{perm}, rz); auto tran_wh = prog.insert_instruction(ins, op::transpose{perm}, wh);
auto rr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, r);
auto trr = prog.insert_instruction(ins, op::transpose{perm}, rr); auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, r); auto rz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr);
auto trh = prog.insert_instruction(ins, op::transpose{perm}, rh); auto tran_rz = prog.insert_instruction(ins, op::transpose{perm}, rz);
auto rr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr);
auto tran_rr = prog.insert_instruction(ins, op::transpose{perm}, rr);
auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
auto tran_rh = prog.insert_instruction(ins, op::transpose{perm}, rh);
// initial states
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
// bias // bias
instruction_ref br_bz, br_br, br_wbh, br_rbh, br_bh; instruction_ref brcst_bz, brcst_br, brcst_wbh, brcst_rbh, brcst_bh;
if(bias != prog.end()) if(bias != prog.end())
{ {
auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, bias); auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, bias); auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
br_wbh = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, wbh); auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
brcst_wbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wbh);
auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, bias); auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, bias); auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, bias); auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
br_rbh = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, rbh); brcst_rbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rbh);
auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz); auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz);
br_bz = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, bz); brcst_bz = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bz);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr); auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
br_br = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, br); brcst_br = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, br);
auto bh = prog.insert_instruction(ins, op::add{}, wbh, rbh); auto bh = prog.insert_instruction(ins, op::add{}, wbh, rbh);
br_bh = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, bh); brcst_bh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bh);
} }
long seq_index = is_forward ? 0 : seq_len - 1;
for(long i = 0; i < seq_len; i++) for(long i = 0; i < seq_len; i++)
{ {
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input); auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz) // equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto xwzt = prog.insert_instruction(ins, op::dot{}, xt, twz); auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz);
auto hrzt = prog.insert_instruction(ins, op::dot{}, ih, trz); auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz);
auto xwhr_zt = prog.insert_instruction(ins, op::add{}, xwzt, hrzt); auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz);
if(bias != prog.end()) if(bias != prog.end())
{ {
xwhr_zt = prog.insert_instruction(ins, op::add{}, xwhr_zt, br_bz); xht_z = prog.insert_instruction(ins, op::add{}, xht_z, brcst_bz);
} }
auto zt = prog.insert_instruction(ins, actv_func1, xwhr_zt); auto zt = prog.insert_instruction(ins, actv_func1, xht_z);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr) // equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto xwrt = prog.insert_instruction(ins, op::dot{}, xt, twr); auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr);
auto hrrt = prog.insert_instruction(ins, op::dot{}, ih, trr); auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr);
auto xwhr_rt = prog.insert_instruction(ins, op::add{}, xwrt, hrrt); auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
if(bias != prog.end()) if(bias != prog.end())
{ {
xwhr_rt = prog.insert_instruction(ins, op::add{}, xwhr_rt, br_br); xht_r = prog.insert_instruction(ins, op::add{}, xht_r, brcst_br);
} }
auto rt = prog.insert_instruction(ins, actv_func1, xwhr_rt); auto rt = prog.insert_instruction(ins, actv_func1, xht_r);
instruction_ref xwhh_rt; instruction_ref xht_h;
if(linear_before_reset == 0) if(linear_before_reset == 0)
{ {
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto xwht = prog.insert_instruction(ins, op::dot{}, xt, twh); auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
auto rt_ht = prog.insert_instruction(ins, op::mul{}, rt, ih); auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht, trh); auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh);
xwhh_rt = prog.insert_instruction(ins, op::add{}, xwht, rt_rh); xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end()) if(bias != prog.end())
{ {
xwhh_rt = prog.insert_instruction(ins, op::add{}, xwhh_rt, br_bh); xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_bh);
} }
} }
else else
{ {
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto xwht = prog.insert_instruction(ins, op::dot{}, xt, twh); auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
auto ih_rht = prog.insert_instruction(ins, op::dot{}, ih, trh); auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh);
if(bias != prog.end()) if(bias != prog.end())
{ {
ih_rht = prog.insert_instruction(ins, op::add{}, ih_rht, br_rbh); ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brcst_rbh);
} }
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ih_rht); auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
xwhh_rt = prog.insert_instruction(ins, op::add{}, xwht, rt_rh); xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end()) if(bias != prog.end())
{ {
xwhh_rt = prog.insert_instruction(ins, op::add{}, xwhh_rt, br_wbh); xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_wbh);
} }
} }
auto ht = prog.insert_instruction(ins, actv_func2, xwhh_rt); auto ht = prog.insert_instruction(ins, actv_func2, xht_h);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1 // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto z1t = prog.insert_instruction(ins, op::sub{}, l1, zt); auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt);
auto z1tht = prog.insert_instruction(ins, op::mul{}, z1t, ht); auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht);
auto ztht1 = prog.insert_instruction(ins, op::mul{}, zt, ih); auto zt_ht1 = prog.insert_instruction(ins, op::mul{}, zt, sih);
ih = prog.insert_instruction(ins, op::add{}, z1tht, ztht1); sih = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1);
final_out = prog.insert_instruction(ins, op::unsqueeze{{0}}, ih); last_out = prog.insert_instruction(ins, op::unsqueeze{{0}}, sih);
if(is_forward) if(is_forward)
{ {
hidden_out = (seq_index == 0) hidden_out = (seq_index == 0)
? final_out ? last_out
: prog.insert_instruction(ins, op::concat{0}, hidden_out, final_out); : prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
} }
else else
{ {
hidden_out = (seq_index == seq_len - 1) hidden_out = (seq_index == seq_len - 1)
? final_out ? last_out
: prog.insert_instruction(ins, op::concat{0}, final_out, hidden_out); : prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
} }
seq_index = is_forward ? (seq_index + 1) : (seq_index - 1); seq_index = is_forward ? (seq_index + 1) : (seq_index - 1);
} }
std::vector<instruction_ref> out_args; std::vector<instruction_ref> out_args;
out_args.push_back(hidden_out); out_args.push_back(hidden_out);
out_args.push_back(final_out); out_args.push_back(last_out);
return out_args; return out_args;
} }
......
...@@ -26,7 +26,7 @@ void rewrite_rnn::apply(program& prog) const ...@@ -26,7 +26,7 @@ void rewrite_rnn::apply(program& prog) const
std::size_t hidden_size = args[1]->get_shape().lens()[1]; std::size_t hidden_size = args[1]->get_shape().lens()[1];
std::size_t batch_size = seq_shape.lens()[1]; std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type(); shape::type_t type = seq_shape.type();
migraphx::shape ih_shape{type, {batch_size, hidden_size}}; migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<char> data(ih_shape.bytes(), 0); std::vector<char> data(ih_shape.bytes(), 0);
auto rnn_op = any_cast<op::rnn>(ins->get_operator()); auto rnn_op = any_cast<op::rnn>(ins->get_operator());
...@@ -133,7 +133,7 @@ void rewrite_rnn::apply(program& prog) const ...@@ -133,7 +133,7 @@ void rewrite_rnn::apply(program& prog) const
} }
// rewrite the rnn_last_output operator that right after the rnn // rewrite the rnn_last_output operator that right after the rnn
// operator. Intuitively, we can do a slice on the input to get // operator. Intuitively, we can do a slice on its input to get
// the last output, but it is already existed in the rnn operator, // the last output, but it is already existed in the rnn operator,
// so we can just use it as the output here // so we can just use it as the output here
if(ins->name() == "rnn_last_output") if(ins->name() == "rnn_last_output")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment