Commit 233f3bcc authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code backup

parent 1133d782
......@@ -52,7 +52,7 @@ struct rewrite_rnn
program& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
int input_forget,
const operation& actv_func1,
const operation& actv_func2,
const operation& actv_func3) const;
......
......@@ -1026,7 +1026,7 @@ struct onnx_parser
if(args.size() < 8)
{
auto ins = prog.add_instruction(op::undefined{});
args.insert(args.end(), 6 - args.size(), ins);
args.insert(args.end(), 8 - args.size(), ins);
}
// first output for concatenation of hidden states
......
......@@ -674,13 +674,130 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
{
assert(ins->name() == "lstm");
auto args = ins->inputs();
shape seq_shape = args[0]->get_shape();
std::size_t hidden_size = args[1]->get_shape().lens()[2];
std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type();
migraphx::shape ihc_shape{type, {1, batch_size, hidden_size}};
std::vector<float> ihc_data(ih_shape.elements(), 0.0);
migraphx::shape pph_shape{type, {1, 3 * hidden_size}};
std::vector<float> ppl_data(pph_shape.elements(), 0.0);
auto &actv_funcs = lstm_actv_funcs(ins);
auto lstm_op = any_cast<op::lstm>(ins->get_operator());
op::lstm::lstm_direction_t dirct = lstm_op.direction;
instruction_ref last_output{};
instruction_ref last_cell_output{};
if (dirct == op::lstm::bidirectional)
{
// input weight matrix
// input weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
// hidden state weight matrix
auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
// process bias
instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end();
if (args.size() >= 4 && args[3]->name() != "undefined")
{
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
}
// process intial hidden state, it is the 6th argument
instruction_ref ih_forward{};
instruction_ref ih_reverse{};
if(args.size() >= 6 && args[5]->name() != "undefined")
{
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
}
else
{
ih_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
ih_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
}
// process initial cell value
instruction_ref ic_forward{};
instruction_ref ic_reverse{};
if (args.size() >= 7 && args[6]->name() != "undefined")
{
ic_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[6]);
ic_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[6]);
}
else
{
ic_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
ic_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
}
// process weight of the peephole
instruction_ref pph_forward{};
instruction_ref pph_reverse{};
if (args.size() == 8 && args[7]->name() != "undefined")
{
pph_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[7]);
pph_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[7]);
}
else
{
pph_forward = prog.add_literal(migraphx::literal{pph_shape, pph_data});
pph_reverse = prog.add_literal(migraphx::literal{pph_shape, pph_data});
}
auto ret_forward = lstm_cell(true, prog, ins,
{args[0], w_forward, r_forward, bias_forward,
ih_forward, ic_forward, pph_forward},
lstm_op.input_forget,
actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2));
auto ret_reverse = lstm_cell(false, prog, ins,
{args[0], w_reverse, r_reverse, bias_reverse,
ih_reverse, ic_reverse, pph_reverse},
lstm_op.input_forget,
actv_funcs.at(3),
actv_funcs.at(4),
actv_funcs.at(5));
auto concat_output = prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// last cell output
auto concat_cell_output = prog.insert_instruction(ins, op::concat{1}, ret_forward[2], ret_reverse[2]);
last_cell_output = prog.insert_instruction(ins, squeeze{{0}}, concat_cell_output);
// the following logic is to ensure the last instruction is a concat
if (ret_forward[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
{
}
}
else
{
}
}
std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
program& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
int input_forget,
const operation& actv_func1,
const operation& actv_func2,
const operation& actv_func3) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment