Commit 8f074e4e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change the axis value in gather not be to mutable and add corresponding tests,...

change the axis value in gather not be to mutable and add corresponding tests, according to Paul's comments.
parent 341974b6
...@@ -635,7 +635,7 @@ struct as_shape ...@@ -635,7 +635,7 @@ struct as_shape
struct gather struct gather
{ {
mutable int axis = 0; int axis = 0;
std::string name() const { return "gather"; } std::string name() const { return "gather"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
...@@ -649,43 +649,43 @@ struct gather ...@@ -649,43 +649,43 @@ struct gather
} }
// negative axis means counting dimensions from back // negative axis means counting dimensions from back
if(axis < 0) int axis_index = (axis < 0) ? (n_dim + axis) : axis;
{
axis += n_dim;
}
auto type = inputs[0].type(); auto type = inputs[0].type();
lens[axis] = inputs[1].elements(); lens[axis_index] = inputs[1].elements();
return {type, lens}; return {type, lens};
} }
template <class T> template <class T>
void compute_index(const T& out_idx, void compute_index(const T& out_idx, const int axis_index,
const std::vector<std::size_t>& vec_indices, const std::vector<std::size_t>& vec_indices,
const std::size_t max_dim, const std::size_t max_dim,
T& in_idx) const T& in_idx) const
{ {
in_idx = out_idx; in_idx = out_idx;
std::size_t idx = vec_indices.at(out_idx[axis]); std::size_t idx = vec_indices.at(out_idx[axis_index]);
if(idx >= max_dim) if(idx >= max_dim)
{ {
MIGRAPHX_THROW("Gather: indices are out of range in input tensor"); MIGRAPHX_THROW("Gather: indices are out of range in input tensor");
} }
in_idx[axis] = idx; in_idx[axis_index] = idx;
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
// negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (output_shape.lens().size() + axis) : axis;
// max dimension in axis // max dimension in axis
std::size_t max_dim = args[0].get_shape().lens()[axis]; std::size_t max_dim = args[0].get_shape().lens()[axis_index];
std::vector<std::size_t> vec_indices; std::vector<std::size_t> vec_indices;
args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); }); args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
std::vector<std::size_t> in_idx; std::vector<std::size_t> in_idx;
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
this->compute_index(idx, vec_indices, max_dim, in_idx); this->compute_index(idx, axis_index, vec_indices, max_dim, in_idx);
output(idx.begin(), idx.end()) = input(in_idx.begin(), in_idx.end()); output(idx.begin(), idx.end()) = input(in_idx.begin(), in_idx.end());
}); });
}); });
......
...@@ -14,8 +14,9 @@ namespace device { ...@@ -14,8 +14,9 @@ namespace device {
argument gather(hipStream_t stream, argument gather(hipStream_t stream,
const migraphx::shape& output_shape, const migraphx::shape& output_shape,
std::vector<migraphx::argument> args, std::vector<migraphx::argument> args,
std::size_t axis) int axis)
{ {
int axis_index = (axis < 0) ? (axis + output_shape.lens().size()) : axis;
visit_all(args.back(), args[0])([&](auto output, auto input) { visit_all(args.back(), args[0])([&](auto output, auto input) {
std::size_t nelements = output_shape.elements(); std::size_t nelements = output_shape.elements();
args[1].visit([&](auto indices) { args[1].visit([&](auto indices) {
...@@ -27,7 +28,7 @@ argument gather(hipStream_t stream, ...@@ -27,7 +28,7 @@ argument gather(hipStream_t stream,
hip_tensor_descriptor<ndim> desc_output(output.get_shape()); hip_tensor_descriptor<ndim> desc_output(output.get_shape());
gs_launch(stream, nelements)([=](auto i) { gs_launch(stream, nelements)([=](auto i) {
auto lens = desc_output.multi(i); auto lens = desc_output.multi(i);
lens[axis] = indices_ptr[lens[axis]]; lens[axis_index] = indices_ptr[lens[axis_index]];
outptr[i] = inptr[desc_input.linear(lens)]; outptr[i] = inptr[desc_input.linear(lens)];
}); });
}); });
......
...@@ -13,7 +13,7 @@ namespace device { ...@@ -13,7 +13,7 @@ namespace device {
argument gather(hipStream_t stream, argument gather(hipStream_t stream,
const migraphx::shape& output_shape, const migraphx::shape& output_shape,
std::vector<migraphx::argument> args, std::vector<migraphx::argument> args,
std::size_t axis); int axis);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -950,6 +950,22 @@ struct test_gather ...@@ -950,6 +950,22 @@ struct test_gather
} }
}; };
struct test_gather_neg_axis
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}};
std::vector<int> indices{1, 2, 2, 1};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p;
}
};
void manual_identity() void manual_identity()
{ {
migraphx::program p; migraphx::program p;
...@@ -1090,4 +1106,6 @@ int main() ...@@ -1090,4 +1106,6 @@ int main()
verify_program<test_conv_bn_relu_pooling>(); verify_program<test_conv_bn_relu_pooling>();
verify_program<test_conv_bn_relu_pooling2>(); verify_program<test_conv_bn_relu_pooling2>();
verify_program<test_slice>(); verify_program<test_slice>();
verify_program<test_gather>();
verify_program<test_gather_neg_axis>();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment