Commit bddd8454 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

more test examples and code cleanup.

parent 62eea2df
...@@ -154,13 +154,19 @@ void rewrite_gru::apply(program& prog) const ...@@ -154,13 +154,19 @@ void rewrite_gru::apply(program& prog) const
// replace the corresponding gru_last_output instruction // replace the corresponding gru_last_output instruction
// with the last_output, if gru_last_output exists // with the last_output, if gru_last_output exists
auto last_output_it = // while loop to handle case of multiple gru_last_output operators
std::find_if(ins->outputs().begin(), ins->outputs().end(), [](auto i) { auto last_output_it = ins->outputs().begin();
while (last_output_it != ins->outputs().end())
{
last_output_it = std::find_if(last_output_it, ins->outputs().end(), [] (auto i) {
return i->name() == "gru_last_output"; return i->name() == "gru_last_output";
}); });
if(last_output_it != ins->outputs().end())
{ if(last_output_it != ins->outputs().end())
prog.replace_instruction(*last_output_it, last_output); {
prog.replace_instruction(*last_output_it, last_output);
last_output_it++;
}
} }
} }
} }
......
...@@ -10,7 +10,6 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,7 +10,6 @@ inline namespace MIGRAPHX_INLINE_NS {
void rewrite_rnn::apply(program& prog) const void rewrite_rnn::apply(program& prog) const
{ {
std::unordered_map<instruction_ref, instruction_ref> map_last_output;
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(prog))
{ {
// rewrite rnn operator // rewrite rnn operator
...@@ -32,6 +31,7 @@ void rewrite_rnn::apply(program& prog) const ...@@ -32,6 +31,7 @@ void rewrite_rnn::apply(program& prog) const
auto actv_funcs = compute_actv_funcs(ins); auto actv_funcs = compute_actv_funcs(ins);
auto rnn_op = any_cast<op::rnn>(ins->get_operator()); auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn::rnn_direction_t dicrt = rnn_op.direction; op::rnn::rnn_direction_t dicrt = rnn_op.direction;
instruction_ref last_output{};
if(dicrt == op::rnn::bidirectional) if(dicrt == op::rnn::bidirectional)
{ {
// input weight matrix // input weight matrix
...@@ -87,7 +87,7 @@ void rewrite_rnn::apply(program& prog) const ...@@ -87,7 +87,7 @@ void rewrite_rnn::apply(program& prog) const
auto concat_output = auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output); last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// The following logic is to ensure the last instruction rewritten from // The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction // rnn operator is a concat instruction
...@@ -107,7 +107,6 @@ void rewrite_rnn::apply(program& prog) const ...@@ -107,7 +107,6 @@ void rewrite_rnn::apply(program& prog) const
hidden_output = prog.replace_instruction( hidden_output = prog.replace_instruction(
ins, op::concat{1}, {ret_forward[0], ret_reverse[0]}); ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
} }
map_last_output[hidden_output] = last_output;
} }
else else
{ {
...@@ -138,7 +137,7 @@ void rewrite_rnn::apply(program& prog) const ...@@ -138,7 +137,7 @@ void rewrite_rnn::apply(program& prog) const
auto ret = auto ret =
rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0)); rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]); last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// following logic is to ensure the last instruction is a // following logic is to ensure the last instruction is a
// concat instruction // concat instruction
...@@ -155,30 +154,23 @@ void rewrite_rnn::apply(program& prog) const ...@@ -155,30 +154,23 @@ void rewrite_rnn::apply(program& prog) const
hidden_output = hidden_output =
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1); prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
} }
// auto last_it = std::find_if();
// if(last_it != ins->outputs().end())
// {
// }
map_last_output[hidden_output] = last_output;
} }
}
// rewrite the rnn_last_output operator that right after the rnn // search its output to find if there are rnn_last_output operator
// operator. Intuitively, we can do a slice on its input to get // while loop to handle case of multiple rnn_last_output operators
// the last output, but it is already existed in the rnn operator, auto last_output_it = ins->outputs().begin();
// so we can just use it as the output here while (last_output_it != ins->outputs().end())
if(ins->name() == "rnn_last_output")
{
auto inputs = ins->inputs();
assert(inputs.size() == 1);
auto arg = inputs[0];
if(map_last_output.count(arg) == 0)
{ {
MIGRAPHX_THROW("RNN_LAST_OUTPUT: no related rnn operator as its input"); last_output_it = std::find_if(last_output_it, ins->outputs().end(), [] (auto i) {
} return i->name() == "rnn_last_output";
});
prog.replace_instruction(ins, map_last_output[arg]); if(last_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_output_it, last_output);
last_output_it++;
}
}
} }
} }
} }
......
...@@ -138,6 +138,43 @@ TEST_CASE(rnn_forward) ...@@ -138,6 +138,43 @@ TEST_CASE(rnn_forward)
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
} }
// multiple rnn_last_output operators
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn::forward, clip},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({});
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
// 3 args // 3 args
{ {
migraphx::program p; migraphx::program p;
...@@ -617,6 +654,48 @@ TEST_CASE(gru_forward) ...@@ -617,6 +654,48 @@ TEST_CASE(gru_forward)
EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
} }
// two gru_last_output operators after gru
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto concat_hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward,
clip,
1},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, concat_hs);
p.add_instruction(migraphx::op::gru_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.3969709,
0.43360898,
0.35775262,
0.23280787,
-0.52179873,
-0.21944991,
0.4535257,
-0.13735442,
0.51757574,
0.50380427};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// last output for output, linear_before_reset = 0 // last output for output, linear_before_reset = 0
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment