Commit 770c7d27 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge branch develop

parents bbd4e0c3 3499ec7d
...@@ -987,15 +987,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -987,15 +987,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph); auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph); auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph);
pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphi); pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphi);
pphi_brcst = prog.insert_instruction(ins, op::contiguous{}, pphi_brcst);
auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph); auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph);
ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, ppho); ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, ppho);
ppho_brcst = prog.insert_instruction(ins, op::contiguous{}, ppho_brcst);
auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph); auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph);
pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphf); pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphf);
pphf_brcst = prog.insert_instruction(ins, op::contiguous{}, pphf_brcst);
} }
for(long i = 0; i < seq_len; ++i) for(long i = 0; i < seq_len; ++i)
......
...@@ -14,8 +14,9 @@ std::string target::name() const { return "cpu"; } ...@@ -14,8 +14,9 @@ std::string target::name() const { return "cpu"; }
std::vector<pass> target::get_passes(migraphx::context&) const std::vector<pass> target::get_passes(migraphx::context&) const
{ {
return {auto_contiguous{}, return {rewrite_rnn{},
rewrite_rnn{}, dead_code_elimination{},
auto_contiguous{},
dead_code_elimination{}, dead_code_elimination{},
lowering{}, lowering{},
dead_code_elimination{}}; dead_code_elimination{}};
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <future> #include <future>
#include <thread> #include <thread>
#include "test.hpp" #include <test.hpp>
#ifdef __clang__ #ifdef __clang__
#pragma clang diagnostic push #pragma clang diagnostic push
...@@ -134,7 +134,7 @@ migraphx::argument run_gpu(migraphx::program& p) ...@@ -134,7 +134,7 @@ migraphx::argument run_gpu(migraphx::program& p)
} }
template <class V> template <class V>
void verify_program() void run_verify_program()
{ {
auto_print::set_terminate_handler(migraphx::get_type_name<V>()); auto_print::set_terminate_handler(migraphx::get_type_name<V>());
// std::cout << migraphx::get_type_name<V>() << std::endl; // std::cout << migraphx::get_type_name<V>() << std::endl;
...@@ -156,7 +156,27 @@ void verify_program() ...@@ -156,7 +156,27 @@ void verify_program()
std::set_terminate(nullptr); std::set_terminate(nullptr);
} }
struct test_literals template <class T>
int auto_register_verify_program()
{
test::add_test_case(migraphx::get_type_name<T>(), [] { run_verify_program<T>(); });
return 0;
}
template <class T>
struct verify_program
{
static int static_register;
// This typedef ensures that the static member will be instantiated if
// the class itself is instantiated
using static_register_type =
std::integral_constant<decltype(&static_register), &static_register>;
};
template <class T>
int verify_program<T>::static_register = auto_register_verify_program<T>(); // NOLINT
struct test_literals : verify_program<test_literals>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -171,7 +191,7 @@ struct test_literals ...@@ -171,7 +191,7 @@ struct test_literals
} }
}; };
struct test_add struct test_add : verify_program<test_add>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -184,7 +204,7 @@ struct test_add ...@@ -184,7 +204,7 @@ struct test_add
} }
}; };
struct test_add_half struct test_add_half : verify_program<test_add_half>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -197,7 +217,7 @@ struct test_add_half ...@@ -197,7 +217,7 @@ struct test_add_half
} }
}; };
struct test_mul struct test_mul : verify_program<test_mul>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -210,7 +230,7 @@ struct test_mul ...@@ -210,7 +230,7 @@ struct test_mul
} }
}; };
struct test_exp struct test_exp : verify_program<test_exp>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -223,7 +243,7 @@ struct test_exp ...@@ -223,7 +243,7 @@ struct test_exp
} }
}; };
struct test_log struct test_log : verify_program<test_log>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -236,7 +256,7 @@ struct test_log ...@@ -236,7 +256,7 @@ struct test_log
} }
}; };
struct test_sin struct test_sin : verify_program<test_sin>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -248,7 +268,7 @@ struct test_sin ...@@ -248,7 +268,7 @@ struct test_sin
} }
}; };
struct test_cos struct test_cos : verify_program<test_cos>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -260,7 +280,7 @@ struct test_cos ...@@ -260,7 +280,7 @@ struct test_cos
} }
}; };
struct test_tan struct test_tan : verify_program<test_tan>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -272,7 +292,7 @@ struct test_tan ...@@ -272,7 +292,7 @@ struct test_tan
} }
}; };
struct test_sinh struct test_sinh : verify_program<test_sinh>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -284,7 +304,7 @@ struct test_sinh ...@@ -284,7 +304,7 @@ struct test_sinh
} }
}; };
struct test_cosh struct test_cosh : verify_program<test_cosh>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -296,7 +316,7 @@ struct test_cosh ...@@ -296,7 +316,7 @@ struct test_cosh
} }
}; };
struct test_tanh struct test_tanh : verify_program<test_tanh>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -307,7 +327,7 @@ struct test_tanh ...@@ -307,7 +327,7 @@ struct test_tanh
} }
}; };
struct test_asin struct test_asin : verify_program<test_asin>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -319,7 +339,7 @@ struct test_asin ...@@ -319,7 +339,7 @@ struct test_asin
} }
}; };
struct test_acos struct test_acos : verify_program<test_acos>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -331,7 +351,7 @@ struct test_acos ...@@ -331,7 +351,7 @@ struct test_acos
} }
}; };
struct test_atan struct test_atan : verify_program<test_atan>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -343,7 +363,7 @@ struct test_atan ...@@ -343,7 +363,7 @@ struct test_atan
} }
}; };
struct test_scale struct test_scale : verify_program<test_scale>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -357,7 +377,7 @@ struct test_scale ...@@ -357,7 +377,7 @@ struct test_scale
} }
}; };
struct test_slice struct test_slice : verify_program<test_slice>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -372,7 +392,7 @@ struct test_slice ...@@ -372,7 +392,7 @@ struct test_slice
} }
}; };
struct test_triadd struct test_triadd : verify_program<test_triadd>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -387,7 +407,7 @@ struct test_triadd ...@@ -387,7 +407,7 @@ struct test_triadd
} }
}; };
struct test_triadd2 struct test_triadd2 : verify_program<test_triadd2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -404,7 +424,7 @@ struct test_triadd2 ...@@ -404,7 +424,7 @@ struct test_triadd2
} }
}; };
struct test_add_broadcast struct test_add_broadcast : verify_program<test_add_broadcast>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -418,7 +438,7 @@ struct test_add_broadcast ...@@ -418,7 +438,7 @@ struct test_add_broadcast
} }
}; };
struct test_add_broadcast2 struct test_add_broadcast2 : verify_program<test_add_broadcast2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -432,7 +452,7 @@ struct test_add_broadcast2 ...@@ -432,7 +452,7 @@ struct test_add_broadcast2
} }
}; };
struct test_add_broadcast3 struct test_add_broadcast3 : verify_program<test_add_broadcast3>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -446,7 +466,7 @@ struct test_add_broadcast3 ...@@ -446,7 +466,7 @@ struct test_add_broadcast3
} }
}; };
struct test_add_broadcast4 struct test_add_broadcast4 : verify_program<test_add_broadcast4>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -460,7 +480,7 @@ struct test_add_broadcast4 ...@@ -460,7 +480,7 @@ struct test_add_broadcast4
} }
}; };
struct test_add_broadcast5 struct test_add_broadcast5 : verify_program<test_add_broadcast5>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -474,7 +494,7 @@ struct test_add_broadcast5 ...@@ -474,7 +494,7 @@ struct test_add_broadcast5
} }
}; };
struct test_triadd_broadcast struct test_triadd_broadcast : verify_program<test_triadd_broadcast>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -490,7 +510,7 @@ struct test_triadd_broadcast ...@@ -490,7 +510,7 @@ struct test_triadd_broadcast
} }
}; };
struct test_sub struct test_sub : verify_program<test_sub>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -505,7 +525,7 @@ struct test_sub ...@@ -505,7 +525,7 @@ struct test_sub
} }
}; };
struct test_sub2 struct test_sub2 : verify_program<test_sub2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -522,7 +542,7 @@ struct test_sub2 ...@@ -522,7 +542,7 @@ struct test_sub2
} }
}; };
struct test_softmax struct test_softmax : verify_program<test_softmax>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -533,7 +553,7 @@ struct test_softmax ...@@ -533,7 +553,7 @@ struct test_softmax
} }
}; };
struct test_softmax2 struct test_softmax2 : verify_program<test_softmax2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -545,7 +565,7 @@ struct test_softmax2 ...@@ -545,7 +565,7 @@ struct test_softmax2
} }
}; };
struct test_conv struct test_conv : verify_program<test_conv>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -559,7 +579,7 @@ struct test_conv ...@@ -559,7 +579,7 @@ struct test_conv
} }
}; };
struct test_conv2 struct test_conv2 : verify_program<test_conv2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -573,7 +593,7 @@ struct test_conv2 ...@@ -573,7 +593,7 @@ struct test_conv2
} }
}; };
struct test_group_conv struct test_group_conv : verify_program<test_group_conv>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -589,7 +609,7 @@ struct test_group_conv ...@@ -589,7 +609,7 @@ struct test_group_conv
} }
}; };
struct test_conv_relu struct test_conv_relu : verify_program<test_conv_relu>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -604,7 +624,7 @@ struct test_conv_relu ...@@ -604,7 +624,7 @@ struct test_conv_relu
} }
}; };
struct test_conv_relu_half struct test_conv_relu_half : verify_program<test_conv_relu_half>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -619,7 +639,7 @@ struct test_conv_relu_half ...@@ -619,7 +639,7 @@ struct test_conv_relu_half
} }
}; };
struct test_add_relu struct test_add_relu : verify_program<test_add_relu>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -632,7 +652,7 @@ struct test_add_relu ...@@ -632,7 +652,7 @@ struct test_add_relu
} }
}; };
struct test_sigmoid struct test_sigmoid : verify_program<test_sigmoid>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -643,7 +663,7 @@ struct test_sigmoid ...@@ -643,7 +663,7 @@ struct test_sigmoid
} }
}; };
struct test_abs struct test_abs : verify_program<test_abs>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -654,7 +674,7 @@ struct test_abs ...@@ -654,7 +674,7 @@ struct test_abs
} }
}; };
struct test_leaky_relu struct test_leaky_relu : verify_program<test_leaky_relu>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -665,7 +685,7 @@ struct test_leaky_relu ...@@ -665,7 +685,7 @@ struct test_leaky_relu
} }
}; };
struct test_elu struct test_elu : verify_program<test_elu>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -676,7 +696,7 @@ struct test_elu ...@@ -676,7 +696,7 @@ struct test_elu
} }
}; };
struct test_relu_lrn struct test_relu_lrn : verify_program<test_relu_lrn>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -688,7 +708,7 @@ struct test_relu_lrn ...@@ -688,7 +708,7 @@ struct test_relu_lrn
} }
}; };
struct test_conv_pooling struct test_conv_pooling : verify_program<test_conv_pooling>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -704,7 +724,7 @@ struct test_conv_pooling ...@@ -704,7 +724,7 @@ struct test_conv_pooling
} }
}; };
struct test_global_avg_pooling struct test_global_avg_pooling : verify_program<test_global_avg_pooling>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -719,7 +739,7 @@ struct test_global_avg_pooling ...@@ -719,7 +739,7 @@ struct test_global_avg_pooling
} }
}; };
struct test_global_max_pooling struct test_global_max_pooling : verify_program<test_global_max_pooling>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -734,7 +754,7 @@ struct test_global_max_pooling ...@@ -734,7 +754,7 @@ struct test_global_max_pooling
} }
}; };
struct test_gemm struct test_gemm : verify_program<test_gemm>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -746,7 +766,7 @@ struct test_gemm ...@@ -746,7 +766,7 @@ struct test_gemm
} }
}; };
struct test_gemm_ex struct test_gemm_ex : verify_program<test_gemm_ex>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -758,7 +778,7 @@ struct test_gemm_ex ...@@ -758,7 +778,7 @@ struct test_gemm_ex
} }
}; };
struct test_gemm_half struct test_gemm_half : verify_program<test_gemm_half>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -770,7 +790,7 @@ struct test_gemm_half ...@@ -770,7 +790,7 @@ struct test_gemm_half
} }
}; };
struct test_gemm_ld struct test_gemm_ld //: verify_program<test_gemm_ld>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -784,7 +804,7 @@ struct test_gemm_ld ...@@ -784,7 +804,7 @@ struct test_gemm_ld
} }
}; };
struct test_gemm_transposeb struct test_gemm_transposeb : verify_program<test_gemm_transposeb>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -797,7 +817,7 @@ struct test_gemm_transposeb ...@@ -797,7 +817,7 @@ struct test_gemm_transposeb
} }
}; };
struct test_gemm_transposeb_ex struct test_gemm_transposeb_ex : verify_program<test_gemm_transposeb_ex>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -810,7 +830,7 @@ struct test_gemm_transposeb_ex ...@@ -810,7 +830,7 @@ struct test_gemm_transposeb_ex
} }
}; };
struct test_gemm_transposea struct test_gemm_transposea : verify_program<test_gemm_transposea>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -823,7 +843,7 @@ struct test_gemm_transposea ...@@ -823,7 +843,7 @@ struct test_gemm_transposea
} }
}; };
struct test_gemm_transposea_ex struct test_gemm_transposea_ex : verify_program<test_gemm_transposea_ex>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -836,7 +856,7 @@ struct test_gemm_transposea_ex ...@@ -836,7 +856,7 @@ struct test_gemm_transposea_ex
} }
}; };
struct test_gemm_transposeab struct test_gemm_transposeab : verify_program<test_gemm_transposeab>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -850,7 +870,7 @@ struct test_gemm_transposeab ...@@ -850,7 +870,7 @@ struct test_gemm_transposeab
} }
}; };
struct gemm_mutli_dim_2 struct gemm_multi_dim_2 : verify_program<gemm_multi_dim_2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -866,7 +886,7 @@ struct gemm_mutli_dim_2 ...@@ -866,7 +886,7 @@ struct gemm_mutli_dim_2
} }
}; };
struct gemm_mutli_dim_2_3 struct gemm_multi_dim_2_3 : verify_program<gemm_multi_dim_2_3>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -882,7 +902,7 @@ struct gemm_mutli_dim_2_3 ...@@ -882,7 +902,7 @@ struct gemm_mutli_dim_2_3
} }
}; };
struct gemm_mutli_3args struct gemm_multi_3args : verify_program<gemm_multi_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -902,7 +922,7 @@ struct gemm_mutli_3args ...@@ -902,7 +922,7 @@ struct gemm_mutli_3args
} }
}; };
struct gemm_mutli_3args_beta0 struct gemm_multi_3args_beta0 : verify_program<gemm_multi_3args_beta0>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -922,7 +942,7 @@ struct gemm_mutli_3args_beta0 ...@@ -922,7 +942,7 @@ struct gemm_mutli_3args_beta0
} }
}; };
struct gemm_mutli_3args_alpha0 struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -942,7 +962,7 @@ struct gemm_mutli_3args_alpha0 ...@@ -942,7 +962,7 @@ struct gemm_mutli_3args_alpha0
} }
}; };
struct test_contiguous struct test_contiguous : verify_program<test_contiguous>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -955,7 +975,7 @@ struct test_contiguous ...@@ -955,7 +975,7 @@ struct test_contiguous
} }
}; };
struct test_eliminate_contiguous struct test_eliminate_contiguous : verify_program<test_eliminate_contiguous>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -971,7 +991,7 @@ struct test_eliminate_contiguous ...@@ -971,7 +991,7 @@ struct test_eliminate_contiguous
} }
}; };
struct test_transpose struct test_transpose : verify_program<test_transpose>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -985,7 +1005,7 @@ struct test_transpose ...@@ -985,7 +1005,7 @@ struct test_transpose
} }
}; };
struct test_batchnorm_inference_2 struct test_batchnorm_inference_2 : verify_program<test_batchnorm_inference_2>
{ {
const size_t width = 14; const size_t width = 14;
const size_t height = 14; const size_t height = 14;
...@@ -1008,7 +1028,7 @@ struct test_batchnorm_inference_2 ...@@ -1008,7 +1028,7 @@ struct test_batchnorm_inference_2
} }
}; };
struct test_batchnorm_inference struct test_batchnorm_inference : verify_program<test_batchnorm_inference>
{ {
const size_t width = 3; const size_t width = 3;
const size_t height = 3; const size_t height = 3;
...@@ -1031,7 +1051,7 @@ struct test_batchnorm_inference ...@@ -1031,7 +1051,7 @@ struct test_batchnorm_inference
} }
}; };
struct test_conv_bn struct test_conv_bn : verify_program<test_conv_bn>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1052,7 +1072,7 @@ struct test_conv_bn ...@@ -1052,7 +1072,7 @@ struct test_conv_bn
} }
}; };
struct test_conv_bn_relu_pooling struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1076,7 +1096,7 @@ struct test_conv_bn_relu_pooling ...@@ -1076,7 +1096,7 @@ struct test_conv_bn_relu_pooling
} }
}; };
struct test_concat struct test_concat : verify_program<test_concat>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1093,7 +1113,7 @@ struct test_concat ...@@ -1093,7 +1113,7 @@ struct test_concat
} }
}; };
struct test_concat2 struct test_concat2 : verify_program<test_concat2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1110,7 +1130,7 @@ struct test_concat2 ...@@ -1110,7 +1130,7 @@ struct test_concat2
} }
}; };
struct test_concat_relu struct test_concat_relu : verify_program<test_concat_relu>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1131,7 +1151,7 @@ struct test_concat_relu ...@@ -1131,7 +1151,7 @@ struct test_concat_relu
} }
}; };
struct test_pad struct test_pad : verify_program<test_pad>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1150,7 +1170,7 @@ struct test_pad ...@@ -1150,7 +1170,7 @@ struct test_pad
} }
}; };
struct test_pooling_autopad struct test_pooling_autopad : verify_program<test_pooling_autopad>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1166,7 +1186,7 @@ struct test_pooling_autopad ...@@ -1166,7 +1186,7 @@ struct test_pooling_autopad
} }
}; };
struct test_gather struct test_gather : verify_program<test_gather>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1182,7 +1202,7 @@ struct test_gather ...@@ -1182,7 +1202,7 @@ struct test_gather
} }
}; };
struct test_gather_neg_axis struct test_gather_neg_axis : verify_program<test_gather_neg_axis>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1198,7 +1218,7 @@ struct test_gather_neg_axis ...@@ -1198,7 +1218,7 @@ struct test_gather_neg_axis
} }
}; };
struct test_gather_scalar_output struct test_gather_scalar_output : verify_program<test_gather_scalar_output>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1214,7 +1234,7 @@ struct test_gather_scalar_output ...@@ -1214,7 +1234,7 @@ struct test_gather_scalar_output
} }
}; };
struct test_gather_scalar_index struct test_gather_scalar_index : verify_program<test_gather_scalar_index>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1230,7 +1250,7 @@ struct test_gather_scalar_index ...@@ -1230,7 +1250,7 @@ struct test_gather_scalar_index
} }
}; };
struct test_gather_1d_index struct test_gather_1d_index : verify_program<test_gather_1d_index>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1292,7 +1312,7 @@ void manual_test_concat_relu() ...@@ -1292,7 +1312,7 @@ void manual_test_concat_relu()
std::cout << result << std::endl; std::cout << result << std::endl;
} }
struct test_conv_bn_relu_pooling2 struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2>
{ {
static migraphx::instruction_ref static migraphx::instruction_ref
add_bn(migraphx::program& p, migraphx::instruction_ref x, std::size_t channels) add_bn(migraphx::program& p, migraphx::instruction_ref x, std::size_t channels)
...@@ -1329,7 +1349,7 @@ struct test_conv_bn_relu_pooling2 ...@@ -1329,7 +1349,7 @@ struct test_conv_bn_relu_pooling2
} }
}; };
struct test_rnn_forward struct test_rnn_forward : verify_program<test_rnn_forward>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1371,7 +1391,7 @@ struct test_rnn_forward ...@@ -1371,7 +1391,7 @@ struct test_rnn_forward
} }
}; };
struct test_rnn_forward10 struct test_rnn_forward10 : verify_program<test_rnn_forward10>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1413,7 +1433,7 @@ struct test_rnn_forward10 ...@@ -1413,7 +1433,7 @@ struct test_rnn_forward10
} }
}; };
struct test_rnn_reverse struct test_rnn_reverse : verify_program<test_rnn_reverse>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1453,7 +1473,7 @@ struct test_rnn_reverse ...@@ -1453,7 +1473,7 @@ struct test_rnn_reverse
} }
}; };
struct test_rnn_reverse2 struct test_rnn_reverse2 : verify_program<test_rnn_reverse2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1493,7 +1513,7 @@ struct test_rnn_reverse2 ...@@ -1493,7 +1513,7 @@ struct test_rnn_reverse2
} }
}; };
struct test_rnn_3args struct test_rnn_3args : verify_program<test_rnn_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1525,7 +1545,7 @@ struct test_rnn_3args ...@@ -1525,7 +1545,7 @@ struct test_rnn_3args
} }
}; };
struct test_rnn_4args struct test_rnn_4args : verify_program<test_rnn_4args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1560,7 +1580,7 @@ struct test_rnn_4args ...@@ -1560,7 +1580,7 @@ struct test_rnn_4args
} }
}; };
struct test_rnn_5args struct test_rnn_5args : verify_program<test_rnn_5args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1599,7 +1619,7 @@ struct test_rnn_5args ...@@ -1599,7 +1619,7 @@ struct test_rnn_5args
} }
}; };
struct test_rnn_bidirectional struct test_rnn_bidirectional : verify_program<test_rnn_bidirectional>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1641,7 +1661,7 @@ struct test_rnn_bidirectional ...@@ -1641,7 +1661,7 @@ struct test_rnn_bidirectional
} }
}; };
struct test_rnn_bidirectional10 struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1682,7 +1702,7 @@ struct test_rnn_bidirectional10 ...@@ -1682,7 +1702,7 @@ struct test_rnn_bidirectional10
} }
}; };
struct test_rnn_bi_3args struct test_rnn_bi_3args : verify_program<test_rnn_bi_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1717,7 +1737,7 @@ struct test_rnn_bi_3args ...@@ -1717,7 +1737,7 @@ struct test_rnn_bi_3args
} }
}; };
struct test_gru_forward_last struct test_gru_forward_last : verify_program<test_gru_forward_last>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1761,7 +1781,7 @@ struct test_gru_forward_last ...@@ -1761,7 +1781,7 @@ struct test_gru_forward_last
} }
}; };
struct test_gru_forward_hs struct test_gru_forward_hs : verify_program<test_gru_forward_hs>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1803,7 +1823,7 @@ struct test_gru_forward_hs ...@@ -1803,7 +1823,7 @@ struct test_gru_forward_hs
} }
}; };
struct test_gru_forward_3args_und struct test_gru_forward_3args_und : verify_program<test_gru_forward_3args_und>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1839,7 +1859,7 @@ struct test_gru_forward_3args_und ...@@ -1839,7 +1859,7 @@ struct test_gru_forward_3args_und
} }
}; };
struct test_gru_forward_3args struct test_gru_forward_3args : verify_program<test_gru_forward_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1871,7 +1891,7 @@ struct test_gru_forward_3args ...@@ -1871,7 +1891,7 @@ struct test_gru_forward_3args
} }
}; };
struct test_gru_forward_seq1 struct test_gru_forward_seq1 : verify_program<test_gru_forward_seq1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1903,7 +1923,7 @@ struct test_gru_forward_seq1 ...@@ -1903,7 +1923,7 @@ struct test_gru_forward_seq1
} }
}; };
struct test_gru_forward_default_actv struct test_gru_forward_default_actv : verify_program<test_gru_forward_default_actv>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1933,7 +1953,7 @@ struct test_gru_forward_default_actv ...@@ -1933,7 +1953,7 @@ struct test_gru_forward_default_actv
} }
}; };
struct test_gru_forward_default_actv1 struct test_gru_forward_default_actv1 : verify_program<test_gru_forward_default_actv1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -1974,7 +1994,7 @@ struct test_gru_forward_default_actv1 ...@@ -1974,7 +1994,7 @@ struct test_gru_forward_default_actv1
} }
}; };
struct test_gru_reverse_last struct test_gru_reverse_last : verify_program<test_gru_reverse_last>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2018,7 +2038,7 @@ struct test_gru_reverse_last ...@@ -2018,7 +2038,7 @@ struct test_gru_reverse_last
} }
}; };
struct test_gru_reverse_3args struct test_gru_reverse_3args : verify_program<test_gru_reverse_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2050,7 +2070,7 @@ struct test_gru_reverse_3args ...@@ -2050,7 +2070,7 @@ struct test_gru_reverse_3args
} }
}; };
struct test_gru_bidirct_last struct test_gru_bidirct_last : verify_program<test_gru_bidirct_last>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2094,7 +2114,7 @@ struct test_gru_bidirct_last ...@@ -2094,7 +2114,7 @@ struct test_gru_bidirct_last
} }
}; };
struct test_gru_bidirct_hs struct test_gru_bidirct_hs : verify_program<test_gru_bidirct_hs>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2136,7 +2156,7 @@ struct test_gru_bidirct_hs ...@@ -2136,7 +2156,7 @@ struct test_gru_bidirct_hs
} }
}; };
struct test_gru_bidirct_3args_und struct test_gru_bidirct_3args_und : verify_program<test_gru_bidirct_3args_und>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2172,7 +2192,7 @@ struct test_gru_bidirct_3args_und ...@@ -2172,7 +2192,7 @@ struct test_gru_bidirct_3args_und
} }
}; };
struct test_gru_bidirct_3args struct test_gru_bidirct_3args : verify_program<test_gru_bidirct_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2204,7 +2224,7 @@ struct test_gru_bidirct_3args ...@@ -2204,7 +2224,7 @@ struct test_gru_bidirct_3args
} }
}; };
struct test_gru_bidirct_seq1 struct test_gru_bidirct_seq1 : verify_program<test_gru_bidirct_seq1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2236,7 +2256,7 @@ struct test_gru_bidirct_seq1 ...@@ -2236,7 +2256,7 @@ struct test_gru_bidirct_seq1
} }
}; };
struct test_gru_bidirct_default_actv struct test_gru_bidirct_default_actv : verify_program<test_gru_bidirct_default_actv>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2266,7 +2286,7 @@ struct test_gru_bidirct_default_actv ...@@ -2266,7 +2286,7 @@ struct test_gru_bidirct_default_actv
} }
}; };
struct test_gru_bidirct_default_actv1 struct test_gru_bidirct_default_actv1 : verify_program<test_gru_bidirct_default_actv1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2308,7 +2328,7 @@ struct test_gru_bidirct_default_actv1 ...@@ -2308,7 +2328,7 @@ struct test_gru_bidirct_default_actv1
} }
}; };
struct test_lstm_forward_last struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2358,7 +2378,7 @@ struct test_lstm_forward_last ...@@ -2358,7 +2378,7 @@ struct test_lstm_forward_last
} }
}; };
struct test_lstm_forward_hs struct test_lstm_forward_hs : verify_program<test_lstm_forward_hs>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2408,7 +2428,7 @@ struct test_lstm_forward_hs ...@@ -2408,7 +2428,7 @@ struct test_lstm_forward_hs
} }
}; };
struct test_lstm_forward_3args_und struct test_lstm_forward_3args_und : verify_program<test_lstm_forward_3args_und>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2448,7 +2468,7 @@ struct test_lstm_forward_3args_und ...@@ -2448,7 +2468,7 @@ struct test_lstm_forward_3args_und
} }
}; };
struct test_lstm_forward_3args struct test_lstm_forward_3args : verify_program<test_lstm_forward_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2482,7 +2502,7 @@ struct test_lstm_forward_3args ...@@ -2482,7 +2502,7 @@ struct test_lstm_forward_3args
} }
}; };
struct test_lstm_forward_seq1 struct test_lstm_forward_seq1 : verify_program<test_lstm_forward_seq1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2516,7 +2536,7 @@ struct test_lstm_forward_seq1 ...@@ -2516,7 +2536,7 @@ struct test_lstm_forward_seq1
} }
}; };
struct test_lstm_forward_default_actv struct test_lstm_forward_default_actv : verify_program<test_lstm_forward_default_actv>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2546,7 +2566,7 @@ struct test_lstm_forward_default_actv ...@@ -2546,7 +2566,7 @@ struct test_lstm_forward_default_actv
} }
}; };
struct test_lstm_forward_default_actv1 struct test_lstm_forward_default_actv1 : verify_program<test_lstm_forward_default_actv1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2587,7 +2607,7 @@ struct test_lstm_forward_default_actv1 ...@@ -2587,7 +2607,7 @@ struct test_lstm_forward_default_actv1
} }
}; };
struct test_lstm_reverse_last struct test_lstm_reverse_last : verify_program<test_lstm_reverse_last>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2638,7 +2658,7 @@ struct test_lstm_reverse_last ...@@ -2638,7 +2658,7 @@ struct test_lstm_reverse_last
} }
}; };
struct test_lstm_reverse_3args struct test_lstm_reverse_3args : verify_program<test_lstm_reverse_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2672,7 +2692,7 @@ struct test_lstm_reverse_3args ...@@ -2672,7 +2692,7 @@ struct test_lstm_reverse_3args
} }
}; };
struct test_lstm_reverse_3args_cell_output struct test_lstm_reverse_3args_cell_output : verify_program<test_lstm_reverse_3args_cell_output>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2707,7 +2727,7 @@ struct test_lstm_reverse_3args_cell_output ...@@ -2707,7 +2727,7 @@ struct test_lstm_reverse_3args_cell_output
} }
}; };
struct test_lstm_bidirct_last struct test_lstm_bidirct_last : verify_program<test_lstm_bidirct_last>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2758,7 +2778,7 @@ struct test_lstm_bidirct_last ...@@ -2758,7 +2778,7 @@ struct test_lstm_bidirct_last
} }
}; };
struct test_lstm_bidirct_hs struct test_lstm_bidirct_hs : verify_program<test_lstm_bidirct_hs>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2800,7 +2820,7 @@ struct test_lstm_bidirct_hs ...@@ -2800,7 +2820,7 @@ struct test_lstm_bidirct_hs
} }
}; };
struct test_lstm_bidirct_3args_und struct test_lstm_bidirct_3args_und : verify_program<test_lstm_bidirct_3args_und>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2839,7 +2859,7 @@ struct test_lstm_bidirct_3args_und ...@@ -2839,7 +2859,7 @@ struct test_lstm_bidirct_3args_und
} }
}; };
struct test_lstm_bidirct_3args struct test_lstm_bidirct_3args : verify_program<test_lstm_bidirct_3args>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2871,7 +2891,7 @@ struct test_lstm_bidirct_3args ...@@ -2871,7 +2891,7 @@ struct test_lstm_bidirct_3args
} }
}; };
struct test_lstm_bidirct_seq1 struct test_lstm_bidirct_seq1 : verify_program<test_lstm_bidirct_seq1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2903,7 +2923,7 @@ struct test_lstm_bidirct_seq1 ...@@ -2903,7 +2923,7 @@ struct test_lstm_bidirct_seq1
} }
}; };
struct test_lstm_bidirct_default_actv struct test_lstm_bidirct_default_actv : verify_program<test_lstm_bidirct_default_actv>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2933,7 +2953,7 @@ struct test_lstm_bidirct_default_actv ...@@ -2933,7 +2953,7 @@ struct test_lstm_bidirct_default_actv
} }
}; };
struct test_lstm_bidirct_default_actv1 struct test_lstm_bidirct_default_actv1 : verify_program<test_lstm_bidirct_default_actv1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -2975,7 +2995,7 @@ struct test_lstm_bidirct_default_actv1 ...@@ -2975,7 +2995,7 @@ struct test_lstm_bidirct_default_actv1
} }
}; };
struct test_lstm_bidirct_default_actv2 struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_default_actv2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -3018,7 +3038,7 @@ struct test_lstm_bidirct_default_actv2 ...@@ -3018,7 +3038,7 @@ struct test_lstm_bidirct_default_actv2
}; };
template <int Axis> template <int Axis>
struct test_logsoftmax struct test_logsoftmax : verify_program<test_logsoftmax<Axis>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -3031,8 +3051,14 @@ struct test_logsoftmax ...@@ -3031,8 +3051,14 @@ struct test_logsoftmax
} }
}; };
template struct test_logsoftmax<0>;
template struct test_logsoftmax<1>;
template struct test_logsoftmax<2>;
template struct test_logsoftmax<3>;
template struct test_logsoftmax<4>;
template <int Axis> template <int Axis>
struct test_logsoftmax_1 struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -3045,131 +3071,7 @@ struct test_logsoftmax_1 ...@@ -3045,131 +3071,7 @@ struct test_logsoftmax_1
} }
}; };
int main() template struct test_logsoftmax_1<0>;
{ template struct test_logsoftmax_1<1>;
verify_program<test_relu_lrn>();
verify_program<test_pooling_autopad>(); int main(int argc, const char* argv[]) { test::run(argc, argv); }
verify_program<test_abs>();
verify_program<test_concat>();
verify_program<test_concat2>();
verify_program<test_concat_relu>();
verify_program<test_pad>();
verify_program<test_add>();
verify_program<test_add_half>();
verify_program<test_mul>();
verify_program<test_exp>();
verify_program<test_log>();
verify_program<test_sin>();
verify_program<test_cos>();
verify_program<test_tan>();
verify_program<test_sinh>();
verify_program<test_cosh>();
verify_program<test_tanh>();
verify_program<test_asin>();
verify_program<test_acos>();
verify_program<test_atan>();
verify_program<test_scale>();
verify_program<test_triadd>();
verify_program<test_triadd2>();
verify_program<test_add_broadcast>();
verify_program<test_add_broadcast2>();
verify_program<test_add_broadcast3>();
verify_program<test_add_broadcast4>();
verify_program<test_add_broadcast5>();
verify_program<test_triadd_broadcast>();
verify_program<test_sub>();
verify_program<test_sub2>();
verify_program<test_softmax>();
verify_program<test_softmax2>();
verify_program<test_conv>();
verify_program<test_conv2>();
verify_program<test_group_conv>();
verify_program<test_conv_relu>();
verify_program<test_conv_relu_half>();
verify_program<test_add_relu>();
verify_program<test_leaky_relu>();
verify_program<test_sigmoid>();
verify_program<test_elu>();
verify_program<test_conv_pooling>();
verify_program<test_global_avg_pooling>();
verify_program<test_global_max_pooling>();
verify_program<test_gemm>();
verify_program<test_gemm_ex>();
verify_program<test_gemm_half>();
// verify_program<test_gemm_ld>();
verify_program<test_gemm_transposeb>();
verify_program<test_gemm_transposeb_ex>();
verify_program<test_gemm_transposea>();
verify_program<test_gemm_transposea_ex>();
verify_program<test_gemm_transposeab>();
verify_program<gemm_mutli_dim_2>();
verify_program<gemm_mutli_dim_2_3>();
verify_program<gemm_mutli_3args>();
verify_program<gemm_mutli_3args_beta0>();
verify_program<gemm_mutli_3args_alpha0>();
verify_program<test_contiguous>();
verify_program<test_eliminate_contiguous>();
verify_program<test_transpose>();
verify_program<test_batchnorm_inference>();
verify_program<test_batchnorm_inference_2>();
verify_program<test_conv_bn>();
verify_program<test_conv_bn_relu_pooling>();
verify_program<test_conv_bn_relu_pooling2>();
verify_program<test_slice>();
verify_program<test_gather>();
verify_program<test_gather_neg_axis>();
verify_program<test_gather_scalar_output>();
verify_program<test_gather_scalar_index>();
verify_program<test_gather_1d_index>();
verify_program<test_rnn_forward>();
verify_program<test_rnn_forward10>();
verify_program<test_rnn_reverse>();
verify_program<test_rnn_reverse2>();
verify_program<test_rnn_3args>();
verify_program<test_rnn_4args>();
verify_program<test_rnn_5args>();
verify_program<test_rnn_bidirectional>();
verify_program<test_rnn_bidirectional10>();
verify_program<test_rnn_bi_3args>();
verify_program<test_gru_forward_last>();
verify_program<test_gru_forward_hs>();
verify_program<test_gru_forward_3args_und>();
verify_program<test_gru_forward_3args>();
verify_program<test_gru_forward_seq1>();
verify_program<test_gru_forward_default_actv>();
verify_program<test_gru_forward_default_actv1>();
verify_program<test_gru_reverse_last>();
verify_program<test_gru_reverse_3args>();
verify_program<test_gru_bidirct_last>();
verify_program<test_gru_bidirct_hs>();
verify_program<test_gru_bidirct_3args_und>();
verify_program<test_gru_bidirct_3args>();
verify_program<test_gru_bidirct_seq1>();
verify_program<test_gru_bidirct_default_actv>();
verify_program<test_gru_bidirct_default_actv1>();
verify_program<test_lstm_forward_last>();
verify_program<test_lstm_forward_hs>();
verify_program<test_lstm_forward_3args_und>();
verify_program<test_lstm_forward_3args>();
verify_program<test_lstm_forward_seq1>();
verify_program<test_lstm_forward_default_actv>();
verify_program<test_lstm_forward_default_actv1>();
verify_program<test_lstm_reverse_last>();
verify_program<test_lstm_reverse_3args>();
verify_program<test_lstm_reverse_3args_cell_output>();
verify_program<test_lstm_bidirct_last>();
verify_program<test_lstm_bidirct_hs>();
verify_program<test_lstm_bidirct_3args_und>();
verify_program<test_lstm_bidirct_3args>();
verify_program<test_lstm_bidirct_seq1>();
verify_program<test_lstm_bidirct_default_actv>();
verify_program<test_lstm_bidirct_default_actv1>();
verify_program<test_lstm_bidirct_default_actv2>();
verify_program<test_logsoftmax<0>>();
verify_program<test_logsoftmax<1>>();
verify_program<test_logsoftmax<2>>();
verify_program<test_logsoftmax<3>>();
verify_program<test_logsoftmax<4>>();
verify_program<test_logsoftmax_1<0>>();
verify_program<test_logsoftmax_1<1>>();
}
...@@ -192,10 +192,10 @@ inline void add_test_case(std::string name, std::function<void()> f) ...@@ -192,10 +192,10 @@ inline void add_test_case(std::string name, std::function<void()> f)
get_test_cases().emplace_back(std::move(name), std::move(f)); get_test_cases().emplace_back(std::move(name), std::move(f));
} }
struct auto_register struct auto_register_test_case
{ {
template <class F> template <class F>
auto_register(const char* name, F f) noexcept auto_register_test_case(const char* name, F f) noexcept
{ {
add_test_case(name, f); add_test_case(name, f);
} }
...@@ -258,9 +258,9 @@ inline void run(int argc, const char* argv[]) ...@@ -258,9 +258,9 @@ inline void run(int argc, const char* argv[])
#define TEST_PRIMITIVE_CAT(x, ...) x##__VA_ARGS__ #define TEST_PRIMITIVE_CAT(x, ...) x##__VA_ARGS__
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define TEST_CASE_REGISTER(...) \ #define TEST_CASE_REGISTER(...) \
static test::auto_register TEST_CAT(register_test_case_, __LINE__) = \ static test::auto_register_test_case TEST_CAT(register_test_case_, __LINE__) = \
test::auto_register(#__VA_ARGS__, &__VA_ARGS__); test::auto_register_test_case(#__VA_ARGS__, &__VA_ARGS__);
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define TEST_CASE(...) \ #define TEST_CASE(...) \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment