Unverified Commit 15eb1987 authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge pull request #273 from ROCmSoftwarePlatform/rnn_optimization

Rnn optimization
parents f93eeca3 67c6e634
Pipeline #672 failed with stages
in 0 seconds
......@@ -28,8 +28,10 @@ struct binary : op_name<Derived>
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto s1 = args[0].get_shape();
auto s2 = args[1].get_shape();
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
if(input1.get_shape().packed() and input2.get_shape().packed())
if(s1 == s2 and input1.get_shape().packed() and input2.get_shape().packed())
{
std::transform(input1.begin(),
input1.end(),
......
......@@ -52,6 +52,8 @@ inline std::string transform_string(std::string s, F f)
inline std::string to_upper(std::string s) { return transform_string(std::move(s), ::toupper); }
inline std::string to_lower(std::string s) { return transform_string(std::move(s), ::tolower); }
inline bool starts_with(const std::string& value, const std::string& prefix)
{
if(prefix.size() > value.size())
......
......@@ -100,6 +100,7 @@ struct onnx_parser
void init_actv_func()
{
// Support name format of all lower case or the first letter capital
map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
......@@ -871,7 +872,9 @@ struct onnx_parser
auto names = attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::copy(names.begin(), names.end(), vec_names.begin());
std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
return to_lower(name);
});
}
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
......@@ -962,7 +965,9 @@ struct onnx_parser
auto names = attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::copy(names.begin(), names.end(), vec_names.begin());
std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
return to_lower(name);
});
}
// need 4 activation functions
......@@ -1089,7 +1094,9 @@ struct onnx_parser
auto names = attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::copy(names.begin(), names.end(), vec_names.begin());
std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
return to_lower(name);
});
}
// need 6 activation functions for bidirectional directions
......
......@@ -8,6 +8,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/type_name.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp>
......@@ -101,8 +102,13 @@ migraphx::shape to_shape(const py::buffer_info& info)
t = as.type_enum();
n = sizeof(as());
}
});
if(n == 0)
{
MIGRAPHX_THROW("MIGRAPHX PYTHON: Unsupported data type" + info.format);
}
auto strides = info.strides;
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t {
return n > 0 ? i / n : 0;
......
......@@ -204,17 +204,19 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);
// initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
auto sih_lens = sih->get_shape().lens();
// bias
instruction_ref bb{};
if(bias != prog.end())
{
long hs = r->get_shape().lens()[2];
long hs = static_cast<long>(r->get_shape().lens()[2]);
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto b = prog.insert_instruction(ins, op::add{}, wb, rb);
bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape().lens()}, b);
auto wrb = prog.insert_instruction(ins, op::add{}, wb, rb);
bb = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wrb);
}
instruction_ref hidden_out = prog.end();
......@@ -228,20 +230,15 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
instruction_ref ht;
if(bias != prog.end())
{
ht = prog.insert_instruction(ins, op::add{}, xt_ht, bias);
}
else
{
ht = xt_ht;
xt_wi = prog.insert_instruction(ins, op::add{}, xt_wi, bb);
}
auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
// apply activation function
ht = prog.insert_instruction(ins, actv_func, ht);
sih = ht;
auto ht = prog.insert_instruction(ins, actv_func, xt_ht);
sih = ht;
// add the dimensions of sequence length (axis 0 for sequence length,
// axis 1 for num_directions
......@@ -485,62 +482,41 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
long hs = static_cast<long>(r_shape.lens()[2]);
migraphx::shape s(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
std::vector<int> data(s.elements(), 1);
std::vector<float> data(s.elements(), 1.0f);
auto l1 = prog.add_literal(migraphx::literal{s, data});
// weight matrix
// w matrix squeeze to 2-dim and do a transpose
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto wz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw);
auto tran_wz = prog.insert_instruction(ins, op::transpose{perm}, wz);
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto tw = prog.insert_instruction(ins, op::transpose{perm}, sw);
auto wr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw);
auto tran_wr = prog.insert_instruction(ins, op::transpose{perm}, wr);
// r slide to two part, zr and h
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto rzr = prog.insert_instruction(ins, op::slice{{0}, {0}, {2 * hs}}, sr);
auto trzr = prog.insert_instruction(ins, op::transpose{perm}, rzr);
auto wh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw);
auto tran_wh = prog.insert_instruction(ins, op::transpose{perm}, wh);
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto rz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr);
auto tran_rz = prog.insert_instruction(ins, op::transpose{perm}, rz);
auto rr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr);
auto tran_rr = prog.insert_instruction(ins, op::transpose{perm}, rr);
auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
auto tran_rh = prog.insert_instruction(ins, op::transpose{perm}, rh);
auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
auto trh = prog.insert_instruction(ins, op::transpose{perm}, rh);
// initial states
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
size_t bs = ih->get_shape().lens()[1];
// bias
instruction_ref brcst_bz{};
instruction_ref brcst_br{};
instruction_ref brcst_wbh{};
instruction_ref brcst_rbh{};
instruction_ref brcst_bh{};
instruction_ref bwb{};
instruction_ref brb_zr{};
instruction_ref brb_h{};
if(bias != prog.end())
{
auto broadcast_lens = sih->get_shape().lens();
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
brcst_wbh = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, wbh);
auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
brcst_rbh = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, rbh);
auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz);
brcst_bz = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, bz);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
brcst_br = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, br);
auto bh = prog.insert_instruction(ins, op::add{}, wbh, rbh);
brcst_bh = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, bh);
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {3 * hs}}, sbias);
bwb = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(3 * hs)}}, wb);
auto rb_zr = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {5 * hs}}, sbias);
auto rb_h = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
brb_zr = prog.insert_instruction(
ins, op::broadcast{1, {bs, static_cast<size_t>(2 * hs)}}, rb_zr);
brb_h = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(hs)}}, rb_h);
}
for(long i = 0; i < seq_len; i++)
......@@ -549,56 +525,58 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz);
auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz);
auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz);
auto xt_w = prog.insert_instruction(ins, op::dot{}, xt, tw);
auto ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr);
if(bias != prog.end())
{
xht_z = prog.insert_instruction(ins, op::add{}, xht_z, brcst_bz);
xt_w = prog.insert_instruction(ins, op::add{}, xt_w, bwb);
ih1_rzr = prog.insert_instruction(ins, op::add{}, ih1_rzr, brb_zr);
}
auto zt = prog.insert_instruction(ins, actv_func1, xht_z);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr);
auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr);
auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
if(bias != prog.end())
{
xht_r = prog.insert_instruction(ins, op::add{}, xht_r, brcst_br);
}
auto rt = prog.insert_instruction(ins, actv_func1, xht_r);
auto xw_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_w);
auto xw_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_w);
auto xw_h = prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_w);
auto hr_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, ih1_rzr);
auto hr_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, ih1_rzr);
auto xw_hr_z = prog.insert_instruction(ins, op::add{}, xw_z, hr_z);
auto zt = prog.insert_instruction(ins, actv_func1, xw_hr_z);
instruction_ref xht_h;
auto xw_hr_r = prog.insert_instruction(ins, op::add{}, xw_r, hr_r);
auto rt = prog.insert_instruction(ins, actv_func1, xw_hr_r);
instruction_ref hr_h{};
if(linear_before_reset == 0)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end())
{
xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_bh);
hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh, brb_h);
}
else
{
hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh);
}
}
else
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh);
instruction_ref ht1_rh{};
if(bias != prog.end())
{
ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brcst_rbh);
ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh, brb_h);
}
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end())
else
{
xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_wbh);
ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh);
}
hr_h = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
}
auto ht = prog.insert_instruction(ins, actv_func2, xht_h);
auto xw_hr_h = prog.insert_instruction(ins, op::add{}, xw_h, hr_h);
auto ht = prog.insert_instruction(ins, actv_func2, xw_hr_h);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt);
......@@ -913,35 +891,16 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
migraphx::shape r_shape = r->get_shape();
long seq_len = static_cast<long>(seq_shape.lens()[0]);
long hs = static_cast<long>(r_shape.lens()[2]);
auto bs = ih->get_shape().lens()[1];
std::vector<int64_t> perm{1, 0};
// w matrix
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto wi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw);
auto tran_wi = prog.insert_instruction(ins, op::transpose{perm}, wi);
auto wo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw);
auto tran_wo = prog.insert_instruction(ins, op::transpose{perm}, wo);
auto wf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw);
auto tran_wf = prog.insert_instruction(ins, op::transpose{perm}, wf);
// w matrix, squeeze and transpose
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto tsw = prog.insert_instruction(ins, op::transpose{perm}, sw);
auto wc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sw);
auto tran_wc = prog.insert_instruction(ins, op::transpose{perm}, wc);
// r matrix
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto ri = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr);
auto tran_ri = prog.insert_instruction(ins, op::transpose{perm}, ri);
auto ro = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr);
auto tran_ro = prog.insert_instruction(ins, op::transpose{perm}, ro);
auto rf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
auto tran_rf = prog.insert_instruction(ins, op::transpose{perm}, rf);
auto rc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sr);
auto tran_rc = prog.insert_instruction(ins, op::transpose{perm}, rc);
// r matrix, squeeze and transpose
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto tsr = prog.insert_instruction(ins, op::transpose{perm}, sr);
// initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
......@@ -951,40 +910,23 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto ic_lens = sic->get_shape().lens();
// bias
instruction_ref bi_brcst{};
instruction_ref bo_brcst{};
instruction_ref bf_brcst{};
instruction_ref bc_brcst{};
instruction_ref wrb{};
if(bias != prog.end())
{
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto bxi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto bhi = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto bi = prog.insert_instruction(ins, op::add{}, bxi, bhi);
bi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bi);
auto bxo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto bho = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
auto bo = prog.insert_instruction(ins, op::add{}, bxo, bho);
bo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bo);
auto bxf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
auto bhf = prog.insert_instruction(ins, op::slice{{0}, {6 * hs}, {7 * hs}}, sbias);
auto bf = prog.insert_instruction(ins, op::add{}, bxf, bhf);
bf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bf);
auto bxc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto bhc = prog.insert_instruction(ins, op::slice{{0}, {7 * hs}, {8 * hs}}, sbias);
auto bc = prog.insert_instruction(ins, op::add{}, bxc, bhc);
bc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bc);
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto ub_wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {4 * hs}}, sbias);
auto ub_rb = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {8 * hs}}, sbias);
auto ub_wrb = prog.insert_instruction(ins, op::add{}, ub_wb, ub_rb);
wrb = prog.insert_instruction(
ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_wrb);
}
// peep hole
instruction_ref pphi_brcst{};
instruction_ref ppho_brcst{};
instruction_ref pphf_brcst{};
if(pph != prog.end())
{
auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
......@@ -1004,44 +946,31 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri);
auto it_before_actv = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
if(pph != prog.end())
{
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);
}
auto xt_tsw = prog.insert_instruction(ins, op::dot{}, xt, tsw);
auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr);
auto xt_sih = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr);
if(bias != prog.end())
{
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, bi_brcst);
xt_sih = prog.insert_instruction(ins, op::add{}, xt_sih, wrb);
}
auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
// equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
auto xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf);
auto ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf);
auto ft_before_actv = prog.insert_instruction(ins, op::add{}, xt_wf, ht_rf);
auto it_before_actv = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_sih);
auto ot_before_actv = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_sih);
auto ft_before_actv =
prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_sih);
auto ct_before_actv =
prog.insert_instruction(ins, op::slice{{1}, {3 * hs}, {4 * hs}}, xt_sih);
if(pph != prog.end())
{
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);
auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct);
}
if(bias != prog.end())
{
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, bf_brcst);
}
auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv);
// equation ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
auto xt_wc = prog.insert_instruction(ins, op::dot{}, xt, tran_wc);
auto ht_rc = prog.insert_instruction(ins, op::dot{}, sih, tran_rc);
auto ct_before_actv = prog.insert_instruction(ins, op::add{}, xt_wc, ht_rc);
if(bias != prog.end())
{
ct_before_actv = prog.insert_instruction(ins, op::add{}, ct_before_actv, bc_brcst);
}
auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv);
// equation Ct = ft (.) Ct-1 + it (.) ct
......@@ -1050,19 +979,11 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto cellt = prog.insert_instruction(ins, op::add{}, ft_cell, it_ct);
last_cell_output = cellt;
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
auto xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo);
auto ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro);
auto ot_before_actv = prog.insert_instruction(ins, op::add{}, xt_wo, ht_ro);
if(pph != prog.end())
{
auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt);
}
if(bias != prog.end())
{
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, bo_brcst);
}
auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv);
// Ht = ot (.) h(Ct)
......
......@@ -2666,10 +2666,11 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
auto und = p.add_instruction(migraphx::op::undefined{});
auto output = p.add_instruction(
migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
r,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment