"src/vscode:/vscode.git/clone" did not exist on "d628942b6656176d4d6b3c16405e4f640d62cf29"
Commit 1596cf1f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 6d0742b6
...@@ -741,12 +741,12 @@ struct onnx_parser ...@@ -741,12 +741,12 @@ struct onnx_parser
act_funcs[1] = attributes.at("activations").strings(1); act_funcs[1] = attributes.at("activations").strings(1);
} }
if (act_funcs.size() != 2) if(act_funcs.size() != 2)
{ {
MIGRAPHX_THROW("GRU: wrong activation function attribute"); MIGRAPHX_THROW("GRU: wrong activation function attribute");
} }
for (std::size_t i = 0; i < act_funcs.size(); ++i) for(std::size_t i = 0; i < act_funcs.size(); ++i)
{ {
if(actv_funcs.count(act_funcs.at(i)) == 0) if(actv_funcs.count(act_funcs.at(i)) == 0)
{ {
...@@ -762,14 +762,17 @@ struct onnx_parser ...@@ -762,14 +762,17 @@ struct onnx_parser
} }
int linear_before_reset = 0; int linear_before_reset = 0;
if (contains(attributes, "linear_before_reset")) if(contains(attributes, "linear_before_reset"))
{ {
linear_before_reset = parse_value(attributes.at("linear_before_reset")).at<int>(); linear_before_reset = parse_value(attributes.at("linear_before_reset")).at<int>();
} }
return prog.add_instruction(op::gru{hidden_size, return prog.add_instruction(
op::gru{hidden_size,
{actv_funcs[act_funcs.at(0)], actv_funcs[act_funcs.at(1)]}, {actv_funcs[act_funcs.at(0)], actv_funcs[act_funcs.at(1)]},
dirct, clip, linear_before_reset}, dirct,
clip,
linear_before_reset},
std::move(args)); std::move(args));
} }
......
...@@ -104,7 +104,8 @@ void rewrite_gru::apply(program& prog) const ...@@ -104,7 +104,8 @@ void rewrite_gru::apply(program& prog) const
gru_op.actv_funcs.at(2), gru_op.actv_funcs.at(2),
gru_op.actv_funcs.at(3)); gru_op.actv_funcs.at(3));
// auto final_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]); // auto final_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1],
// ret_reverse[1]);
// add the dimension of num_direction // add the dimension of num_direction
ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]); ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]);
...@@ -138,8 +139,17 @@ void rewrite_gru::apply(program& prog) const ...@@ -138,8 +139,17 @@ void rewrite_gru::apply(program& prog) const
ih = prog.add_literal(migraphx::literal{s, data}); ih = prog.add_literal(migraphx::literal{s, data});
} }
auto ret = gru_oper( auto ret = gru_oper(is_forward,
is_forward, prog, ins, args[0], w, r, ih, bias, gru_op.linear_before_reset, gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(1)); prog,
ins,
args[0],
w,
r,
ih,
bias,
gru_op.linear_before_reset,
gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(1));
// add the dimension of num_direction // add the dimension of num_direction
prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]); prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
...@@ -185,16 +195,16 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward, ...@@ -185,16 +195,16 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
// bias // bias
instruction_ref br_bz, br_br, br_wbh, br_rbh, br_bh; instruction_ref br_bz, br_br, br_wbh, br_rbh, br_bh;
if (bias != prog.end()) if(bias != prog.end())
{ {
auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, bias); auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, bias);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2*hs}}, bias); auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, bias);
wbh = prog.insert_instruction(ins, op::slice{{0}, {2*hs}, {3*hs}}, bias); wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, bias);
br_wbh = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, wbh); br_wbh = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, wbh);
auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3*hs}, {4*hs}}, bias); auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, bias);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4*hs}, {5*hs}}, bias); auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, bias);
rbh = prog.insert_instruction(ins, op::slice{{0}, {5*hs}, {6*hs}}, bias); rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, bias);
br_rbh = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, rbh); br_rbh = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, rbh);
auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz); auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz);
...@@ -212,7 +222,7 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward, ...@@ -212,7 +222,7 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
auto xwzt = prog.insert_instruction(ins, op::dot{}, xt, twz); auto xwzt = prog.insert_instruction(ins, op::dot{}, xt, twz);
auto hrzt = prog.insert_instruction(ins, op::dot{}, ih, trz); auto hrzt = prog.insert_instruction(ins, op::dot{}, ih, trz);
auto xwhr_zt = prog.insert_instruction(ins, op::add{}, xwzt, hrzt); auto xwhr_zt = prog.insert_instruction(ins, op::add{}, xwzt, hrzt);
if (bias != prog.end()) if(bias != prog.end())
{ {
xwhr_zt = prog.insert_instruction(ins, op::add{}, xwhr_zt, br_bz); xwhr_zt = prog.insert_instruction(ins, op::add{}, xwhr_zt, br_bz);
} }
...@@ -222,21 +232,21 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward, ...@@ -222,21 +232,21 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
auto xwrt = prog.insert_instruction(ins, op::dot{}, xt, twr); auto xwrt = prog.insert_instruction(ins, op::dot{}, xt, twr);
auto hrrt = prog.insert_instruction(ins, op::dot{}, xt, trr); auto hrrt = prog.insert_instruction(ins, op::dot{}, xt, trr);
auto xwhr_rt = prog.insert_instruction(ins, op::add{}, xwrt, hrrt); auto xwhr_rt = prog.insert_instruction(ins, op::add{}, xwrt, hrrt);
if (bias != prog.end()) if(bias != prog.end())
{ {
xwhr_rt = prog.insert_instruction(ins, op::add{}, xwhr_rt, br_br); xwhr_rt = prog.insert_instruction(ins, op::add{}, xwhr_rt, br_br);
} }
auto rt = prog.insert_instruction(ins, actv_func1, xwhr_rt); auto rt = prog.insert_instruction(ins, actv_func1, xwhr_rt);
instruction_ref xwhh_rt; instruction_ref xwhh_rt;
if (linear_before_reset == 0) if(linear_before_reset == 0)
{ {
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto xwht = prog.insert_instruction(ins, op::dot{}, xt, twh); auto xwht = prog.insert_instruction(ins, op::dot{}, xt, twh);
auto rt_ht = prog.insert_instruction(ins, op::mul{}, rt, ih); auto rt_ht = prog.insert_instruction(ins, op::mul{}, rt, ih);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht, trh); auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht, trh);
xwhh_rt = prog.insert_instruction(ins, op::add{}, xwht, rt_rt); xwhh_rt = prog.insert_instruction(ins, op::add{}, xwht, rt_rt);
if (bias != prog.end()) if(bias != prog.end())
{ {
xwhh_rt = prog.insert_instruction(ins, op::add{}, xwhh_rt, br_bh); xwhh_rt = prog.insert_instruction(ins, op::add{}, xwhh_rt, br_bh);
} }
...@@ -246,13 +256,13 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward, ...@@ -246,13 +256,13 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto xwht = prog.insert_instruction(ins, op::dot{}, xt, twh); auto xwht = prog.insert_instruction(ins, op::dot{}, xt, twh);
auto ih_rht = prog.insert_instruction(ins, op::dot{}, ih, twh); auto ih_rht = prog.insert_instruction(ins, op::dot{}, ih, twh);
if (bias != prog.end()) if(bias != prog.end())
{ {
ih_rht = prog.insert_instruction(ins, op::add{}, ih_rht, br_rbh); ih_rht = prog.insert_instruction(ins, op::add{}, ih_rht, br_rbh);
} }
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ih_rht); auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ih_rht);
xwhh_rt = prog.insert_instruction(ins, op::add{}, xwht, rt_rh); xwhh_rt = prog.insert_instruction(ins, op::add{}, xwht, rt_rh);
if (bias != prog.end()) if(bias != prog.end())
{ {
xwhh_rt = prog.insert_instruction(ins, op::add{}, xwhh_rt, br_wbh); xwhh_rt = prog.insert_instruction(ins, op::add{}, xwhh_rt, br_wbh);
} }
...@@ -268,9 +278,8 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward, ...@@ -268,9 +278,8 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
if(is_forward) if(is_forward)
{ {
hidden_out = (seq_index == 0) hidden_out =
? ih (seq_index == 0) ? ih : prog.insert_instruction(ins, op::concat{0}, hidden_out, ih);
: prog.insert_instruction(ins, op::concat{0}, hidden_out, ih);
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment