Commit e29e613f authored by Paul's avatar Paul
Browse files

Fix gather args

parent aa9863b6
......@@ -12,21 +12,22 @@ namespace gpu {
namespace device {
argument gather(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
argument result,
argument arg1,
argument arg2,
int axis)
{
auto axis_index = (axis < 0) ? (axis + args[0].get_shape().lens().size()) : axis;
visit_all(args.back(), args[0])([&](auto output, auto input) {
std::size_t nelements = output_shape.elements();
args[1].visit([&](auto indices) {
auto axis_index = (axis < 0) ? (axis + arg1.get_shape().lens().size()) : axis;
auto& input_shape = arg1.get_shape();
auto lens = input_shape.lens();
lens[axis_index] = arg2.get_shape().elements();
std::size_t nelements = result.get_shape().elements();
visit_all(result, arg1)([&](auto output, auto input) {
arg2.visit([&](auto indices) {
const auto* indices_ptr = device_cast(indices.data());
auto* out_ptr = device_cast(output.data());
const auto* in_ptr = device_cast(input.data());
auto& input_shape = args[0].get_shape();
auto lens = input_shape.lens();
lens[axis_index] = args[1].get_shape().elements();
migraphx::shape out_comp_shape{output_shape.type(), lens};
migraphx::shape out_comp_shape{result.get_shape().type(), lens};
visit_tensor_size(out_comp_shape.lens().size(), [&](auto n_out_dim) {
hip_tensor_descriptor<n_out_dim> desc_input(input_shape);
hip_tensor_descriptor<n_out_dim> desc_output(out_comp_shape);
......@@ -39,7 +40,7 @@ argument gather(hipStream_t stream,
});
});
return args.back();
return result;
}
} // namespace device
......
......@@ -16,7 +16,7 @@ argument hip_gather::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
return device::gather(ctx.get_stream().get(), output_shape, args, op.axis);
return device::gather(ctx.get_stream().get(), args.back(), args[0], args[1], op.axis);
}
} // namespace gpu
......
......@@ -11,8 +11,9 @@ namespace gpu {
namespace device {
argument gather(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
argument result,
argument arg1,
argument arg2,
int axis);
} // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment