Commit 8f074e4e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change the axis value in gather not be to mutable and add corresponding tests,...

change the axis value in gather not be to mutable and add corresponding tests, according to Paul's comments.
parent 341974b6
......@@ -635,7 +635,7 @@ struct as_shape
struct gather
{
mutable int axis = 0;
int axis = 0;
std::string name() const { return "gather"; }
shape compute_shape(std::vector<shape> inputs) const
......@@ -649,43 +649,43 @@ struct gather
}
// negative axis means counting dimensions from back
if(axis < 0)
{
axis += n_dim;
}
int axis_index = (axis < 0) ? (n_dim + axis) : axis;
auto type = inputs[0].type();
lens[axis] = inputs[1].elements();
lens[axis_index] = inputs[1].elements();
return {type, lens};
}
template <class T>
void compute_index(const T& out_idx,
void compute_index(const T& out_idx, const int axis_index,
const std::vector<std::size_t>& vec_indices,
const std::size_t max_dim,
T& in_idx) const
{
in_idx = out_idx;
std::size_t idx = vec_indices.at(out_idx[axis]);
std::size_t idx = vec_indices.at(out_idx[axis_index]);
if(idx >= max_dim)
{
MIGRAPHX_THROW("Gather: indices are out of range in input tensor");
}
in_idx[axis] = idx;
in_idx[axis_index] = idx;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
// negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (output_shape.lens().size() + axis) : axis;
// max dimension in axis
std::size_t max_dim = args[0].get_shape().lens()[axis];
std::size_t max_dim = args[0].get_shape().lens()[axis_index];
std::vector<std::size_t> vec_indices;
args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
visit_all(result, args[0])([&](auto output, auto input) {
std::vector<std::size_t> in_idx;
shape_for_each(output.get_shape(), [&](const auto& idx) {
this->compute_index(idx, vec_indices, max_dim, in_idx);
this->compute_index(idx, axis_index, vec_indices, max_dim, in_idx);
output(idx.begin(), idx.end()) = input(in_idx.begin(), in_idx.end());
});
});
......
......@@ -14,8 +14,9 @@ namespace device {
argument gather(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
std::size_t axis)
int axis)
{
int axis_index = (axis < 0) ? (axis + output_shape.lens().size()) : axis;
visit_all(args.back(), args[0])([&](auto output, auto input) {
std::size_t nelements = output_shape.elements();
args[1].visit([&](auto indices) {
......@@ -27,7 +28,7 @@ argument gather(hipStream_t stream,
hip_tensor_descriptor<ndim> desc_output(output.get_shape());
gs_launch(stream, nelements)([=](auto i) {
auto lens = desc_output.multi(i);
lens[axis] = indices_ptr[lens[axis]];
lens[axis_index] = indices_ptr[lens[axis_index]];
outptr[i] = inptr[desc_input.linear(lens)];
});
});
......
......@@ -13,7 +13,7 @@ namespace device {
argument gather(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
std::size_t axis);
int axis);
} // namespace device
} // namespace gpu
......
......@@ -950,6 +950,22 @@ struct test_gather
}
};
struct test_gather_neg_axis
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}};
std::vector<int> indices{1, 2, 2, 1};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p;
}
};
void manual_identity()
{
migraphx::program p;
......@@ -1090,4 +1106,6 @@ int main()
verify_program<test_conv_bn_relu_pooling>();
verify_program<test_conv_bn_relu_pooling2>();
verify_program<test_slice>();
verify_program<test_gather>();
verify_program<test_gather_neg_axis>();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment