Commit 93c89587 authored by Paul's avatar Paul
Browse files

Split onnx tests

parent d2532d0e
#include <onnx_test.hpp>
#include <onnx_test_utils.hpp>
TEST_CASE(layer_norm_3d_test)
{
migraphx::program p = make_layer_norm({1, 4, 2}, {2}, {2}, 2);
auto prog = optimize_onnx("layer_norm_3d_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
#include <onnx_test_utils.hpp>
TEST_CASE(layer_norm_4d_half_test)
{
migraphx::program p =
make_layer_norm({3, 3, 3, 3}, {3}, {3}, 3, false, 1e-5f, migraphx::shape::half_type);
auto prog = optimize_onnx("layer_norm_4d_half_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
#include <onnx_test_utils.hpp>
TEST_CASE(layer_norm_4d_test)
{
migraphx::program p = make_layer_norm({3, 3, 3, 3}, {3}, {3}, 3);
auto prog = optimize_onnx("layer_norm_4d_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(layer_norm_invalid_axis_error_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("layer_norm_invalid_axis_error_test.onnx"); }));
}
#include <onnx_test.hpp>
TEST_CASE(layer_norm_invalid_input_count_error_test)
{
EXPECT(test::throws(
[&] { migraphx::parse_onnx("layer_norm_invalid_input_count_error_test.onnx"); }));
}
#include <onnx_test.hpp>
TEST_CASE(layer_norm_invalid_minus_axis_error_test)
{
EXPECT(test::throws(
[&] { migraphx::parse_onnx("layer_norm_invalid_minus_axis_error_test.onnx"); }));
}
#include <onnx_test.hpp>
#include <onnx_test_utils.hpp>
TEST_CASE(layer_norm_small_eps_half_test)
{
migraphx::program p =
make_layer_norm({1, 2}, {2}, {1}, 1, true, 1e-7, migraphx::shape::half_type);
auto prog = optimize_onnx("layer_norm_small_eps_half_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
#include <onnx_test_utils.hpp>
TEST_CASE(layer_norm_without_bias_test)
{
migraphx::program p = make_layer_norm({1, 2}, {2}, {1}, 1, true);
auto prog = optimize_onnx("layer_norm_without_bias_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(leaky_relu_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
float alpha = 0.01f;
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {3}});
mm->add_instruction(migraphx::make_op("leaky_relu", {{"alpha", alpha}}), l0);
auto prog = optimize_onnx("leaky_relu_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(less_bool_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sf{migraphx::shape::float_type, {2, 3}};
migraphx::shape sb{migraphx::shape::bool_type, {2, 3}};
auto input1 = mm->add_parameter("x1", sf);
auto input2 = mm->add_parameter("x2", sb);
auto cin1 = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
input1);
auto ret = mm->add_instruction(migraphx::make_op("less"), cin1, input2);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("less_bool_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(less_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
auto input1 = mm->add_literal(migraphx::literal(s, data));
auto input2 = mm->add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {2, 3}});
auto le = mm->add_instruction(migraphx::make_op("less"), input1, input2);
auto ret = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
le);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("less_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(lessorequal_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input1 = mm->add_parameter("x1", migraphx::shape{migraphx::shape::float_type, {3}});
auto input2 = mm->add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {3}});
auto temp = mm->add_instruction(migraphx::make_op("greater"), input1, input2);
auto bt = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), temp);
auto le = mm->add_instruction(migraphx::make_op("not"), bt);
mm->add_return({le});
auto prog = migraphx::parse_onnx("lessorequal_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(log_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
mm->add_instruction(migraphx::make_op("log"), input);
auto prog = optimize_onnx("log_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(logical_and_bcast_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::bool_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::bool_type, {4, 5}});
auto l2 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", l0->get_shape().lens()}}), l1);
auto ret = mm->add_instruction(migraphx::make_op("logical_and"), l0, l2);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("logical_and_bcast_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(logical_or_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::bool_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::bool_type, {2, 3, 4, 5}});
auto ret = mm->add_instruction(migraphx::make_op("logical_or"), l0, l1);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("logical_or_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(logical_xor_bcast_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::bool_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::bool_type, {4, 1}});
auto l2 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", l0->get_shape().lens()}}), l1);
auto ret = mm->add_instruction(migraphx::make_op("logical_xor"), l0, l2);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("logical_xor_bcast_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(logsoftmax_nonstd_input_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {6, 9}});
auto l1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {1, 0}}, {"ends", {4, 4}}}), l0);
auto l2 = mm->add_instruction(migraphx::make_op("logsoftmax", {{"axis", -1}}), l1);
mm->add_return({l2});
auto prog = migraphx::parse_onnx("logsoftmax_nonstd_input_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(logsoftmax_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
int axis = 1;
mm->add_instruction(migraphx::make_op("logsoftmax", {{"axis", axis}}), l0);
auto prog = optimize_onnx("logsoftmax_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(loop_default_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape su{migraphx::shape::float_type};
auto a = mm->add_parameter("a", su);
auto b = mm->add_parameter("b", su);
migraphx::shape si{migraphx::shape::int64_type};
auto max_iter = mm->add_literal(migraphx::literal(si, {10}));
migraphx::shape sc{migraphx::shape::bool_type};
auto icond = mm->add_literal(migraphx::literal(sc, {1}));
mm->add_instruction(migraphx::make_op("undefined"));
auto* body = p.create_module("Loop_3_loop");
body->add_parameter("iteration_num", {migraphx::shape::int64_type});
body->add_parameter("keep_going_inp", {migraphx::shape::bool_type});
auto var = body->add_parameter("b_in", su);
auto ad = body->add_instruction(migraphx::make_op("add"), a, var);
auto sb = body->add_instruction(migraphx::make_op("sub"), a, var);
auto gt = body->add_instruction(migraphx::make_op("greater"), ad, sb);
auto cv = body->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), gt);
auto ad1 = body->add_instruction(migraphx::make_op("add"), sb, sb);
body->add_return({cv, sb, ad, ad1});
auto lp = mm->add_instruction(
migraphx::make_op("loop", {{"max_iterations", 10}}), {max_iter, icond, b}, {body});
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), lp);
mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), lp);
auto r2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), lp);
mm->add_return({r0, r2});
auto prog = migraphx::parse_onnx("loop_default_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(loop_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape si{migraphx::shape::int64_type, {1}};
auto max_iter = mm->add_parameter("max_trip_count", si);
migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto icond = mm->add_parameter("keep_going_cond", sc);
migraphx::shape su{migraphx::shape::float_type, {1}};
auto a = mm->add_parameter("a", su);
auto b = mm->add_parameter("b", su);
auto* body = p.create_module("Loop_4_loop");
body->add_parameter("iteration_num", si);
body->add_parameter("keep_going_inp", sc);
auto var = body->add_parameter("b_in", su);
auto ad = body->add_instruction(migraphx::make_op("add"), a, var);
auto sb = body->add_instruction(migraphx::make_op("sub"), a, var);
auto gt = body->add_instruction(migraphx::make_op("greater"), ad, sb);
auto cv = body->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), gt);
auto ad1 = body->add_instruction(migraphx::make_op("add"), sb, sb);
body->add_return({cv, sb, ad, ad1});
auto lp = mm->add_instruction(
migraphx::make_op("loop", {{"max_iterations", 10}}), {max_iter, icond, b}, {body});
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), lp);
mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), lp);
auto r2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), lp);
mm->add_return({r0, r2});
auto prog = migraphx::parse_onnx("loop_test.onnx");
EXPECT(p == prog);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment