Commit bd1b90a9 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Use indices instead of output for length adjustment

This is needed to get multi(i) below to work correctly when indexing.

Originally thought this was with the output_t.
parent 29060678
......@@ -39,7 +39,7 @@ __device__ void gather(const T& data_t, const U& indices_t, const V& output_t)
auto lengths = data_t.get_shape().lens;
auto axis_dim_size = lengths[axis];
lengths[axis] = output_t.get_shape().elements();
lengths[axis] = indices_t.get_shape().elements();
auto out_comp = make_shape(lengths, output_t.get_shape().strides);
out_comp.calculate_strides();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment