Commit 03b59435 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

refine the implementation of gru operator.

parent fb481fed
...@@ -85,16 +85,26 @@ void rewrite_gru::apply(program& prog) const ...@@ -85,16 +85,26 @@ void rewrite_gru::apply(program& prog) const
gru_op.actv_funcs.at(2), gru_op.actv_funcs.at(2),
gru_op.actv_funcs.at(3)); gru_op.actv_funcs.at(3));
auto last_output = auto concat_output =
prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]); prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// add the dimension of num_direction
ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]); // The following logic is to ensure the last instruction rewritten
ret_reverse[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_reverse[0]); // from gru operator is a concat
instruction_ref hidden_state{};
// concat the forward and reverse output if (ret_forward[0] == prog.end())
auto hidden_state = {
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]}); hidden_state = prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
{
ret_forward[0] =
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
hidden_state = prog.replace_instruction(
ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
map_last_output[hidden_state] = last_output; map_last_output[hidden_state] = last_output;
} }
else else
...@@ -134,10 +144,19 @@ void rewrite_gru::apply(program& prog) const ...@@ -134,10 +144,19 @@ void rewrite_gru::apply(program& prog) const
gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1)); gru_op.actv_funcs.at(1));
auto last_output = ret[1]; auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// add the dimension of num_direction instruction_ref hidden_state{};
auto hidden_state = prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]); if (ret[0] == prog.end())
{
hidden_state = prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[1];
hidden_state = prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
map_last_output[hidden_state] = last_output; map_last_output[hidden_state] = last_output;
} }
} }
...@@ -168,7 +187,7 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward, ...@@ -168,7 +187,7 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
operation& actv_func1, operation& actv_func1,
operation& actv_func2) const operation& actv_func2) const
{ {
instruction_ref hidden_out, last_out; instruction_ref hidden_states = prog.end(), last_output;
long seq_len = static_cast<long>(input->get_shape().lens()[0]); long seq_len = static_cast<long>(input->get_shape().lens()[0]);
long hs = static_cast<long>(r->get_shape().lens()[2]); long hs = static_cast<long>(r->get_shape().lens()[2]);
...@@ -227,9 +246,9 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward, ...@@ -227,9 +246,9 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
brcst_bh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bh); brcst_bh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bh);
} }
long seq_index = is_forward ? 0 : seq_len - 1;
for(long i = 0; i < seq_len; i++) for(long i = 0; i < seq_len; i++)
{ {
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input); auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
...@@ -289,28 +308,26 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward, ...@@ -289,28 +308,26 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht); auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht);
auto zt_ht1 = prog.insert_instruction(ins, op::mul{}, zt, sih); auto zt_ht1 = prog.insert_instruction(ins, op::mul{}, zt, sih);
sih = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1); sih = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1);
last_out = prog.insert_instruction(ins, op::unsqueeze{{0}}, sih); last_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
if (i < seq_len - 1)
{
if(is_forward) if(is_forward)
{ {
hidden_out = (seq_index == 0) hidden_states = (seq_index == 0)
? last_out ? last_output
: prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out); : prog.insert_instruction(ins, op::concat{0}, hidden_states, last_output);
} }
else else
{ {
hidden_out = (seq_index == seq_len - 1) hidden_states = (seq_index == seq_len - 1)
? last_out ? last_output
: prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out); : prog.insert_instruction(ins, op::concat{0}, last_output, hidden_states);
}
} }
seq_index = is_forward ? (seq_index + 1) : (seq_index - 1);
} }
std::vector<instruction_ref> out_args; return {hidden_states, last_output};
out_args.push_back(hidden_out);
out_args.push_back(last_out);
return out_args;
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment