Commit 09cee914 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

simplify the gpu gather implementation

parent b466ceb9
...@@ -23,14 +23,6 @@ argument gather(hipStream_t stream, ...@@ -23,14 +23,6 @@ argument gather(hipStream_t stream,
const auto* indices_ptr = device_cast(indices.data()); const auto* indices_ptr = device_cast(indices.data());
auto* out_ptr = device_cast(output.data()); auto* out_ptr = device_cast(output.data());
const auto* in_ptr = device_cast(input.data()); const auto* in_ptr = device_cast(input.data());
if(output_shape.scalar())
{
gs_launch(stream, 1)(
[=](auto i) { out_ptr[i] = in_ptr[static_cast<int>(indices_ptr[0])]; });
}
else
{
// if indices are a scalar, output has one dim smaller than input
auto& input_shape = args[0].get_shape(); auto& input_shape = args[0].get_shape();
auto lens = input_shape.lens(); auto lens = input_shape.lens();
lens[axis_index] = args[1].get_shape().elements(); lens[axis_index] = args[1].get_shape().elements();
...@@ -44,7 +36,6 @@ argument gather(hipStream_t stream, ...@@ -44,7 +36,6 @@ argument gather(hipStream_t stream,
out_ptr[ii] = in_ptr[desc_input.linear(in_idx)]; out_ptr[ii] = in_ptr[desc_input.linear(in_idx)];
}); });
}); });
}
}); });
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment