Commit 904a913b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

first version that gru operator generate correct output.

parents 85e8901b 341974b6
......@@ -647,15 +647,16 @@ struct gather
{
check_shapes{inputs, *this}.has(2);
auto lens = inputs[0].lens();
if(axis >= lens.size() || axis < -lens.size())
int n_dim = static_cast<int>(lens.size());
if(axis >= n_dim || axis < -n_dim)
{
MIGRAPHX_THROW("Gather, axis is out of range.");
MIGRAPHX_THROW("Gather: axis is out of range.");
}
// negative value means counting dimensions from back
// negative axis means counting dimensions from back
if(axis < 0)
{
axis += lens.size();
axis += n_dim;
}
auto type = inputs[0].type();
......
......@@ -173,8 +173,9 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
long hs = static_cast<long>(r->get_shape().lens()[1]);
long seq_index = is_forward ? 0 : seq_len - 1;
migraphx::shape s(input->get_shape().type(), {1});
auto l1 = prog.add_literal(migraphx::literal{s, {1}});
migraphx::shape s(input->get_shape().type(), {input->get_shape().lens()[1], static_cast<std::size_t>(hs)});
std::vector<int> data(s.elements(), 1);
auto l1 = prog.add_literal(migraphx::literal{s, data});
// weight matrix
std::vector<int64_t> perm{1, 0};
......@@ -210,7 +211,8 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
br_bz = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, bz);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
br_br = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, br);
br_bh = prog.insert_instruction(ins, op::add{}, br_wbh, br_rbh);
auto bh = prog.insert_instruction(ins, op::add{}, wbh, rbh);
br_bh = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, bh);
}
for(long i = 0; i < seq_len; i++)
......@@ -229,7 +231,7 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto xwrt = prog.insert_instruction(ins, op::dot{}, xt, twr);
auto hrrt = prog.insert_instruction(ins, op::dot{}, xt, trr);
auto hrrt = prog.insert_instruction(ins, op::dot{}, ih, trr);
auto xwhr_rt = prog.insert_instruction(ins, op::add{}, xwrt, hrrt);
if(bias != prog.end())
{
......@@ -254,7 +256,7 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto xwht = prog.insert_instruction(ins, op::dot{}, xt, twh);
auto ih_rht = prog.insert_instruction(ins, op::dot{}, ih, twh);
auto ih_rht = prog.insert_instruction(ins, op::dot{}, ih, trh);
if(bias != prog.end())
{
ih_rht = prog.insert_instruction(ins, op::add{}, ih_rht, br_rbh);
......
......@@ -112,8 +112,8 @@ TEST_CASE(gather_test)
auto a0 = p.add_literal(migraphx::literal{s, data});
migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{0, 2};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 0;
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
......@@ -132,8 +132,28 @@ TEST_CASE(gather_test)
auto a0 = p.add_literal(migraphx::literal{s, data});
migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{0, 2};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 1;
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = 1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res_data(4 * 5);
std::vector<float> golden = {0.5f, 2.5f, 3.5f, 5.5f, 6.5f, 8.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
{
migraphx::program p;
std::vector<float> data(3 * 3);
std::iota(data.begin(), data.end(), 0.5);
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto a0 = p.add_literal(migraphx::literal{s, data});
migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{0, 2};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
......
......@@ -942,9 +942,9 @@ struct test_gather
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}};
std::vector<int> indices{1, 2, 2, 1};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 0;
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p;
}
......
......@@ -217,7 +217,7 @@ TEST_CASE(gather)
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
std::size_t axis = 1;
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 6, 4, 5}},
migraphx::op::gather{axis},
input,
......@@ -227,7 +227,7 @@ TEST_CASE(gather)
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
std::size_t axis = 4;
int axis = 4;
throws_shape(migraphx::op::gather{axis}, input, indices);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment