concat.cpp 1.27 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/concat.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
6

Paul's avatar
Paul committed
7
namespace migraphx {
Paul's avatar
Paul committed
8
inline namespace MIGRAPHX_INLINE_NS {
9
10
11
namespace gpu {
namespace device {

Paul's avatar
Paul committed
12
13
argument concat(hipStream_t stream,
                const migraphx::shape&,
Paul's avatar
Paul committed
14
                std::vector<migraphx::argument> args,
wsttiger's avatar
wsttiger committed
15
                std::vector<std::size_t> offsets)
16
{
Paul's avatar
Paul committed
17
    auto ninputs = args.size() - 1;
Paul's avatar
Paul committed
18
    for(std::size_t j = 0; j < ninputs; j++)
19
    {
Paul's avatar
Paul committed
20
        auto&& arg            = args[j];
Paul's avatar
Paul committed
21
        std::size_t nelements = arg.get_shape().elements();
Paul's avatar
Paul committed
22
        auto offset           = offsets[j];
Paul's avatar
Paul committed
23
24
        shape arg_shape{arg.get_shape().type(), arg.get_shape().lens()};
        hip_visit_all(args.back(), arg, arg_shape)([&](auto output, auto input, auto input_shape) {
Paul's avatar
Paul committed
25
            gs_launch(stream, nelements)([=](auto i) {
Paul's avatar
Paul committed
26
                auto input_idx              = input_shape.multi(i);
Paul's avatar
Paul committed
27
28
                auto idx                    = output.get_shape().index(input_idx);
                output.data()[idx + offset] = input[input_idx];
29
30
31
32
33
34
35
36
            });
        });
    }
    return args.back();
}

} // namespace device
} // namespace gpu
Paul's avatar
Paul committed
37
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
38
} // namespace migraphx