concat.cpp 1.12 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/concat.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
6

Paul's avatar
Paul committed
7
namespace migraphx {
Paul's avatar
Paul committed
8
inline namespace MIGRAPHX_INLINE_NS {
9
10
11
namespace gpu {
namespace device {

Paul's avatar
Paul committed
12
13
argument concat(hipStream_t stream,
                const migraphx::shape&,
Paul's avatar
Paul committed
14
                std::vector<migraphx::argument> args,
wsttiger's avatar
wsttiger committed
15
                std::vector<std::size_t> offsets)
16
{
Paul's avatar
Paul committed
17
    auto ninputs = args.size() - 1;
Paul's avatar
Paul committed
18
    for(std::size_t j = 0; j < ninputs; j++)
19
    {
Paul's avatar
Paul committed
20
        auto&& arg            = args[j];
Paul's avatar
Paul committed
21
        std::size_t nelements = arg.get_shape().elements();
Paul's avatar
Paul committed
22
        auto offset           = offsets[j];
Paul's avatar
Paul committed
23
24
        hip_visit_all(args.back(), arg)([&](auto output, auto input) {
            gs_launch(stream, nelements)([=](auto i) {
Paul's avatar
Paul committed
25
                auto idx                    = output.get_shape().index(input.get_shape().multi(i));
Paul's avatar
Paul committed
26
                output.data()[idx + offset] = input.data()[i];
27
28
29
30
31
32
33
34
            });
        });
    }
    return args.back();
}

} // namespace device
} // namespace gpu
Paul's avatar
Paul committed
35
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
36
} // namespace migraphx