gather.cpp 1.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/gather.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/hip.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

argument gather(hipStream_t stream,
Paul's avatar
Paul committed
15
16
17
                argument result,
                argument arg1,
                argument arg2,
18
                int axis)
19
{
Paul's avatar
Paul committed
20
21
22
23
24
25
26
    auto axis_index = (axis < 0) ? (axis + arg1.get_shape().lens().size()) : axis;
    auto& input_shape       = arg1.get_shape();
    auto lens               = input_shape.lens();
    lens[axis_index]        = arg2.get_shape().elements();
    std::size_t nelements = result.get_shape().elements();
    visit_all(result, arg1)([&](auto output, auto input) {
        arg2.visit([&](auto indices) {
27
            const auto* indices_ptr = device_cast(indices.data());
Shucai Xiao's avatar
Shucai Xiao committed
28
29
            auto* out_ptr           = device_cast(output.data());
            const auto* in_ptr      = device_cast(input.data());
Paul's avatar
Paul committed
30
            migraphx::shape out_comp_shape{result.get_shape().type(), lens};
31
32
33
34
35
36
37
            visit_tensor_size(out_comp_shape.lens().size(), [&](auto n_out_dim) {
                hip_tensor_descriptor<n_out_dim> desc_input(input_shape);
                hip_tensor_descriptor<n_out_dim> desc_output(out_comp_shape);
                gs_launch(stream, nelements)([=](auto ii) {
                    auto in_idx        = desc_output.multi(ii);
                    in_idx[axis_index] = indices_ptr[in_idx[axis_index]];
                    out_ptr[ii]        = in_ptr[desc_input.linear(in_idx)];
38
                });
39
            });
40
41
42
        });
    });

Paul's avatar
Paul committed
43
    return result;
44
45
46
47
48
49
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx