Commit 18fc2362 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

optimize the rewrite of gru operator to reduce the number of matrix multiplication calls.

parent b8090620
...@@ -489,58 +489,40 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -489,58 +489,40 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
long hs = static_cast<long>(r_shape.lens()[2]); long hs = static_cast<long>(r_shape.lens()[2]);
migraphx::shape s(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]}); migraphx::shape s(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
std::vector<int> data(s.elements(), 1); std::vector<float> data(s.elements(), 1.0f);
auto l1 = prog.add_literal(migraphx::literal{s, data}); auto l1 = prog.add_literal(migraphx::literal{s, data});
// weight matrix // w matrix squeeze to 2-dim and do a transpose
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w); auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto wz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw); auto tw = prog.insert_instruction(ins, op::transpose{perm}, sw);
auto tran_wz = prog.insert_instruction(ins, op::transpose{perm}, wz);
auto wr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw);
auto tran_wr = prog.insert_instruction(ins, op::transpose{perm}, wr);
auto wh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw);
auto tran_wh = prog.insert_instruction(ins, op::transpose{perm}, wh);
// r slide to two part, zr and h
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r); auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto rz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr); auto rzr = prog.insert_instruction(ins, op::slice{{0}, {0}, {2 * hs}}, sr);
auto tran_rz = prog.insert_instruction(ins, op::transpose{perm}, rz); auto trzr = prog.insert_instruction(ins, op::transpose{perm}, rzr);
auto rr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr);
auto tran_rr = prog.insert_instruction(ins, op::transpose{perm}, rr);
auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr); auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
auto tran_rh = prog.insert_instruction(ins, op::transpose{perm}, rh); auto trh = prog.insert_instruction(ins, op::transpose{perm}, rh);
// initial states // initial states
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
auto sih_lens = sih->get_shape().lens(); size_t bs = ih->get_shape().lens()[1];
// bias // bias
instruction_ref bwbz{}; instruction_ref bwb{};
instruction_ref brbz{}; instruction_ref brb_zr{};
instruction_ref bwbr{}; instruction_ref brb_h{};
instruction_ref brbr{};
instruction_ref bwbh{};
instruction_ref brbh{};
if(bias != prog.end()) if(bias != prog.end())
{ {
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {3 * hs}}, sbias);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias); bwb = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(3 * hs)}}, wb);
auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
bwbz = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wbz); auto rb_zr = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {5 * hs}}, sbias);
bwbr = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wbr); auto rb_h = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
bwbh = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wbh); brb_zr = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(2 * hs)}}, rb_zr);
brb_h = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(hs)}}, rb_h);
auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
brbz = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, rbz);
brbr = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, rbr);
brbh = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, rbh);
} }
for(long i = 0; i < seq_len; i++) for(long i = 0; i < seq_len; i++)
...@@ -549,73 +531,64 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -549,73 +531,64 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq); auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
instruction_ref zt{}; instruction_ref xt_w{};
instruction_ref ht{}; instruction_ref ih1_rzr{};
if(bias != prog.end()) if(bias != prog.end())
{ {
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz) xt_w = prog.insert_instruction(ins, op::dot{}, xt, tw, bwb);
auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz, bwbz); ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr, brb_zr);
auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz, brbz); }
auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz); else
zt = prog.insert_instruction(ins, actv_func1, xht_z); {
xt_w = prog.insert_instruction(ins, op::dot{}, xt, tw);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr) ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr);
auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr, bwbr); }
auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr, brbr);
auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr); auto xw_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_w);
auto rt = prog.insert_instruction(ins, actv_func1, xht_r); auto xw_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_w);
auto xw_h = prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_w);
instruction_ref xht_h{};
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh, bwbh); auto hr_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, ih1_rzr);
if(linear_before_reset == 0) auto hr_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, ih1_rzr);
auto xw_hr_z = prog.insert_instruction(ins, op::add{}, xw_z, hr_z);
auto zt = prog.insert_instruction(ins, actv_func1, xw_hr_z);
auto xw_hr_r = prog.insert_instruction(ins, op::add{}, xw_r, hr_r);
auto rt = prog.insert_instruction(ins, actv_func1, xw_hr_r);
instruction_ref hr_h{};
if(linear_before_reset == 0)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
if (bias != prog.end())
{ {
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh, brb_h);
auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh, brbh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
} }
else else
{ {
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh);
auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh, brbh);
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
} }
ht = prog.insert_instruction(ins, actv_func2, xht_h);
} }
else else
{ {
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz) // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz); instruction_ref ht1_rh{};
auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz); if (bias != prog.end())
auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz);
zt = prog.insert_instruction(ins, actv_func1, xht_z);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr);
auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr);
auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
auto rt = prog.insert_instruction(ins, actv_func1, xht_r);
instruction_ref xht_h;
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
if(linear_before_reset == 0)
{ {
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh, brb_h);
auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
} }
else else
{ {
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh);
auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh);
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
} }
ht = prog.insert_instruction(ins, actv_func2, xht_h); hr_h = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
} }
auto xw_hr_h = prog.insert_instruction(ins, op::add{}, xw_h, hr_h);
auto ht = prog.insert_instruction(ins, actv_func2, xw_hr_h);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1 // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt); auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt);
auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht); auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht);
......
...@@ -2650,7 +2650,7 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last> ...@@ -2650,7 +2650,7 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto output = p.add_instruction( auto output = p.add_instruction(
migraphx::op::gru{hidden_size, migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment