Commit a87890be authored by Shucai Xiao's avatar Shucai Xiao
Browse files

commit gru changes.

parent 6b1e5e63
......@@ -1173,6 +1173,20 @@ struct rnn
}
};
struct rnn_last_output
{
std::string name() const { return "rnn_last_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
dims.erase(dims.begin());
return {inputs[0].type(), dims};
}
};
struct gru
{
enum gru_direction_t
......@@ -1217,20 +1231,6 @@ struct gru
}
};
struct rnn_last_output
{
std::string name() const { return "rnn_last_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
dims.erase(dims.begin());
return {inputs[0].type(), dims};
}
};
struct gru_last_output
{
std::string name() const { return "gru_last_output"; }
......
......@@ -832,8 +832,8 @@ struct onnx_parser
{
// 4 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provides,
// repeat 1 four times. If 2 actv functins are provides,
// use the algorithm that: if 1 actv function is provided,
// repeat 1 four times. If 2 actv functins are provided,
// assume forward and reverse use the same pair of actv
// functions. For the case of 3 actv functions provided,
// assume the 3rd one is repeated once and used by the
......@@ -869,12 +869,11 @@ struct onnx_parser
}
});
std::vector<operation> vec_actv_funcs;
for_each(vec_names.begin(), vec_names.end(), [&](auto& name) {
vec_actv_funcs.push_back(map_actv_funcs[name]);
std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
return map_actv_funcs[name];
});
// To be added later
float clip = 0.0;
if(contains(attributes, "clip"))
{
......@@ -887,18 +886,22 @@ struct onnx_parser
linear_before_reset = parse_value(attributes.at("linear_before_reset")).at<int>();
}
std::vector<instruction_ref> result;
// append undefined opeator to make 6 arguments
if (args.size() < 6)
{
auto ins = prog.add_instruction(op::undefined{});
args.insert(args.end(), 6 - args.size(), ins);
}
// first output for concatenation of hidden states
auto hidden_states = prog.add_instruction(
op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
std::move(args));
result.push_back(hidden_states);
// second output for last gru output
auto last_output = prog.add_instruction(op::gru_last_output{}, hidden_states);
result.push_back(last_output);
return result;
return {hidden_states, last_output};
}
void parse_from(std::istream& is)
......
......@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
void rewrite_gru::apply(program& prog) const
{
instruction_ref last_output = prog.end();
std::unordered_map<instruction_ref, instruction_ref> map_last_output;
for(auto ins : iterator_for(prog))
{
if(ins->name() == "gru")
......@@ -22,9 +22,9 @@ void rewrite_gru::apply(program& prog) const
shape seq_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2];
std::size_t batchs = seq_shape.lens()[1];
std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type();
migraphx::shape ih_shape{type, {1, batchs, hidden_size}};
migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<char> data(ih_shape.bytes(), 0);
auto gru_op = any_cast<op::gru>(ins->get_operator());
......@@ -42,7 +42,7 @@ void rewrite_gru::apply(program& prog) const
// bias
instruction_ref bias_forward, bias_reverse;
bias_forward = bias_reverse = prog.end();
if(args.size() >= 4)
if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
{
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
......@@ -50,12 +50,10 @@ void rewrite_gru::apply(program& prog) const
// intial hidden state
instruction_ref ih_forward, ih_reverse;
if(args.size() == 6 ||
(args.size() == 5 && args[4]->get_shape().lens().size() == 3))
if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
{
auto arg_ih = (args.size() == 6) ? args[5] : args[4];
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, arg_ih);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, arg_ih);
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
}
else
{
......@@ -87,7 +85,7 @@ void rewrite_gru::apply(program& prog) const
gru_op.actv_funcs.at(2),
gru_op.actv_funcs.at(3));
last_output =
auto last_output =
prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);
// add the dimension of num_direction
......@@ -95,7 +93,8 @@ void rewrite_gru::apply(program& prog) const
ret_reverse[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_reverse[0]);
// concat the forward and reverse output
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
auto hidden_state = prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
map_last_output[hidden_state] = last_output;
}
else
{
......@@ -106,17 +105,16 @@ void rewrite_gru::apply(program& prog) const
// bias
instruction_ref bias = prog.end();
if(args.size() >= 4)
if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
{
bias = args[3];
}
// intial hidden state
instruction_ref ih;
if(args.size() == 6 ||
(args.size() == 5 && args[4]->get_shape().lens().size() == 3))
if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
{
ih = args.size() == 6 ? args[5] : args[4];
ih = args[5];
}
else
{
......@@ -135,10 +133,11 @@ void rewrite_gru::apply(program& prog) const
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1));
last_output = ret[1];
auto last_output = ret[1];
// add the dimension of num_direction
prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
auto hidden_state = prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
map_last_output[hidden_state] = last_output;
}
}
......@@ -148,11 +147,10 @@ void rewrite_gru::apply(program& prog) const
// so we can just use it as the output here
if(ins->name() == "gru_last_output")
{
if(last_output != prog.end())
{
prog.replace_instruction(ins, op::identity{}, last_output);
last_output = prog.end();
}
auto inputs = ins->inputs();
assert(inputs.size() == 1);
assert(map_last_output.count(inputs[0]) > 0);
prog.replace_instruction(ins, map_last_output[inputs[0]]);
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment