concat.cpp 3.34 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/concat.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
6

Paul's avatar
Paul committed
7
namespace migraphx {
Paul's avatar
Paul committed
8
inline namespace MIGRAPHX_INLINE_NS {
9
10
11
namespace gpu {
namespace device {

Paul's avatar
Paul committed
12
#if 1
Paul's avatar
Paul committed
13
argument concat(hipStream_t stream,
Paul's avatar
Paul committed
14
15
16
17
                const migraphx::shape&,
                std::vector<migraphx::argument> args_vec,
                std::vector<std::size_t> offsets_vec)
{
Paul's avatar
Paul committed
18
    static constexpr const std::size_t limit = 6;
Paul's avatar
Paul committed
19
    if(offsets_vec.size() > limit)
Paul's avatar
Paul committed
20
        MIGRAPHX_THROW("Too many arguments to concat");
Paul's avatar
Paul committed
21
22
23
24
25
26
    std::size_t nelements =
        std::max_element(args_vec.begin(),
                         std::prev(args_vec.end()),
                         by(std::less<>{}, [&](auto&& x) { return x.get_shape().elements(); }))
            ->get_shape()
            .elements();
Paul's avatar
Paul committed
27
    auto offsets = to_hip_vector<limit>(offsets_vec);
Paul's avatar
Paul committed
28
29
    hip_visit_all<limit + 1>(args_vec)([&](auto args) {
        auto output  = args.back();
Paul's avatar
Paul committed
30
        auto ninputs = args.size() - 1;
Paul's avatar
Paul committed
31
        gs_launch(stream, nelements)([=](auto i) {
Paul's avatar
Paul committed
32
            for(std::size_t j = 0; j < ninputs; j++)
Paul's avatar
Paul committed
33
34
            {
                auto&& arg = args[j];
Paul's avatar
Paul committed
35
                if(i >= arg.size())
Paul's avatar
Paul committed
36
37
38
39
40
41
42
43
44
45
46
47
48
                    continue;
                auto idx = output.get_shape().index(arg.get_shape().multi(i));
                output.data()[idx + offsets[j]] = arg.data()[i];
            }
        });
    });
    return args_vec.back();
}

#else

argument concat(hipStream_t stream,
                const migraphx::shape&,
Paul's avatar
Paul committed
49
                std::vector<migraphx::argument> args,
wsttiger's avatar
wsttiger committed
50
                std::vector<std::size_t> offsets)
51
{
Paul's avatar
Paul committed
52
    auto ninputs = args.size() - 1;
Paul's avatar
Paul committed
53
    for(std::size_t j = 0; j < ninputs; j++)
54
    {
Paul's avatar
Paul committed
55
        auto&& arg            = args[j];
Paul's avatar
Paul committed
56
        std::size_t nelements = arg.get_shape().elements();
Paul's avatar
Paul committed
57
        auto offset           = offsets[j];
Paul's avatar
Paul committed
58
59
        hip_visit_all(args.back(), arg)([&](auto output, auto input) {
            gs_launch(stream, nelements)([=](auto i) {
Paul's avatar
Paul committed
60
                auto idx                    = output.get_shape().index(input.get_shape().multi(i));
Paul's avatar
Paul committed
61
                output.data()[idx + offset] = input.data()[i];
62
63
64
65
66
67
            });
        });
    }
    return args.back();
}

Paul's avatar
Paul committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
// argument concat(hipStream_t stream,
//                 const migraphx::shape& output_shape,
//                 std::vector<migraphx::argument> args,
//                 std::vector<std::size_t> offsets)
// {
//     for(std::size_t l = 0; l < args.size() - 1; l++)
//     {
//         auto argl             = args[l];
//         std::size_t nelements = argl.get_shape().elements();
//         visit_all(args.back(), argl)([&](auto output, auto input) {
//             visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
//                 auto* outptr      = output.data() + offsets[l];
//                 const auto* inptr = input.data();
//                 hip_tensor_descriptor<ndim> desc_input(input.get_shape());
//                 hip_tensor_descriptor<ndim> desc_output(output.get_shape());
//                 gs_launch(stream, nelements)(
//                     [=](auto i) { outptr[desc_output.linear(desc_input.multi(i))] = inptr[i]; });
//             });
//         });
//     }
//     return args.back();
// }
#endif
91
92
} // namespace device
} // namespace gpu
Paul's avatar
Paul committed
93
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
94
} // namespace migraphx