Commit 249f5024 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

optimize the lstm operator rewrite.

parent a5941134
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/type_name.hpp>
#ifdef HAVE_GPU #ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
...@@ -101,8 +102,13 @@ migraphx::shape to_shape(const py::buffer_info& info) ...@@ -101,8 +102,13 @@ migraphx::shape to_shape(const py::buffer_info& info)
t = as.type_enum(); t = as.type_enum();
n = sizeof(as()); n = sizeof(as());
} }
}); });
if (n == 0)
{
MIGRAPHX_THROW("MIGRAPHX PYTHON: Unsupported data type" + info.format);
}
auto strides = info.strides; auto strides = info.strides;
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t { std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t {
return n > 0 ? i / n : 0; return n > 0 ? i / n : 0;
...@@ -134,6 +140,7 @@ PYBIND11_MODULE(migraphx, m) ...@@ -134,6 +140,7 @@ PYBIND11_MODULE(migraphx, m)
.def("__init__", .def("__init__",
[](migraphx::argument& x, py::buffer b) { [](migraphx::argument& x, py::buffer b) {
py::buffer_info info = b.request(); py::buffer_info info = b.request();
auto s = to_shape(info);
new(&x) migraphx::argument(to_shape(info), info.ptr); new(&x) migraphx::argument(to_shape(info), info.ptr);
}) })
.def("get_shape", &migraphx::argument::get_shape) .def("get_shape", &migraphx::argument::get_shape)
......
...@@ -903,35 +903,16 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -903,35 +903,16 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
migraphx::shape r_shape = r->get_shape(); migraphx::shape r_shape = r->get_shape();
long seq_len = static_cast<long>(seq_shape.lens()[0]); long seq_len = static_cast<long>(seq_shape.lens()[0]);
long hs = static_cast<long>(r_shape.lens()[2]); long hs = static_cast<long>(r_shape.lens()[2]);
auto bs = ih->get_shape().lens()[1];
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
// w matrix // w matrix, squeeze and transpose
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w); auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto wi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw); auto tsw = prog.insert_instruction(ins, op::transpose{perm}, sw);
auto tran_wi = prog.insert_instruction(ins, op::transpose{perm}, wi);
auto wo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw); // r matrix, squeeze and transpose
auto tran_wo = prog.insert_instruction(ins, op::transpose{perm}, wo);
auto wf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw);
auto tran_wf = prog.insert_instruction(ins, op::transpose{perm}, wf);
auto wc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sw);
auto tran_wc = prog.insert_instruction(ins, op::transpose{perm}, wc);
// r matrix
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r); auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto ri = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr); auto tsr = prog.insert_instruction(ins, op::transpose{perm}, sr);
auto tran_ri = prog.insert_instruction(ins, op::transpose{perm}, ri);
auto ro = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr);
auto tran_ro = prog.insert_instruction(ins, op::transpose{perm}, ro);
auto rf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
auto tran_rf = prog.insert_instruction(ins, op::transpose{perm}, rf);
auto rc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sr);
auto tran_rc = prog.insert_instruction(ins, op::transpose{perm}, rc);
// initial hidden state // initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
...@@ -941,37 +922,17 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -941,37 +922,17 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto ic_lens = sic->get_shape().lens(); auto ic_lens = sic->get_shape().lens();
// bias // bias
instruction_ref wbi_brcst{}; instruction_ref wb{};
instruction_ref rbi_brcst{}; instruction_ref rb{};
instruction_ref wbo_brcst{};
instruction_ref rbo_brcst{};
instruction_ref wbf_brcst{};
instruction_ref rbf_brcst{};
instruction_ref wbc_brcst{};
instruction_ref rbc_brcst{};
if(bias != prog.end()) if(bias != prog.end())
{ {
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wbi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto ub_wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {4 * hs}}, sbias);
auto rbi = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias); auto ub_rb = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {8 * hs}}, sbias);
wbi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, wbi);
rbi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, rbi);
auto wbo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto rbo = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
wbo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, wbo);
rbo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, rbo);
auto wbf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias); wb = prog.insert_instruction(ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_wb);
auto rbf = prog.insert_instruction(ins, op::slice{{0}, {6 * hs}, {7 * hs}}, sbias); rb = prog.insert_instruction(ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_rb);
wbf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, wbf);
rbf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, rbf);
auto wbc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto rbc = prog.insert_instruction(ins, op::slice{{0}, {7 * hs}, {8 * hs}}, sbias);
wbc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, wbc);
rbc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, rbc);
} }
// peep hole // peep hole
...@@ -997,30 +958,25 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -997,30 +958,25 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq); auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
instruction_ref xt_wi{}; instruction_ref xt_sih{};
instruction_ref ht_ri{}; if (bias != prog.end())
instruction_ref xt_wf{};
instruction_ref ht_rf{};
if(bias != prog.end())
{ {
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) auto xt_tsw = prog.insert_instruction(ins, op::dot{}, xt, tsw, wb);
xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi, wbi_brcst); auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr, rb);
ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri, rbi_brcst); xt_sih = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr);
// equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf, wbf_brcst);
ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf, rbf_brcst);
} }
else else
{ {
xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi); auto xt_tsw = prog.insert_instruction(ins, op::dot{}, xt, tsw);
ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri); auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr);
xt_sih = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr);
xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf);
ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf);
} }
auto it_before_actv = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
auto ft_before_actv = prog.insert_instruction(ins, op::add{}, xt_wf, ht_rf); auto it_before_actv = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_sih);
auto ot_before_actv = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_sih);
auto ft_before_actv = prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_sih);
auto ct_before_actv = prog.insert_instruction(ins, op::slice{{1}, {3 * hs}, {4 * hs}}, xt_sih);
if(pph != prog.end()) if(pph != prog.end())
{ {
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic); auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
...@@ -1031,21 +987,6 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1031,21 +987,6 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
} }
auto it = prog.insert_instruction(ins, actv_func1, it_before_actv); auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv); auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv);
// equation ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
instruction_ref xt_wc{};
instruction_ref ht_rc{};
if(bias != prog.end())
{
xt_wc = prog.insert_instruction(ins, op::dot{}, xt, tran_wc, wbc_brcst);
ht_rc = prog.insert_instruction(ins, op::dot{}, sih, tran_rc, rbc_brcst);
}
else
{
xt_wc = prog.insert_instruction(ins, op::dot{}, xt, tran_wc);
ht_rc = prog.insert_instruction(ins, op::dot{}, sih, tran_rc);
}
auto ct_before_actv = prog.insert_instruction(ins, op::add{}, xt_wc, ht_rc);
auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv); auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv);
// equation Ct = ft (.) Ct-1 + it (.) ct // equation Ct = ft (.) Ct-1 + it (.) ct
...@@ -1054,20 +995,6 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1054,20 +995,6 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto cellt = prog.insert_instruction(ins, op::add{}, ft_cell, it_ct); auto cellt = prog.insert_instruction(ins, op::add{}, ft_cell, it_ct);
last_cell_output = cellt; last_cell_output = cellt;
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
instruction_ref xt_wo{};
instruction_ref ht_ro{};
if(bias != prog.end())
{
xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo, wbo_brcst);
ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro, rbo_brcst);
}
else
{
xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo);
ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro);
}
auto ot_before_actv = prog.insert_instruction(ins, op::add{}, xt_wo, ht_ro);
if(pph != prog.end()) if(pph != prog.end())
{ {
auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt); auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment