"vscode:/vscode.git/clone" did not exist on "154799581fdb9a96eed8fd0759c544f1fb7a0081"
Commit 3b447c38 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add lstm operator

parent b8478c6f
...@@ -1256,6 +1256,78 @@ struct gru_last_output ...@@ -1256,6 +1256,78 @@ struct gru_last_output
} }
}; };
struct lstm
{
enum lstm_direction_t
{
forward,
reverse,
bidirectional,
};
std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{sigmoid{}, tanh{}, tanh{}};
gru_direction_t direction = forward;
float clip = 0.0f;
int input_forget = 0;
std::string name() const { return "lstm"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto in_dims = inputs[0].lens();
auto hidden_dims = inputs[2].lens();
if(hidden_size != hidden_dims[2])
{
MIGRAPHX_THROW("LSTM: hidden size mismatch in attribute and input");
}
std::size_t num_directions = 1;
if(direction == bidirectional)
{
num_directions = 2;
}
if(num_directions != hidden_dims[0])
{
MIGRAPHX_THROW("LSTM: num_direction does not match the direction attribute");
}
std::vector<std::size_t> out_dims(in_dims);
out_dims.insert(out_dims.begin() + 1, num_directions);
out_dims.back() = hidden_size;
return {inputs[0].type(), out_dims};
}
};
struct lstm_last_output
{
std::string name() const { return "lstm_last_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
dims.erase(dims.begin());
return {inputs[0].type(), dims};
}
};
struct lstm_last_cell_output
{
std::string name() const { return "lstm_last_cell_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
dims.erase(dims.begin());
return {inputs[0].type(), dims};
}
};
struct undefined struct undefined
{ {
std::string name() const { return "undefined"; } std::string name() const { return "undefined"; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment