Commit 0427ce2b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format.

parent a2ea4ecd
......@@ -1062,28 +1062,28 @@ struct rnn
bidirectional,
};
std::size_t hidden_size = 1;
operation actv_func = tanh{};
std::size_t hidden_size = 1;
operation actv_func = tanh{};
rnn_direction_t direction = forward;
float clip = 0.0f;
float clip = 0.0f;
std::string name() const { return "rnn"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto in_dims = inputs[0].lens();
auto in_dims = inputs[0].lens();
auto hidden_dims = inputs[1].lens();
if (hidden_size != hidden_dims[1])
if(hidden_size != hidden_dims[1])
{
MIGRAPHX_THROW("RNN: hidden size mismatch in attribute and input");
}
std::size_t num_directions = 1;
if (direction == rnn_direction_t::bidirectional)
if(direction == rnn_direction_t::bidirectional)
{
num_directions = 2;
}
if (num_directions != hidden_dims[0])
if(num_directions != hidden_dims[0])
{
MIGRAPHX_THROW("RNN: num_direction does not match the direction attribute");
}
......@@ -1096,7 +1096,6 @@ struct rnn
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -7,7 +7,6 @@
#include <migraphx/operators.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -661,7 +661,7 @@ struct onnx_parser
actv_func_map.insert(std::make_pair("relu", op::relu{}));
actv_func_map.insert(std::make_pair("sigmoid", op::sigmoid{}));
if (actv_func_map.count(activation_func) == 0)
if(actv_func_map.count(activation_func) == 0)
{
MIGRAPHX_THROW("RNN: activation function " + activation_func + " not supported");
}
......@@ -690,7 +690,8 @@ struct onnx_parser
clip = parse_value(attributes.at("clip")).at<float>();
}
return prog.add_instruction(op::rnn{hidden_size, actv_func_map[activation_func], dirct, clip}, std::move(args));
return prog.add_instruction(
op::rnn{hidden_size, actv_func_map[activation_func], dirct, clip}, std::move(args));
}
void parse_from(std::istream& is)
......
......@@ -18,22 +18,21 @@ void rewrite_rnn::apply(program& prog) const
}
// could be 3 to 5 inputs (though onnx::rnn has 6 inputs,
// the 5th one is undefined and ignored by protobuf. so
// the 5th one is undefined and ignored by protobuf. so
// we need to process up to 5 inputs
auto args = ins->inputs();
shape seq_shape = args[0]->get_shape();
shape wgt_shape = args[1]->get_shape();
shape seq_shape = args[0]->get_shape();
shape wgt_shape = args[1]->get_shape();
std::size_t hidden_size = wgt_shape.lens()[1];
std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type();
std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type();
migraphx::shape s{type, {batch_size, hidden_size}};
std::vector<char> data(s.bytes(), 0);
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn::rnn_direction_t dicrt = rnn_op.direction;
if(dicrt == op::rnn::rnn_direction_t::bidirectional)
if(dicrt == op::rnn::rnn_direction_t::bidirectional)
{
std::vector<int64_t> perm{1, 0};
// process input weight matrix
......@@ -65,17 +64,19 @@ void rewrite_rnn::apply(program& prog) const
long h_size = static_cast<long>(hidden_size);
auto b_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
b_forward = prog.insert_instruction(ins, op::squeeze{{0}}, b_forward);
auto wbf = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, b_forward);
auto rbf = prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_forward);
auto bf = prog.insert_instruction(ins, op::add{}, wbf, rbf);
auto wbf = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, b_forward);
auto rbf =
prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_forward);
auto bf = prog.insert_instruction(ins, op::add{}, wbf, rbf);
bias_forward = prog.insert_instruction(ins, op::broadcast{1, s}, bf);
// backward
auto b_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
b_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, b_reverse);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, b_reverse);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_reverse);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, b_reverse);
auto rbr =
prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_reverse);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
bias_reverse = prog.insert_instruction(ins, op::broadcast{1, s}, br);
}
......@@ -144,9 +145,9 @@ void rewrite_rnn::apply(program& prog) const
long h_size = static_cast<long>(hidden_size);
auto bwr = prog.insert_instruction(ins, op::squeeze{{0}}, args[3]);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, bwr);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, bwr);
auto b = prog.insert_instruction(ins, op::add{}, wb, rb);
bias = prog.insert_instruction(ins, op::broadcast{1, s}, b);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, bwr);
auto b = prog.insert_instruction(ins, op::add{}, wb, rb);
bias = prog.insert_instruction(ins, op::broadcast{1, s}, b);
}
// process intial hidden state
......@@ -159,7 +160,8 @@ void rewrite_rnn::apply(program& prog) const
{
ih = prog.add_literal(migraphx::literal{s, data});
}
auto ret = rnn_oper(is_forward, prog, ins, args[0], trans_xw, trans_hw, ih, bias, rnn_op.actv_func);
auto ret = rnn_oper(
is_forward, prog, ins, args[0], trans_xw, trans_hw, ih, bias, rnn_op.actv_func);
// add the dimension of num_direction
prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
......@@ -168,14 +170,14 @@ void rewrite_rnn::apply(program& prog) const
}
std::vector<instruction_ref> rewrite_rnn::rnn_oper(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref wx,
instruction_ref wh,
instruction_ref ih,
instruction_ref bias,
operation& actv_func) const
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref wx,
instruction_ref wh,
instruction_ref ih,
instruction_ref bias,
operation& actv_func) const
{
instruction_ref hidden_out, final_out;
migraphx::shape input_shape = input->get_shape();
......@@ -183,8 +185,8 @@ std::vector<instruction_ref> rewrite_rnn::rnn_oper(bool is_forward,
long seq_index = is_forward ? 0 : seq_len - 1;
for(std::size_t i = 0; i < seq_len; i++)
{
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
auto x_w = prog.insert_instruction(ins, op::dot{}, xt, wx);
auto h_r = prog.insert_instruction(ins, op::dot{}, ih, wh);
auto x_h = prog.insert_instruction(ins, op::add{}, x_w, h_r);
......@@ -208,14 +210,14 @@ std::vector<instruction_ref> rewrite_rnn::rnn_oper(bool is_forward,
if(is_forward)
{
hidden_out = (seq_index == 0)
? output
: prog.insert_instruction(ins, op::concat{0}, hidden_out, output);
? output
: prog.insert_instruction(ins, op::concat{0}, hidden_out, output);
}
else
{
hidden_out = (seq_index == seq_len - 1)
? output
: prog.insert_instruction(ins, op::concat{0}, output, hidden_out);
? output
: prog.insert_instruction(ins, op::concat{0}, output, hidden_out);
}
seq_index = is_forward ? (seq_index + 1) : (seq_index - 1);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment