Unverified Commit b5090737 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge pull request #158 from ROCmSoftwarePlatform/gather_operator

Gather operator
parents 2d80965f 58018632
......@@ -635,49 +635,58 @@ struct as_shape
struct gather
{
std::size_t axis = 0;
int axis = 0;
std::string name() const { return "gather"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
auto lens = inputs[0].lens();
if(axis >= lens.size())
int n_dim = static_cast<int>(lens.size());
if(axis >= n_dim || axis < -n_dim)
{
MIGRAPHX_THROW("Gather, axis is out of range.");
MIGRAPHX_THROW("Gather: axis is out of range.");
}
auto type = inputs[0].type();
lens[axis] = inputs[1].elements();
// negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (n_dim + axis) : axis;
auto type = inputs[0].type();
lens[axis_index] = inputs[1].elements();
return {type, lens};
}
template <class T>
void compute_index(const T& out_idx,
const int axis_index,
const std::vector<std::size_t>& vec_indices,
const std::size_t max_dim,
T& in_idx) const
{
in_idx = out_idx;
std::size_t idx = vec_indices.at(out_idx[axis]);
std::size_t idx = vec_indices.at(out_idx[axis_index]);
if(idx >= max_dim)
{
MIGRAPHX_THROW("Gather: indices are out of range in input tensor");
}
in_idx[axis] = idx;
in_idx[axis_index] = idx;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
// negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (output_shape.lens().size() + axis) : axis;
// max dimension in axis
std::size_t max_dim = args[0].get_shape().lens()[axis];
std::size_t max_dim = args[0].get_shape().lens()[axis_index];
std::vector<std::size_t> vec_indices;
args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
visit_all(result, args[0])([&](auto output, auto input) {
std::vector<std::size_t> in_idx;
shape_for_each(output.get_shape(), [&](const auto& idx) {
this->compute_index(idx, vec_indices, max_dim, in_idx);
this->compute_index(idx, axis_index, vec_indices, max_dim, in_idx);
output(idx.begin(), idx.end()) = input(in_idx.begin(), in_idx.end());
});
});
......
......@@ -370,7 +370,7 @@ struct onnx_parser
instruction_ref
parse_gather(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
std::size_t axis = 0;
int axis = 0;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
......
......@@ -14,8 +14,9 @@ namespace device {
argument gather(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
std::size_t axis)
int axis)
{
int axis_index = (axis < 0) ? (axis + output_shape.lens().size()) : axis;
visit_all(args.back(), args[0])([&](auto output, auto input) {
std::size_t nelements = output_shape.elements();
args[1].visit([&](auto indices) {
......@@ -26,9 +27,9 @@ argument gather(hipStream_t stream,
hip_tensor_descriptor<ndim> desc_input(input.get_shape());
hip_tensor_descriptor<ndim> desc_output(output.get_shape());
gs_launch(stream, nelements)([=](auto i) {
auto lens = desc_output.multi(i);
lens[axis] = indices_ptr[lens[axis]];
outptr[i] = inptr[desc_input.linear(lens)];
auto lens = desc_output.multi(i);
lens[axis_index] = indices_ptr[lens[axis_index]];
outptr[i] = inptr[desc_input.linear(lens)];
});
});
});
......
......@@ -13,7 +13,7 @@ namespace device {
argument gather(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
std::size_t axis);
int axis);
} // namespace device
} // namespace gpu
......
......@@ -112,8 +112,8 @@ TEST_CASE(gather_test)
auto a0 = p.add_literal(migraphx::literal{s, data});
migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{0, 2};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 0;
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
......@@ -132,8 +132,28 @@ TEST_CASE(gather_test)
auto a0 = p.add_literal(migraphx::literal{s, data});
migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{0, 2};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 1;
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = 1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res_data(4 * 5);
std::vector<float> golden = {0.5f, 2.5f, 3.5f, 5.5f, 6.5f, 8.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
{
migraphx::program p;
std::vector<float> data(3 * 3);
std::iota(data.begin(), data.end(), 0.5);
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto a0 = p.add_literal(migraphx::literal{s, data});
migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{0, 2};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
......
......@@ -942,9 +942,25 @@ struct test_gather
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}};
std::vector<int> indices{1, 2, 2, 1};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 0;
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p;
}
};
struct test_gather_neg_axis
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}};
std::vector<int> indices{1, 2, 2, 1};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p;
}
......@@ -1090,4 +1106,6 @@ int main()
verify_program<test_conv_bn_relu_pooling>();
verify_program<test_conv_bn_relu_pooling2>();
verify_program<test_slice>();
verify_program<test_gather>();
verify_program<test_gather_neg_axis>();
}
......@@ -415,9 +415,9 @@ TEST_CASE(shape_test)
TEST_CASE(gather_test)
{
migraphx::program p;
auto l0 = p.add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}});
std::size_t axis = 1;
auto l0 = p.add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}});
int axis = 1;
p.add_instruction(migraphx::op::gather{axis}, l0, l1);
auto prog = migraphx::parse_onnx("gather_test.onnx");
......@@ -431,8 +431,8 @@ TEST_CASE(shape_gather_test)
auto l1 =
p.add_literal(migraphx::shape{migraphx::shape::int64_type, {3}}, l0->get_shape().lens());
migraphx::shape const_shape{migraphx::shape::int32_type, {1}};
auto l2 = p.add_literal(migraphx::literal{const_shape, {1}});
std::size_t axis = 0;
auto l2 = p.add_literal(migraphx::literal{const_shape, {1}});
int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, l1, l2);
auto prog = migraphx::parse_onnx("shape_gather.onnx");
......
......@@ -217,7 +217,7 @@ TEST_CASE(gather)
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
std::size_t axis = 1;
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 6, 4, 5}},
migraphx::op::gather{axis},
input,
......@@ -227,7 +227,24 @@ TEST_CASE(gather)
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
std::size_t axis = 4;
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {6, 3, 4, 5}},
migraphx::op::gather{axis},
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = 4;
throws_shape(migraphx::op::gather{axis}, input, indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = -5;
throws_shape(migraphx::op::gather{axis}, input, indices);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment