Commit 5bd83328 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix bugs in testing lstm opeerators

parent 9c51e846
...@@ -984,22 +984,32 @@ struct onnx_parser ...@@ -984,22 +984,32 @@ struct onnx_parser
// provided. This may need change later // provided. This may need change later
switch(vec_names.size()) switch(vec_names.size())
{ {
case 1: vec_names.insert(vec_names.end(), 5, vec_names.back()); break; case 1:
vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0),
vec_names.at(0), vec_names.at(0), vec_names.at(0)};
break;
case 2: case 2:
// repeat the 2nd actv func once, then repeat all three another time // repeat the 2nd actv func once, then repeat all three another time
vec_names.push_back(vec_names.back()); vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1),
vec_names.insert(vec_names.end(), vec_names.begin(), vec_names.end()); vec_names.at(0), vec_names.at(1), vec_names.at(1)};
break; break;
case 3: case 3:
// repeat all three actv funcs once // repeat all three actv funcs once
vec_names.insert(vec_names.end(), vec_names.begin(), vec_names.end()); vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(2),
vec_names.at(0), vec_names.at(1), vec_names.at(2)};
break; break;
case 4: vec_names.insert(vec_names.end(), 2, vec_names.back()); break; case 4:
vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(2),
vec_names.at(3), vec_names.at(3), vec_names.at(3)};
break;
case 5: vec_names.push_back(vec_names.back()); break; case 5:
vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(2),
vec_names.at(3), vec_names.at(4), vec_names.at(4)};
break;
default: break; default: break;
} }
...@@ -1008,11 +1018,13 @@ struct onnx_parser ...@@ -1008,11 +1018,13 @@ struct onnx_parser
{ {
switch(vec_names.size()) switch(vec_names.size())
{ {
case 1: vec_names.insert(vec_names.end(), 2, vec_names.back()); break; case 1:
vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)};
break;
case 2: case 2:
// repeat the 2nd actv func once, so we have 3 actv funcs // repeat the 2nd actv func once, so we have 3 actv funcs
vec_names.push_back(vec_names.back()); vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)};
break; break;
default: break; default: break;
......
...@@ -674,7 +674,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -674,7 +674,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
auto args = ins->inputs(); auto args = ins->inputs();
shape seq_shape = args[0]->get_shape(); shape seq_shape = args[0]->get_shape();
std::size_t hidden_size = args[1]->get_shape().lens()[2]; std::size_t hidden_size = args[2]->get_shape().lens()[2];
std::size_t batch_size = seq_shape.lens()[1]; std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type(); shape::type_t type = seq_shape.type();
migraphx::shape ihc_shape{type, {1, batch_size, hidden_size}}; migraphx::shape ihc_shape{type, {1, batch_size, hidden_size}};
...@@ -831,7 +831,6 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -831,7 +831,6 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
// process weight of the peephole // process weight of the peephole
instruction_ref pph{}; instruction_ref pph{};
instruction_ref pph_reverse{};
if(args.size() == 8 && args[7]->name() != "undefined") if(args.size() == 8 && args[7]->name() != "undefined")
{ {
pph = args[7]; pph = args[7];
...@@ -995,12 +994,15 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -995,12 +994,15 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph); auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph); auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph);
auto pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphi); auto pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphi);
pphi_brcst = prog.insert_instruction(ins, op::contiguous{}, pphi_brcst);
auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph); auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph);
auto ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, ppho); auto ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, ppho);
ppho_brcst = prog.insert_instruction(ins, op::contiguous{}, ppho_brcst);
auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph); auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph);
auto pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphf); auto pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphf);
pphf_brcst = prog.insert_instruction(ins, op::contiguous{}, pphf_brcst);
for(long i = 0; i < seq_len; ++i) for(long i = 0; i < seq_len; ++i)
{ {
...@@ -1012,7 +1014,8 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1012,7 +1014,8 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi); auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri); auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri);
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic); auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
auto it_before_actv = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri, pphi_ct); auto it_before_actv = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);
if(bias != prog.end()) if(bias != prog.end())
{ {
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, bi_brcst); it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, bi_brcst);
...@@ -1023,7 +1026,8 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1023,7 +1026,8 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf); auto xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf);
auto ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf); auto ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf);
auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic); auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
auto ft_before_actv = prog.insert_instruction(ins, op::add{}, xt_wf, ht_rf, pphf_ct); auto ft_before_actv = prog.insert_instruction(ins, op::add{}, xt_wf, ht_rf);
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct);
if(bias != prog.end()) if(bias != prog.end())
{ {
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, bf_brcst); ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, bf_brcst);
...@@ -1050,7 +1054,8 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1050,7 +1054,8 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo); auto xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo);
auto ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro); auto ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro);
auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt); auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
auto ot_before_actv = prog.insert_instruction(ins, op::add{}, xt_wo, ht_ro, ppho_cellt); auto ot_before_actv = prog.insert_instruction(ins, op::add{}, xt_wo, ht_ro);
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt);
if(bias != prog.end()) if(bias != prog.end())
{ {
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, bo_brcst); ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, bo_brcst);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment