"include/ck/utility/magic_division.hpp" did not exist on "3bf52e60c5374c9a63256dff5e3442a4046c81dc"
Commit 18fc2362 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

optimize the rewrite of gru operator to reduce the number of matrix multiplication calls.

parent b8090620
......@@ -489,58 +489,40 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
long hs = static_cast<long>(r_shape.lens()[2]);
migraphx::shape s(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
std::vector<int> data(s.elements(), 1);
std::vector<float> data(s.elements(), 1.0f);
auto l1 = prog.add_literal(migraphx::literal{s, data});
// weight matrix
// w matrix squeeze to 2-dim and do a transpose
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto wz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw);
auto tran_wz = prog.insert_instruction(ins, op::transpose{perm}, wz);
auto wr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw);
auto tran_wr = prog.insert_instruction(ins, op::transpose{perm}, wr);
auto wh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw);
auto tran_wh = prog.insert_instruction(ins, op::transpose{perm}, wh);
auto tw = prog.insert_instruction(ins, op::transpose{perm}, sw);
// r slide to two part, zr and h
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto rz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr);
auto tran_rz = prog.insert_instruction(ins, op::transpose{perm}, rz);
auto rr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr);
auto tran_rr = prog.insert_instruction(ins, op::transpose{perm}, rr);
auto rzr = prog.insert_instruction(ins, op::slice{{0}, {0}, {2 * hs}}, sr);
auto trzr = prog.insert_instruction(ins, op::transpose{perm}, rzr);
auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
auto tran_rh = prog.insert_instruction(ins, op::transpose{perm}, rh);
auto trh = prog.insert_instruction(ins, op::transpose{perm}, rh);
// initial states
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
auto sih_lens = sih->get_shape().lens();
size_t bs = ih->get_shape().lens()[1];
// bias
instruction_ref bwbz{};
instruction_ref brbz{};
instruction_ref bwbr{};
instruction_ref brbr{};
instruction_ref bwbh{};
instruction_ref brbh{};
instruction_ref bwb{};
instruction_ref brb_zr{};
instruction_ref brb_h{};
if(bias != prog.end())
{
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
bwbz = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wbz);
bwbr = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wbr);
bwbh = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wbh);
auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
brbz = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, rbz);
brbr = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, rbr);
brbh = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, rbh);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {3 * hs}}, sbias);
bwb = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(3 * hs)}}, wb);
auto rb_zr = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {5 * hs}}, sbias);
auto rb_h = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
brb_zr = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(2 * hs)}}, rb_zr);
brb_h = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(hs)}}, rb_h);
}
for(long i = 0; i < seq_len; i++)
......@@ -549,73 +531,64 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
instruction_ref zt{};
instruction_ref ht{};
instruction_ref xt_w{};
instruction_ref ih1_rzr{};
if(bias != prog.end())
{
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz, bwbz);
auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz, brbz);
auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz);
zt = prog.insert_instruction(ins, actv_func1, xht_z);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr, bwbr);
auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr, brbr);
auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
auto rt = prog.insert_instruction(ins, actv_func1, xht_r);
instruction_ref xht_h{};
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh, bwbh);
if(linear_before_reset == 0)
xt_w = prog.insert_instruction(ins, op::dot{}, xt, tw, bwb);
ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr, brb_zr);
}
else
{
xt_w = prog.insert_instruction(ins, op::dot{}, xt, tw);
ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr);
}
auto xw_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_w);
auto xw_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_w);
auto xw_h = prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_w);
auto hr_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, ih1_rzr);
auto hr_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, ih1_rzr);
auto xw_hr_z = prog.insert_instruction(ins, op::add{}, xw_z, hr_z);
auto zt = prog.insert_instruction(ins, actv_func1, xw_hr_z);
auto xw_hr_r = prog.insert_instruction(ins, op::add{}, xw_r, hr_r);
auto rt = prog.insert_instruction(ins, actv_func1, xw_hr_r);
instruction_ref hr_h{};
if(linear_before_reset == 0)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
if (bias != prog.end())
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh, brbh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh, brb_h);
}
else
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh, brbh);
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh);
}
ht = prog.insert_instruction(ins, actv_func2, xht_h);
}
else
{
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz);
auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz);
auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz);
zt = prog.insert_instruction(ins, actv_func1, xht_z);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr);
auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr);
auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
auto rt = prog.insert_instruction(ins, actv_func1, xht_r);
instruction_ref xht_h;
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
if(linear_before_reset == 0)
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
instruction_ref ht1_rh{};
if (bias != prog.end())
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh, brb_h);
}
else
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh);
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh);
}
ht = prog.insert_instruction(ins, actv_func2, xht_h);
hr_h = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
}
auto xw_hr_h = prog.insert_instruction(ins, op::add{}, xw_h, hr_h);
auto ht = prog.insert_instruction(ins, actv_func2, xw_hr_h);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt);
auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht);
......
......@@ -2650,7 +2650,7 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
auto und = p.add_instruction(migraphx::op::undefined{});
auto output = p.add_instruction(
migraphx::op::gru{hidden_size,
migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment