Commit 4957715b authored by turneram's avatar turneram
Browse files

Merge remote-tracking branch 'origin/develop' into dev2

parents f99a3036 4ec8209f
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_gathernd_batch_dims_2 : verify_program<test_gathernd_batch_dims_2>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 3, 1, 3}};
migraphx::shape is{migraphx::shape::int64_type, {2, 3, 2}};
std::vector<int64_t> indices{0, 0, 0, 1, 0, 2, 0, 2, 0, 1, 0, 0};
auto a0 = mm->add_parameter("data", ds);
auto a1 = mm->add_literal(migraphx::literal{is, indices});
int batch_dims = 2;
mm->add_instruction(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), a0, a1);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_gathernd_default : verify_program<test_gathernd_default>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 2}};
migraphx::shape is{migraphx::shape::int64_type, {2, 2}};
std::vector<int64_t> indices{0, 0, 1, 1};
auto a0 = mm->add_parameter("data", ds);
auto a1 = mm->add_literal(migraphx::literal{is, indices});
mm->add_instruction(migraphx::make_op("gathernd"), a0, a1);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_gathernd_negative_indices : verify_program<test_gathernd_negative_indices>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 2}};
migraphx::shape is{migraphx::shape::int64_type, {2, 1, 1}};
std::vector<int64_t> indices{-1, 0};
auto a0 = mm->add_parameter("data", ds);
auto a1 = mm->add_literal(migraphx::literal{is, indices});
int batch_dims = 1;
mm->add_instruction(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), a0, a1);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment