Commit 67491293 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add the gru operator

parent 69102b29
......@@ -1067,8 +1067,8 @@ struct rnn
bidirectional,
};
std::size_t hidden_size = 1;
operation actv_func = tanh{};
std::size_t hidden_size = 1;
operation actv_func{tanh{}};
rnn_direction_t direction = forward;
float clip = 0.0f;
......@@ -1076,14 +1076,14 @@ struct rnn
shape compute_shape(std::vector<shape> inputs) const
{
auto in_dims = inputs[0].lens();
auto hidden_dims = inputs[1].lens();
if(hidden_size != hidden_dims[1])
auto hidden_dims = inputs[2].lens();
if(hidden_size != hidden_dims[2])
{
MIGRAPHX_THROW("RNN: hidden size mismatch in attribute and input");
}
std::size_t num_directions = 1;
if(direction == rnn_direction_t::bidirectional)
if(direction == bidirectional)
{
num_directions = 2;
}
......@@ -1101,6 +1101,50 @@ struct rnn
}
};
struct gru
{
enum gru_direction_t
{
forward,
reverse,
bidirectional,
};
std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{sigmoid{}, tanh{}};
gru_direction_t direction = forward;
float clip = 0.0f;
int linear_before_reset = 0;
std::string name() const { return "gru"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto in_dims = inputs[0].lens();
auto hidden_dims = inputs[2].lens();
if(hidden_size != hidden_dims[2])
{
MIGRAPHX_THROW("GRU: hidden size mismatch in attribute and input");
}
std::size_t num_directions = 1;
if(direction == bidirectional)
{
num_directions = 2;
}
if(num_directions != hidden_dims[0])
{
MIGRAPHX_THROW("GRU: num_direction does not match the direction attribute");
}
std::vector<std::size_t> out_dims(in_dims);
out_dims.insert(out_dims.begin() + 1, num_directions);
out_dims.back() = hidden_size;
return {inputs[0].type(), out_dims};
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment