gather.cpp 1.67 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/gather.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

Paul's avatar
Paul committed
13
argument gather(hipStream_t stream, argument result, argument arg1, argument arg2, int axis)
14
{
Shucai Xiao's avatar
Shucai Xiao committed
15
16
17
18
19
    auto axis_index    = (axis < 0) ? (axis + arg1.get_shape().lens().size()) : axis;
    auto& input_shape  = arg1.get_shape();
    auto lens          = input_shape.lens();
    auto axis_dim_size = lens[axis_index];
    lens[axis_index]   = arg2.get_shape().elements();
Paul's avatar
Paul committed
20
    shape out_comp_shape{result.get_shape().type(), lens};
Paul's avatar
Paul committed
21
    std::size_t nelements = result.get_shape().elements();
Paul's avatar
Paul committed
22

Paul's avatar
Paul committed
23
24
25
26
    visit_all(result, arg1)([&](auto output, auto input_v) {
        hip_visit_views(input_v, out_comp_shape)([&](auto input, auto out_comp) {
            arg2.visit([&](auto indices) {
                const auto* indices_ptr = device_cast(indices.data());
Paul's avatar
Paul committed
27
                auto* output_ptr        = device_cast(output.data());
28
                gs_launch(stream, nelements, 256)([=](auto i) __device__ {
Paul's avatar
Paul committed
29
                    auto idx        = out_comp.multi(i);
Shucai Xiao's avatar
Shucai Xiao committed
30
31
32
                    auto in_index   = indices_ptr[idx[axis_index]];
                    in_index        = (in_index < 0) ? in_index + axis_dim_size : in_index;
                    idx[axis_index] = in_index;
Paul's avatar
Paul committed
33
                    output_ptr[i]   = input[idx];
34
                });
35
            });
36
37
38
        });
    });

Paul's avatar
Paul committed
39
    return result;
40
41
42
43
44
45
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx