Commit 0d13db6e authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Backup of trying to get gather working

Tried to get a proper templated shape of out_comp but right now this
seems to break as I can't just update the length of a shape and get a proper
output of the strides. Currently this breaks/asserts.

I think this is the cause of axis > 0 failing since we're not getting proper gathering
for the other axes as a result and get repeated rows with the wrong data.
parent 96ff9c10
......@@ -44,6 +44,15 @@ struct greater
}
};
struct multiplies
{
template<class T>
constexpr T operator()(const T &lhs, const T &rhs) const
{
return lhs * rhs;
}
};
template <class InputIt, class T, class BinaryOperation>
constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op)
{
......@@ -54,6 +63,24 @@ constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op)
return init;
}
template<class InputIt, class OutputIt, class BinaryOperation>
constexpr OutputIt partial_sum(InputIt first, InputIt last,
OutputIt d_first, BinaryOperation op)
{
if (first == last)
return d_first;
typename std::iterator_traits<InputIt>::value_type sum = *first;
*d_first = sum;
while (++first != last)
{
sum = op(std::move(sum), *first); // std::move since C++20
*++d_first = sum;
}
return ++d_first;
}
template <class InputIt, class OutputIt>
constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first)
{
......
......@@ -149,6 +149,9 @@ struct array
constexpr T* begin() { return d; }
constexpr const T* begin() const { return d; }
constexpr T* rbegin() { return d[N -1]; }
constexpr const T* rbegin() const { return d[N - 1]; }
constexpr T* end() { return d + size(); }
constexpr const T* end() const { return d + size(); }
......
......@@ -36,10 +36,19 @@ template <int axis, class T, class U, class V>
__device__ void gather(const T& data_t, const U& indices_t, const V& output_t)
{
auto ind = make_index();
auto axis_dim_size = data_t.get_shape().lens[axis];
auto lengths = data_t.get_shape().lens;
auto axis_dim_size = lengths[axis];
lengths[axis] = output_t.get_shape().elements();
auto out_comp = make_shape(lengths, output_t.get_shape().strides);
out_comp.calculate_strides();
//print_once("axis: ", axis, "\n");
//print_once("axis dim:", axis_dim_size, "\n");
ind.global_stride(output_t.get_shape().elements(), [&](auto i) {
/* Debug
/* Debug
print_once("Inputs: ");
for(auto& item : data_t)
{
......@@ -60,38 +69,43 @@ __device__ void gather(const T& data_t, const U& indices_t, const V& output_t)
print_once(item, " ");
}
print_once("\n"); */
auto idx = data_t.get_shape().multi(i);
auto idx = out_comp.multi(i);
if(indices_t.get_shape().elements() == 1)
{
idx = data_t.get_shape().multi_stride(i);
idx = out_comp.multi_stride(i);
}
/*print_once("idx: ");
auto in_index = indices_t[idx[axis]];
auto new_in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
print("idx: ");
for(auto& item : idx)
{
print_once(item, " ");
}
print_once("\n"); */
diff_int in_index = indices_t[idx[axis]];
// print_once("index ", in_index, "\n");
auto new_in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
// print_once("New index ", new_in_index, "\n");
print_once("\n");
//print("index ", in_index, "\n");
//print("New index ", new_in_index, "\n");
idx[axis] = new_in_index;
print("updated idx: ");
for(auto& item : idx)
{
print_once(item, " ");
}
print_once("\n");
output_t[i] = data_t[idx];
/* Debug
print_once("outputs after: ");
/* Debug
print("outputs after: ");
for(auto & item: output_t)
{
print_once(item, " ");
}
print_once("\n");
*/
print_once("\n"); */
});
}
......
......@@ -39,8 +39,23 @@ struct shape
constexpr shape() = default;
constexpr shape(Lens l) : lens(l) {shape{}.calculate_strides();}
constexpr shape(Lens l, Strides s) : lens(l), strides(s) {}
constexpr auto calculate_strides()
{
strides.resize(lens.size(), 0);
if(strides.empty())
return;
strides.back() = 1;
partial_sum(lens.rbegin(),
lens.rend() - 1,
strides.rbegin() + 1,
multiplies());
}
constexpr auto elements() const { return _c<Lens{}.product()>; }
constexpr auto element_space() const { return _c<Strides{}.dot(Lens{} - 1) + 1>; }
......@@ -122,6 +137,7 @@ struct shape
index_int tidx = idx;
for(diff_int is = result.size() - 1; is > 0; is--)
{
MIGRAPHX_ASSERT(lens[is] > 1);
result[is] = tidx % lens[is];
tidx = tidx / lens[is];
}
......@@ -136,6 +152,7 @@ struct shape
index_int tidx = idx;
for(diff_int is = result.size() - 1; is > 0; is--)
{
MIGRAPHX_ASSERT(lens[is] > 1);
result[is] = tidx % strides[is];
tidx = tidx / strides[is];
}
......@@ -158,6 +175,8 @@ struct shape
ss << "{" << s.lens << "}, {" << s.strides << "}";
return ss;
}
};
template <class Lens, class Strides>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment