Commit f8c319e3 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add more code for lstm operator

parent 398c0157
...@@ -1267,7 +1267,7 @@ struct lstm ...@@ -1267,7 +1267,7 @@ struct lstm
std::size_t hidden_size = 1; std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{sigmoid{}, tanh{}, tanh{}}; std::vector<operation> actv_funcs{sigmoid{}, tanh{}, tanh{}};
gru_direction_t direction = forward; lstm_direction_t direction = forward;
float clip = 0.0f; float clip = 0.0f;
int input_forget = 0; int input_forget = 0;
......
...@@ -45,6 +45,19 @@ struct rewrite_rnn ...@@ -45,6 +45,19 @@ struct rewrite_rnn
const operation& actv_func2) const; const operation& actv_func2) const;
std::vector<operation> gru_actv_funcs(instruction_ref ins) const; std::vector<operation> gru_actv_funcs(instruction_ref ins) const;
// for lstm operators
void apply_lstm(program& prog, instruction_ref ins) const;
std::vector<instruction_ref> lstm_cell(bool is_forward,
program& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
const operation& actv_func1,
const operation& actv_func2,
const operation& actv_func3) const;
std::vector<operation> lstm_actv_funcs(instruction_ref ins) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -900,6 +900,159 @@ struct onnx_parser ...@@ -900,6 +900,159 @@ struct onnx_parser
return {hidden_states, last_output}; return {hidden_states, last_output};
} }
std::vector<instruction_ref>
parse_lstm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2];
if(contains(attributes, "hidden_size"))
{
std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att)
{
MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
}
}
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(attributes, "direction"))
{
direction = attributes.at("direction").s();
}
op::lstm::lstm_direction_t dirct = op::lstm::forward;
if(direction == "bidirectional")
{
dirct = op::lstm::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::lstm::reverse;
}
else if (direction == "forward")
{
dirct = op::lstm::forward;
}
else
{
MIGRAPHX_THROW("LSTM: incorrect direction attribute");
}
std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
if(contains(attributes, "activations"))
{
auto names = attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::transform(
names.begin(), names.end(), vec_names.begin(), [](auto& str) { return str; });
}
// need 6 activation functions for bidirectional directions
if(dirct == op::lstm::bidirectional)
{
// 6 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provided,
// repeat 1st six times. If 2 actv functins are provided,
// repeat 2nd once, then repeat all three once
// if 3 actv funcs are provide, repeat all three once.
// the same algorithm is used for 4, 5, and 6 actv funcions
// provided. This may need change later
switch(vec_names.size())
{
case 1:
vec_names.insert(vec_names.end(), 5, vec_names.back());
break;
case 2:
// repeat the 2nd actv func once, then repeat all three another time
vec_names.push_back(vec_names.back());
vec_names.insert(vec_names.end(), vec_names.begin(), vec_names.end());
break;
case 3:
// repeat all three actv funcs once
vec_names.insert(vec_names.end(), vec_names.begin(), vec_names.end());
break;
case 4:
vec_names.insert(vec_names.end(), 2, vec_names.back());
break;
case 5:
vec_names.push_back(vec_names.back());
break;
default:
break;
}
}
else
{
switch(vec_names.size())
{
case 1:
vec_names.insert(vec_names.end(), 2, vec_names.back());
break;
case 2:
// repeat the 2nd actv func once, so we have 3 actv funcs
vec_names.push_back(vec_names.back());
break;
default:
break;
}
}
for_each(vec_names.begin(), vec_names.end(), [&](auto& name) {
if(map_actv_funcs.count(name) == 0)
{
MIGRAPHX_THROW("LSTM: activation function " + std::string(name) + " not supported");
}
});
std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
return map_actv_funcs[name];
});
float clip = 0.0;
if(contains(attributes, "clip"))
{
clip = parse_value(attributes.at("clip")).at<float>();
}
int input_forget = 0;
if(contains(attributes, "input_forget"))
{
input_forget = parse_value(attributes.at("input_forget")).at<int>();
}
// append undefined opeator to make 6 arguments
if(args.size() < 8)
{
auto ins = prog.add_instruction(op::undefined{});
args.insert(args.end(), 6 - args.size(), ins);
}
// first output for concatenation of hidden states
auto hidden_states = prog.add_instruction(
op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget},
std::move(args));
// second output for last lstm output
auto last_output = prog.add_instruction(op::lstm_last_output{}, hidden_states);
// third output for last cell output
auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);
return {hidden_states, last_output, last_cell_output};
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
onnx::ModelProto model; onnx::ModelProto model;
......
...@@ -668,5 +668,84 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const ...@@ -668,5 +668,84 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
} }
} }
// for lstm operators
void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
{
assert(ins->name() == "lstm");
auto args = ins->inputs();
}
std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
program& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
const operation& actv_func1,
const operation& actv_func2,
const operation& actv_func3) const
{
return {};
}
std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
{
auto lstm_op = any_cast<op::lstm>(ins->get_operator());
// before rewrite the lstm operator, need to ensure
// we have 6 actv funcs, even though a user does not
// specifiy any actv func. If less than 46, use the
// algorithm in parse_lstm to make 6 actv functions
const auto &actv_funcs = lstm_op.actv_funcs;
std::size_t num_actv_funcs = actv_funcs.size();
if(lstm_op.direction == op::lstm::bidirectional)
{
switch(num_actv_funcs)
{
case 0:
return {op::sigmoid{}, op::tanh{}, op::tanh{},
op::sigmoid{}, op::tanh{}, op::tanh{}};
case 1:
return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0),
actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
case 2:
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1),
actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};
case 3:
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(2),
actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(2)};
case 4:
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(2),
actv_funcs.at(3), actv_funcs.at(3), actv_funcs.at(3)};
case 5:
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(2),
actv_funcs.at(3), actv_funcs.at(4), actv_funcs.at(4)};
default:
return actv_funcs;
}
}
else
{
switch(num_actv_funcs)
{
case 0:
return {op::sigmoid{}, op::tanh{}, op::tanh{}};
case 1:
return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
case 2:
return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};
default:
return actv_funcs;
}
}
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment