Commit 93c89587 authored by Paul's avatar Paul
Browse files

Split onnx tests

parent d2532d0e
#include <onnx_test.hpp>
TEST_CASE(split_minus_axis_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
auto r1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {-1}}, {"starts", {0}}, {"ends", {5}}}), input);
auto r2 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {-1}}, {"starts", {5}}, {"ends", {10}}}), input);
auto r3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {-1}}, {"starts", {10}}, {"ends", {15}}}), input);
mm->add_return({r1, r2, r3});
auto prog = migraphx::parse_onnx("split_minus_axis_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(split_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
auto r1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {7}}}), input);
auto r2 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {7}}, {"ends", {11}}}), input);
auto r3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {11}}, {"ends", {15}}}), input);
mm->add_return({r1, r2, r3});
auto prog = migraphx::parse_onnx("split_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(split_test_default)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
auto r1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {5}}}), input);
auto r2 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {5}}, {"ends", {10}}}), input);
mm->add_return({r1, r2});
auto prog = migraphx::parse_onnx("split_test_default.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(split_test_invalid_num_outputs)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("split_test_invalid_num_outputs.onnx"); }));
}
#include <onnx_test.hpp>
TEST_CASE(split_test_invalid_split)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("split_test_invalid_split.onnx"); }));
}
#include <onnx_test.hpp>
TEST_CASE(split_test_no_attribute)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape si{migraphx::shape::int64_type, {4}, {1}};
std::vector<int> ind = {75, 75, 75, 75};
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {300, 15}});
mm->add_literal(migraphx::literal(si, ind));
auto r1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {75}}}), input);
auto r2 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {75}}, {"ends", {150}}}), input);
auto r3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {150}}, {"ends", {225}}}), input);
auto r4 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {225}}, {"ends", {300}}}), input);
mm->add_return({r1, r2, r3, r4});
auto prog = migraphx::parse_onnx("split_test_no_attribute.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(split_test_no_attribute_invalid_input_split)
{
EXPECT(test::throws(
[&] { migraphx::parse_onnx("split_test_no_attribute_invalid_input_split.onnx"); }));
}
#include <onnx_test.hpp>
TEST_CASE(split_test_no_attribute_invalid_split)
{
EXPECT(
test::throws([&] { migraphx::parse_onnx("split_test_no_attribute_invalid_split.onnx"); }));
}
#include <onnx_test.hpp>
TEST_CASE(split_test_uneven)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {12, 15}});
auto r1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3}}}), input);
auto r2 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {3}}, {"ends", {6}}}), input);
auto r3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {6}}, {"ends", {9}}}), input);
auto r4 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {9}}, {"ends", {12}}}), input);
auto r5 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {12}}, {"ends", {12}}}), input);
mm->add_return({r1, r2, r3, r4, r5});
auto prog = migraphx::parse_onnx("split_test_uneven.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(split_test_uneven_num_outputs)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {11, 15}});
auto r1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3}}}), input);
auto r2 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {3}}, {"ends", {6}}}), input);
auto r3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {6}}, {"ends", {9}}}), input);
auto r4 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {9}}, {"ends", {11}}}), input);
mm->add_return({r1, r2, r3, r4});
auto prog = migraphx::parse_onnx("split_test_uneven_num_outputs.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(sqrt_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
mm->add_instruction(migraphx::make_op("sqrt"), input);
auto prog = optimize_onnx("sqrt_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(squeeze_axes_input_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal({migraphx::shape::int64_type, {2}}, {1, 3}));
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 1, 5, 1}});
auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 3}}}), l0);
mm->add_return({l1});
auto prog = migraphx::parse_onnx("squeeze_axes_input_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(squeeze_empty_axes_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal{migraphx::shape::int64_type});
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 1, 5, 1}});
auto l1 = mm->add_instruction(migraphx::make_op("squeeze"), l0);
mm->add_return({l1});
auto prog = migraphx::parse_onnx("squeeze_empty_axes_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(squeeze_unsqueeze_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int64_t> squeeze_axes{0, 2, 3, 5};
std::vector<int64_t> unsqueeze_axes{0, 1, 3, 5};
auto l0 = mm->add_parameter("0",
migraphx::shape{migraphx::shape::float_type,
{{1, 1}, {1, 4}, {1, 1}, {1, 1}, {1, 4}, {1, 1}}});
auto c0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", squeeze_axes}}), c0);
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
auto ret = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), c1);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4};
auto prog = parse_onnx("squeeze_unsqueeze_dyn_test.onnx", options);
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(squeeze_unsqueeze_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int64_t> squeeze_axes{0, 2, 3, 5};
std::vector<int64_t> unsqueeze_axes{0, 1, 3, 5};
auto l0 =
mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 1, 1, 2, 1}});
auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", squeeze_axes}}), l0);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), l1);
auto prog = optimize_onnx("squeeze_unsqueeze_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(sub_bcast_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", l0->get_shape().lens()}}), l1);
mm->add_instruction(migraphx::make_op("sub"), l0, l2);
auto prog = optimize_onnx("sub_bcast_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(sub_scalar_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1}});
auto m1 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 5}}}), l1);
mm->add_instruction(migraphx::make_op("sub"), l0, m1);
auto prog = optimize_onnx("sub_scalar_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(sum_int_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::int16_type, {3}});
auto input1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::uint16_type, {3}});
auto input2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::uint32_type, {3}});
auto cin0 = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::uint32_type)}}),
input0);
auto cin1 = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::uint32_type)}}),
input1);
auto l0 = mm->add_instruction(migraphx::make_op("add"), cin0, cin1);
mm->add_instruction(migraphx::make_op("add"), l0, input2);
auto prog = optimize_onnx("sum_int_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(sum_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
auto input1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3}});
auto input2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {3}});
auto l0 = mm->add_instruction(migraphx::make_op("add"), input0, input1);
mm->add_instruction(migraphx::make_op("add"), l0, input2);
auto prog = optimize_onnx("sum_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(sum_type_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l_bool = mm->add_literal({migraphx::shape{migraphx::shape::bool_type, {2}}, {1, 0}});
auto l_int8 = mm->add_literal({migraphx::shape{migraphx::shape::int8_type, {2}}, {1, 1}});
auto l_uint8 = mm->add_literal({migraphx::shape{migraphx::shape::uint8_type, {2}}, {1, 1}});
auto l_uint16 = mm->add_literal({migraphx::shape{migraphx::shape::uint16_type, {2}}, {1, 1}});
auto l_uint32 = mm->add_literal({migraphx::shape{migraphx::shape::uint32_type, {2}}, {1, 1}});
auto l_uint64 = mm->add_literal({migraphx::shape{migraphx::shape::uint64_type, {2}}, {1, 1}});
auto l_double = mm->add_literal({migraphx::shape{migraphx::shape::double_type, {2}}, {1, 1}});
auto l_raw = mm->add_literal({migraphx::shape{migraphx::shape::double_type, {2}}, {1.5, 2.0}});
auto o_bool = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::double_type)}}),
l_bool);
auto o_int8 = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::double_type)}}),
l_int8);
auto o_uint8 = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::double_type)}}),
l_uint8);
auto o_uint16 = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::double_type)}}),
l_uint16);
auto o_uint32 = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::double_type)}}),
l_uint32);
auto o_uint64 = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::double_type)}}),
l_uint64);
auto s0 = mm->add_instruction(migraphx::make_op("add"), o_bool, o_int8);
auto s1 = mm->add_instruction(migraphx::make_op("add"), s0, o_uint8);
auto s2 = mm->add_instruction(migraphx::make_op("add"), s1, o_uint16);
auto s3 = mm->add_instruction(migraphx::make_op("add"), s2, o_uint32);
auto s4 = mm->add_instruction(migraphx::make_op("add"), s3, o_uint64);
auto s5 = mm->add_instruction(migraphx::make_op("add"), s4, l_double);
auto s6 = mm->add_instruction(migraphx::make_op("add"), s5, l_raw);
mm->add_return({s6});
auto prog = migraphx::parse_onnx("sum_type_test.onnx");
EXPECT(p == prog);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment