gather.cpp 4.14 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/gather.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/hip.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

argument gather(hipStream_t stream,
                const migraphx::shape& output_shape,
                std::vector<migraphx::argument> args,
17
                int axis)
18
{
19
    int axis_index = (axis < 0) ? (axis + output_shape.lens().size()) : axis;
20
21
22
    visit_all(args.back(), args[0])([&](auto output, auto input) {
        std::size_t nelements = output_shape.elements();
        args[1].visit([&](auto indices) {
23
24
25
            const auto* indices_ptr = device_cast(indices.data());
            auto* outptr            = device_cast(output.data());
            const auto* inptr       = device_cast(input.data());
Shucai Xiao's avatar
Shucai Xiao committed
26
            if(output_shape.scalar())
27
            {
Shucai Xiao's avatar
Shucai Xiao committed
28
                gs_launch(stream, 1)([=](auto i) { outptr[i] = inptr[indices_ptr[0]]; });
29
            }
Shucai Xiao's avatar
Shucai Xiao committed
30
31
            else
            {
32
33
34
35
                visit_tensor_size(output_shape.lens().size(), [&](auto n_out_dim) {
                    visit_tensor_size(args[0].get_shape().lens().size(), [&](auto n_in_dim) {
                        hip_tensor_descriptor<n_in_dim> desc_input(input.get_shape());
                        hip_tensor_descriptor<n_out_dim> desc_output(output.get_shape());
Shucai Xiao's avatar
Shucai Xiao committed
36
                        if(args[1].get_shape().scalar())
37
38
                        {
                            gs_launch(stream, nelements)([=](auto ii) {
Shucai Xiao's avatar
Shucai Xiao committed
39
40
41
                                auto out_idx = desc_output.multi(ii);
                                auto in_idx  = desc_input.multi(0);
                                for(int i = 0; i < axis_index; ++i)
42
43
44
45
                                {
                                    in_idx[i] = out_idx[i];
                                }
                                in_idx[axis_index] = indices_ptr[0];
Shucai Xiao's avatar
Shucai Xiao committed
46
                                for(int i = axis_index + 1; i < n_in_dim; ++i)
47
48
49
                                {
                                    in_idx[i] = out_idx[i - 1];
                                }
Shucai Xiao's avatar
Shucai Xiao committed
50
                                outptr[ii] = inptr[desc_input.linear(in_idx)];
51
52
                            });
                        }
Shucai Xiao's avatar
Shucai Xiao committed
53
                        else
54
                        {
Shucai Xiao's avatar
Shucai Xiao committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
                            visit_tensor_size(
                                args[1].get_shape().lens().size(), [&](auto n_ind_dim) {
                                    hip_tensor_descriptor<n_ind_dim> desc_ind(args[1].get_shape());
                                    gs_launch(stream, nelements)([=](auto ii) {
                                        auto out_idx = desc_output.multi(ii);
                                        auto in_idx  = desc_input.multi(0);
                                        for(int i = 0; i < axis_index; ++i)
                                        {
                                            in_idx[i] = out_idx[i];
                                        }
                                        auto ind_idx = desc_ind.multi(0);
                                        for(int i = 0; i < n_ind_dim; ++i)
                                        {
                                            ind_idx[i] = out_idx[i + axis_index];
                                        }
                                        in_idx[axis_index] = indices_ptr[desc_ind.linear(ind_idx)];
                                        for(int i = axis_index + 1; i < n_in_dim; ++i)
                                        {
                                            in_idx[i] = out_idx[i + n_ind_dim - 1];
                                        }
                                        outptr[ii] = inptr[desc_input.linear(in_idx)];
                                    });
77
78
79
80
81
                                });
                        }
                    });
                });
            }
82
83
84
85
86
87
88
89
90
91
        });
    });

    return args.back();
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx