"lmdeploy/vscode:/vscode.git/clone" did not exist on "955c019c8c148d3dfe1a861dbc346e96a323ce55"
Commit b7f5e9bd authored by Shucai Xiao's avatar Shucai Xiao
Browse files

optimized lstm_rewrite

parent 4702c17e
...@@ -738,18 +738,13 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -738,18 +738,13 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
} }
// process weight of the peephole // process weight of the peephole
instruction_ref pph_forward{}; instruction_ref pph_forward = prog.end();
instruction_ref pph_reverse{}; instruction_ref pph_reverse = prog.end();
if(args.size() == 8 && args[7]->name() != "undefined") if(args.size() == 8 && args[7]->name() != "undefined")
{ {
pph_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[7]); pph_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[7]);
pph_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[7]); pph_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[7]);
} }
else
{
pph_forward = prog.add_literal(migraphx::literal{pph_shape, pph_data});
pph_reverse = prog.add_literal(migraphx::literal{pph_shape, pph_data});
}
auto ret_forward = lstm_cell( auto ret_forward = lstm_cell(
true, true,
...@@ -830,15 +825,11 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -830,15 +825,11 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
} }
// process weight of the peephole // process weight of the peephole
instruction_ref pph{}; instruction_ref pph = prog.end();
if(args.size() == 8 && args[7]->name() != "undefined") if(args.size() == 8 && args[7]->name() != "undefined")
{ {
pph = args[7]; pph = args[7];
} }
else
{
pph = prog.add_literal(migraphx::literal{pph_shape, pph_data});
}
auto ret = lstm_cell(is_forward, auto ret = lstm_cell(is_forward,
prog, prog,
...@@ -991,18 +982,25 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -991,18 +982,25 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
} }
// peep hole // peep hole
auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph); instruction_ref pphi_brcst{};
auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph); instruction_ref ppho_brcst{};
auto pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphi); instruction_ref pphf_brcst{};
pphi_brcst = prog.insert_instruction(ins, op::contiguous{}, pphi_brcst);
auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph);
auto ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, ppho);
ppho_brcst = prog.insert_instruction(ins, op::contiguous{}, ppho_brcst);
auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph); if (pph != prog.end())
auto pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphf); {
pphf_brcst = prog.insert_instruction(ins, op::contiguous{}, pphf_brcst); auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph);
pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphi);
pphi_brcst = prog.insert_instruction(ins, op::contiguous{}, pphi_brcst);
auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph);
ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, ppho);
ppho_brcst = prog.insert_instruction(ins, op::contiguous{}, ppho_brcst);
auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph);
pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphf);
pphf_brcst = prog.insert_instruction(ins, op::contiguous{}, pphf_brcst);
}
for(long i = 0; i < seq_len; ++i) for(long i = 0; i < seq_len; ++i)
{ {
...@@ -1013,9 +1011,11 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1013,9 +1011,11 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) // equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi); auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri); auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri);
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
auto it_before_actv = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri); auto it_before_actv = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct); if (pph != prog.end()) {
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);
}
if(bias != prog.end()) if(bias != prog.end())
{ {
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, bi_brcst); it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, bi_brcst);
...@@ -1025,9 +1025,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1025,9 +1025,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
// equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) // equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
auto xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf); auto xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf);
auto ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf); auto ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf);
auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
auto ft_before_actv = prog.insert_instruction(ins, op::add{}, xt_wf, ht_rf); auto ft_before_actv = prog.insert_instruction(ins, op::add{}, xt_wf, ht_rf);
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct); if (pph != prog.end())
{
auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct);
}
if(bias != prog.end()) if(bias != prog.end())
{ {
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, bf_brcst); ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, bf_brcst);
...@@ -1053,9 +1056,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1053,9 +1056,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) // ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
auto xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo); auto xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo);
auto ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro); auto ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro);
auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
auto ot_before_actv = prog.insert_instruction(ins, op::add{}, xt_wo, ht_ro); auto ot_before_actv = prog.insert_instruction(ins, op::add{}, xt_wo, ht_ro);
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt); if (pph != prog.end())
{
auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt);
}
if(bias != prog.end()) if(bias != prog.end())
{ {
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, bo_brcst); ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, bo_brcst);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment