Commit dea2cc08 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 8f074e4e
...@@ -651,14 +651,15 @@ struct gather ...@@ -651,14 +651,15 @@ struct gather
// negative axis means counting dimensions from back // negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (n_dim + axis) : axis; int axis_index = (axis < 0) ? (n_dim + axis) : axis;
auto type = inputs[0].type(); auto type = inputs[0].type();
lens[axis_index] = inputs[1].elements(); lens[axis_index] = inputs[1].elements();
return {type, lens}; return {type, lens};
} }
template <class T> template <class T>
void compute_index(const T& out_idx, const int axis_index, void compute_index(const T& out_idx,
const int axis_index,
const std::vector<std::size_t>& vec_indices, const std::vector<std::size_t>& vec_indices,
const std::size_t max_dim, const std::size_t max_dim,
T& in_idx) const T& in_idx) const
......
...@@ -27,9 +27,9 @@ argument gather(hipStream_t stream, ...@@ -27,9 +27,9 @@ argument gather(hipStream_t stream,
hip_tensor_descriptor<ndim> desc_input(input.get_shape()); hip_tensor_descriptor<ndim> desc_input(input.get_shape());
hip_tensor_descriptor<ndim> desc_output(output.get_shape()); hip_tensor_descriptor<ndim> desc_output(output.get_shape());
gs_launch(stream, nelements)([=](auto i) { gs_launch(stream, nelements)([=](auto i) {
auto lens = desc_output.multi(i); auto lens = desc_output.multi(i);
lens[axis_index] = indices_ptr[lens[axis_index]]; lens[axis_index] = indices_ptr[lens[axis_index]];
outptr[i] = inptr[desc_input.linear(lens)]; outptr[i] = inptr[desc_input.linear(lens)];
}); });
}); });
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment