Commit dea2cc08 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 8f074e4e
......@@ -651,14 +651,15 @@ struct gather
// negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (n_dim + axis) : axis;
auto type = inputs[0].type();
auto type = inputs[0].type();
lens[axis_index] = inputs[1].elements();
return {type, lens};
}
template <class T>
void compute_index(const T& out_idx, const int axis_index,
void compute_index(const T& out_idx,
const int axis_index,
const std::vector<std::size_t>& vec_indices,
const std::size_t max_dim,
T& in_idx) const
......
......@@ -27,9 +27,9 @@ argument gather(hipStream_t stream,
hip_tensor_descriptor<ndim> desc_input(input.get_shape());
hip_tensor_descriptor<ndim> desc_output(output.get_shape());
gs_launch(stream, nelements)([=](auto i) {
auto lens = desc_output.multi(i);
auto lens = desc_output.multi(i);
lens[axis_index] = indices_ptr[lens[axis_index]];
outptr[i] = inptr[desc_input.linear(lens)];
outptr[i] = inptr[desc_input.linear(lens)];
});
});
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment