Commit 892d1100 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change implementation of the gather operator to include scalar index as input

parent 066c48e3
...@@ -794,7 +794,7 @@ struct gather ...@@ -794,7 +794,7 @@ struct gather
{ {
argument result{output_shape}; argument result{output_shape};
// negative axis means counting dimensions from back // negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (output_shape.lens().size() + axis) : axis; int axis_index = (axis < 0) ? (args[0].get_shape().lens().size() + axis) : axis;
// max dimension in axis // max dimension in axis
// std::size_t max_dim = args[0].get_shape().lens()[axis_index]; // std::size_t max_dim = args[0].get_shape().lens()[axis_index];
......
...@@ -434,11 +434,11 @@ struct onnx_parser ...@@ -434,11 +434,11 @@ struct onnx_parser
const std::vector<instruction_ref>&) const std::vector<instruction_ref>&)
{ {
literal v = parse_value(attributes.at("value")); literal v = parse_value(attributes.at("value"));
migraphx::shape v_shape = v.get_shape(); auto dim_size = attributes.at("value").t().dims_size();
// for constant containing 1 element, consider it as a scalar // if dim_size is 0, it is a scalar
if(v_shape.elements() == 1) if(dim_size == 0)
{ {
migraphx::shape scalar_shape{v_shape.type(), {1}, {0}}; migraphx::shape scalar_shape{v.get_shape().type(), {1}, {0}};
return prog.add_literal(migraphx::literal{scalar_shape, v.data()}); return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
} }
......
...@@ -16,67 +16,32 @@ argument gather(hipStream_t stream, ...@@ -16,67 +16,32 @@ argument gather(hipStream_t stream,
std::vector<migraphx::argument> args, std::vector<migraphx::argument> args,
int axis) int axis)
{ {
int axis_index = (axis < 0) ? (axis + output_shape.lens().size()) : axis; int axis_index = (axis < 0) ? (axis + args[0].get_shape().lens().size()) : axis;
visit_all(args.back(), args[0])([&](auto output, auto input) { visit_all(args.back(), args[0])([&](auto output, auto input) {
std::size_t nelements = output_shape.elements(); std::size_t nelements = output_shape.elements();
args[1].visit([&](auto indices) { args[1].visit([&](auto indices) {
const auto* indices_ptr = device_cast(indices.data()); const auto* indices_ptr = device_cast(indices.data());
auto* outptr = device_cast(output.data()); auto* out_ptr = device_cast(output.data());
const auto* inptr = device_cast(input.data()); const auto* in_ptr = device_cast(input.data());
if(output_shape.scalar()) if(output_shape.scalar())
{ {
gs_launch(stream, gs_launch(stream,
1)([=](auto i) { outptr[i] = inptr[static_cast<int>(indices_ptr[0])]; }); 1)([=](auto i) { out_ptr[i] = in_ptr[static_cast<int>(indices_ptr[0])]; });
} }
else else
{ {
visit_tensor_size(output_shape.lens().size(), [&](auto n_out_dim) { // if indices are a scalar, output has one dim smaller than input
visit_tensor_size(args[0].get_shape().lens().size(), [&](auto n_in_dim) { auto& input_shape = args[0].get_shape();
hip_tensor_descriptor<n_in_dim> desc_input(input.get_shape()); auto lens = input_shape.lens();
hip_tensor_descriptor<n_out_dim> desc_output(output.get_shape()); lens[axis_index] = args[1].get_shape().elements();
if(args[1].get_shape().scalar()) migraphx::shape out_comp_shape{output_shape.type(), lens};
{ visit_tensor_size(out_comp_shape.lens().size(), [&](auto n_out_dim) {
gs_launch(stream, nelements)([=](auto ii) { hip_tensor_descriptor<n_out_dim> desc_input(input_shape);
auto out_idx = desc_output.multi(ii); hip_tensor_descriptor<n_out_dim> desc_output(out_comp_shape);
auto in_idx = desc_input.multi(0); gs_launch(stream, nelements)([=](auto ii) {
for(int i = 0; i < axis_index; ++i) auto in_idx = desc_output.multi(ii);
{ in_idx[axis_index] = indices_ptr[in_idx[axis_index]];
in_idx[i] = out_idx[i]; out_ptr[ii] = in_ptr[desc_input.linear(in_idx)];
}
in_idx[axis_index] = indices_ptr[0];
for(int i = axis_index + 1; i < n_in_dim; ++i)
{
in_idx[i] = out_idx[i - 1];
}
outptr[ii] = inptr[desc_input.linear(in_idx)];
});
}
else
{
visit_tensor_size(
args[1].get_shape().lens().size(), [&](auto n_ind_dim) {
hip_tensor_descriptor<n_ind_dim> desc_ind(args[1].get_shape());
gs_launch(stream, nelements)([=](auto ii) {
auto out_idx = desc_output.multi(ii);
auto in_idx = desc_input.multi(0);
for(int i = 0; i < axis_index; ++i)
{
in_idx[i] = out_idx[i];
}
auto ind_idx = desc_ind.multi(0);
for(int i = 0; i < n_ind_dim; ++i)
{
ind_idx[i] = out_idx[i + axis_index];
}
in_idx[axis_index] = indices_ptr[desc_ind.linear(ind_idx)];
for(int i = axis_index + 1; i < n_in_dim; ++i)
{
in_idx[i] = out_idx[i + n_ind_dim - 1];
}
outptr[ii] = inptr[desc_input.linear(in_idx)];
});
});
}
}); });
}); });
} }
......
...@@ -235,7 +235,7 @@ TEST_CASE(gather) ...@@ -235,7 +235,7 @@ TEST_CASE(gather)
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}}; migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = 1; int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 6, 4, 5}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 4, 5}},
migraphx::op::gather{axis}, migraphx::op::gather{axis},
input, input,
indices); indices);
...@@ -245,7 +245,7 @@ TEST_CASE(gather) ...@@ -245,7 +245,7 @@ TEST_CASE(gather)
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}}; migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = -4; int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {6, 3, 4, 5}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 3, 4, 5}},
migraphx::op::gather{axis}, migraphx::op::gather{axis},
input, input,
indices); indices);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment