Commit 40b6e561 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Backup of gather changes

Currently failing negative indices and negative axis tests.

All others "seem" to work

Noticed an oddball case that the cases that fail pass, if the sizes of a dimension
of a container is even instead of odd...
parent 3dac460c
...@@ -34,7 +34,7 @@ namespace gpu { ...@@ -34,7 +34,7 @@ namespace gpu {
// NOLINTNEXTLINE // NOLINTNEXTLINE
static const char* const gather_kernel = R"__migraphx__( static const char* const gather_kernel = R"__migraphx__(
#include <migraphx/kernels/gathernd.hpp> #include <migraphx/kernels/gather.hpp>
#include <migraphx/kernels/ops.hpp> #include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp> #include <migraphx/kernels/generic_constant.hpp>
......
...@@ -25,7 +25,10 @@ ...@@ -25,7 +25,10 @@
#define MIGRAPHX_GUARD_KERNELS_GATHERND_HPP #define MIGRAPHX_GUARD_KERNELS_GATHERND_HPP
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/algorithm.hpp> #include <migraphx/kernels/algorithm.hpp>
// debugging use MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG) for assertions
#include <migraphx/kernels/print.hpp>
namespace migraphx { namespace migraphx {
...@@ -33,15 +36,62 @@ template <int axis, class T, class U, class V> ...@@ -33,15 +36,62 @@ template <int axis, class T, class U, class V>
__device__ void gather(const T& data_t, const U& indices_t, const V& output_t) __device__ void gather(const T& data_t, const U& indices_t, const V& output_t)
{ {
auto ind = make_index(); auto ind = make_index();
auto output_shape = output_t.get_shape();
auto axis_dim_size = data_t.get_shape().lens[axis]; auto axis_dim_size = data_t.get_shape().lens[axis];
ind.global_stride(output_shape.elements(), [&](auto i) { ind.global_stride(output_t.get_shape().elements(), [&](auto i) {
auto idx = output_shape.multi(i); /* Debug
auto in_index = indices_t[idx[axis]]; print_once("Inputs: ");
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index; for(auto& item : data_t)
idx[axis] = in_index; {
output_t[i] = indices_t[idx]; print_once(item, " ");
}
print_once("\n");
print_once("indices: ");
for(auto& item : indices_t)
{
print_once(item, " ");
}
print_once("\n");
print_once("outputs before: ");
for(auto& item : output_t)
{
print_once(item, " ");
}
print_once("\n"); */
auto idx = data_t.get_shape().multi(i);
if(indices_t.get_shape().elements() == 1)
{
idx = data_t.get_shape().multi_stride(i);
}
/*print_once("idx: ");
for(auto& item : idx)
{
print_once(item, " ");
}
print_once("\n"); */
diff_int in_index = indices_t[idx[axis]];
// print_once("index ", in_index, "\n");
auto new_in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
// print_once("New index ", new_in_index, "\n");
idx[axis] = new_in_index;
output_t[i] = data_t[idx];
/* Debug
print_once("outputs after: ");
for(auto & item: output_t)
{
print_once(item, " ");
}
print_once("\n");
*/
}); });
} }
......
...@@ -90,7 +90,7 @@ struct miopen_apply ...@@ -90,7 +90,7 @@ struct miopen_apply
add_extend_op("argmax"); add_extend_op("argmax");
add_extend_op("argmin"); add_extend_op("argmin");
add_extend_op("gather"); // add_extend_op("gather");
add_extend_op("logsoftmax"); add_extend_op("logsoftmax");
add_extend_op("lrn"); add_extend_op("lrn");
add_extend_op("multinomial"); add_extend_op("multinomial");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment