"docs/vscode:/vscode.git/clone" did not exist on "6f6faffdf4ada01ab930878c3e97616f5e9ec6ed"
Commit 67491293 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add the gru operator

parent 69102b29
...@@ -1067,8 +1067,8 @@ struct rnn ...@@ -1067,8 +1067,8 @@ struct rnn
bidirectional, bidirectional,
}; };
std::size_t hidden_size = 1; std::size_t hidden_size = 1;
operation actv_func = tanh{}; operation actv_func{tanh{}};
rnn_direction_t direction = forward; rnn_direction_t direction = forward;
float clip = 0.0f; float clip = 0.0f;
...@@ -1076,14 +1076,14 @@ struct rnn ...@@ -1076,14 +1076,14 @@ struct rnn
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
auto in_dims = inputs[0].lens(); auto in_dims = inputs[0].lens();
auto hidden_dims = inputs[1].lens(); auto hidden_dims = inputs[2].lens();
if(hidden_size != hidden_dims[1]) if(hidden_size != hidden_dims[2])
{ {
MIGRAPHX_THROW("RNN: hidden size mismatch in attribute and input"); MIGRAPHX_THROW("RNN: hidden size mismatch in attribute and input");
} }
std::size_t num_directions = 1; std::size_t num_directions = 1;
if(direction == rnn_direction_t::bidirectional) if(direction == bidirectional)
{ {
num_directions = 2; num_directions = 2;
} }
...@@ -1101,6 +1101,50 @@ struct rnn ...@@ -1101,6 +1101,50 @@ struct rnn
} }
}; };
struct gru
{
enum gru_direction_t
{
forward,
reverse,
bidirectional,
};
std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{sigmoid{}, tanh{}};
gru_direction_t direction = forward;
float clip = 0.0f;
int linear_before_reset = 0;
std::string name() const { return "gru"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto in_dims = inputs[0].lens();
auto hidden_dims = inputs[2].lens();
if(hidden_size != hidden_dims[2])
{
MIGRAPHX_THROW("GRU: hidden size mismatch in attribute and input");
}
std::size_t num_directions = 1;
if(direction == bidirectional)
{
num_directions = 2;
}
if(num_directions != hidden_dims[0])
{
MIGRAPHX_THROW("GRU: num_direction does not match the direction attribute");
}
std::vector<std::size_t> out_dims(in_dims);
out_dims.insert(out_dims.begin() + 1, num_directions);
out_dims.back() = hidden_size;
return {inputs[0].type(), out_dims};
}
};
} // namespace op } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment