Commit 97568d53 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code backup

parent bc890c75
...@@ -677,12 +677,12 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -677,12 +677,12 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
std::size_t batch_size = seq_shape.lens()[1]; std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type(); shape::type_t type = seq_shape.type();
migraphx::shape ihc_shape{type, {1, batch_size, hidden_size}}; migraphx::shape ihc_shape{type, {1, batch_size, hidden_size}};
std::vector<float> ihc_data(ih_shape.elements(), 0.0); std::vector<float> ihc_data(ihc_shape.elements(), 0.0);
migraphx::shape pph_shape{type, {1, 3 * hidden_size}}; migraphx::shape pph_shape{type, {1, 3 * hidden_size}};
std::vector<float> ppl_data(pph_shape.elements(), 0.0); std::vector<float> pph_data(pph_shape.elements(), 0.0);
auto& actv_funcs = lstm_actv_funcs(ins); auto actv_funcs = lstm_actv_funcs(ins);
auto lstm_op = any_cast<op::lstm>(ins->get_operator()); auto lstm_op = any_cast<op::lstm>(ins->get_operator());
op::lstm::lstm_direction_t dirct = lstm_op.direction; op::lstm::lstm_direction_t dirct = lstm_op.direction;
...@@ -777,7 +777,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -777,7 +777,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
// last cell output // last cell output
auto concat_cell_output = auto concat_cell_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[2], ret_reverse[2]); prog.insert_instruction(ins, op::concat{1}, ret_forward[2], ret_reverse[2]);
last_cell_output = prog.insert_instruction(ins, squeeze{{0}}, concat_cell_output); last_cell_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_cell_output);
// the following logic is to ensure the last instruction is a concat // the following logic is to ensure the last instruction is a concat
if(ret_forward[0] == prog.end()) if(ret_forward[0] == prog.end())
...@@ -795,6 +795,106 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -795,6 +795,106 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
} }
else else
{ {
bool is_forward = (dirct == op::lstm::forward);
// weight matrices
auto w = args[1];
auto r = args[2];
// bias
instruction_ref bias = prog.end();
if (args.size() >= 4 && args[3]->name() != "undefined")
{
bias = args[3];
}
// initial hidden state
instruction_ref ih{};
if (args.size() >= 6 && args[5]->name() != "undefined")
{
ih = args[5];
}
else
{
ih = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
}
// initial cell value
instruction_ref ic{};
if (args.size() >= 7 && args[6]->name() != "undefined")
{
ic = args[6];
}
else
{
ic = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
}
// process weight of the peephole
instruction_ref pph{};
instruction_ref pph_reverse{};
if(args.size() == 8 && args[7]->name() != "undefined")
{
pph = args[7];
}
else
{
pph = prog.add_literal(migraphx::literal{pph_shape, pph_data});
}
auto ret = lstm_cell(is_forward,
prog,
ins,
{args[0], w, r, bias, ih, ic, pph},
lstm_op.input_forget,
actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
last_cell_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[2]);
if (ret[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
}
// replace the corresponding lstm_last_output instruction
// with the last_output, and the lstm_last_cell_output with
// the last_cell_output. The while loop is to handle the case
// of multiple lstm_last_output and lstm_last_cell_output
// operators
auto last_output_it = ins->outputs().begin();
while(last_output_it != ins->outputs().end())
{
last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "lstm_last_output";
});
if(last_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_output_it, last_output);
last_output_it++;
}
}
auto last_cell_output_it = ins->outputs().begin();
while(last_cell_output_it != ins->outputs().end())
{
last_cell_output_it = std::find_if(last_cell_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "lstm_last_cell_output";
});
if(last_cell_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_cell_output_it, last_cell_output);
last_cell_output_it++;
}
} }
} }
...@@ -807,6 +907,18 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -807,6 +907,18 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
const operation& actv_func2, const operation& actv_func2,
const operation& actv_func3) const const operation& actv_func3) const
{ {
// must have 7 args in the input vector
assert(inputs.size() == 7);
auto seq = inputs.at(0);
auto w = inputs.at(1);
auto r = inputs.at(2);
auto bias = inputs.at(3);
auto ih = inputs.at(4);
auto ic = inputs.at(5);
auto pph = inputs.at(6);
instruction_ref
return {}; return {};
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment