"src/include/rtg/tensor_view.hpp" did not exist on "592dd27326a42edd996beb50cbde58fe8c24ed51"
Commit e29e613f authored by Paul's avatar Paul
Browse files

Fix gather args

parent aa9863b6
...@@ -12,21 +12,22 @@ namespace gpu { ...@@ -12,21 +12,22 @@ namespace gpu {
namespace device { namespace device {
argument gather(hipStream_t stream, argument gather(hipStream_t stream,
const migraphx::shape& output_shape, argument result,
std::vector<migraphx::argument> args, argument arg1,
argument arg2,
int axis) int axis)
{ {
auto axis_index = (axis < 0) ? (axis + args[0].get_shape().lens().size()) : axis; auto axis_index = (axis < 0) ? (axis + arg1.get_shape().lens().size()) : axis;
visit_all(args.back(), args[0])([&](auto output, auto input) { auto& input_shape = arg1.get_shape();
std::size_t nelements = output_shape.elements(); auto lens = input_shape.lens();
args[1].visit([&](auto indices) { lens[axis_index] = arg2.get_shape().elements();
std::size_t nelements = result.get_shape().elements();
visit_all(result, arg1)([&](auto output, auto input) {
arg2.visit([&](auto indices) {
const auto* indices_ptr = device_cast(indices.data()); const auto* indices_ptr = device_cast(indices.data());
auto* out_ptr = device_cast(output.data()); auto* out_ptr = device_cast(output.data());
const auto* in_ptr = device_cast(input.data()); const auto* in_ptr = device_cast(input.data());
auto& input_shape = args[0].get_shape(); migraphx::shape out_comp_shape{result.get_shape().type(), lens};
auto lens = input_shape.lens();
lens[axis_index] = args[1].get_shape().elements();
migraphx::shape out_comp_shape{output_shape.type(), lens};
visit_tensor_size(out_comp_shape.lens().size(), [&](auto n_out_dim) { visit_tensor_size(out_comp_shape.lens().size(), [&](auto n_out_dim) {
hip_tensor_descriptor<n_out_dim> desc_input(input_shape); hip_tensor_descriptor<n_out_dim> desc_input(input_shape);
hip_tensor_descriptor<n_out_dim> desc_output(out_comp_shape); hip_tensor_descriptor<n_out_dim> desc_output(out_comp_shape);
...@@ -39,7 +40,7 @@ argument gather(hipStream_t stream, ...@@ -39,7 +40,7 @@ argument gather(hipStream_t stream,
}); });
}); });
return args.back(); return result;
} }
} // namespace device } // namespace device
......
...@@ -16,7 +16,7 @@ argument hip_gather::compute(context& ctx, ...@@ -16,7 +16,7 @@ argument hip_gather::compute(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
return device::gather(ctx.get_stream().get(), output_shape, args, op.axis); return device::gather(ctx.get_stream().get(), args.back(), args[0], args[1], op.axis);
} }
} // namespace gpu } // namespace gpu
......
...@@ -11,8 +11,9 @@ namespace gpu { ...@@ -11,8 +11,9 @@ namespace gpu {
namespace device { namespace device {
argument gather(hipStream_t stream, argument gather(hipStream_t stream,
const migraphx::shape& output_shape, argument result,
std::vector<migraphx::argument> args, argument arg1,
argument arg2,
int axis); int axis);
} // namespace device } // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment