Commit 93c89587 authored by Paul's avatar Paul
Browse files

Split onnx tests

parent d2532d0e
#include <onnx_test.hpp>
TEST_CASE(onehot_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s_ind{migraphx::shape::int32_type, {5, 2}};
migraphx::shape s_val{migraphx::shape::half_type, {2}};
mm->add_literal(3);
auto l_ind = mm->add_parameter("indices", s_ind);
auto l_val = mm->add_parameter("values", s_val);
migraphx::shape s_dep{migraphx::shape::half_type, {3, 3}};
std::vector<float> data_dep{1, 0, 0, 0, 1, 0, 0, 0, 1};
auto l_dep = mm->add_literal(migraphx::literal(s_dep, data_dep));
auto gather_out = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), l_dep, l_ind);
auto tr_out = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}),
gather_out);
auto off_val = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), l_val);
auto on_val = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), l_val);
auto diff = mm->add_instruction(migraphx::make_op("sub"), on_val, off_val);
auto mb_off_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 2}}}), off_val);
auto mb_diff =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 2}}}), diff);
auto mul = mm->add_instruction(migraphx::make_op("mul"), tr_out, mb_diff);
auto r = mm->add_instruction(migraphx::make_op("add"), mul, mb_off_val);
mm->add_return({r});
auto prog = migraphx::parse_onnx("onehot_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(pad_3arg_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}});
mm->add_literal({migraphx::shape{migraphx::shape::float_type}, {1.0f}});
mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {1, 1, 2, 2}});
auto r = mm->add_instruction(
migraphx::make_op("pad", {{"pads", {1, 1, 2, 2}}, {"value", 1.0f}}), l0);
mm->add_return({r});
auto prog = migraphx::parse_onnx("pad_3arg_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(pad_4arg_axes_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}});
// axes=[1,3]
mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 3}});
// constant_value=1
mm->add_literal({migraphx::shape{migraphx::shape::float_type}, {1.0f}});
// pads=[1,3,2,4]
mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {1, 3, 2, 4}});
auto r = mm->add_instruction(
migraphx::make_op("pad", {{"pads", {0, 1, 0, 3, 0, 2, 0, 4}}, {"value", 1.0f}}), l0);
mm->add_return({r});
auto prog = migraphx::parse_onnx("pad_4arg_axes_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(pad_4arg_invalid_axes_error_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("pad_4arg_invalid_axes_error_test.onnx"); }));
}
#include <onnx_test.hpp>
TEST_CASE(pad_4arg_neg_axes_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}});
// axes=[-3,-1]
mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {2}}, {-3, -1}});
// constant_value=1
mm->add_literal({migraphx::shape{migraphx::shape::float_type}, {1.0f}});
// pads=[1,3,2,4]
mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {1, 3, 2, 4}});
auto r = mm->add_instruction(
migraphx::make_op("pad", {{"pads", {0, 1, 0, 3, 0, 2, 0, 4}}, {"value", 1.0f}}), l0);
mm->add_return({r});
auto prog = migraphx::parse_onnx("pad_4arg_neg_axes_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(pad_asym_invalid_pads_error_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("pad_asym_invalid_pads_error_test.onnx"); }));
}
#include <onnx_test.hpp>
TEST_CASE(pad_asym_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}});
mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 1, 0, 3, 0, 2, 0, 4}}}), l0);
auto prog = optimize_onnx("pad_asym_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(pad_attr_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {{2, 4, {2}}, {2, 4, {2}}}});
auto ret = mm->add_instruction(migraphx::make_op("pad", {{"pads", {1, 1, 1, 1}}}), x);
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["0"] = {{2, 4, {2}}, {2, 4, {2}}};
auto prog = parse_onnx("pad_attr_dyn_test.onnx", options);
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(pad_cnst_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {{2, 4, {2}}, {2, 4, {2}}}});
mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {0, 2, 0, 1}});
auto ret = mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 2, 0, 1}}}), x);
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["0"] = {{2, 4, {2}}, {2, 4, {2}}};
auto prog = parse_onnx("pad_cnst_dyn_test.onnx", options);
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(pad_dyn_reflect_error)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {2, 4, {2}};
EXPECT(test::throws([&] { migraphx::parse_onnx("pad_dyn_reflect_error.onnx", options); }));
}
#include <onnx_test.hpp>
TEST_CASE(pad_reflect_multiaxis_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3}});
mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {0, 2, 2, 0}});
auto l1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 1}}, {"ends", {2, 2}}}), l0);
auto l2 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 2}}, {"ends", {2, 3}}}), l0);
auto l3 = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l1, l0);
auto l4 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 0}}, {"ends", {1, 5}}}), l3);
auto l5 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {1, 0}}, {"ends", {2, 5}}}), l3);
auto r = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), l3, l4, l5);
mm->add_return({r});
auto prog = migraphx::parse_onnx("pad_reflect_multiaxis_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(pad_reflect_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}});
mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {0, 2, 0, 1}});
auto l1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 1}}, {"ends", {2, 2}}}), l0);
auto l2 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 0}}, {"ends", {2, 1}}}), l0);
auto l3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 0}}, {"ends", {2, 1}}}), l0);
auto r = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l1, l0, l3);
mm->add_return({r});
auto prog = migraphx::parse_onnx("pad_reflect_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(pad_reflect_with_axes_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}});
mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {1}}, {1}});
mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {2}}, {2, 1}});
auto l1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 1}}, {"ends", {2, 2}}}), l0);
auto l2 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 0}}, {"ends", {2, 1}}}), l0);
auto l3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 0}}, {"ends", {2, 1}}}), l0);
auto r = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l1, l0, l3);
mm->add_return({r});
auto prog = migraphx::parse_onnx("pad_reflect_with_axes_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(pad_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}});
mm->add_instruction(migraphx::make_op("pad", {{"pads", {1, 1, 1, 1}}}), l0);
auto prog = optimize_onnx("pad_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(pow_fp32_i64_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 5}});
auto l1f = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l1);
auto ret = mm->add_instruction(migraphx::make_op("pow"), l0, l1f);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("pow_fp32_i64_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(pow_i64_fp32_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l0f = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l0);
auto fr = mm->add_instruction(migraphx::make_op("pow"), l0f, l1);
auto ir = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::int64_type}}), fr);
mm->add_return({ir});
auto prog = migraphx::parse_onnx("pow_i64_fp32_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(pow_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
mm->add_instruction(migraphx::make_op("pow"), l0, l1);
auto prog = optimize_onnx("pow_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(prefix_scan_sum)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {1}, {1}}, {0}});
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}});
auto ret = mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", true}, {"reverse", true}}),
l0);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("prefix_scan_sum_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(prelu_brcst_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5}});
auto bl1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", l0->get_shape().lens()}}), l1);
auto ret = mm->add_instruction(migraphx::make_op("prelu"), l0, bl1);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("prelu_brcst_test.onnx");
EXPECT(p == prog);
}
#include <onnx_test.hpp>
TEST_CASE(qlinearadd_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto a = mm->add_parameter("A", {migraphx::shape::uint8_type, {64}});
auto b = mm->add_parameter("B", {migraphx::shape::uint8_type, {64}});
auto sc_a = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.05}});
auto z_pt_a = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {0}});
auto sc_b = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.05}});
auto z_pt_b = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {128}});
auto sc_c = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.05}});
auto z_pt_c = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {64}});
auto scale_a_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_a);
auto z_pt_a_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_a);
auto fp_a =
mm->add_instruction(migraphx::make_op("dequantizelinear"), a, scale_a_bcast, z_pt_a_bcast);
auto scale_b_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_b);
auto z_pt_b_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_b);
auto fp_b =
mm->add_instruction(migraphx::make_op("dequantizelinear"), b, scale_b_bcast, z_pt_b_bcast);
auto fp_c = mm->add_instruction(migraphx::make_op("add"), fp_a, fp_b);
auto scale_c_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_c);
auto z_pt_c_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_c);
auto c =
mm->add_instruction(migraphx::make_op("quantizelinear"), fp_c, scale_c_bcast, z_pt_c_bcast);
mm->add_return({c});
auto prog = migraphx::parse_onnx("qlinearadd_test.onnx");
EXPECT(p.sort() == prog.sort());
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment