Commit 89b80be6 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'gather_operator' into seq2seq_example

parents 82bd8e2e 33b6bcb6
......@@ -766,7 +766,7 @@ struct gather
}
// for scalar output
if(lens.size() == 0)
if(lens.empty())
{
return {type, {1}, {0}};
}
......@@ -774,21 +774,27 @@ struct gather
return {type, lens};
}
// template <class T>
// void compute_index(const T& out_idx,
// const int axis_index,
// const std::vector<std::size_t>& vec_indices,
// const std::size_t max_dim,
// T& in_idx) const
// {
// in_idx = out_idx;
// std::size_t idx = vec_indices.at(out_idx[axis_index]);
// if(idx >= max_dim)
// {
// MIGRAPHX_THROW("Gather: indices are out of range in input tensor");
// }
// in_idx[axis_index] = idx;
// }
template <typename V, typename T>
T compute_data_index(const V& indices, const int axis_index, const T& out_idx) const
{
auto data_idx = out_idx;
std::size_t index{};
if(!indices.get_shape().scalar())
{
auto start_it = data_idx.begin() + axis_index;
auto end_it = data_idx.begin() + axis_index + indices.get_shape().lens().size();
std::vector<std::size_t> ind_idx(start_it, end_it);
data_idx.erase(start_it, end_it);
index = indices(ind_idx.begin(), ind_idx.end());
}
else
{
index = indices.front();
}
data_idx.insert(data_idx.begin() + axis_index, index);
return data_idx;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
......@@ -797,37 +803,16 @@ struct gather
int axis_index = (axis < 0) ? (args[0].get_shape().lens().size() + axis) : axis;
// max dimension in axis
// std::size_t max_dim = args[0].get_shape().lens()[axis_index];
// std::vector<std::size_t> vec_indices;
// args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) {
if(indices.get_shape().scalar())
if(output_shape.scalar())
{
if(output_shape.scalar())
{
output[0] = data[indices.front()];
}
else
{
shape_for_each(output.get_shape(), [&](const auto& out_idx) {
auto data_idx = out_idx;
data_idx.insert(data_idx.begin() + axis_index, indices.front());
output(out_idx.begin(), out_idx.end()) =
data(data_idx.begin(), data_idx.end());
});
}
output[0] = data[indices.front()];
}
else
{
auto ind_lens = indices.get_shape().lens();
shape_for_each(output.get_shape(), [&](const auto& out_idx) {
auto data_idx = out_idx;
auto start_it = data_idx.begin() + axis_index;
auto end_it = data_idx.begin() + axis_index + ind_lens.size();
std::vector<std::size_t> ind_idx(start_it, end_it);
data_idx.erase(start_it, end_it);
data_idx.insert(start_it, indices(ind_idx.begin(), ind_idx.end()));
auto data_idx = compute_data_index(indices, axis_index, out_idx);
output(out_idx.begin(), out_idx.end()) =
data(data_idx.begin(), data_idx.end());
});
......
......@@ -164,6 +164,48 @@ TEST_CASE(gather_test)
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
{
migraphx::program p;
std::vector<float> data(3 * 3);
std::iota(data.begin(), data.end(), 0.5);
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto a0 = p.add_literal(migraphx::literal{s, data});
// scalar index
migraphx::shape s_indices{migraphx::shape::int32_type, {1}, {0}};
std::vector<int> indices{0};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res_data{};
std::vector<float> golden = {0.5f, 3.5f, 6.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
{
migraphx::program p;
std::vector<float> data(3);
std::iota(data.begin(), data.end(), 0.5);
migraphx::shape s{migraphx::shape::float_type, {3}};
auto a0 = p.add_literal(migraphx::literal{s, data});
// scalar index
migraphx::shape s_indices{migraphx::shape::int32_type, {1}, {0}};
std::vector<int> indices{0};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res_data{};
std::vector<float> golden = {0.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
}
TEST_CASE(squeeze_test)
......
......@@ -1068,6 +1068,54 @@ struct test_gather_neg_axis
}
};
struct test_gather_scalar_output
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {1}, {0}};
std::vector<int> indices{1};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p;
}
};
struct test_gather_scalar_index
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {1}, {0}};
std::vector<int> indices{1};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p;
}
};
struct test_gather_1d_index
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {1}};
std::vector<int> indices{1};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p;
}
};
void manual_identity()
{
migraphx::program p;
......@@ -2904,6 +2952,9 @@ int main()
verify_program<test_slice>();
verify_program<test_gather>();
verify_program<test_gather_neg_axis>();
verify_program<test_gather_scalar_output>();
verify_program<test_gather_scalar_index>();
verify_program<test_gather_1d_index>();
verify_program<test_rnn_forward>();
verify_program<test_rnn_forward10>();
verify_program<test_rnn_reverse>();
......
......@@ -251,6 +251,56 @@ TEST_CASE(gather)
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {1}};
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}},
migraphx::op::gather{axis},
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {1}, {0}};
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}},
migraphx::op::gather{axis},
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {1}, {0}};
int axis = 3;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4}},
migraphx::op::gather{axis},
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {3}};
migraphx::shape indices{migraphx::shape::int32_type, {1}, {0}};
int axis = 0;
expect_shape(migraphx::shape{migraphx::shape::float_type, {1}, {0}},
migraphx::op::gather{axis},
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {3}};
migraphx::shape indices{migraphx::shape::int32_type, {1}};
int axis = 0;
expect_shape(migraphx::shape{migraphx::shape::float_type, {1}},
migraphx::op::gather{axis},
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment