Commit 60d8b962 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix errors for rnn optimization.

parent dd26f1aa
......@@ -513,19 +513,25 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
// bias
instruction_ref wbz{}, rbz{};
instruction_ref wbr{}, rbr{};
instruction_ref wbh{}, rbh{};
instruction_ref bwbz{}, brbz{};
instruction_ref bwbr{}, brbr{};
instruction_ref bwbh{}, brbh{};
if(bias != prog.end())
{
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
bwbz = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wbz);
bwbr = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wbr);
bwbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wbh);
rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
brbz = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rbz);
brbr = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rbr);
brbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rbh);
}
for(long i = 0; i < seq_len; i++)
......@@ -539,30 +545,30 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
if(bias != prog.end())
{
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz, wbz);
auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz, rbz);
auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz, bwbz);
auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz, brbz);
auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz);
zt = prog.insert_instruction(ins, actv_func1, xht_z);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr, wbr);
auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr, rbr);
auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr, bwbr);
auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr, brbr);
auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
auto rt = prog.insert_instruction(ins, actv_func1, xht_r);
instruction_ref xht_h{};
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh, wbh);
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh, bwbh);
if(linear_before_reset == 0)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh, rbh);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh, brbh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
}
else
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh, rbh);
auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh, brbh);
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
}
......
......@@ -566,7 +566,7 @@ TEST_CASE(gemm_test)
auto t0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto alpha = 2.f;
auto beta = 2.0f;
auto beta = 1.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, t0, t1);
auto prog = migraphx::parse_onnx("gemm_test.onnx");
......
......@@ -554,175 +554,6 @@ TEST_CASE(gemm)
}
}
// 3 input arguments
TEST_CASE(gemm)
{
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
migraphx::op::dot{},
s_m1,
s_m2,
s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
migraphx::op::dot{},
s_m1,
s_m2,
s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
migraphx::op::dot{},
s_m1,
s_m2,
s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
migraphx::op::dot{},
s_m1,
s_m2,
s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
migraphx::op::dot{},
s_m1,
s_m2,
s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
migraphx::op::dot{},
s_m1,
s_m2,
s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
migraphx::op::dot{},
s_m1,
s_m2,
s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 7}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {3, 8}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {3, 7}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 8}},
migraphx::op::dot{},
s_m1,
s_m2,
s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 8}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 1}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {2, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {2, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {2, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {2, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
}
TEST_CASE(rnn)
{
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment