Commit 03b59435 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

refine the implementation of gru operator.

parent fb481fed
......@@ -85,16 +85,26 @@ void rewrite_gru::apply(program& prog) const
gru_op.actv_funcs.at(2),
gru_op.actv_funcs.at(3));
auto last_output =
prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);
// add the dimension of num_direction
ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]);
ret_reverse[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_reverse[0]);
// concat the forward and reverse output
auto hidden_state =
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// The following logic is to ensure the last instruction rewritten
// from gru operator is a concat
instruction_ref hidden_state{};
if (ret_forward[0] == prog.end())
{
hidden_state = prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
{
ret_forward[0] =
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
hidden_state = prog.replace_instruction(
ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
map_last_output[hidden_state] = last_output;
}
else
......@@ -134,10 +144,19 @@ void rewrite_gru::apply(program& prog) const
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1));
auto last_output = ret[1];
auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// add the dimension of num_direction
auto hidden_state = prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
instruction_ref hidden_state{};
if (ret[0] == prog.end())
{
hidden_state = prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[1];
hidden_state = prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
map_last_output[hidden_state] = last_output;
}
}
......@@ -168,7 +187,7 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
operation& actv_func1,
operation& actv_func2) const
{
instruction_ref hidden_out, last_out;
instruction_ref hidden_states = prog.end(), last_output;
long seq_len = static_cast<long>(input->get_shape().lens()[0]);
long hs = static_cast<long>(r->get_shape().lens()[2]);
......@@ -227,9 +246,9 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
brcst_bh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bh);
}
long seq_index = is_forward ? 0 : seq_len - 1;
for(long i = 0; i < seq_len; i++)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
......@@ -289,28 +308,26 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht);
auto zt_ht1 = prog.insert_instruction(ins, op::mul{}, zt, sih);
sih = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1);
last_out = prog.insert_instruction(ins, op::unsqueeze{{0}}, sih);
last_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
if(is_forward)
if (i < seq_len - 1)
{
hidden_out = (seq_index == 0)
? last_out
: prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
}
else
{
hidden_out = (seq_index == seq_len - 1)
? last_out
: prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
if(is_forward)
{
hidden_states = (seq_index == 0)
? last_output
: prog.insert_instruction(ins, op::concat{0}, hidden_states, last_output);
}
else
{
hidden_states = (seq_index == seq_len - 1)
? last_output
: prog.insert_instruction(ins, op::concat{0}, last_output, hidden_states);
}
}
seq_index = is_forward ? (seq_index + 1) : (seq_index - 1);
}
std::vector<instruction_ref> out_args;
out_args.push_back(hidden_out);
out_args.push_back(last_out);
return out_args;
return {hidden_states, last_output};
}
} // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment