Commit c6f80ac7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 18fc2362
...@@ -494,20 +494,20 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -494,20 +494,20 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
// w matrix squeeze to 2-dim and do a transpose // w matrix squeeze to 2-dim and do a transpose
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w); auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto tw = prog.insert_instruction(ins, op::transpose{perm}, sw); auto tw = prog.insert_instruction(ins, op::transpose{perm}, sw);
// r slide to two part, zr and h // r slide to two part, zr and h
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r); auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto rzr = prog.insert_instruction(ins, op::slice{{0}, {0}, {2 * hs}}, sr); auto rzr = prog.insert_instruction(ins, op::slice{{0}, {0}, {2 * hs}}, sr);
auto trzr = prog.insert_instruction(ins, op::transpose{perm}, rzr); auto trzr = prog.insert_instruction(ins, op::transpose{perm}, rzr);
auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr); auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
auto trh = prog.insert_instruction(ins, op::transpose{perm}, rh); auto trh = prog.insert_instruction(ins, op::transpose{perm}, rh);
// initial states // initial states
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
size_t bs = ih->get_shape().lens()[1]; size_t bs = ih->get_shape().lens()[1];
// bias // bias
instruction_ref bwb{}; instruction_ref bwb{};
...@@ -516,13 +516,14 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -516,13 +516,14 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
if(bias != prog.end()) if(bias != prog.end())
{ {
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {3 * hs}}, sbias); auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {3 * hs}}, sbias);
bwb = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(3 * hs)}}, wb); bwb = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(3 * hs)}}, wb);
auto rb_zr = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {5 * hs}}, sbias); auto rb_zr = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {5 * hs}}, sbias);
auto rb_h = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias); auto rb_h = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
brb_zr = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(2 * hs)}}, rb_zr); brb_zr = prog.insert_instruction(
brb_h = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(hs)}}, rb_h); ins, op::broadcast{1, {bs, static_cast<size_t>(2 * hs)}}, rb_zr);
brb_h = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(hs)}}, rb_h);
} }
for(long i = 0; i < seq_len; i++) for(long i = 0; i < seq_len; i++)
...@@ -535,12 +536,12 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -535,12 +536,12 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
instruction_ref ih1_rzr{}; instruction_ref ih1_rzr{};
if(bias != prog.end()) if(bias != prog.end())
{ {
xt_w = prog.insert_instruction(ins, op::dot{}, xt, tw, bwb); xt_w = prog.insert_instruction(ins, op::dot{}, xt, tw, bwb);
ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr, brb_zr); ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr, brb_zr);
} }
else else
{ {
xt_w = prog.insert_instruction(ins, op::dot{}, xt, tw); xt_w = prog.insert_instruction(ins, op::dot{}, xt, tw);
ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr); ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr);
} }
...@@ -552,42 +553,42 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -552,42 +553,42 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto hr_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, ih1_rzr); auto hr_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, ih1_rzr);
auto xw_hr_z = prog.insert_instruction(ins, op::add{}, xw_z, hr_z); auto xw_hr_z = prog.insert_instruction(ins, op::add{}, xw_z, hr_z);
auto zt = prog.insert_instruction(ins, actv_func1, xw_hr_z); auto zt = prog.insert_instruction(ins, actv_func1, xw_hr_z);
auto xw_hr_r = prog.insert_instruction(ins, op::add{}, xw_r, hr_r); auto xw_hr_r = prog.insert_instruction(ins, op::add{}, xw_r, hr_r);
auto rt = prog.insert_instruction(ins, actv_func1, xw_hr_r); auto rt = prog.insert_instruction(ins, actv_func1, xw_hr_r);
instruction_ref hr_h{}; instruction_ref hr_h{};
if(linear_before_reset == 0) if(linear_before_reset == 0)
{ {
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih); auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
if (bias != prog.end()) if(bias != prog.end())
{ {
hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh, brb_h); hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh, brb_h);
} }
else else
{ {
hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh); hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh);
} }
} }
else else
{ {
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
instruction_ref ht1_rh{}; instruction_ref ht1_rh{};
if (bias != prog.end()) if(bias != prog.end())
{ {
ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh, brb_h); ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh, brb_h);
} }
else else
{ {
ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh); ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh);
} }
hr_h = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh); hr_h = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
} }
auto xw_hr_h = prog.insert_instruction(ins, op::add{}, xw_h, hr_h); auto xw_hr_h = prog.insert_instruction(ins, op::add{}, xw_h, hr_h);
auto ht = prog.insert_instruction(ins, actv_func2, xw_hr_h); auto ht = prog.insert_instruction(ins, actv_func2, xw_hr_h);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1 // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt); auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt);
......
...@@ -2650,10 +2650,11 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last> ...@@ -2650,10 +2650,11 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto output = p.add_instruction( auto output = p.add_instruction(
migraphx::op::lstm{hidden_size, migraphx::op::lstm{
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, hidden_size,
migraphx::op::rnn_direction::forward, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
clip}, migraphx::op::rnn_direction::forward,
clip},
seq, seq,
w, w,
r, r,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment