Commit 20d8803c authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Add definition for global_stride similar to removed device function

parent 86ff3faf
...@@ -59,7 +59,11 @@ __device__ void gather(const T& data_t, const U& indices_t, const V& output_t, S ...@@ -59,7 +59,11 @@ __device__ void gather(const T& data_t, const U& indices_t, const V& output_t, S
auto* output_ptr = output_t.data(); auto* output_ptr = output_t.data();
ind.global_stride(output_shape.elements(), [&](auto i) { ind.global_stride(output_shape.elements(), [&](auto i) {
auto idx = output_shape.multi(i);
auto in_index = indices_ptr[idx[axis]];
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
idx[axis] = in_index;
output_ptr[i] = indices_t[idx];
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment