Commit a7a3f867 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix comments.

parent 343a5774
......@@ -632,11 +632,7 @@ struct reshape
rdims[i] = missing_dim;
}
}
// if(dims.back() == -1)
//{
// rdims.pop_back();
// std::copy(idims.begin() + rdims.size(), idims.end(), std::back_inserter(rdims));
//}
shape s{inputs.front().type(), rdims};
if(s.elements() != inputs.front().elements())
MIGRAPHX_THROW("Wrong number of elements for reshape");
......
......@@ -739,8 +739,9 @@ struct onnx_parser
}
});
// bidirectional should have two activation functions
// if only one actv function is provides, we use it in both
// bidirectional case should have two activation functions.
// one is for forward, and the other is for reverse.
// if only one actv function is provided, we use it in both
// forward and reverse direction
if(dirct == op::rnn::bidirectional)
{
......@@ -750,9 +751,9 @@ struct onnx_parser
}
}
std::vector<operation> vec_actv_funcs;
for_each(vec_names.begin(), vec_names.end(), [&](auto& fn) {
vec_actv_funcs.push_back(map_actv_funcs[fn]);
std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& fn) {
return map_actv_funcs[fn];
});
// To be added later
......
......@@ -137,14 +137,9 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
assert(ins != rep);
if(ins == std::prev(this->end()))
{
// additional check to ensure the ins to be replaced is either
// the rnn_last_output, gru_last_output, or lstm_last_output
if(ins->name() == "rnn_last_output")
{
return replace_instruction(ins, op::identity{}, rep);
}
}
// TODO: Should it be an error if the output is empty?
if(ins->outputs().empty())
......
......@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
void rewrite_rnn::apply(program& prog) const
{
instruction_ref last_output = prog.end();
std::unordered_map<instruction_ref, instruction_ref> map_last_output;
for(auto ins : iterator_for(prog))
{
// rewrite rnn operator
......@@ -87,14 +87,15 @@ void rewrite_rnn::apply(program& prog) const
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction
// sequence len is 1
instruction_ref hidden_output{};
if(ret_forward[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
hidden_output = prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
{
......@@ -102,8 +103,9 @@ void rewrite_rnn::apply(program& prog) const
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
hidden_output = prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
map_last_output[hidden_output] = last_output;
}
else
{
......@@ -135,21 +137,23 @@ void rewrite_rnn::apply(program& prog) const
auto ret = rnn_cell(
is_forward, prog, ins, args[0], w, r, bias, ih, rnn_op.actv_funcs.at(0));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// following logic is to ensure the last instruction is a
// concat instruction
// sequence len is 1
instruction_ref hidden_output{};
if(ret[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{0}, ret[1]);
hidden_output = prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
hidden_output = prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
map_last_output[hidden_output] = last_output;
}
}
......@@ -159,16 +163,15 @@ void rewrite_rnn::apply(program& prog) const
// so we can just use it as the output here
if(ins->name() == "rnn_last_output")
{
// if rnn operator is executed, the last_output != prog.end()
if(last_output != prog.end())
auto inputs = ins->inputs();
assert(inputs.size() == 1);
auto arg = inputs[0];
if (map_last_output.count(arg) == 0)
{
prog.replace_instruction(ins, last_output);
last_output = prog.end();
}
else
{
MIGRAPHX_THROW("RNN_LAST_OUTPUT: must put after rnn operator");
MIGRAPHX_THROW("RNN_LAST_OUTPUT: no related rnn operator as its input");
}
prog.replace_instruction(ins, map_last_output[arg]);
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment