gather.cpp 1.62 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/gather.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/hip.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

argument gather(hipStream_t stream,
                const migraphx::shape& output_shape,
                std::vector<migraphx::argument> args,
17
                int axis)
18
{
19
    int axis_index = (axis < 0) ? (axis + output_shape.lens().size()) : axis;
20
21
    visit_all(args.back(), args[0])([&](auto output, auto input) {
        std::size_t nelements = output_shape.elements();
22
        args[1].visit([=](auto indices) {
23
24
25
26
27
28
            visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
                const auto* indices_ptr = device_cast(indices.data());
                auto* outptr            = device_cast(output.data());
                const auto* inptr       = device_cast(input.data());
                hip_tensor_descriptor<ndim> desc_input(input.get_shape());
                hip_tensor_descriptor<ndim> desc_output(output.get_shape());
29
                gs_launch(stream, nelements)([&](auto i) {
Shucai Xiao's avatar
Shucai Xiao committed
30
                    auto lens        = desc_output.multi(i);
31
                    lens[axis_index] = indices_ptr[lens[axis_index]];
Shucai Xiao's avatar
Shucai Xiao committed
32
                    outptr[i]        = inptr[desc_input.linear(lens)];
33
34
35
36
37
38
39
40
41
42
43
44
                });
            });
        });
    });

    return args.back();
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx