Commit 40b6e561 authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Backup of gather changes

Currently failing negative indices and negative axis tests.

All others "seem" to work

Noticed an oddball case that the cases that fail pass, if the sizes of a dimension
of a container is even instead of odd...
parent 3dac460c
......@@ -34,7 +34,7 @@ namespace gpu {
// NOLINTNEXTLINE
static const char* const gather_kernel = R"__migraphx__(
#include <migraphx/kernels/gathernd.hpp>
#include <migraphx/kernels/gather.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
......
......@@ -25,7 +25,10 @@
#define MIGRAPHX_GUARD_KERNELS_GATHERND_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/algorithm.hpp>
// debugging use MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG) for assertions
#include <migraphx/kernels/print.hpp>
namespace migraphx {
......@@ -33,15 +36,62 @@ template <int axis, class T, class U, class V>
__device__ void gather(const T& data_t, const U& indices_t, const V& output_t)
{
auto ind = make_index();
auto output_shape = output_t.get_shape();
auto axis_dim_size = data_t.get_shape().lens[axis];
ind.global_stride(output_shape.elements(), [&](auto i) {
auto idx = output_shape.multi(i);
auto in_index = indices_t[idx[axis]];
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
idx[axis] = in_index;
output_t[i] = indices_t[idx];
ind.global_stride(output_t.get_shape().elements(), [&](auto i) {
/* Debug
print_once("Inputs: ");
for(auto& item : data_t)
{
print_once(item, " ");
}
print_once("\n");
print_once("indices: ");
for(auto& item : indices_t)
{
print_once(item, " ");
}
print_once("\n");
print_once("outputs before: ");
for(auto& item : output_t)
{
print_once(item, " ");
}
print_once("\n"); */
auto idx = data_t.get_shape().multi(i);
if(indices_t.get_shape().elements() == 1)
{
idx = data_t.get_shape().multi_stride(i);
}
/*print_once("idx: ");
for(auto& item : idx)
{
print_once(item, " ");
}
print_once("\n"); */
diff_int in_index = indices_t[idx[axis]];
// print_once("index ", in_index, "\n");
auto new_in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
// print_once("New index ", new_in_index, "\n");
idx[axis] = new_in_index;
output_t[i] = data_t[idx];
/* Debug
print_once("outputs after: ");
for(auto & item: output_t)
{
print_once(item, " ");
}
print_once("\n");
*/
});
}
......
......@@ -90,7 +90,7 @@ struct miopen_apply
add_extend_op("argmax");
add_extend_op("argmin");
add_extend_op("gather");
// add_extend_op("gather");
add_extend_op("logsoftmax");
add_extend_op("lrn");
add_extend_op("multinomial");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment