Commit 0427ce2b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format.

parent a2ea4ecd
...@@ -1072,18 +1072,18 @@ struct rnn ...@@ -1072,18 +1072,18 @@ struct rnn
{ {
auto in_dims = inputs[0].lens(); auto in_dims = inputs[0].lens();
auto hidden_dims = inputs[1].lens(); auto hidden_dims = inputs[1].lens();
if (hidden_size != hidden_dims[1]) if(hidden_size != hidden_dims[1])
{ {
MIGRAPHX_THROW("RNN: hidden size mismatch in attribute and input"); MIGRAPHX_THROW("RNN: hidden size mismatch in attribute and input");
} }
std::size_t num_directions = 1; std::size_t num_directions = 1;
if (direction == rnn_direction_t::bidirectional) if(direction == rnn_direction_t::bidirectional)
{ {
num_directions = 2; num_directions = 2;
} }
if (num_directions != hidden_dims[0]) if(num_directions != hidden_dims[0])
{ {
MIGRAPHX_THROW("RNN: num_direction does not match the direction attribute"); MIGRAPHX_THROW("RNN: num_direction does not match the direction attribute");
} }
...@@ -1096,7 +1096,6 @@ struct rnn ...@@ -1096,7 +1096,6 @@ struct rnn
} }
}; };
} // namespace op } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -661,7 +661,7 @@ struct onnx_parser ...@@ -661,7 +661,7 @@ struct onnx_parser
actv_func_map.insert(std::make_pair("relu", op::relu{})); actv_func_map.insert(std::make_pair("relu", op::relu{}));
actv_func_map.insert(std::make_pair("sigmoid", op::sigmoid{})); actv_func_map.insert(std::make_pair("sigmoid", op::sigmoid{}));
if (actv_func_map.count(activation_func) == 0) if(actv_func_map.count(activation_func) == 0)
{ {
MIGRAPHX_THROW("RNN: activation function " + activation_func + " not supported"); MIGRAPHX_THROW("RNN: activation function " + activation_func + " not supported");
} }
...@@ -690,7 +690,8 @@ struct onnx_parser ...@@ -690,7 +690,8 @@ struct onnx_parser
clip = parse_value(attributes.at("clip")).at<float>(); clip = parse_value(attributes.at("clip")).at<float>();
} }
return prog.add_instruction(op::rnn{hidden_size, actv_func_map[activation_func], dirct, clip}, std::move(args)); return prog.add_instruction(
op::rnn{hidden_size, actv_func_map[activation_func], dirct, clip}, std::move(args));
} }
void parse_from(std::istream& is) void parse_from(std::istream& is)
......
...@@ -30,7 +30,6 @@ void rewrite_rnn::apply(program& prog) const ...@@ -30,7 +30,6 @@ void rewrite_rnn::apply(program& prog) const
migraphx::shape s{type, {batch_size, hidden_size}}; migraphx::shape s{type, {batch_size, hidden_size}};
std::vector<char> data(s.bytes(), 0); std::vector<char> data(s.bytes(), 0);
auto rnn_op = any_cast<op::rnn>(ins->get_operator()); auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn::rnn_direction_t dicrt = rnn_op.direction; op::rnn::rnn_direction_t dicrt = rnn_op.direction;
if(dicrt == op::rnn::rnn_direction_t::bidirectional) if(dicrt == op::rnn::rnn_direction_t::bidirectional)
...@@ -66,7 +65,8 @@ void rewrite_rnn::apply(program& prog) const ...@@ -66,7 +65,8 @@ void rewrite_rnn::apply(program& prog) const
auto b_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]); auto b_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
b_forward = prog.insert_instruction(ins, op::squeeze{{0}}, b_forward); b_forward = prog.insert_instruction(ins, op::squeeze{{0}}, b_forward);
auto wbf = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, b_forward); auto wbf = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, b_forward);
auto rbf = prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_forward); auto rbf =
prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_forward);
auto bf = prog.insert_instruction(ins, op::add{}, wbf, rbf); auto bf = prog.insert_instruction(ins, op::add{}, wbf, rbf);
bias_forward = prog.insert_instruction(ins, op::broadcast{1, s}, bf); bias_forward = prog.insert_instruction(ins, op::broadcast{1, s}, bf);
...@@ -74,7 +74,8 @@ void rewrite_rnn::apply(program& prog) const ...@@ -74,7 +74,8 @@ void rewrite_rnn::apply(program& prog) const
auto b_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]); auto b_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
b_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, b_reverse); b_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, b_reverse);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, b_reverse); auto wbr = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, b_reverse);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_reverse); auto rbr =
prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_reverse);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr); auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
bias_reverse = prog.insert_instruction(ins, op::broadcast{1, s}, br); bias_reverse = prog.insert_instruction(ins, op::broadcast{1, s}, br);
} }
...@@ -159,7 +160,8 @@ void rewrite_rnn::apply(program& prog) const ...@@ -159,7 +160,8 @@ void rewrite_rnn::apply(program& prog) const
{ {
ih = prog.add_literal(migraphx::literal{s, data}); ih = prog.add_literal(migraphx::literal{s, data});
} }
auto ret = rnn_oper(is_forward, prog, ins, args[0], trans_xw, trans_hw, ih, bias, rnn_op.actv_func); auto ret = rnn_oper(
is_forward, prog, ins, args[0], trans_xw, trans_hw, ih, bias, rnn_op.actv_func);
// add the dimension of num_direction // add the dimension of num_direction
prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]); prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment