Commit 0427ce2b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format.

parent a2ea4ecd
......@@ -1072,18 +1072,18 @@ struct rnn
{
auto in_dims = inputs[0].lens();
auto hidden_dims = inputs[1].lens();
if (hidden_size != hidden_dims[1])
if(hidden_size != hidden_dims[1])
{
MIGRAPHX_THROW("RNN: hidden size mismatch in attribute and input");
}
std::size_t num_directions = 1;
if (direction == rnn_direction_t::bidirectional)
if(direction == rnn_direction_t::bidirectional)
{
num_directions = 2;
}
if (num_directions != hidden_dims[0])
if(num_directions != hidden_dims[0])
{
MIGRAPHX_THROW("RNN: num_direction does not match the direction attribute");
}
......@@ -1096,7 +1096,6 @@ struct rnn
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -7,7 +7,6 @@
#include <migraphx/operators.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -661,7 +661,7 @@ struct onnx_parser
actv_func_map.insert(std::make_pair("relu", op::relu{}));
actv_func_map.insert(std::make_pair("sigmoid", op::sigmoid{}));
if (actv_func_map.count(activation_func) == 0)
if(actv_func_map.count(activation_func) == 0)
{
MIGRAPHX_THROW("RNN: activation function " + activation_func + " not supported");
}
......@@ -690,7 +690,8 @@ struct onnx_parser
clip = parse_value(attributes.at("clip")).at<float>();
}
return prog.add_instruction(op::rnn{hidden_size, actv_func_map[activation_func], dirct, clip}, std::move(args));
return prog.add_instruction(
op::rnn{hidden_size, actv_func_map[activation_func], dirct, clip}, std::move(args));
}
void parse_from(std::istream& is)
......
......@@ -30,7 +30,6 @@ void rewrite_rnn::apply(program& prog) const
migraphx::shape s{type, {batch_size, hidden_size}};
std::vector<char> data(s.bytes(), 0);
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn::rnn_direction_t dicrt = rnn_op.direction;
if(dicrt == op::rnn::rnn_direction_t::bidirectional)
......@@ -66,7 +65,8 @@ void rewrite_rnn::apply(program& prog) const
auto b_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
b_forward = prog.insert_instruction(ins, op::squeeze{{0}}, b_forward);
auto wbf = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, b_forward);
auto rbf = prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_forward);
auto rbf =
prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_forward);
auto bf = prog.insert_instruction(ins, op::add{}, wbf, rbf);
bias_forward = prog.insert_instruction(ins, op::broadcast{1, s}, bf);
......@@ -74,7 +74,8 @@ void rewrite_rnn::apply(program& prog) const
auto b_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
b_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, b_reverse);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, b_reverse);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_reverse);
auto rbr =
prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_reverse);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
bias_reverse = prog.insert_instruction(ins, op::broadcast{1, s}, br);
}
......@@ -159,7 +160,8 @@ void rewrite_rnn::apply(program& prog) const
{
ih = prog.add_literal(migraphx::literal{s, data});
}
auto ret = rnn_oper(is_forward, prog, ins, args[0], trans_xw, trans_hw, ih, bias, rnn_op.actv_func);
auto ret = rnn_oper(
is_forward, prog, ins, args[0], trans_xw, trans_hw, ih, bias, rnn_op.actv_func);
// add the dimension of num_direction
prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment