Commit 84ecee26 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 249f5024
...@@ -104,7 +104,7 @@ migraphx::shape to_shape(const py::buffer_info& info) ...@@ -104,7 +104,7 @@ migraphx::shape to_shape(const py::buffer_info& info)
} }
}); });
if (n == 0) if(n == 0)
{ {
MIGRAPHX_THROW("MIGRAPHX PYTHON: Unsupported data type" + info.format); MIGRAPHX_THROW("MIGRAPHX PYTHON: Unsupported data type" + info.format);
} }
...@@ -140,7 +140,7 @@ PYBIND11_MODULE(migraphx, m) ...@@ -140,7 +140,7 @@ PYBIND11_MODULE(migraphx, m)
.def("__init__", .def("__init__",
[](migraphx::argument& x, py::buffer b) { [](migraphx::argument& x, py::buffer b) {
py::buffer_info info = b.request(); py::buffer_info info = b.request();
auto s = to_shape(info); auto s = to_shape(info);
new(&x) migraphx::argument(to_shape(info), info.ptr); new(&x) migraphx::argument(to_shape(info), info.ptr);
}) })
.def("get_shape", &migraphx::argument::get_shape) .def("get_shape", &migraphx::argument::get_shape)
......
...@@ -903,17 +903,17 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -903,17 +903,17 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
migraphx::shape r_shape = r->get_shape(); migraphx::shape r_shape = r->get_shape();
long seq_len = static_cast<long>(seq_shape.lens()[0]); long seq_len = static_cast<long>(seq_shape.lens()[0]);
long hs = static_cast<long>(r_shape.lens()[2]); long hs = static_cast<long>(r_shape.lens()[2]);
auto bs = ih->get_shape().lens()[1]; auto bs = ih->get_shape().lens()[1];
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
// w matrix, squeeze and transpose // w matrix, squeeze and transpose
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w); auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto tsw = prog.insert_instruction(ins, op::transpose{perm}, sw); auto tsw = prog.insert_instruction(ins, op::transpose{perm}, sw);
// r matrix, squeeze and transpose // r matrix, squeeze and transpose
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r); auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto tsr = prog.insert_instruction(ins, op::transpose{perm}, sr); auto tsr = prog.insert_instruction(ins, op::transpose{perm}, sr);
// initial hidden state // initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
...@@ -931,8 +931,10 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -931,8 +931,10 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto ub_wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {4 * hs}}, sbias); auto ub_wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {4 * hs}}, sbias);
auto ub_rb = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {8 * hs}}, sbias); auto ub_rb = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {8 * hs}}, sbias);
wb = prog.insert_instruction(ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_wb); wb = prog.insert_instruction(
rb = prog.insert_instruction(ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_rb); ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_wb);
rb = prog.insert_instruction(
ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_rb);
} }
// peep hole // peep hole
...@@ -959,23 +961,25 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -959,23 +961,25 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
instruction_ref xt_sih{}; instruction_ref xt_sih{};
if (bias != prog.end()) if(bias != prog.end())
{ {
auto xt_tsw = prog.insert_instruction(ins, op::dot{}, xt, tsw, wb); auto xt_tsw = prog.insert_instruction(ins, op::dot{}, xt, tsw, wb);
auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr, rb); auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr, rb);
xt_sih = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr); xt_sih = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr);
} }
else else
{ {
auto xt_tsw = prog.insert_instruction(ins, op::dot{}, xt, tsw); auto xt_tsw = prog.insert_instruction(ins, op::dot{}, xt, tsw);
auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr); auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr);
xt_sih = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr); xt_sih = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr);
} }
auto it_before_actv = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_sih); auto it_before_actv = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_sih);
auto ot_before_actv = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_sih); auto ot_before_actv = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_sih);
auto ft_before_actv = prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_sih); auto ft_before_actv =
auto ct_before_actv = prog.insert_instruction(ins, op::slice{{1}, {3 * hs}, {4 * hs}}, xt_sih); prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_sih);
auto ct_before_actv =
prog.insert_instruction(ins, op::slice{{1}, {3 * hs}, {4 * hs}}, xt_sih);
if(pph != prog.end()) if(pph != prog.end())
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment