concat.cpp 1.24 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
#include <migraph/shape.hpp>
#include <migraph/argument.hpp>
#include <migraph/gpu/device/concat.hpp>
#include <migraph/gpu/device/tensor.hpp>
#include <migraph/gpu/device/launch.hpp>

namespace migraph {
namespace gpu {
namespace device {

Paul's avatar
Paul committed
11
12
argument concat(hipStream_t stream,
                const migraph::shape& output_shape,
wsttiger's avatar
wsttiger committed
13
14
                std::vector<migraph::argument> args,
                std::vector<std::size_t> offsets)
15
{
wsttiger's avatar
wsttiger committed
16
    for(std::size_t l = 0; l < args.size() - 1; l++)
17
    {
wsttiger's avatar
wsttiger committed
18
        auto argl             = args[l];
19
20
        std::size_t nelements = argl.get_shape().elements();
        visit_all(args.back(), argl)([&](auto output, auto input) {
wsttiger's avatar
wsttiger committed
21
            visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
22
23
24
25
                auto* outptr      = output.data() + offsets[l];
                const auto* inptr = input.data();
                hip_tensor_descriptor<ndim> desc_input(input.get_shape());
                hip_tensor_descriptor<ndim> desc_output(output.get_shape());
Paul's avatar
Paul committed
26
                gs_launch(stream, nelements)(
wsttiger's avatar
wsttiger committed
27
                    [=](auto i) { outptr[desc_output.linear(desc_input.multi(i))] = inptr[i]; });
28
29
30
31
32
33
34
35
36
            });
        });
    }
    return args.back();
}

} // namespace device
} // namespace gpu
} // namespace migraph