Commit b23aec08 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 2d7f3523
...@@ -50,11 +50,12 @@ void rewrite_gru::apply(program& prog) const ...@@ -50,11 +50,12 @@ void rewrite_gru::apply(program& prog) const
// intial hidden state // intial hidden state
instruction_ref ih_forward, ih_reverse; instruction_ref ih_forward, ih_reverse;
if(args.size() == 6 || (args.size() == 5 && args[4]->get_shape().lens().size() == 3)) if(args.size() == 6 ||
(args.size() == 5 && args[4]->get_shape().lens().size() == 3))
{ {
auto arg_ih = (args.size() == 6) ? args[5] : args[4]; auto arg_ih = (args.size() == 6) ? args[5] : args[4];
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, arg_ih); ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, arg_ih);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, arg_ih); ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, arg_ih);
} }
else else
{ {
...@@ -86,7 +87,8 @@ void rewrite_gru::apply(program& prog) const ...@@ -86,7 +87,8 @@ void rewrite_gru::apply(program& prog) const
gru_op.actv_funcs.at(2), gru_op.actv_funcs.at(2),
gru_op.actv_funcs.at(3)); gru_op.actv_funcs.at(3));
last_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]); last_output =
prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);
// add the dimension of num_direction // add the dimension of num_direction
ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]); ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]);
...@@ -111,9 +113,10 @@ void rewrite_gru::apply(program& prog) const ...@@ -111,9 +113,10 @@ void rewrite_gru::apply(program& prog) const
// intial hidden state // intial hidden state
instruction_ref ih; instruction_ref ih;
if(args.size() == 6 || (args.size() == 5 && args[4]->get_shape().lens().size() == 3)) if(args.size() == 6 ||
(args.size() == 5 && args[4]->get_shape().lens().size() == 3))
{ {
ih = args.size() == 6 ? args[5]: args[4]; ih = args.size() == 6 ? args[5] : args[4];
} }
else else
{ {
...@@ -143,9 +146,9 @@ void rewrite_gru::apply(program& prog) const ...@@ -143,9 +146,9 @@ void rewrite_gru::apply(program& prog) const
// operator. Intuitively, we can do a slice on its input to get // operator. Intuitively, we can do a slice on its input to get
// the last output, but it is already existed in the rnn operator, // the last output, but it is already existed in the rnn operator,
// so we can just use it as the output here // so we can just use it as the output here
if (ins->name() == "gru_last_output") if(ins->name() == "gru_last_output")
{ {
if (last_output != prog.end()) if(last_output != prog.end())
{ {
prog.replace_instruction(ins, op::identity{}, last_output); prog.replace_instruction(ins, op::identity{}, last_output);
last_output = prog.end(); last_output = prog.end();
...@@ -167,8 +170,8 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward, ...@@ -167,8 +170,8 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
operation& actv_func2) const operation& actv_func2) const
{ {
instruction_ref hidden_out, last_out; instruction_ref hidden_out, last_out;
long seq_len = static_cast<long>(input->get_shape().lens()[0]); long seq_len = static_cast<long>(input->get_shape().lens()[0]);
long hs = static_cast<long>(r->get_shape().lens()[2]); long hs = static_cast<long>(r->get_shape().lens()[2]);
migraphx::shape s(input->get_shape().type(), migraphx::shape s(input->get_shape().type(),
{input->get_shape().lens()[1], static_cast<std::size_t>(hs)}); {input->get_shape().lens()[1], static_cast<std::size_t>(hs)});
...@@ -177,24 +180,24 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward, ...@@ -177,24 +180,24 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
// weight matrix // weight matrix
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w); auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto wz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw); auto wz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw);
auto tran_wz = prog.insert_instruction(ins, op::transpose{perm}, wz); auto tran_wz = prog.insert_instruction(ins, op::transpose{perm}, wz);
auto wr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw); auto wr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw);
auto tran_wr = prog.insert_instruction(ins, op::transpose{perm}, wr); auto tran_wr = prog.insert_instruction(ins, op::transpose{perm}, wr);
auto wh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw); auto wh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw);
auto tran_wh = prog.insert_instruction(ins, op::transpose{perm}, wh); auto tran_wh = prog.insert_instruction(ins, op::transpose{perm}, wh);
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r); auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto rz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr); auto rz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr);
auto tran_rz = prog.insert_instruction(ins, op::transpose{perm}, rz); auto tran_rz = prog.insert_instruction(ins, op::transpose{perm}, rz);
auto rr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr); auto rr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr);
auto tran_rr = prog.insert_instruction(ins, op::transpose{perm}, rr); auto tran_rr = prog.insert_instruction(ins, op::transpose{perm}, rr);
auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr); auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
auto tran_rh = prog.insert_instruction(ins, op::transpose{perm}, rh); auto tran_rh = prog.insert_instruction(ins, op::transpose{perm}, rh);
// initial states // initial states
...@@ -205,24 +208,24 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward, ...@@ -205,24 +208,24 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
if(bias != prog.end()) if(bias != prog.end())
{ {
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias); auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias); auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
brcst_wbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wbh); brcst_wbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wbh);
auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias); auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias); auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias); auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
brcst_rbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rbh); brcst_rbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rbh);
auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz); auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz);
brcst_bz = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bz); brcst_bz = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bz);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr); auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
brcst_br = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, br); brcst_br = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, br);
auto bh = prog.insert_instruction(ins, op::add{}, wbh, rbh); auto bh = prog.insert_instruction(ins, op::add{}, wbh, rbh);
brcst_bh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bh); brcst_bh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bh);
} }
long seq_index = is_forward ? 0 : seq_len - 1; long seq_index = is_forward ? 0 : seq_len - 1;
...@@ -232,8 +235,8 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward, ...@@ -232,8 +235,8 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz) // equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz); auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz);
auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz); auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz);
auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz); auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz);
if(bias != prog.end()) if(bias != prog.end())
{ {
...@@ -242,8 +245,8 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward, ...@@ -242,8 +245,8 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
auto zt = prog.insert_instruction(ins, actv_func1, xht_z); auto zt = prog.insert_instruction(ins, actv_func1, xht_z);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr) // equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr); auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr);
auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr); auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr);
auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr); auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
if(bias != prog.end()) if(bias != prog.end())
{ {
...@@ -257,8 +260,8 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward, ...@@ -257,8 +260,8 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh); auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih); auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh); auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh); xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end()) if(bias != prog.end())
{ {
xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_bh); xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_bh);
...@@ -267,14 +270,14 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward, ...@@ -267,14 +270,14 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
else else
{ {
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh); auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh); auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh);
if(bias != prog.end()) if(bias != prog.end())
{ {
ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brcst_rbh); ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brcst_rbh);
} }
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh); auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh); xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end()) if(bias != prog.end())
{ {
xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_wbh); xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_wbh);
...@@ -283,11 +286,11 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward, ...@@ -283,11 +286,11 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
auto ht = prog.insert_instruction(ins, actv_func2, xht_h); auto ht = prog.insert_instruction(ins, actv_func2, xht_h);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1 // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt); auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt);
auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht); auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht);
auto zt_ht1 = prog.insert_instruction(ins, op::mul{}, zt, sih); auto zt_ht1 = prog.insert_instruction(ins, op::mul{}, zt, sih);
sih = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1); sih = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1);
last_out = prog.insert_instruction(ins, op::unsqueeze{{0}}, sih); last_out = prog.insert_instruction(ins, op::unsqueeze{{0}}, sih);
if(is_forward) if(is_forward)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment