concat.cpp 1.23 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
#include <migraph/shape.hpp>
#include <migraph/argument.hpp>
#include <migraph/gpu/device/concat.hpp>
#include <migraph/gpu/device/tensor.hpp>
#include <migraph/gpu/device/launch.hpp>

namespace migraph {
namespace gpu {
namespace device {

Paul's avatar
Paul committed
11
argument concat(hipStream_t stream, const migraph::shape& output_shape,
wsttiger's avatar
wsttiger committed
12
13
                std::vector<migraph::argument> args,
                std::vector<std::size_t> offsets)
14
{
wsttiger's avatar
wsttiger committed
15
    for(std::size_t l = 0; l < args.size() - 1; l++)
16
    {
wsttiger's avatar
wsttiger committed
17
        auto argl             = args[l];
18
19
        std::size_t nelements = argl.get_shape().elements();
        visit_all(args.back(), argl)([&](auto output, auto input) {
wsttiger's avatar
wsttiger committed
20
            visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
21
22
23
24
                auto* outptr      = output.data() + offsets[l];
                const auto* inptr = input.data();
                hip_tensor_descriptor<ndim> desc_input(input.get_shape());
                hip_tensor_descriptor<ndim> desc_output(output.get_shape());
Paul's avatar
Paul committed
25
                gs_launch(stream, nelements)(
wsttiger's avatar
wsttiger committed
26
                    [=](auto i) { outptr[desc_output.linear(desc_input.multi(i))] = inptr[i]; });
27
28
29
30
31
32
33
34
35
            });
        });
    }
    return args.back();
}

} // namespace device
} // namespace gpu
} // namespace migraph