Commit 9b512b56 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code backup

parent 11a2dd87
......@@ -924,8 +924,69 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
instruction_ref last_cell_output{};
migraphx::shape seq_shape = seq->get_shape();
migraphx::shape r_shape = r->get_shape();
long seq_len = static_cast<long>(seq_shape.lens()[0]);
long hs = static_cast<long>(r->get_shape().lens()[2]);
long hs = static_cast<long>(r_shape.lens()[2]);
std::vector<int64_t> perm{1, 0};
// w matrix
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto wi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw);
auto tran_wi = prog.insert_instruction(ins, op::transpose{perm}, wi);
auto wo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2*hs}}, sw);
auto tran_wo = prog.insert_instruction(ins, op::transpose{perm}, wo);
auto wf = prog.insert_instruction(ins, op::slice{{0}, {2*hs}, {3*hs}}, sw);
auto tran_wf = prog.insert_instruction(ins, op::transpose{perm}, wf);
auto wc = prog.insert_instruction(ins, op::slice{{0}, {3*hs}, {4*hs}}, sw);
auto tran_wc = prog.insert_instruction(ins, op::transpose{perm}, wc);
// r matrix
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto ri = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr);
auto tran_ri = prog.insert_instruction(ins, op::transpose{perm}, ri);
auto ro = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2*hs}}, sr);
auto tran_ro = prog.insert_instruction(ins, op::transpose{perm}, ro);
auto rf = prog.insert_instruction(ins, op::slice{{0}, {2*hs}, {3*hs}}, sr);
auto tran_rf = prog.insert_instruction(ins, op::transpose{perm}, rf);
auto rc = prog.insert_instruction(ins, op::slice{{0}, {3*hs}, {4*hs}}, sr);
auto tran_rc = prog.insert_instruction(ins, op::transpose{perm}, rc);
// initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
// initial cell state
auto sic = prog.insert_instruction(ins, op::sequeeze{{0}}, ic);
auto ic_shape = sic->get_shape();
// peep hole
auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph);
pphi_bcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphi);
ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2*hs}}, spph);
ppho_bcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, ppho);
pphf = prog.insert_instruction(ins, op::slice{{0}, {2*hs}, {3*hs}}, spph);
pphf_bcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphf);
for (long i = 0; i < seq_len; ++i)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri);
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_bcst, sic);
}
}
std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment