gather.cpp 2.22 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/gather.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/hip.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

argument gather(hipStream_t stream,
                const migraphx::shape& output_shape,
                std::vector<migraphx::argument> args,
17
                int axis)
18
{
19
    int axis_index = (axis < 0) ? (axis + args[0].get_shape().lens().size()) : axis;
20
21
22
    visit_all(args.back(), args[0])([&](auto output, auto input) {
        std::size_t nelements = output_shape.elements();
        args[1].visit([&](auto indices) {
23
            const auto* indices_ptr = device_cast(indices.data());
24
25
            auto* out_ptr            = device_cast(output.data());
            const auto* in_ptr       = device_cast(input.data());
Shucai Xiao's avatar
Shucai Xiao committed
26
            if(output_shape.scalar())
27
            {
Shucai Xiao's avatar
Shucai Xiao committed
28
                gs_launch(stream,
29
                          1)([=](auto i) { out_ptr[i] = in_ptr[static_cast<int>(indices_ptr[0])]; });
30
            }
Shucai Xiao's avatar
Shucai Xiao committed
31
32
            else
            {
33
34
35
36
37
38
39
40
41
42
43
44
                // if indices are a scalar, output has one dim smaller than input
                auto& input_shape = args[0].get_shape();
                auto lens = input_shape.lens();
                lens[axis_index] = args[1].get_shape().elements();
                migraphx::shape out_comp_shape{output_shape.type(), lens};
                visit_tensor_size(out_comp_shape.lens().size(), [&](auto n_out_dim) {
                    hip_tensor_descriptor<n_out_dim> desc_input(input_shape);
                    hip_tensor_descriptor<n_out_dim> desc_output(out_comp_shape);
                    gs_launch(stream, nelements)([=](auto ii) {
                        auto in_idx = desc_output.multi(ii);
                        in_idx[axis_index] = indices_ptr[in_idx[axis_index]];
                        out_ptr[ii] = in_ptr[desc_input.linear(in_idx)];
45
46
47
                    });
                });
            }
48
49
50
51
52
53
54
55
56
57
        });
    });

    return args.back();
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx