Commit 40babfd5 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format.

parent 03b59435
......@@ -92,11 +92,12 @@ void rewrite_gru::apply(program& prog) const
// The following logic is to ensure the last instruction rewritten
// from gru operator is a concat
instruction_ref hidden_state{};
if (ret_forward[0] == prog.end())
if(ret_forward[0] == prog.end())
{
hidden_state = prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
hidden_state = prog.replace_instruction(
ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
else
{
ret_forward[0] =
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
......@@ -104,7 +105,7 @@ void rewrite_gru::apply(program& prog) const
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
hidden_state = prog.replace_instruction(
ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
}
map_last_output[hidden_state] = last_output;
}
else
......@@ -147,7 +148,7 @@ void rewrite_gru::apply(program& prog) const
auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
instruction_ref hidden_state{};
if (ret[0] == prog.end())
if(ret[0] == prog.end())
{
hidden_state = prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
......@@ -155,7 +156,8 @@ void rewrite_gru::apply(program& prog) const
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[1];
hidden_state = prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
hidden_state =
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
map_last_output[hidden_state] = last_output;
}
......@@ -188,8 +190,8 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
operation& actv_func2) const
{
instruction_ref hidden_states = prog.end(), last_output;
long seq_len = static_cast<long>(input->get_shape().lens()[0]);
long hs = static_cast<long>(r->get_shape().lens()[2]);
long seq_len = static_cast<long>(input->get_shape().lens()[0]);
long hs = static_cast<long>(r->get_shape().lens()[2]);
migraphx::shape s(input->get_shape().type(),
{input->get_shape().lens()[1], static_cast<std::size_t>(hs)});
......@@ -310,19 +312,21 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
sih = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1);
last_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
if (i < seq_len - 1)
if(i < seq_len - 1)
{
if(is_forward)
{
hidden_states = (seq_index == 0)
? last_output
: prog.insert_instruction(ins, op::concat{0}, hidden_states, last_output);
hidden_states =
(seq_index == 0)
? last_output
: prog.insert_instruction(ins, op::concat{0}, hidden_states, last_output);
}
else
{
hidden_states = (seq_index == seq_len - 1)
? last_output
: prog.insert_instruction(ins, op::concat{0}, last_output, hidden_states);
hidden_states =
(seq_index == seq_len - 1)
? last_output
: prog.insert_instruction(ins, op::concat{0}, last_output, hidden_states);
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment