Commit aea65ed4 authored by Paul's avatar Paul
Browse files

Merge branch 'gpu-test-select' into scheduler

parents 5440a9b8 e05c1915
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <future> #include <future>
#include <thread> #include <thread>
#include "test.hpp" #include <test.hpp>
#ifdef __clang__ #ifdef __clang__
#pragma clang diagnostic push #pragma clang diagnostic push
...@@ -134,7 +134,7 @@ migraphx::argument run_gpu(migraphx::program& p) ...@@ -134,7 +134,7 @@ migraphx::argument run_gpu(migraphx::program& p)
} }
template <class V> template <class V>
void verify_program() void run_verify_program()
{ {
auto_print::set_terminate_handler(migraphx::get_type_name<V>()); auto_print::set_terminate_handler(migraphx::get_type_name<V>());
// std::cout << migraphx::get_type_name<V>() << std::endl; // std::cout << migraphx::get_type_name<V>() << std::endl;
...@@ -156,7 +156,27 @@ void verify_program() ...@@ -156,7 +156,27 @@ void verify_program()
std::set_terminate(nullptr); std::set_terminate(nullptr);
} }
struct test_literals template <class T>
int auto_register_verify_program()
{
test::add_test_case(migraphx::get_type_name<T>(), [] { run_verify_program<T>(); });
return 0;
}
template <class T>
struct verify_program
{
static int static_register;
// This typedef ensures that the static member will be instantiated if
// the class itself is instantiated
using static_register_type =
std::integral_constant<decltype(&static_register), &static_register>;
};
template <class T>
int verify_program<T>::static_register = auto_register_verify_program<T>(); // NOLINT
struct test_literals : verify_program<test_literals>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -171,7 +191,7 @@ struct test_literals ...@@ -171,7 +191,7 @@ struct test_literals
} }
}; };
struct test_add struct test_add : verify_program<test_add>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -184,7 +204,7 @@ struct test_add ...@@ -184,7 +204,7 @@ struct test_add
} }
}; };
struct test_add_half struct test_add_half : verify_program<test_add_half>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -197,7 +217,7 @@ struct test_add_half ...@@ -197,7 +217,7 @@ struct test_add_half
} }
}; };
struct test_mul struct test_mul : verify_program<test_mul>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -210,7 +230,7 @@ struct test_mul ...@@ -210,7 +230,7 @@ struct test_mul
} }
}; };
struct test_exp struct test_exp : verify_program<test_exp>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -223,7 +243,7 @@ struct test_exp ...@@ -223,7 +243,7 @@ struct test_exp
} }
}; };
struct test_log struct test_log : verify_program<test_log>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -236,7 +256,7 @@ struct test_log ...@@ -236,7 +256,7 @@ struct test_log
} }
}; };
struct test_sin struct test_sin : verify_program<test_sin>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -248,7 +268,7 @@ struct test_sin ...@@ -248,7 +268,7 @@ struct test_sin
} }
}; };
struct test_cos struct test_cos : verify_program<test_cos>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -260,7 +280,7 @@ struct test_cos ...@@ -260,7 +280,7 @@ struct test_cos
} }
}; };
struct test_tan struct test_tan : verify_program<test_tan>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -272,7 +292,7 @@ struct test_tan ...@@ -272,7 +292,7 @@ struct test_tan
} }
}; };
struct test_sinh struct test_sinh : verify_program<test_sinh>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -284,7 +304,7 @@ struct test_sinh ...@@ -284,7 +304,7 @@ struct test_sinh
} }
}; };
struct test_cosh struct test_cosh : verify_program<test_cosh>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -296,7 +316,7 @@ struct test_cosh ...@@ -296,7 +316,7 @@ struct test_cosh
} }
}; };
struct test_tanh struct test_tanh : verify_program<test_tanh>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -307,7 +327,7 @@ struct test_tanh ...@@ -307,7 +327,7 @@ struct test_tanh
} }
}; };
struct test_asin struct test_asin : verify_program<test_asin>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -319,7 +339,7 @@ struct test_asin ...@@ -319,7 +339,7 @@ struct test_asin
} }
}; };
struct test_acos struct test_acos : verify_program<test_acos>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -331,7 +351,7 @@ struct test_acos ...@@ -331,7 +351,7 @@ struct test_acos
} }
}; };
struct test_atan struct test_atan : verify_program<test_atan>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -343,7 +363,7 @@ struct test_atan ...@@ -343,7 +363,7 @@ struct test_atan
} }
}; };
struct test_scale struct test_scale : verify_program<test_scale>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -357,7 +377,7 @@ struct test_scale ...@@ -357,7 +377,7 @@ struct test_scale
} }
}; };
struct test_slice struct test_slice : verify_program<test_slice>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -372,7 +392,7 @@ struct test_slice ...@@ -372,7 +392,7 @@ struct test_slice
} }
}; };
struct test_triadd struct test_triadd : verify_program<test_triadd>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -387,7 +407,7 @@ struct test_triadd ...@@ -387,7 +407,7 @@ struct test_triadd
} }
}; };
struct test_triadd2 struct test_triadd2 : verify_program<test_triadd2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -404,7 +424,7 @@ struct test_triadd2 ...@@ -404,7 +424,7 @@ struct test_triadd2
} }
}; };
struct test_add_broadcast struct test_add_broadcast : verify_program<test_add_broadcast>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -418,7 +438,7 @@ struct test_add_broadcast ...@@ -418,7 +438,7 @@ struct test_add_broadcast
} }
}; };
struct test_add_broadcast2 struct test_add_broadcast2 : verify_program<test_add_broadcast2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -432,7 +452,7 @@ struct test_add_broadcast2 ...@@ -432,7 +452,7 @@ struct test_add_broadcast2
} }
}; };
struct test_add_broadcast3 struct test_add_broadcast3 : verify_program<test_add_broadcast3>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -446,7 +466,7 @@ struct test_add_broadcast3 ...@@ -446,7 +466,7 @@ struct test_add_broadcast3
} }
}; };
struct test_add_broadcast4 struct test_add_broadcast4 : verify_program<test_add_broadcast4>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -460,7 +480,7 @@ struct test_add_broadcast4 ...@@ -460,7 +480,7 @@ struct test_add_broadcast4
} }
}; };
struct test_add_broadcast5 struct test_add_broadcast5 : verify_program<test_add_broadcast5>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -474,7 +494,7 @@ struct test_add_broadcast5 ...@@ -474,7 +494,7 @@ struct test_add_broadcast5
} }
}; };
struct test_triadd_broadcast struct test_triadd_broadcast : verify_program<test_triadd_broadcast>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -490,7 +510,7 @@ struct test_triadd_broadcast ...@@ -490,7 +510,7 @@ struct test_triadd_broadcast
} }
}; };
struct test_sub struct test_sub : verify_program<test_sub>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -505,7 +525,7 @@ struct test_sub ...@@ -505,7 +525,7 @@ struct test_sub
} }
}; };
struct test_sub2 struct test_sub2 : verify_program<test_sub2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -522,7 +542,7 @@ struct test_sub2 ...@@ -522,7 +542,7 @@ struct test_sub2
} }
}; };
struct test_softmax struct test_softmax : verify_program<test_softmax>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -533,7 +553,7 @@ struct test_softmax ...@@ -533,7 +553,7 @@ struct test_softmax
} }
}; };
struct test_softmax2 struct test_softmax2 : verify_program<test_softmax2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -545,7 +565,7 @@ struct test_softmax2 ...@@ -545,7 +565,7 @@ struct test_softmax2
} }
}; };
struct test_conv struct test_conv : verify_program<test_conv>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -559,7 +579,7 @@ struct test_conv ...@@ -559,7 +579,7 @@ struct test_conv
} }
}; };
struct test_conv2 struct test_conv2 : verify_program<test_conv2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -573,7 +593,7 @@ struct test_conv2 ...@@ -573,7 +593,7 @@ struct test_conv2
} }
}; };
struct test_group_conv struct test_group_conv : verify_program<test_group_conv>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -589,7 +609,7 @@ struct test_group_conv ...@@ -589,7 +609,7 @@ struct test_group_conv
} }
}; };
struct test_conv_relu struct test_conv_relu : verify_program<test_conv_relu>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -604,7 +624,7 @@ struct test_conv_relu ...@@ -604,7 +624,7 @@ struct test_conv_relu
} }
}; };
struct test_conv_relu_half struct test_conv_relu_half : verify_program<test_conv_relu_half>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -619,7 +639,7 @@ struct test_conv_relu_half ...@@ -619,7 +639,7 @@ struct test_conv_relu_half
} }
}; };
struct test_add_relu struct test_add_relu : verify_program<test_add_relu>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -632,7 +652,7 @@ struct test_add_relu ...@@ -632,7 +652,7 @@ struct test_add_relu
} }
}; };
struct test_sigmoid struct test_sigmoid : verify_program<test_sigmoid>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -643,7 +663,7 @@ struct test_sigmoid ...@@ -643,7 +663,7 @@ struct test_sigmoid
} }
}; };
struct test_abs struct test_abs : verify_program<test_abs>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -654,7 +674,7 @@ struct test_abs ...@@ -654,7 +674,7 @@ struct test_abs
} }
}; };
struct test_leaky_relu struct test_leaky_relu : verify_program<test_leaky_relu>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -665,7 +685,7 @@ struct test_leaky_relu ...@@ -665,7 +685,7 @@ struct test_leaky_relu
} }
}; };
struct test_elu struct test_elu : verify_program<test_elu>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -676,7 +696,7 @@ struct test_elu ...@@ -676,7 +696,7 @@ struct test_elu
} }
}; };
struct test_relu_lrn struct test_relu_lrn : verify_program<test_relu_lrn>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -688,7 +708,7 @@ struct test_relu_lrn ...@@ -688,7 +708,7 @@ struct test_relu_lrn
} }
}; };
struct test_conv_pooling struct test_conv_pooling : verify_program<test_conv_pooling>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -704,7 +724,7 @@ struct test_conv_pooling ...@@ -704,7 +724,7 @@ struct test_conv_pooling
} }
}; };
struct test_global_avg_pooling struct test_global_avg_pooling : verify_program<test_global_avg_pooling>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -719,7 +739,7 @@ struct test_global_avg_pooling ...@@ -719,7 +739,7 @@ struct test_global_avg_pooling
} }
}; };
struct test_global_max_pooling struct test_global_max_pooling : verify_program<test_global_max_pooling>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -734,7 +754,7 @@ struct test_global_max_pooling ...@@ -734,7 +754,7 @@ struct test_global_max_pooling
} }
}; };
struct test_gemm struct test_gemm : verify_program<test_gemm>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -746,7 +766,7 @@ struct test_gemm ...@@ -746,7 +766,7 @@ struct test_gemm
} }
}; };
struct test_gemm_half struct test_gemm_half : verify_program<test_gemm_half>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -758,7 +778,7 @@ struct test_gemm_half ...@@ -758,7 +778,7 @@ struct test_gemm_half
} }
}; };
struct test_gemm_ld struct test_gemm_ld //: verify_program<test_gemm_ld>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -772,7 +792,7 @@ struct test_gemm_ld ...@@ -772,7 +792,7 @@ struct test_gemm_ld
} }
}; };
struct test_gemm_transposeb struct test_gemm_transposeb : verify_program<test_gemm_transposeb>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -785,7 +805,7 @@ struct test_gemm_transposeb ...@@ -785,7 +805,7 @@ struct test_gemm_transposeb
} }
}; };
struct test_gemm_transposea struct test_gemm_transposea : verify_program<test_gemm_transposea>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -798,7 +818,7 @@ struct test_gemm_transposea ...@@ -798,7 +818,7 @@ struct test_gemm_transposea
} }
}; };
struct test_gemm_transposeab struct test_gemm_transposeab : verify_program<test_gemm_transposeab>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -812,7 +832,7 @@ struct test_gemm_transposeab ...@@ -812,7 +832,7 @@ struct test_gemm_transposeab
} }
}; };
struct test_contiguous struct test_contiguous : verify_program<test_contiguous>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -825,7 +845,7 @@ struct test_contiguous ...@@ -825,7 +845,7 @@ struct test_contiguous
} }
}; };
struct test_eliminate_contiguous struct test_eliminate_contiguous : verify_program<test_eliminate_contiguous>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -841,7 +861,7 @@ struct test_eliminate_contiguous ...@@ -841,7 +861,7 @@ struct test_eliminate_contiguous
} }
}; };
struct test_transpose struct test_transpose : verify_program<test_transpose>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -855,7 +875,7 @@ struct test_transpose ...@@ -855,7 +875,7 @@ struct test_transpose
} }
}; };
struct test_batchnorm_inference_2 struct test_batchnorm_inference_2 : verify_program<test_batchnorm_inference_2>
{ {
const size_t width = 14; const size_t width = 14;
const size_t height = 14; const size_t height = 14;
...@@ -878,7 +898,7 @@ struct test_batchnorm_inference_2 ...@@ -878,7 +898,7 @@ struct test_batchnorm_inference_2
} }
}; };
struct test_batchnorm_inference struct test_batchnorm_inference : verify_program<test_batchnorm_inference>
{ {
const size_t width = 3; const size_t width = 3;
const size_t height = 3; const size_t height = 3;
...@@ -901,7 +921,7 @@ struct test_batchnorm_inference ...@@ -901,7 +921,7 @@ struct test_batchnorm_inference
} }
}; };
struct test_conv_bn struct test_conv_bn : verify_program<test_conv_bn>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -922,7 +942,7 @@ struct test_conv_bn ...@@ -922,7 +942,7 @@ struct test_conv_bn
} }
}; };
struct test_conv_bn_relu_pooling struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -946,7 +966,7 @@ struct test_conv_bn_relu_pooling ...@@ -946,7 +966,7 @@ struct test_conv_bn_relu_pooling
} }
}; };
struct test_concat struct test_concat : verify_program<test_concat>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -963,7 +983,7 @@ struct test_concat ...@@ -963,7 +983,7 @@ struct test_concat
} }
}; };
struct test_concat2 struct test_concat2 : verify_program<test_concat2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -980,7 +1000,7 @@ struct test_concat2 ...@@ -980,7 +1000,7 @@ struct test_concat2
} }
}; };
struct test_concat_relu struct test_concat_relu : verify_program<test_concat_relu>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1001,7 +1021,7 @@ struct test_concat_relu ...@@ -1001,7 +1021,7 @@ struct test_concat_relu
} }
}; };
struct test_pad struct test_pad : verify_program<test_pad>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1020,7 +1040,7 @@ struct test_pad ...@@ -1020,7 +1040,7 @@ struct test_pad
} }
}; };
struct test_pooling_autopad struct test_pooling_autopad : verify_program<test_pooling_autopad>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1036,7 +1056,7 @@ struct test_pooling_autopad ...@@ -1036,7 +1056,7 @@ struct test_pooling_autopad
} }
}; };
struct test_gather struct test_gather : verify_program<test_gather>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1052,7 +1072,7 @@ struct test_gather ...@@ -1052,7 +1072,7 @@ struct test_gather
} }
}; };
struct test_gather_neg_axis struct test_gather_neg_axis : verify_program<test_gather_neg_axis>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1068,7 +1088,7 @@ struct test_gather_neg_axis ...@@ -1068,7 +1088,7 @@ struct test_gather_neg_axis
} }
}; };
struct test_gather_scalar_output struct test_gather_scalar_output : verify_program<test_gather_scalar_output>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1084,7 +1104,7 @@ struct test_gather_scalar_output ...@@ -1084,7 +1104,7 @@ struct test_gather_scalar_output
} }
}; };
struct test_gather_scalar_index struct test_gather_scalar_index : verify_program<test_gather_scalar_index>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1100,7 +1120,7 @@ struct test_gather_scalar_index ...@@ -1100,7 +1120,7 @@ struct test_gather_scalar_index
} }
}; };
struct test_gather_1d_index struct test_gather_1d_index : verify_program<test_gather_1d_index>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1162,7 +1182,7 @@ void manual_test_concat_relu() ...@@ -1162,7 +1182,7 @@ void manual_test_concat_relu()
std::cout << result << std::endl; std::cout << result << std::endl;
} }
struct test_conv_bn_relu_pooling2 struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2>
{ {
static migraphx::instruction_ref static migraphx::instruction_ref
add_bn(migraphx::program& p, migraphx::instruction_ref x, std::size_t channels) add_bn(migraphx::program& p, migraphx::instruction_ref x, std::size_t channels)
...@@ -1199,7 +1219,7 @@ struct test_conv_bn_relu_pooling2 ...@@ -1199,7 +1219,7 @@ struct test_conv_bn_relu_pooling2
} }
}; };
struct test_rnn_forward struct test_rnn_forward : verify_program<test_rnn_forward>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1241,7 +1261,7 @@ struct test_rnn_forward ...@@ -1241,7 +1261,7 @@ struct test_rnn_forward
} }
}; };
struct test_rnn_forward10 struct test_rnn_forward10 : verify_program<test_rnn_forward10>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1283,7 +1303,7 @@ struct test_rnn_forward10 ...@@ -1283,7 +1303,7 @@ struct test_rnn_forward10
} }
}; };
struct test_rnn_reverse struct test_rnn_reverse : verify_program<test_rnn_reverse>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1323,7 +1343,7 @@ struct test_rnn_reverse ...@@ -1323,7 +1343,7 @@ struct test_rnn_reverse
} }
}; };
struct test_rnn_reverse2 struct test_rnn_reverse2 : verify_program<test_rnn_reverse2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1363,7 +1383,7 @@ struct test_rnn_reverse2 ...@@ -1363,7 +1383,7 @@ struct test_rnn_reverse2
} }
}; };
struct test_rnn_3args struct test_rnn_3args : verify_program<test_rnn_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1395,7 +1415,7 @@ struct test_rnn_3args ...@@ -1395,7 +1415,7 @@ struct test_rnn_3args
} }
}; };
struct test_rnn_4args struct test_rnn_4args : verify_program<test_rnn_4args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1430,7 +1450,7 @@ struct test_rnn_4args ...@@ -1430,7 +1450,7 @@ struct test_rnn_4args
} }
}; };
struct test_rnn_5args struct test_rnn_5args : verify_program<test_rnn_5args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1469,7 +1489,7 @@ struct test_rnn_5args ...@@ -1469,7 +1489,7 @@ struct test_rnn_5args
} }
}; };
struct test_rnn_bidirectional struct test_rnn_bidirectional : verify_program<test_rnn_bidirectional>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1511,7 +1531,7 @@ struct test_rnn_bidirectional ...@@ -1511,7 +1531,7 @@ struct test_rnn_bidirectional
} }
}; };
struct test_rnn_bidirectional10 struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1552,7 +1572,7 @@ struct test_rnn_bidirectional10 ...@@ -1552,7 +1572,7 @@ struct test_rnn_bidirectional10
} }
}; };
struct test_rnn_bi_3args struct test_rnn_bi_3args : verify_program<test_rnn_bi_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1587,7 +1607,7 @@ struct test_rnn_bi_3args ...@@ -1587,7 +1607,7 @@ struct test_rnn_bi_3args
} }
}; };
struct test_gru_forward_last struct test_gru_forward_last : verify_program<test_gru_forward_last>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1631,7 +1651,7 @@ struct test_gru_forward_last ...@@ -1631,7 +1651,7 @@ struct test_gru_forward_last
} }
}; };
struct test_gru_forward_hs struct test_gru_forward_hs : verify_program<test_gru_forward_hs>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1673,7 +1693,7 @@ struct test_gru_forward_hs ...@@ -1673,7 +1693,7 @@ struct test_gru_forward_hs
} }
}; };
struct test_gru_forward_3args_und struct test_gru_forward_3args_und : verify_program<test_gru_forward_3args_und>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1709,7 +1729,7 @@ struct test_gru_forward_3args_und ...@@ -1709,7 +1729,7 @@ struct test_gru_forward_3args_und
} }
}; };
struct test_gru_forward_3args struct test_gru_forward_3args : verify_program<test_gru_forward_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1741,7 +1761,7 @@ struct test_gru_forward_3args ...@@ -1741,7 +1761,7 @@ struct test_gru_forward_3args
} }
}; };
struct test_gru_forward_seq1 struct test_gru_forward_seq1 : verify_program<test_gru_forward_seq1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1773,7 +1793,7 @@ struct test_gru_forward_seq1 ...@@ -1773,7 +1793,7 @@ struct test_gru_forward_seq1
} }
}; };
struct test_gru_forward_default_actv struct test_gru_forward_default_actv : verify_program<test_gru_forward_default_actv>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1803,7 +1823,7 @@ struct test_gru_forward_default_actv ...@@ -1803,7 +1823,7 @@ struct test_gru_forward_default_actv
} }
}; };
struct test_gru_forward_default_actv1 struct test_gru_forward_default_actv1 : verify_program<test_gru_forward_default_actv1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1844,7 +1864,7 @@ struct test_gru_forward_default_actv1 ...@@ -1844,7 +1864,7 @@ struct test_gru_forward_default_actv1
} }
}; };
struct test_gru_reverse_last struct test_gru_reverse_last : verify_program<test_gru_reverse_last>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1888,7 +1908,7 @@ struct test_gru_reverse_last ...@@ -1888,7 +1908,7 @@ struct test_gru_reverse_last
} }
}; };
struct test_gru_reverse_3args struct test_gru_reverse_3args : verify_program<test_gru_reverse_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1920,7 +1940,7 @@ struct test_gru_reverse_3args ...@@ -1920,7 +1940,7 @@ struct test_gru_reverse_3args
} }
}; };
struct test_gru_bidirct_last struct test_gru_bidirct_last : verify_program<test_gru_bidirct_last>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1964,7 +1984,7 @@ struct test_gru_bidirct_last ...@@ -1964,7 +1984,7 @@ struct test_gru_bidirct_last
} }
}; };
struct test_gru_bidirct_hs struct test_gru_bidirct_hs : verify_program<test_gru_bidirct_hs>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2006,7 +2026,7 @@ struct test_gru_bidirct_hs ...@@ -2006,7 +2026,7 @@ struct test_gru_bidirct_hs
} }
}; };
struct test_gru_bidirct_3args_und struct test_gru_bidirct_3args_und : verify_program<test_gru_bidirct_3args_und>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2042,7 +2062,7 @@ struct test_gru_bidirct_3args_und ...@@ -2042,7 +2062,7 @@ struct test_gru_bidirct_3args_und
} }
}; };
struct test_gru_bidirct_3args struct test_gru_bidirct_3args : verify_program<test_gru_bidirct_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2074,7 +2094,7 @@ struct test_gru_bidirct_3args ...@@ -2074,7 +2094,7 @@ struct test_gru_bidirct_3args
} }
}; };
struct test_gru_bidirct_seq1 struct test_gru_bidirct_seq1 : verify_program<test_gru_bidirct_seq1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2106,7 +2126,7 @@ struct test_gru_bidirct_seq1 ...@@ -2106,7 +2126,7 @@ struct test_gru_bidirct_seq1
} }
}; };
struct test_gru_bidirct_default_actv struct test_gru_bidirct_default_actv : verify_program<test_gru_bidirct_default_actv>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2136,7 +2156,7 @@ struct test_gru_bidirct_default_actv ...@@ -2136,7 +2156,7 @@ struct test_gru_bidirct_default_actv
} }
}; };
struct test_gru_bidirct_default_actv1 struct test_gru_bidirct_default_actv1 : verify_program<test_gru_bidirct_default_actv1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2178,7 +2198,7 @@ struct test_gru_bidirct_default_actv1 ...@@ -2178,7 +2198,7 @@ struct test_gru_bidirct_default_actv1
} }
}; };
struct test_lstm_forward_last struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2228,7 +2248,7 @@ struct test_lstm_forward_last ...@@ -2228,7 +2248,7 @@ struct test_lstm_forward_last
} }
}; };
struct test_lstm_forward_hs struct test_lstm_forward_hs : verify_program<test_lstm_forward_hs>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2278,7 +2298,7 @@ struct test_lstm_forward_hs ...@@ -2278,7 +2298,7 @@ struct test_lstm_forward_hs
} }
}; };
struct test_lstm_forward_3args_und struct test_lstm_forward_3args_und : verify_program<test_lstm_forward_3args_und>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2318,7 +2338,7 @@ struct test_lstm_forward_3args_und ...@@ -2318,7 +2338,7 @@ struct test_lstm_forward_3args_und
} }
}; };
struct test_lstm_forward_3args struct test_lstm_forward_3args : verify_program<test_lstm_forward_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2352,7 +2372,7 @@ struct test_lstm_forward_3args ...@@ -2352,7 +2372,7 @@ struct test_lstm_forward_3args
} }
}; };
struct test_lstm_forward_seq1 struct test_lstm_forward_seq1 : verify_program<test_lstm_forward_seq1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2386,7 +2406,7 @@ struct test_lstm_forward_seq1 ...@@ -2386,7 +2406,7 @@ struct test_lstm_forward_seq1
} }
}; };
struct test_lstm_forward_default_actv struct test_lstm_forward_default_actv : verify_program<test_lstm_forward_default_actv>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2416,7 +2436,7 @@ struct test_lstm_forward_default_actv ...@@ -2416,7 +2436,7 @@ struct test_lstm_forward_default_actv
} }
}; };
struct test_lstm_forward_default_actv1 struct test_lstm_forward_default_actv1 : verify_program<test_lstm_forward_default_actv1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2457,7 +2477,7 @@ struct test_lstm_forward_default_actv1 ...@@ -2457,7 +2477,7 @@ struct test_lstm_forward_default_actv1
} }
}; };
struct test_lstm_reverse_last struct test_lstm_reverse_last : verify_program<test_lstm_reverse_last>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2508,7 +2528,7 @@ struct test_lstm_reverse_last ...@@ -2508,7 +2528,7 @@ struct test_lstm_reverse_last
} }
}; };
struct test_lstm_reverse_3args struct test_lstm_reverse_3args : verify_program<test_lstm_reverse_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2542,7 +2562,7 @@ struct test_lstm_reverse_3args ...@@ -2542,7 +2562,7 @@ struct test_lstm_reverse_3args
} }
}; };
struct test_lstm_reverse_3args_cell_output struct test_lstm_reverse_3args_cell_output : verify_program<test_lstm_reverse_3args_cell_output>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2577,7 +2597,7 @@ struct test_lstm_reverse_3args_cell_output ...@@ -2577,7 +2597,7 @@ struct test_lstm_reverse_3args_cell_output
} }
}; };
struct test_lstm_bidirct_last struct test_lstm_bidirct_last : verify_program<test_lstm_bidirct_last>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2628,7 +2648,7 @@ struct test_lstm_bidirct_last ...@@ -2628,7 +2648,7 @@ struct test_lstm_bidirct_last
} }
}; };
struct test_lstm_bidirct_hs struct test_lstm_bidirct_hs : verify_program<test_lstm_bidirct_hs>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2670,7 +2690,7 @@ struct test_lstm_bidirct_hs ...@@ -2670,7 +2690,7 @@ struct test_lstm_bidirct_hs
} }
}; };
struct test_lstm_bidirct_3args_und struct test_lstm_bidirct_3args_und : verify_program<test_lstm_bidirct_3args_und>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2709,7 +2729,7 @@ struct test_lstm_bidirct_3args_und ...@@ -2709,7 +2729,7 @@ struct test_lstm_bidirct_3args_und
} }
}; };
struct test_lstm_bidirct_3args struct test_lstm_bidirct_3args : verify_program<test_lstm_bidirct_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2741,7 +2761,7 @@ struct test_lstm_bidirct_3args ...@@ -2741,7 +2761,7 @@ struct test_lstm_bidirct_3args
} }
}; };
struct test_lstm_bidirct_seq1 struct test_lstm_bidirct_seq1 : verify_program<test_lstm_bidirct_seq1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2773,7 +2793,7 @@ struct test_lstm_bidirct_seq1 ...@@ -2773,7 +2793,7 @@ struct test_lstm_bidirct_seq1
} }
}; };
struct test_lstm_bidirct_default_actv struct test_lstm_bidirct_default_actv : verify_program<test_lstm_bidirct_default_actv>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2803,7 +2823,7 @@ struct test_lstm_bidirct_default_actv ...@@ -2803,7 +2823,7 @@ struct test_lstm_bidirct_default_actv
} }
}; };
struct test_lstm_bidirct_default_actv1 struct test_lstm_bidirct_default_actv1 : verify_program<test_lstm_bidirct_default_actv1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2845,7 +2865,7 @@ struct test_lstm_bidirct_default_actv1 ...@@ -2845,7 +2865,7 @@ struct test_lstm_bidirct_default_actv1
} }
}; };
struct test_lstm_bidirct_default_actv2 struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_default_actv2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2887,116 +2907,4 @@ struct test_lstm_bidirct_default_actv2 ...@@ -2887,116 +2907,4 @@ struct test_lstm_bidirct_default_actv2
} }
}; };
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
verify_program<test_relu_lrn>();
verify_program<test_pooling_autopad>();
verify_program<test_abs>();
verify_program<test_concat>();
verify_program<test_concat2>();
verify_program<test_concat_relu>();
verify_program<test_pad>();
verify_program<test_add>();
verify_program<test_add_half>();
verify_program<test_mul>();
verify_program<test_exp>();
verify_program<test_log>();
verify_program<test_sin>();
verify_program<test_cos>();
verify_program<test_tan>();
verify_program<test_sinh>();
verify_program<test_cosh>();
verify_program<test_tanh>();
verify_program<test_asin>();
verify_program<test_acos>();
verify_program<test_atan>();
verify_program<test_scale>();
verify_program<test_triadd>();
verify_program<test_triadd2>();
verify_program<test_add_broadcast>();
verify_program<test_add_broadcast2>();
verify_program<test_add_broadcast3>();
verify_program<test_add_broadcast4>();
verify_program<test_add_broadcast5>();
verify_program<test_triadd_broadcast>();
verify_program<test_sub>();
verify_program<test_sub2>();
verify_program<test_softmax>();
verify_program<test_softmax2>();
verify_program<test_conv>();
verify_program<test_conv2>();
verify_program<test_group_conv>();
verify_program<test_conv_relu>();
verify_program<test_conv_relu_half>();
verify_program<test_add_relu>();
verify_program<test_leaky_relu>();
verify_program<test_sigmoid>();
verify_program<test_elu>();
verify_program<test_conv_pooling>();
verify_program<test_global_avg_pooling>();
verify_program<test_global_max_pooling>();
verify_program<test_gemm>();
verify_program<test_gemm_half>();
// verify_program<test_gemm_ld>();
verify_program<test_gemm_transposeb>();
verify_program<test_gemm_transposea>();
verify_program<test_gemm_transposeab>();
verify_program<test_contiguous>();
verify_program<test_eliminate_contiguous>();
verify_program<test_transpose>();
verify_program<test_batchnorm_inference>();
verify_program<test_batchnorm_inference_2>();
verify_program<test_conv_bn>();
verify_program<test_conv_bn_relu_pooling>();
verify_program<test_conv_bn_relu_pooling2>();
verify_program<test_slice>();
verify_program<test_gather>();
verify_program<test_gather_neg_axis>();
verify_program<test_gather_scalar_output>();
verify_program<test_gather_scalar_index>();
verify_program<test_gather_1d_index>();
verify_program<test_rnn_forward>();
verify_program<test_rnn_forward10>();
verify_program<test_rnn_reverse>();
verify_program<test_rnn_reverse2>();
verify_program<test_rnn_3args>();
verify_program<test_rnn_4args>();
verify_program<test_rnn_5args>();
verify_program<test_rnn_bidirectional>();
verify_program<test_rnn_bidirectional10>();
verify_program<test_rnn_bi_3args>();
verify_program<test_gru_forward_last>();
verify_program<test_gru_forward_hs>();
verify_program<test_gru_forward_3args_und>();
verify_program<test_gru_forward_3args>();
verify_program<test_gru_forward_seq1>();
verify_program<test_gru_forward_default_actv>();
verify_program<test_gru_forward_default_actv1>();
verify_program<test_gru_reverse_last>();
verify_program<test_gru_reverse_3args>();
verify_program<test_gru_bidirct_last>();
verify_program<test_gru_bidirct_hs>();
verify_program<test_gru_bidirct_3args_und>();
verify_program<test_gru_bidirct_3args>();
verify_program<test_gru_bidirct_seq1>();
verify_program<test_gru_bidirct_default_actv>();
verify_program<test_gru_bidirct_default_actv1>();
verify_program<test_lstm_forward_last>();
verify_program<test_lstm_forward_hs>();
verify_program<test_lstm_forward_3args_und>();
verify_program<test_lstm_forward_3args>();
verify_program<test_lstm_forward_seq1>();
verify_program<test_lstm_forward_default_actv>();
verify_program<test_lstm_forward_default_actv1>();
verify_program<test_lstm_reverse_last>();
verify_program<test_lstm_reverse_3args>();
verify_program<test_lstm_reverse_3args_cell_output>();
verify_program<test_lstm_bidirct_last>();
verify_program<test_lstm_bidirct_hs>();
verify_program<test_lstm_bidirct_3args_und>();
verify_program<test_lstm_bidirct_3args>();
verify_program<test_lstm_bidirct_seq1>();
verify_program<test_lstm_bidirct_default_actv>();
verify_program<test_lstm_bidirct_default_actv1>();
verify_program<test_lstm_bidirct_default_actv2>();
}
...@@ -205,10 +205,10 @@ inline void add_test_case(std::string name, std::function<void()> f) ...@@ -205,10 +205,10 @@ inline void add_test_case(std::string name, std::function<void()> f)
get_test_cases().emplace_back(std::move(name), std::move(f)); get_test_cases().emplace_back(std::move(name), std::move(f));
} }
struct auto_register struct auto_register_test_case
{ {
template <class F> template <class F>
auto_register(const char* name, F f) noexcept auto_register_test_case(const char* name, F f) noexcept
{ {
add_test_case(name, f); add_test_case(name, f);
} }
...@@ -271,9 +271,9 @@ inline void run(int argc, const char* argv[]) ...@@ -271,9 +271,9 @@ inline void run(int argc, const char* argv[])
#define TEST_PRIMITIVE_CAT(x, ...) x##__VA_ARGS__ #define TEST_PRIMITIVE_CAT(x, ...) x##__VA_ARGS__
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define TEST_CASE_REGISTER(...) \ #define TEST_CASE_REGISTER(...) \
static test::auto_register TEST_CAT(register_test_case_, __LINE__) = \ static test::auto_register_test_case TEST_CAT(register_test_case_, __LINE__) = \
test::auto_register(#__VA_ARGS__, &__VA_ARGS__); test::auto_register_test_case(#__VA_ARGS__, &__VA_ARGS__);
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define TEST_CASE(...) \ #define TEST_CASE(...) \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment