Commit 93c89587 authored by Paul's avatar Paul
Browse files

Split onnx tests

parent d2532d0e
#include <onnx_test.hpp>
TEST_CASE(tan_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
mm->add_instruction(migraphx::make_op("tan"), input);
auto prog = optimize_onnx("tan_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(tanh_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}});
mm->add_instruction(migraphx::make_op("tanh"), input);
auto prog = optimize_onnx("tanh_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(thresholdedrelu_default_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2, 3}});
auto lz = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {0}});
auto la = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {1.0f}});
auto mbz = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), lz);
auto mba = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), la);
auto condition = mm->add_instruction(migraphx::make_op("greater"), x, mba);
mm->add_instruction(migraphx::make_op("where"), condition, x, mbz);
auto prog = optimize_onnx("thresholdedrelu_default_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(thresholdedrelu_int_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2, 3}});
auto lz = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {0}});
auto la = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {3}});
auto mbz = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), lz);
auto mba = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), la);
auto condition = mm->add_instruction(migraphx::make_op("greater"), x, mba);
mm->add_instruction(migraphx::make_op("where"), condition, x, mbz);
auto prog = optimize_onnx("thresholdedrelu_int_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(thresholdedrelu_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2, 3}});
auto lz = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {0}});
auto la = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {3.0f}});
auto mbz = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), lz);
auto mba = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), la);
auto condition = mm->add_instruction(migraphx::make_op("greater"), x, mba);
mm->add_instruction(migraphx::make_op("where"), condition, x, mbz);
auto prog = optimize_onnx("thresholdedrelu_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(tile_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, {1, 2}});
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), input, input);
auto prog = optimize_onnx("tile_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(tile_test_3x2)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, {3, 2}});
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}});
auto l0 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), input, input);
auto l1 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), l0, input);
mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l1, l1);
auto prog = optimize_onnx("tile_test_3x2.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(topk_attrk_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 5, 3, 2}};
auto data = mm->add_parameter("data", s);
auto out = mm->add_instruction(migraphx::make_op("topk", {{"k", 2}, {"axis", -1}}), data);
auto val = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
mm->add_return({val, ind});
auto prog = migraphx::parse_onnx("topk_attrk_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(topk_neg_axis_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sk{migraphx::shape::int64_type, {1}};
mm->add_literal(migraphx::literal(sk, {3}));
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}};
auto data = mm->add_parameter("data", s);
auto out = mm->add_instruction(
migraphx::make_op("topk", {{"k", 3}, {"axis", -2}, {"largest", 1}}), data);
auto val = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
mm->add_return({val, ind});
auto prog = migraphx::parse_onnx("topk_neg_axis_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(topk_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sk{migraphx::shape::int64_type, {1}};
mm->add_literal(migraphx::literal(sk, {4}));
migraphx::shape s{migraphx::shape::float_type, {2, 5, 3, 2}};
auto data = mm->add_parameter("data", s);
auto out = mm->add_instruction(
migraphx::make_op("topk", {{"k", 4}, {"axis", 1}, {"largest", 0}}), data);
auto val = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
mm->add_return({val, ind});
auto prog = migraphx::parse_onnx("topk_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(transpose_default_perm_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 5, 2, 3}});
std::vector<int64_t> perm{3, 2, 1, 0};
auto r = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), input);
mm->add_return({r});
auto prog = migraphx::parse_onnx("transpose_default_perm_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(transpose_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 2}, {2, 2}, {3, 3}}});
std::vector<int64_t> perm{0, 3, 1, 2};
auto t0 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), input);
mm->add_return({t0});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4};
auto prog = migraphx::parse_onnx("transpose_dyn_test.onnx", options);
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(transpose_gather_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto make_contiguous = [&mm](migraphx::instruction_ref ins) {
if(ins->get_shape().standard())
{
return ins;
}
return mm->add_instruction(migraphx::make_op("contiguous"), ins);
};
auto data =
mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 5, 4, 6}});
auto ind =
mm->add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 4, 3, 5}});
auto tr_data =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), data);
auto tr_ind =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), ind);
int axis = 1;
mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}),
make_contiguous(tr_data),
make_contiguous(tr_ind));
auto prog = optimize_onnx("transpose_gather_test.onnx");
EXPECT(p.sort() == prog.sort());
}
#include <onnx_test.hpp>
TEST_CASE(transpose_invalid_perm_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("transpose_invalid_perm_test.onnx"); }));
}
#include <onnx_test.hpp>
TEST_CASE(transpose_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
std::vector<int64_t> perm{0, 3, 1, 2};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), input);
auto prog = optimize_onnx("transpose_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(undefined_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_instruction(migraphx::make_op("undefined"));
auto l2 = mm->add_instruction(migraphx::make_op("identity"), l1);
mm->add_return({l2});
auto prog = migraphx::parse_onnx("undefined_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(unique_dynamic_sorted_3D_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int64_type, {4, 4, 4}};
auto x = mm->add_parameter("X", s);
auto out = mm->add_instruction(migraphx::make_op("unique", {{"sorted", 1}}), x);
auto y = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto y_ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
auto x_ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), out);
auto count = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 3}}), out);
mm->add_return({y, y_ind, x_ind, count});
auto prog = migraphx::parse_onnx("unique_dynamic_sorted_3D_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(unique_dynamic_sorted_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {6}};
auto x = mm->add_parameter("X", s);
auto out = mm->add_instruction(migraphx::make_op("unique", {{"sorted", 1}, {"axis", 0}}), x);
auto y = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto y_ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
auto x_ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), out);
auto count = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 3}}), out);
mm->add_return({y, y_ind, x_ind, count});
auto prog = migraphx::parse_onnx("unique_dynamic_sorted_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(unique_sorted_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s_x{migraphx::shape::float_type, {6}};
std::vector<float> x_data = {2, 1, 1, 3, 4, 3};
auto x = mm->add_literal(migraphx::literal(s_x, x_data));
auto out = mm->add_instruction(migraphx::make_op("unique", {{"sorted", 1}, {"axis", 0}}), x);
auto y = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto y_idx = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
auto x_idx = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), out);
auto count = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 3}}), out);
mm->add_return({y, y_idx, x_idx, count});
auto prog = migraphx::parse_onnx("unique_sorted_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(unique_unsorted_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s_x{migraphx::shape::float_type, {6}};
std::vector<float> x_data = {2, 1, 1, 3, 4, 3};
auto x = mm->add_literal(migraphx::literal(s_x, x_data));
auto out = mm->add_instruction(migraphx::make_op("unique", {{"sorted", 0}, {"axis", 0}}), x);
auto y = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto y_idx = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
auto x_idx = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), out);
auto count = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 3}}), out);
mm->add_return({y, y_idx, x_idx, count});
auto prog = migraphx::parse_onnx("unique_unsorted_test.onnx");
EXPECT(p == prog);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment