concat.cpp 1.3 KB
Newer Older
1
2
3
4
5
6
7
#include <migraph/shape.hpp>
#include <migraph/argument.hpp>
#include <migraph/gpu/device/concat.hpp>
#include <migraph/gpu/device/tensor.hpp>
#include <migraph/gpu/device/launch.hpp>

namespace migraph {
8
namespace MIGRAPH_INLINE_NS {
9
10
11
namespace gpu {
namespace device {

Paul's avatar
Paul committed
12
13
argument concat(hipStream_t stream,
                const migraph::shape& output_shape,
wsttiger's avatar
wsttiger committed
14
15
                std::vector<migraph::argument> args,
                std::vector<std::size_t> offsets)
16
{
wsttiger's avatar
wsttiger committed
17
    for(std::size_t l = 0; l < args.size() - 1; l++)
18
    {
wsttiger's avatar
wsttiger committed
19
        auto argl             = args[l];
20
21
        std::size_t nelements = argl.get_shape().elements();
        visit_all(args.back(), argl)([&](auto output, auto input) {
wsttiger's avatar
wsttiger committed
22
            visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
23
24
25
26
                auto* outptr      = output.data() + offsets[l];
                const auto* inptr = input.data();
                hip_tensor_descriptor<ndim> desc_input(input.get_shape());
                hip_tensor_descriptor<ndim> desc_output(output.get_shape());
Paul's avatar
Paul committed
27
                gs_launch(stream, nelements)(
wsttiger's avatar
wsttiger committed
28
                    [=](auto i) { outptr[desc_output.linear(desc_input.multi(i))] = inptr[i]; });
29
30
31
32
33
34
35
36
            });
        });
    }
    return args.back();
}

} // namespace device
} // namespace gpu
37
} // namespace MIGRAPH_INLINE_NS
38
} // namespace migraph