concat.cpp 1.02 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/concat.hpp>
4
#include <migraphx/gpu/device/contiguous.hpp>
5

Paul's avatar
Paul committed
6
namespace migraphx {
Paul's avatar
Paul committed
7
inline namespace MIGRAPHX_INLINE_NS {
8
9
10
namespace gpu {
namespace device {

Paul's avatar
Paul committed
11
12
argument concat(hipStream_t stream,
                const migraphx::shape&,
Paul's avatar
Paul committed
13
                std::vector<migraphx::argument> args,
wsttiger's avatar
wsttiger committed
14
                std::vector<std::size_t> offsets)
15
{
Paul's avatar
Paul committed
16
    auto ninputs = args.size() - 1;
Paul's avatar
Paul committed
17
    for(std::size_t j = 0; j < ninputs; j++)
18
    {
19
20
21
22
23
24
        auto&& arg        = args[j];
        auto offset       = offsets[j];
        auto byte_offset  = offset * arg.get_shape().type_size();
        auto output_shape = shape{
            arg.get_shape().type(), arg.get_shape().lens(), args.back().get_shape().strides()};
        auto output = argument{output_shape, args.back().data() + byte_offset};
25
        contiguous(stream, output, arg);
26
27
28
29
30
31
    }
    return args.back();
}

} // namespace device
} // namespace gpu
Paul's avatar
Paul committed
32
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
33
} // namespace migraphx