"driver/conv_driver.cpp" did not exist on "8a4b59785b4f5ba48468d53618ca270c5da599a7"
Commit 60fb49dc authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add a clean implementation of the gather operator.

parent 11efa851
...@@ -766,7 +766,7 @@ struct gather ...@@ -766,7 +766,7 @@ struct gather
} }
// for scalar output // for scalar output
if(lens.size() == 0) if(lens.empty())
{ {
return {type, {1}, {0}}; return {type, {1}, {0}};
} }
...@@ -774,22 +774,9 @@ struct gather ...@@ -774,22 +774,9 @@ struct gather
return {type, lens}; return {type, lens};
} }
argument compute(const shape& output_shape, std::vector<argument> args) const template<typename V, typename T>
T compute_data_index(const V &indices, const int axis_index, const T& out_idx) const
{ {
argument result{output_shape};
// negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (args[0].get_shape().lens().size() + axis) : axis;
// max dimension in axis
visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) {
if(output_shape.scalar())
{
output[0] = data[indices.front()];
}
else
{
shape_for_each(output.get_shape(), [&](const auto& out_idx) {
auto data_idx = out_idx; auto data_idx = out_idx;
std::size_t index{}; std::size_t index{};
if(!indices.get_shape().scalar()) if(!indices.get_shape().scalar())
...@@ -806,6 +793,27 @@ struct gather ...@@ -806,6 +793,27 @@ struct gather
index = indices.front(); index = indices.front();
} }
data_idx.insert(data_idx.begin() + axis_index, index); data_idx.insert(data_idx.begin() + axis_index, index);
return data_idx;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
// negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (args[0].get_shape().lens().size() + axis) : axis;
// max dimension in axis
visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) {
if(output_shape.scalar())
{
output[0] = data[indices.front()];
}
else
{
shape_for_each(output.get_shape(), [&](const auto& out_idx) {
auto data_idx = compute_data_index(indices, axis_index, out_idx);
output(out_idx.begin(), out_idx.end()) = output(out_idx.begin(), out_idx.end()) =
data(data_idx.begin(), data_idx.end()); data(data_idx.begin(), data_idx.end());
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment