Unverified Commit 4ec35e5f authored by turneram's avatar turneram Committed by GitHub
Browse files

Add GatherND operator (#1089)

Add ref and gpu implementations for ONNX op GatherND

Resolves #1032
parent 4c72cc95
......@@ -109,6 +109,7 @@ register_migraphx_ops(
flatten
floor
gather
gathernd
get_tuple_elem
greater
gru
......
#ifndef MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct gathernd
{
int batch_dims = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.batch_dims, "batch_dims"));
}
std::string name() const { return "gathernd"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
auto r = inputs.front().lens().size();
auto q = inputs.back().lens().size();
auto k = inputs.back().lens().back();
if(k > r - batch_dims)
{
MIGRAPHX_THROW("GATHERND: Indices of length " + std::to_string(k) +
" cannot be used to access data of rank " +
std::to_string(r - batch_dims));
}
auto indices_lens_iter = inputs.back().lens().begin();
auto output_lens_size = q + r - k - batch_dims - 1;
std::vector<std::size_t> output_lens(output_lens_size);
std::copy(indices_lens_iter, indices_lens_iter + (q - 1), output_lens.begin());
if(k < r - batch_dims)
{
auto data_lens = inputs.front().lens();
std::copy(
data_lens.begin() + batch_dims + k, data_lens.end(), output_lens.begin() + q - 1);
}
shape output_shape{inputs.front().type(), output_lens};
return output_shape;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) {
auto indices_shape = indices.get_shape();
auto indices_shape_lens = indices_shape.lens();
auto data_shape = data.get_shape();
auto data_shape_lens = data_shape.lens();
auto k = indices_shape.lens().back();
const auto num_slice_dims = k;
std::size_t num_slices = std::accumulate(indices_shape_lens.begin(),
indices_shape_lens.end() - 1,
1,
std::multiplies<std::size_t>());
std::size_t slice_size = std::accumulate(data_shape_lens.begin() + k + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
std::size_t num_batches = std::accumulate(data_shape_lens.begin(),
data_shape_lens.begin() + batch_dims,
1,
std::multiplies<std::size_t>());
std::size_t data_batch_stride =
std::accumulate(data_shape_lens.begin() + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
auto num_slices_per_batch = num_slices / num_batches;
std::vector<std::size_t> sizes_from_slice_dims(num_slice_dims);
{
auto running_product = slice_size;
for(std::size_t i = 0; i < num_slice_dims; ++i)
{
sizes_from_slice_dims[num_slice_dims - 1 - i] = running_product;
running_product *= data_shape_lens[batch_dims + num_slice_dims - 1 - i];
}
}
std::vector<std::size_t> input_slice_offsets(num_slices);
par_for(num_slices, [&](const auto i) {
std::size_t batch_idx = i / num_slices_per_batch;
auto slice_indices = indices.begin() + (i * num_slice_dims);
std::size_t relative_slice_offset = 0;
for(size_t dim_idx = 0; dim_idx < num_slice_dims; ++dim_idx)
{
int64_t index = *(slice_indices + dim_idx);
const std::size_t input_dim_idx = batch_dims + dim_idx;
const auto input_dim = data_shape_lens[input_dim_idx];
if(index < -static_cast<int64_t>(input_dim) or
index >= static_cast<int64_t>(input_dim))
MIGRAPHX_THROW("GatherND: index " + std::to_string(index) +
" is out of bounds for dim of len " +
std::to_string(input_dim));
if(index < 0)
index += input_dim;
relative_slice_offset += index * sizes_from_slice_dims[dim_idx];
}
input_slice_offsets[i] =
(batch_idx * data_batch_stride) + relative_slice_offset;
});
par_for(num_slices * slice_size, [&](const auto i) {
auto slice_offset = input_slice_offsets[i / slice_size];
output[i] = data[slice_offset + i % slice_size];
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -35,6 +35,7 @@
#include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/gathernd.hpp>
#include <migraphx/op/get_tuple_elem.hpp>
#include <migraphx/op/greater.hpp>
#include <migraphx/op/gru.hpp>
......
......@@ -28,6 +28,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Flatten", "flatten"},
{"Floor", "floor"},
{"Gather", "gather"},
{"GatherND", "gathernd"},
{"Identity", "identity"},
{"IsNaN", "isnan"},
{"LeakyRelu", "leaky_relu"},
......
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// NOLINTNEXTLINE
static const char* const gathernd_kernel = R"__migraphx__(
#include <migraphx/kernels/gathernd.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void gathernd_kernel(void* in_data, void* in_indices, void* output)
{
make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
auto settings = make_gathernd_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{BATCH_DIMS}));
gathernd(xs..., settings);
});
}
}
} // namespace migraphx
)__migraphx__";
struct gathernd_compiler : compiler<gathernd_compiler>
{
std::vector<std::string> names() const { return {"gathernd"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
auto out_s = inputs.back();
options.set_launch_params(v, compute_global_for(ctx, out_s.elements()));
options.inputs = inputs;
options.output = out_s;
options.kernel_name = "gathernd_kernel";
options.virtual_inputs = inputs;
// batch_dims
assert(v.contains("batch_dims"));
auto batch_dims = v.at("batch_dims").to<int64_t>();
options.params += " -DBATCH_DIMS=" + std::to_string(batch_dims);
return compile_hip_code_object(gathernd_kernel, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -21,6 +21,16 @@ struct greater
}
};
template <class InputIt, class T, class BinaryOperation>
constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op)
{
for(; first != last; ++first)
{
init = op(std::move(init), *first);
}
return init;
}
template <class InputIt, class OutputIt>
constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first)
{
......
#ifndef MIGRAPHX_GUARD_KERNELS_GATHERND_HPP
#define MIGRAPHX_GUARD_KERNELS_GATHERND_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace migraphx {
template <class T>
struct gathernd_settings
{
T batch_dims{};
};
template <class... Ts>
constexpr gathernd_settings<Ts...> make_gathernd_settings(Ts... xs)
{
return {xs...};
}
template <class T, class U, class V, class Settings>
__device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t, Settings s)
{
auto ind = make_index();
auto batch_dims = s.batch_dims;
auto output_shape = output_t.get_shape();
auto indices_shape = indices_t.get_shape();
auto data_shape = data_t.get_shape();
auto indices_shape_lens = indices_shape.lens;
auto data_shape_lens = data_shape.lens;
auto num_slice_dims = indices_shape_lens.back();
std::size_t num_slices = accumulate(indices_shape_lens.begin(),
indices_shape_lens.end() - 1,
1,
std::multiplies<std::size_t>());
std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
const std::size_t num_batches = accumulate(data_shape_lens.begin(),
data_shape_lens.begin() + batch_dims,
1,
std::multiplies<std::size_t>());
const std::size_t data_batch_stride = accumulate(data_shape_lens.begin() + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
const auto num_slices_per_batch = num_slices / num_batches;
ind.global_stride(output_shape.elements(), [&](auto i) {
const auto* indices_ptr = indices_t.data();
const std::size_t j = i / slice_size;
const std::size_t batch_idx = j / num_slices_per_batch;
auto* slice_indices = indices_ptr + (j * num_slice_dims);
std::size_t relative_slice_offset = 0;
for(std::size_t idx = 0; idx < num_slice_dims; ++idx)
{
int64_t index = slice_indices[idx];
const std::size_t input_dim_idx = batch_dims + idx;
const auto input_dim = data_shape_lens[input_dim_idx];
assert(index >= -static_cast<int64_t>(input_dim) and
index < static_cast<int64_t>(input_dim));
if(index < 0)
index += input_dim;
std::size_t size_from_slice_dims =
accumulate(data_shape_lens.begin() + batch_dims + idx + 1,
data_shape_lens.begin() + batch_dims + num_slice_dims,
slice_size,
std::multiplies<std::size_t>());
relative_slice_offset += index * size_from_slice_dims;
}
auto slice_offset = (batch_idx * data_batch_stride) + relative_slice_offset;
output_t[i] = data_t[slice_offset + i % slice_size];
});
}
} // namespace migraphx
#endif
gathernd_batch_dims_test:
/
data
indicesy"GatherND*
batch_dimsgathernd_batch_dims_testZ
data



