#include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace gpu { namespace device { argument gather(hipStream_t stream, const migraphx::shape& output_shape, std::vector args, std::size_t axis) { visit_all(args.back(), args[0])([&](auto output, auto input) { std::size_t nelements = output_shape.elements(); args[1].visit([&](auto indices) { visit_tensor_size(output_shape.lens().size(), [&](auto ndim) { const auto* indices_ptr = device_cast(indices.data()); auto* outptr = device_cast(output.data()); const auto* inptr = device_cast(input.data()); hip_tensor_descriptor desc_input(input.get_shape()); hip_tensor_descriptor desc_output(output.get_shape()); gs_launch(stream, nelements)([=](auto i) { auto lens = desc_output.multi(i); lens[axis] = indices_ptr[lens[axis]]; outptr[i] = inptr[desc_input.linear(lens)]; }); }); }); }); return args.back(); } } // namespace device } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx