Commit a7a3f867 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix comments.

parent 343a5774
...@@ -632,11 +632,7 @@ struct reshape ...@@ -632,11 +632,7 @@ struct reshape
rdims[i] = missing_dim; rdims[i] = missing_dim;
} }
} }
// if(dims.back() == -1)
//{
// rdims.pop_back();
// std::copy(idims.begin() + rdims.size(), idims.end(), std::back_inserter(rdims));
//}
shape s{inputs.front().type(), rdims}; shape s{inputs.front().type(), rdims};
if(s.elements() != inputs.front().elements()) if(s.elements() != inputs.front().elements())
MIGRAPHX_THROW("Wrong number of elements for reshape"); MIGRAPHX_THROW("Wrong number of elements for reshape");
......
...@@ -739,8 +739,9 @@ struct onnx_parser ...@@ -739,8 +739,9 @@ struct onnx_parser
} }
}); });
// bidirectional should have two activation functions // bidirectional case should have two activation functions.
// if only one actv function is provides, we use it in both // one is for forward, and the other is for reverse.
// if only one actv function is provided, we use it in both
// forward and reverse direction // forward and reverse direction
if(dirct == op::rnn::bidirectional) if(dirct == op::rnn::bidirectional)
{ {
...@@ -750,9 +751,9 @@ struct onnx_parser ...@@ -750,9 +751,9 @@ struct onnx_parser
} }
} }
std::vector<operation> vec_actv_funcs; std::vector<operation> vec_actv_funcs(vec_names.size());
for_each(vec_names.begin(), vec_names.end(), [&](auto& fn) { std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& fn) {
vec_actv_funcs.push_back(map_actv_funcs[fn]); return map_actv_funcs[fn];
}); });
// To be added later // To be added later
......
...@@ -138,12 +138,7 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re ...@@ -138,12 +138,7 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
if(ins == std::prev(this->end())) if(ins == std::prev(this->end()))
{ {
// additional check to ensure the ins to be replaced is either return replace_instruction(ins, op::identity{}, rep);
// the rnn_last_output, gru_last_output, or lstm_last_output
if(ins->name() == "rnn_last_output")
{
return replace_instruction(ins, op::identity{}, rep);
}
} }
// TODO: Should it be an error if the output is empty? // TODO: Should it be an error if the output is empty?
......
...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
void rewrite_rnn::apply(program& prog) const void rewrite_rnn::apply(program& prog) const
{ {
instruction_ref last_output = prog.end(); std::unordered_map<instruction_ref, instruction_ref> map_last_output;
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(prog))
{ {
// rewrite rnn operator // rewrite rnn operator
...@@ -87,14 +87,15 @@ void rewrite_rnn::apply(program& prog) const ...@@ -87,14 +87,15 @@ void rewrite_rnn::apply(program& prog) const
auto concat_output = auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output); auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// The following logic is to ensure the last instruction rewritten from // The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction // rnn operator is a concat instruction
// sequence len is 1 // sequence len is 1
instruction_ref hidden_output{};
if(ret_forward[0] == prog.end()) if(ret_forward[0] == prog.end())
{ {
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); hidden_output = prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
} }
else else
{ {
...@@ -102,8 +103,9 @@ void rewrite_rnn::apply(program& prog) const ...@@ -102,8 +103,9 @@ void rewrite_rnn::apply(program& prog) const
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]); prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] = ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]); prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]}); hidden_output = prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
} }
map_last_output[hidden_output] = last_output;
} }
else else
{ {
...@@ -135,21 +137,23 @@ void rewrite_rnn::apply(program& prog) const ...@@ -135,21 +137,23 @@ void rewrite_rnn::apply(program& prog) const
auto ret = rnn_cell( auto ret = rnn_cell(
is_forward, prog, ins, args[0], w, r, bias, ih, rnn_op.actv_funcs.at(0)); is_forward, prog, ins, args[0], w, r, bias, ih, rnn_op.actv_funcs.at(0));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]); auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// following logic is to ensure the last instruction is a // following logic is to ensure the last instruction is a
// concat instruction // concat instruction
// sequence len is 1 // sequence len is 1
instruction_ref hidden_output{};
if(ret[0] == prog.end()) if(ret[0] == prog.end())
{ {
prog.replace_instruction(ins, op::concat{0}, ret[1]); hidden_output = prog.replace_instruction(ins, op::concat{0}, ret[1]);
} }
else else
{ {
auto concat_arg0 = is_forward ? ret[0] : ret[1]; auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0]; auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1); hidden_output = prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
} }
map_last_output[hidden_output] = last_output;
} }
} }
...@@ -159,16 +163,15 @@ void rewrite_rnn::apply(program& prog) const ...@@ -159,16 +163,15 @@ void rewrite_rnn::apply(program& prog) const
// so we can just use it as the output here // so we can just use it as the output here
if(ins->name() == "rnn_last_output") if(ins->name() == "rnn_last_output")
{ {
// if rnn operator is executed, the last_output != prog.end() auto inputs = ins->inputs();
if(last_output != prog.end()) assert(inputs.size() == 1);
auto arg = inputs[0];
if (map_last_output.count(arg) == 0)
{ {
prog.replace_instruction(ins, last_output); MIGRAPHX_THROW("RNN_LAST_OUTPUT: no related rnn operator as its input");
last_output = prog.end();
}
else
{
MIGRAPHX_THROW("RNN_LAST_OUTPUT: must put after rnn operator");
} }
prog.replace_instruction(ins, map_last_output[arg]);
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment