Commit 2d7f3523 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

rewrite the gru operator to support two outputs.

parent 1fbe8c48
...@@ -1167,6 +1167,20 @@ struct rnn_last_output ...@@ -1167,6 +1167,20 @@ struct rnn_last_output
} }
}; };
struct gru_last_output
{
std::string name() const { return "gru_last_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
dims.erase(dims.begin());
return {inputs[0].type(), dims};
}
};
} // namespace op } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -13,7 +13,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -13,7 +13,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
/** /**
* Rewrite rnn to gemm and add. * Rewrite gru to gemm, mul, and add.
*/ */
struct rewrite_gru struct rewrite_gru
{ {
...@@ -21,14 +21,14 @@ struct rewrite_gru ...@@ -21,14 +21,14 @@ struct rewrite_gru
void apply(program& prog) const; void apply(program& prog) const;
private: private:
std::vector<instruction_ref> gru_oper(bool is_forward, std::vector<instruction_ref> gru_cell(bool is_forward,
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
instruction_ref input, instruction_ref input,
instruction_ref wx, instruction_ref wx,
instruction_ref wh, instruction_ref wh,
instruction_ref ih,
instruction_ref bias, instruction_ref bias,
instruction_ref ih,
int linear_before_reset, int linear_before_reset,
operation& actv_func1, operation& actv_func1,
operation& actv_func2) const; operation& actv_func2) const;
......
...@@ -732,14 +732,14 @@ struct onnx_parser ...@@ -732,14 +732,14 @@ struct onnx_parser
std::move(args)); std::move(args));
result.push_back(hidden_states); result.push_back(hidden_states);
// second out for the last hidden state // second output for the last hidden state
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states); auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
result.push_back(last_output); result.push_back(last_output);
return result; return result;
} }
instruction_ref std::vector<instruction_ref>
parse_gru(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_gru(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
migraphx::shape input_shape = args[0]->get_shape(); migraphx::shape input_shape = args[0]->get_shape();
...@@ -842,9 +842,18 @@ struct onnx_parser ...@@ -842,9 +842,18 @@ struct onnx_parser
linear_before_reset = parse_value(attributes.at("linear_before_reset")).at<int>(); linear_before_reset = parse_value(attributes.at("linear_before_reset")).at<int>();
} }
return prog.add_instruction( std::vector<instruction_ref> result;
// first output for concatenation of hidden states
auto hidden_states = prog.add_instruction(
op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset}, op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
std::move(args)); std::move(args));
result.push_back(hidden_states);
// second output for last gru output
auto last_output = prog.add_instruction(op::gru_last_output{}, hidden_states);
result.push_back(last_output);
return result;
} }
void parse_from(std::istream& is) void parse_from(std::istream& is)
......
This diff is collapsed.
...@@ -26,7 +26,7 @@ void rewrite_rnn::apply(program& prog) const ...@@ -26,7 +26,7 @@ void rewrite_rnn::apply(program& prog) const
std::size_t hidden_size = args[1]->get_shape().lens()[1]; std::size_t hidden_size = args[1]->get_shape().lens()[1];
std::size_t batch_size = seq_shape.lens()[1]; std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type(); shape::type_t type = seq_shape.type();
migraphx::shape ih_shape{type, {batch_size, hidden_size}}; migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<char> data(ih_shape.bytes(), 0); std::vector<char> data(ih_shape.bytes(), 0);
auto rnn_op = any_cast<op::rnn>(ins->get_operator()); auto rnn_op = any_cast<op::rnn>(ins->get_operator());
...@@ -133,7 +133,7 @@ void rewrite_rnn::apply(program& prog) const ...@@ -133,7 +133,7 @@ void rewrite_rnn::apply(program& prog) const
} }
// rewrite the rnn_last_output operator that right after the rnn // rewrite the rnn_last_output operator that right after the rnn
// operator. Intuitively, we can do a slice on the input to get // operator. Intuitively, we can do a slice on its input to get
// the last output, but it is already existed in the rnn operator, // the last output, but it is already existed in the rnn operator,
// so we can just use it as the output here // so we can just use it as the output here
if(ins->name() == "rnn_last_output") if(ins->name() == "rnn_last_output")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment