Commit 69140d27 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Cleanup dead and debug code from gather jit implimentation

Was debugging/ trying to figure out why indexing was incorrect. Used a bunch of
prints and such.
parent a402c83f
......@@ -27,8 +27,6 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/algorithm.hpp>
// debugging use MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG) for assertions
#include <migraphx/kernels/print.hpp>
namespace migraphx {
......@@ -39,73 +37,19 @@ __device__ void gather(const T& data_t, const U& indices_t, const V& output_t)
auto lengths = data_t.get_shape().lens;
auto axis_dim_size = lengths[axis];
lengths[axis] = indices_t.get_shape().elements();
lengths[axis] = indices_t.get_shape().elements();
auto out_comp = make_shape(lengths, output_t.get_shape().strides);
out_comp.calculate_strides();
//print_once("axis: ", axis, "\n");
//print_once("axis dim:", axis_dim_size, "\n");
auto out_comp = make_shape(lengths, output_t.get_shape().strides);
ind.global_stride(output_t.get_shape().elements(), [&](auto i) {
/* Debug
print_once("Inputs: ");
for(auto& item : data_t)
{
print_once(item, " ");
}
print_once("\n");
print_once("indices: ");
for(auto& item : indices_t)
{
print_once(item, " ");
}
print_once("\n");
print_once("outputs before: ");
for(auto& item : output_t)
{
print_once(item, " ");
}
print_once("\n"); */
auto idx = out_comp.multi(i);
if(indices_t.get_shape().elements() == 1)
{
idx = out_comp.multi_stride(i);
}
auto idx = out_comp.multi(i);
auto in_index = indices_t[idx[axis]];
auto new_in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
print("idx: ");
for(auto& item : idx)
{
print_once(item, " ");
}
print_once("\n");
//print("index ", in_index, "\n");
//print("New index ", new_in_index, "\n");
idx[axis] = new_in_index;
print("updated idx: ");
for(auto& item : idx)
{
print_once(item, " ");
}
print_once("\n");
output_t[i] = data_t[idx];
/* Debug
print("outputs after: ");
for(auto & item: output_t)
{
print_once(item, " ");
}
print_once("\n"); */
});
}
......
......@@ -122,7 +122,6 @@ struct shape
index_int tidx = idx;
for(diff_int is = result.size() - 1; is > 0; is--)
{
MIGRAPHX_ASSERT(lens[is] > 1);
result[is] = tidx % lens[is];
tidx = tidx / lens[is];
}
......@@ -130,20 +129,6 @@ struct shape
return result;
}
/// Convert single index into a multi-index
constexpr index_array multi_stride(index_int idx) const
{
index_array result;
index_int tidx = idx;
for(diff_int is = result.size() - 1; is > 0; is--)
{
MIGRAPHX_ASSERT(lens[is] > 1);
result[is] = tidx % strides[is];
tidx = tidx / strides[is];
}
result[0] = tidx;
return result;
}
/// Convert multi-index into a single index
constexpr index_int single(index_array idx) const
{
......@@ -160,8 +145,6 @@ struct shape
ss << "{" << s.lens << "}, {" << s.strides << "}";
return ss;
}
};
template <class Lens, class Strides>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment