Commit 904a913b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

first version that gru operator generate correct output.

parents 85e8901b 341974b6
...@@ -647,15 +647,16 @@ struct gather ...@@ -647,15 +647,16 @@ struct gather
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
if(axis >= lens.size() || axis < -lens.size()) int n_dim = static_cast<int>(lens.size());
if(axis >= n_dim || axis < -n_dim)
{ {
MIGRAPHX_THROW("Gather, axis is out of range."); MIGRAPHX_THROW("Gather: axis is out of range.");
} }
// negative value means counting dimensions from back // negative axis means counting dimensions from back
if(axis < 0) if(axis < 0)
{ {
axis += lens.size(); axis += n_dim;
} }
auto type = inputs[0].type(); auto type = inputs[0].type();
......
...@@ -173,8 +173,9 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward, ...@@ -173,8 +173,9 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
long hs = static_cast<long>(r->get_shape().lens()[1]); long hs = static_cast<long>(r->get_shape().lens()[1]);
long seq_index = is_forward ? 0 : seq_len - 1; long seq_index = is_forward ? 0 : seq_len - 1;
migraphx::shape s(input->get_shape().type(), {1}); migraphx::shape s(input->get_shape().type(), {input->get_shape().lens()[1], static_cast<std::size_t>(hs)});
auto l1 = prog.add_literal(migraphx::literal{s, {1}}); std::vector<int> data(s.elements(), 1);
auto l1 = prog.add_literal(migraphx::literal{s, data});
// weight matrix // weight matrix
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
...@@ -210,7 +211,8 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward, ...@@ -210,7 +211,8 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
br_bz = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, bz); br_bz = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, bz);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr); auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
br_br = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, br); br_br = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, br);
br_bh = prog.insert_instruction(ins, op::add{}, br_wbh, br_rbh); auto bh = prog.insert_instruction(ins, op::add{}, wbh, rbh);
br_bh = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, bh);
} }
for(long i = 0; i < seq_len; i++) for(long i = 0; i < seq_len; i++)
...@@ -229,7 +231,7 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward, ...@@ -229,7 +231,7 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr) // equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto xwrt = prog.insert_instruction(ins, op::dot{}, xt, twr); auto xwrt = prog.insert_instruction(ins, op::dot{}, xt, twr);
auto hrrt = prog.insert_instruction(ins, op::dot{}, xt, trr); auto hrrt = prog.insert_instruction(ins, op::dot{}, ih, trr);
auto xwhr_rt = prog.insert_instruction(ins, op::add{}, xwrt, hrrt); auto xwhr_rt = prog.insert_instruction(ins, op::add{}, xwrt, hrrt);
if(bias != prog.end()) if(bias != prog.end())
{ {
...@@ -254,7 +256,7 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward, ...@@ -254,7 +256,7 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
{ {
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto xwht = prog.insert_instruction(ins, op::dot{}, xt, twh); auto xwht = prog.insert_instruction(ins, op::dot{}, xt, twh);
auto ih_rht = prog.insert_instruction(ins, op::dot{}, ih, twh); auto ih_rht = prog.insert_instruction(ins, op::dot{}, ih, trh);
if(bias != prog.end()) if(bias != prog.end())
{ {
ih_rht = prog.insert_instruction(ins, op::add{}, ih_rht, br_rbh); ih_rht = prog.insert_instruction(ins, op::add{}, ih_rht, br_rbh);
......
...@@ -112,8 +112,8 @@ TEST_CASE(gather_test) ...@@ -112,8 +112,8 @@ TEST_CASE(gather_test)
auto a0 = p.add_literal(migraphx::literal{s, data}); auto a0 = p.add_literal(migraphx::literal{s, data});
migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}}; migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{0, 2}; std::vector<int> indices{0, 2};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 0; int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, a0, a1); p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -132,8 +132,28 @@ TEST_CASE(gather_test) ...@@ -132,8 +132,28 @@ TEST_CASE(gather_test)
auto a0 = p.add_literal(migraphx::literal{s, data}); auto a0 = p.add_literal(migraphx::literal{s, data});
migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}}; migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{0, 2}; std::vector<int> indices{0, 2};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 1; int axis = 1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res_data(4 * 5);
std::vector<float> golden = {0.5f, 2.5f, 3.5f, 5.5f, 6.5f, 8.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
{
migraphx::program p;
std::vector<float> data(3 * 3);
std::iota(data.begin(), data.end(), 0.5);
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto a0 = p.add_literal(migraphx::literal{s, data});
migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{0, 2};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1); p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
......
...@@ -942,9 +942,9 @@ struct test_gather ...@@ -942,9 +942,9 @@ struct test_gather
migraphx::shape s{migraphx::shape::float_type, {3, 3}}; migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}}; migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}};
std::vector<int> indices{1, 2, 2, 1}; std::vector<int> indices{1, 2, 2, 1};
auto a0 = p.add_parameter("data", s); auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 0; int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, a0, a1); p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p; return p;
} }
......
...@@ -217,7 +217,7 @@ TEST_CASE(gather) ...@@ -217,7 +217,7 @@ TEST_CASE(gather)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}}; migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
std::size_t axis = 1; int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 6, 4, 5}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 6, 4, 5}},
migraphx::op::gather{axis}, migraphx::op::gather{axis},
input, input,
...@@ -227,7 +227,7 @@ TEST_CASE(gather) ...@@ -227,7 +227,7 @@ TEST_CASE(gather)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}}; migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
std::size_t axis = 4; int axis = 4;
throws_shape(migraphx::op::gather{axis}, input, indices); throws_shape(migraphx::op::gather{axis}, input, indices);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment