Commit bddd8454 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

more test examples and code cleanup.

parent 62eea2df
......@@ -154,13 +154,19 @@ void rewrite_gru::apply(program& prog) const
// replace the corresponding gru_last_output instruction
// with the last_output, if gru_last_output exists
auto last_output_it =
std::find_if(ins->outputs().begin(), ins->outputs().end(), [](auto i) {
// while loop to handle case of multiple gru_last_output operators
auto last_output_it = ins->outputs().begin();
while (last_output_it != ins->outputs().end())
{
last_output_it = std::find_if(last_output_it, ins->outputs().end(), [] (auto i) {
return i->name() == "gru_last_output";
});
if(last_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_output_it, last_output);
if(last_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_output_it, last_output);
last_output_it++;
}
}
}
}
......
......@@ -10,7 +10,6 @@ inline namespace MIGRAPHX_INLINE_NS {
void rewrite_rnn::apply(program& prog) const
{
std::unordered_map<instruction_ref, instruction_ref> map_last_output;
for(auto ins : iterator_for(prog))
{
// rewrite rnn operator
......@@ -32,6 +31,7 @@ void rewrite_rnn::apply(program& prog) const
auto actv_funcs = compute_actv_funcs(ins);
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn::rnn_direction_t dicrt = rnn_op.direction;
instruction_ref last_output{};
if(dicrt == op::rnn::bidirectional)
{
// input weight matrix
......@@ -87,7 +87,7 @@ void rewrite_rnn::apply(program& prog) const
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction
......@@ -107,7 +107,6 @@ void rewrite_rnn::apply(program& prog) const
hidden_output = prog.replace_instruction(
ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
map_last_output[hidden_output] = last_output;
}
else
{
......@@ -138,7 +137,7 @@ void rewrite_rnn::apply(program& prog) const
auto ret =
rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// following logic is to ensure the last instruction is a
// concat instruction
......@@ -155,30 +154,23 @@ void rewrite_rnn::apply(program& prog) const
hidden_output =
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
// auto last_it = std::find_if();
// if(last_it != ins->outputs().end())
// {
// }
map_last_output[hidden_output] = last_output;
}
}
// rewrite the rnn_last_output operator that right after the rnn
// operator. Intuitively, we can do a slice on its input to get
// the last output, but it is already existed in the rnn operator,
// so we can just use it as the output here
if(ins->name() == "rnn_last_output")
{
auto inputs = ins->inputs();
assert(inputs.size() == 1);
auto arg = inputs[0];
if(map_last_output.count(arg) == 0)
// search its output to find if there are rnn_last_output operator
// while loop to handle case of multiple rnn_last_output operators
auto last_output_it = ins->outputs().begin();
while (last_output_it != ins->outputs().end())
{
MIGRAPHX_THROW("RNN_LAST_OUTPUT: no related rnn operator as its input");
}
last_output_it = std::find_if(last_output_it, ins->outputs().end(), [] (auto i) {
return i->name() == "rnn_last_output";
});
prog.replace_instruction(ins, map_last_output[arg]);
if(last_output_it != ins->outputs().end())
{
prog.replace_instruction(*last_output_it, last_output);
last_output_it++;
}
}
}
}
}
......
......@@ -138,6 +138,43 @@ TEST_CASE(rnn_forward)
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
// multiple rnn_last_output operators
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs =
p.add_instruction(migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn::forward, clip},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{});
auto last_output = p.eval({});
std::vector<float> last_output_data;
last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); });
std::vector<float> last_output_data_gold{0.03445704,
0.19167931,
-0.3946827,
-0.30889652,
-0.22276389,
0.44193283,
-0.16477929,
-0.11893477};
EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold));
}
// 3 args
{
migraphx::program p;
......@@ -617,6 +654,48 @@ TEST_CASE(gru_forward)
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// two gru_last_output operators after gru
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto concat_hs =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward,
clip,
1},
seq,
w,
r,
bias,
und,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, concat_hs);
p.add_instruction(migraphx::op::gru_last_output{}, concat_hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{-0.3969709,
0.43360898,
0.35775262,
0.23280787,
-0.52179873,
-0.21944991,
0.4535257,
-0.13735442,
0.51757574,
0.50380427};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
// last output for output, linear_before_reset = 0
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment