Commit a26a67e9 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

another implementation of the gather operator

parent 834255a3
......@@ -792,15 +792,23 @@ struct gather
}
else
{
auto out_lens = data_shape.lens();
out_lens[axis_index] = vec_indices.size();
migraphx::shape out_comp_shape{output_shape.type(), out_lens};
shape_for_each(out_comp_shape, [&](const auto& out_idx) {
shape_for_each(output_shape, [&](const auto& out_idx) {
auto data_idx = out_idx;
data_idx[axis_index] = vec_indices[data_idx[axis_index]];
// output(out_idx.begin(), out_idx.end()) = data(data_idx.begin(),
// data_idx.end());
output[out_comp_shape.index(out_idx)] = data(data_idx.begin(), data_idx.end());
if (args[1].get_shape().scalar())
{
data_idx.insert(data_idx.begin() + axis_index, vec_indices.front());
}
else
{
args[1].visit([&](auto ind) {
auto start_it = data_idx.begin() + axis_index;
auto end_it = data_idx.end() + axis_index + args[1].get_shape().lens().size();
std::vector<std::size_t> ind_idx(start_it, end_it);
auto ind_it = data_idx.erase(start_it, end_it);
data_idx.insert(ind_it, ind(ind_idx.begin(), ind_idx.end()));
});
}
output(out_idx.begin(), out_idx.end()) = data(data_idx.begin(), data_idx.end());
});
}
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment