concat.cpp 3.26 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/concat.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
6

Paul's avatar
Paul committed
7
namespace migraphx {
Paul's avatar
Paul committed
8
inline namespace MIGRAPHX_INLINE_NS {
9
10
11
namespace gpu {
namespace device {

Paul's avatar
Paul committed
12
#if 1
Paul's avatar
Paul committed
13
argument concat(hipStream_t stream,
Paul's avatar
Paul committed
14
15
16
17
                const migraphx::shape&,
                std::vector<migraphx::argument> args_vec,
                std::vector<std::size_t> offsets_vec)
{
Paul's avatar
Paul committed
18
    static constexpr const std::size_t limit = 6;
Paul's avatar
Paul committed
19
20
21
22
23
24
25
    if (offsets_vec.size() > limit)
        MIGRAPHX_THROW("Too many arguments to concat");
    std::size_t nelements = std::max_element(args_vec.begin(), std::prev(args_vec.end()), by(std::less<>{}, [&](auto&& x) { return x.get_shape().elements(); }))->get_shape().elements();
    auto offsets = to_hip_vector<limit>(offsets_vec);
    hip_visit_all<limit+1>(args_vec)([&](auto args) {
        auto output = args.back();
        auto ninputs = args.size() - 1;
Paul's avatar
Paul committed
26
        gs_launch(stream, nelements)([=](auto i) {
Paul's avatar
Paul committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
            for(std::size_t j = 0;j < ninputs;j++)
            {
                auto&& arg = args[j];
                if (i >= arg.size())
                    continue;
                auto idx = output.get_shape().index(arg.get_shape().multi(i));
                output.data()[idx + offsets[j]] = arg.data()[i];
            }
        });
    });
    return args_vec.back();
}

#else

argument concat(hipStream_t stream,
                const migraphx::shape&,
Paul's avatar
Paul committed
44
                std::vector<migraphx::argument> args,
wsttiger's avatar
wsttiger committed
45
                std::vector<std::size_t> offsets)
46
{
Paul's avatar
Paul committed
47
    auto ninputs = args.size() - 1;
Paul's avatar
Paul committed
48
    for(std::size_t j = 0; j < ninputs; j++)
49
    {
Paul's avatar
Paul committed
50
        auto&& arg            = args[j];
Paul's avatar
Paul committed
51
        std::size_t nelements = arg.get_shape().elements();
Paul's avatar
Paul committed
52
        auto offset           = offsets[j];
Paul's avatar
Paul committed
53
54
        hip_visit_all(args.back(), arg)([&](auto output, auto input) {
            gs_launch(stream, nelements)([=](auto i) {
Paul's avatar
Paul committed
55
                auto idx                    = output.get_shape().index(input.get_shape().multi(i));
Paul's avatar
Paul committed
56
                output.data()[idx + offset] = input.data()[i];
57
58
59
60
61
62
            });
        });
    }
    return args.back();
}

Paul's avatar
Paul committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
// argument concat(hipStream_t stream,
//                 const migraphx::shape& output_shape,
//                 std::vector<migraphx::argument> args,
//                 std::vector<std::size_t> offsets)
// {
//     for(std::size_t l = 0; l < args.size() - 1; l++)
//     {
//         auto argl             = args[l];
//         std::size_t nelements = argl.get_shape().elements();
//         visit_all(args.back(), argl)([&](auto output, auto input) {
//             visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
//                 auto* outptr      = output.data() + offsets[l];
//                 const auto* inptr = input.data();
//                 hip_tensor_descriptor<ndim> desc_input(input.get_shape());
//                 hip_tensor_descriptor<ndim> desc_output(output.get_shape());
//                 gs_launch(stream, nelements)(
//                     [=](auto i) { outptr[desc_output.linear(desc_input.multi(i))] = inptr[i]; });
//             });
//         });
//     }
//     return args.back();
// }
#endif
86
87
} // namespace device
} // namespace gpu
Paul's avatar
Paul committed
88
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
89
} // namespace migraphx