Commit 93db029b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

new CPU implementation of the gather operator.

parent 69c19fcc
...@@ -766,7 +766,7 @@ struct gather ...@@ -766,7 +766,7 @@ struct gather
} }
// for scalar output // for scalar output
if(lens.empty()) if(lens.size() == 0)
{ {
return {type, {1}, {0}}; return {type, {1}, {0}};
} }
...@@ -778,41 +778,39 @@ struct gather ...@@ -778,41 +778,39 @@ struct gather
{ {
argument result{output_shape}; argument result{output_shape};
// negative axis means counting dimensions from back // negative axis means counting dimensions from back
auto data_shape = args[0].get_shape(); int axis_index = (axis < 0) ? (args[0].get_shape().lens().size() + axis) : axis;
int axis_index = (axis < 0) ? (data_shape.lens().size() + axis) : axis;
// max dimension in axis // max dimension in axis
visit_all(result, args[0])([&](auto output, auto data) { visit_all(result, args[0])([&](auto output, auto data) {
std::vector<std::size_t> vec_indices; args[1].visit([&](auto indices) {
args[1].visit( if (output_shape.scalar())
[&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
if(output_shape.scalar())
{ {
output[0] = data[vec_indices.front()]; output[0] = data[indices.front()];
} }
else else
{ {
shape_for_each(output_shape, [&](const auto& out_idx) { shape_for_each(output.get_shape(), [&](const auto& out_idx) {
auto data_idx = out_idx; auto data_idx = out_idx;
if(args[1].get_shape().scalar()) std::size_t index{};
if (!indices.get_shape().scalar())
{ {
data_idx.insert(data_idx.begin() + axis_index, vec_indices.front()); auto start_it = data_idx.begin() + axis_index;
auto end_it = data_idx.begin() + axis_index + indices.get_shape().lens().size();
std::vector<std::size_t> ind_idx(start_it, end_it);
data_idx.erase(start_it, end_it);
index = indices(ind_idx.begin(), ind_idx.end());
} }
else else
{ {
args[1].visit([&](auto ind) { index = indices.front();
auto start_it = data_idx.begin() + axis_index;
auto end_it =
data_idx.end() + axis_index + args[1].get_shape().lens().size();
std::vector<std::size_t> ind_idx(start_it, end_it);
auto ind_it = data_idx.erase(start_it, end_it);
data_idx.insert(ind_it, ind(ind_idx.begin(), ind_idx.end()));
});
} }
output(out_idx.begin(), out_idx.end()) = data(data_idx.begin(), data_idx.end()); data_idx.insert(data_idx.begin() + axis_index, index);
output(out_idx.begin(), out_idx.end()) =
data(data_idx.begin(), data_idx.end());
}); });
} }
}); });
});
return result; return result;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment