#include #include #include #include #include #include namespace migraph { namespace gpu { shape hip_concat::compute_shape(std::vector inputs) const { inputs.pop_back(); return op.compute_shape(inputs); } std::vector hip_concat::compute_offsets(const shape& output_shape, const std::vector args) const { std::vector offsets; std::vector offset(args[0].get_shape().lens().size(), 0); offset[op.axis] = 0; for(const auto& arg : args) { offsets.push_back(output_shape.index(offset)); offset[op.axis] += arg.get_shape().lens()[op.axis]; } return offsets; } argument hip_concat::compute(context&, const shape& output_shape, const std::vector& args) const { std::vector offsets = compute_offsets(output_shape, args); return device::concat(output_shape, args, offsets); } } // namespace gpu } // namespace migraph