Z
indices


b
y


B
\ No newline at end of file
 gathernd_test:q

data
indicesy"GatherND gathernd_testZ
data


Z
indices


b
y

B
\ No newline at end of file
......@@ -1666,6 +1666,35 @@ def gather_elements_axis1_test():
return ([node], [x, i], [y])
@onnx_test
def gathernd_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 2])
i = helper.make_tensor_value_info('indices', TensorProto.INT64, [2, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2])
node = onnx.helper.make_node('GatherND',
inputs=['data', 'indices'],
outputs=['y'])
return ([node], [x, i], [y])
@onnx_test
def gathernd_batch_dims_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 2, 2])
i = helper.make_tensor_value_info('indices', TensorProto.INT64, [2, 1])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2])
node = onnx.helper.make_node(
'GatherND',
inputs=['data', 'indices'],
outputs=['y'],
batch_dims=1,
)
return ([node], [x, i], [y])
@onnx_test
def gemm_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5, 7])
......
......@@ -1582,6 +1582,31 @@ TEST_CASE(gather_elements_axis1_test)
EXPECT(p == prog);
}
TEST_CASE(gathernd_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 2}});
mm->add_instruction(migraphx::make_op("gathernd"), l0, l1);
auto prog = optimize_onnx("gathernd_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(gathernd_batch_dims_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto l1 = mm->add_parameter("indices", migraphx::shape{migraphx::shape::int64_type, {2, 1}});
int batch_dims = 1;
mm->add_instruction(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), l0, l1);
auto prog = optimize_onnx("gathernd_batch_dims_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(gemm_test)
{
migraphx::program p;
......
......@@ -268,9 +268,6 @@ def create_backend_test(testname=None, target_device=None):
backend_test.exclude(r'test_expand_shape_model2_cpu')
backend_test.exclude(r'test_expand_shape_model3_cpu')
backend_test.exclude(r'test_expand_shape_model4_cpu')
backend_test.exclude(r'test_gathernd_example_float32_cpu')
backend_test.exclude(r'test_gathernd_example_int32_batch_dim1_cpu')
backend_test.exclude(r'test_gathernd_example_int32_cpu')
backend_test.exclude(r'test_identity_sequence_cpu')
backend_test.exclude(r'test_maxpool_2d_uint8_cpu')
backend_test.exclude(r'test_negative_log_likelihood_loss_*')
......
......@@ -1653,6 +1653,203 @@ TEST_CASE(gather_test)
}
}
TEST_CASE(gathernd_test)
{
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 2}};
migraphx::shape is{migraphx::shape::int64_type, {2, 2}};
std::vector<float> data_vec(2 * 2);
std::iota(data_vec.begin(), data_vec.end(), 0);
std::vector<int64_t> indices_vec{0, 0, 1, 1};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, indices_vec});
mm->add_instruction(migraphx::make_op("gathernd"), data, indices);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> res_data{};
std::vector<float> gold{0, 3};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, gold));
}
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 2}};
migraphx::shape is{migraphx::shape::int64_type, {2, 1}};
std::vector<float> data_vec(2 * 2);
std::iota(data_vec.begin(), data_vec.end(), 0);
std::vector<int64_t> indices_vec{1, 0};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, indices_vec});
mm->add_instruction(migraphx::make_op("gathernd"), data, indices);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> res_data{};
std::vector<float> gold{2, 3, 0, 1};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, gold));
}
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 3, 1}};
migraphx::shape is{migraphx::shape::int64_type, {2, 2, 1}};
std::vector<float> data_vec(2 * 3 * 1);
std::iota(data_vec.begin(), data_vec.end(), 0);
std::vector<int64_t> indices_vec{1, 0, 0, 1};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, indices_vec});
mm->add_instruction(migraphx::make_op("gathernd"), data, indices);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> res_data{};
std::vector<float> gold{3, 4, 5, 0, 1, 2, 0, 1, 2, 3, 4, 5};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, gold));
}
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 3, 2, 3}};
migraphx::shape is{migraphx::shape::int64_type, {2, 2, 2}};
std::vector<float> data_vec(2 * 3 * 2 * 3);
std::iota(data_vec.begin(), data_vec.end(), 0);
std::vector<int64_t> indices_vec{0, 0, 0, 1, 0, 0, 0, 1};
const int batch_dims = 1;
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, indices_vec});
mm->add_instruction(
migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), data, indices);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> res_data{};
std::vector<float> gold{0, 1, 2, 3, 4, 5, 18, 19, 20, 21, 22, 23};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, gold));
}
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 3, 1, 3}};
migraphx::shape is{migraphx::shape::int64_type, {2, 3, 2}};
std::vector<float> data_vec(2 * 3 * 1 * 3);
std::iota(data_vec.begin(), data_vec.end(), 0);
std::vector<int64_t> indices_vec{0, 0, 0, 1, 0, 2, 0, 2, 0, 1, 0, 0};
const int batch_dims = 2;
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, indices_vec});
mm->add_instruction(
migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), data, indices);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> res_data{};
std::vector<float> gold{0, 4, 8, 11, 13, 15};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, gold));
}
{
// k > r - batch_dims
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 3, 1, 3}};
migraphx::shape is{migraphx::shape::int64_type, {2, 3, 3}};
std::vector<float> data_vec(2 * 3 * 1 * 3);
std::iota(data_vec.begin(), data_vec.end(), 0);
std::vector<int64_t> indices_vec(2 * 3 * 3, 0);
const int batch_dims = 2;
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, indices_vec});
EXPECT(test::throws([&] {
mm->add_instruction(
migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), data, indices);
}));
}
}
TEST_CASE(gathernd_negative_index_test)
{
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 2}};
migraphx::shape is{migraphx::shape::int64_type, {2, 1, 1}};
std::vector<float> data_vec(2 * 2);
std::iota(data_vec.begin(), data_vec.end(), 0);
std::vector<int64_t> indices_vec{-1, 0};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, indices_vec});
mm->add_instruction(migraphx::make_op("gathernd"), data, indices);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> res_data{};
std::vector<float> gold{2, 3, 0, 1};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, gold));
}
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 2}};
migraphx::shape is{migraphx::shape::int64_type, {2, 1, 1}};
std::vector<float> data_vec(2 * 2);
std::iota(data_vec.begin(), data_vec.end(), 0);
std::vector<int64_t> indices_vec{-3, 0};
auto data = mm->add_literal(migraphx::literal{ds, data_vec});
auto indices = mm->add_literal(migraphx::literal{is, indices_vec});
mm->add_instruction(migraphx::make_op("gathernd"), data, indices);
p.compile(migraphx::ref::target{});
EXPECT(test::throws([&] { p.eval({}); }));
}
}
TEST_CASE(globalavgpool_test)
{
migraphx::program p;
......
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_gathernd_batch_dims_1 : verify_program<test_gathernd_batch_dims_1>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 3, 2, 3}};
migraphx::shape is{migraphx::shape::int64_type, {2, 3, 2}};
std::vector<int64_t> indices{1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0};
auto a0 = mm->add_parameter("data", ds);
auto a1 = mm->add_literal(migraphx::literal{is, indices});
int batch_dims = 1;
mm->add_instruction(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), a0, a1);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_gathernd_batch_dims_2 : verify_program<test_gathernd_batch_dims_2>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 3, 1, 3}};
migraphx::shape is{migraphx::shape::int64_type, {2, 3, 2}};
std::vector<int64_t> indices{0, 0, 0, 1, 0, 2, 0, 2, 0, 1, 0, 0};
auto a0 = mm->add_parameter("data", ds);
auto a1 = mm->add_literal(migraphx::literal{is, indices});
int batch_dims = 2;
mm->add_instruction(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), a0, a1);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_gathernd_default : verify_program<test_gathernd_default>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 2}};
migraphx::shape is{migraphx::shape::int64_type, {2, 2}};
std::vector<int64_t> indices{0, 0, 1, 1};
auto a0 = mm->add_parameter("data", ds);
auto a1 = mm->add_literal(migraphx::literal{is, indices});
mm->add_instruction(migraphx::make_op("gathernd"), a0, a1);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_gathernd_negative_indices : verify_program<test_gathernd_negative_indices>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 2}};
migraphx::shape is{migraphx::shape::int64_type, {2, 1, 1}};
std::vector<int64_t> indices{-1, 0};
auto a0 = mm->add_parameter("data", ds);
auto a1 = mm->add_literal(migraphx::literal{is, indices});
int batch_dims = 1;
mm->add_instruction(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), a0, a1);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment