Commit 26ce057d authored by Shucai Xiao's avatar Shucai Xiao
Browse files

tmp implementation.

parent df0d0d9f
...@@ -766,7 +766,7 @@ struct gather ...@@ -766,7 +766,7 @@ struct gather
} }
// for scalar output // for scalar output
if(lens.size() == 0) if(lens.empty())
{ {
return {type, {1}, {0}}; return {type, {1}, {0}};
} }
...@@ -774,65 +774,33 @@ struct gather ...@@ -774,65 +774,33 @@ struct gather
return {type, lens}; return {type, lens};
} }
// template <class T>
// void compute_index(const T& out_idx,
// const int axis_index,
// const std::vector<std::size_t>& vec_indices,
// const std::size_t max_dim,
// T& in_idx) const
// {
// in_idx = out_idx;
// std::size_t idx = vec_indices.at(out_idx[axis_index]);
// if(idx >= max_dim)
// {
// MIGRAPHX_THROW("Gather: indices are out of range in input tensor");
// }
// in_idx[axis_index] = idx;
// }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
// negative axis means counting dimensions from back // negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (args[0].get_shape().lens().size() + axis) : axis; auto data_shape = args[0].get_shape();
int axis_index = (axis < 0) ? (data_shape.lens().size() + axis) : axis;
// max dimension in axis // max dimension in axis
// std::size_t max_dim = args[0].get_shape().lens()[axis_index];
// std::vector<std::size_t> vec_indices;
// args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
visit_all(result, args[0])([&](auto output, auto data) { visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) { std::vector<std::size_t> vec_indices;
if(indices.get_shape().scalar()) args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
{ if (output_shape.scalar())
if(output_shape.scalar()) {
{ output[0] = data[vec_indices.front()];
output[0] = data[indices.front()]; }
} else
else {
{ auto out_lens = data_shape.lens();
shape_for_each(output.get_shape(), [&](const auto& out_idx) { out_lens[axis_index] = vec_indices.size();
auto data_idx = out_idx; migraphx::shape out_comp_shape{output_shape.type(), out_lens};
data_idx.insert(data_idx.begin() + axis_index, indices.front()); shape_for_each(out_comp_shape, [&](const auto& out_idx) {
output(out_idx.begin(), out_idx.end()) = auto data_idx = out_idx;
data(data_idx.begin(), data_idx.end()); data_idx[axis_index] = vec_indices[data_idx[axis_index]];
}); //output(out_idx.begin(), out_idx.end()) = data(data_idx.begin(), data_idx.end());
} output[out_comp_shape.index(out_idx)] = data(data_idx.begin(), data_idx.end());
} });
else }
{
auto ind_lens = indices.get_shape().lens();
shape_for_each(output.get_shape(), [&](const auto& out_idx) {
auto data_idx = out_idx;
auto start_it = data_idx.begin() + axis_index;
auto end_it = data_idx.begin() + axis_index + ind_lens.size();
std::vector<std::size_t> ind_idx(start_it, end_it);
data_idx.erase(start_it, end_it);
data_idx.insert(start_it, indices(ind_idx.begin(), ind_idx.end()));
output(out_idx.begin(), out_idx.end()) =
data(data_idx.begin(), data_idx.end());
});
}
});
}); });
return result; return result;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment