Commit 16fc0314 authored by Khalique's avatar Khalique
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into multibcast_check

parents 39d4398f 3499ec7d
......@@ -987,15 +987,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph);
pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphi);
pphi_brcst = prog.insert_instruction(ins, op::contiguous{}, pphi_brcst);
auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph);
ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, ppho);
ppho_brcst = prog.insert_instruction(ins, op::contiguous{}, ppho_brcst);
auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph);
pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphf);
pphf_brcst = prog.insert_instruction(ins, op::contiguous{}, pphf_brcst);
}
for(long i = 0; i < seq_len; ++i)
......
......@@ -14,8 +14,9 @@ std::string target::name() const { return "cpu"; }
std::vector<pass> target::get_passes(migraphx::context&) const
{
return {auto_contiguous{},
rewrite_rnn{},
return {rewrite_rnn{},
dead_code_elimination{},
auto_contiguous{},
dead_code_elimination{},
lowering{},
dead_code_elimination{}};
......
......@@ -16,7 +16,7 @@
#include <future>
#include <thread>
#include "test.hpp"
#include <test.hpp>
#ifdef __clang__
#pragma clang diagnostic push
......@@ -134,7 +134,7 @@ migraphx::argument run_gpu(migraphx::program& p)
}
template <class V>
void verify_program()
void run_verify_program()
{
auto_print::set_terminate_handler(migraphx::get_type_name<V>());
// std::cout << migraphx::get_type_name<V>() << std::endl;
......@@ -156,7 +156,27 @@ void verify_program()
std::set_terminate(nullptr);
}
struct test_literals
template <class T>
int auto_register_verify_program()
{
test::add_test_case(migraphx::get_type_name<T>(), [] { run_verify_program<T>(); });
return 0;
}
template <class T>
struct verify_program
{
static int static_register;
// This typedef ensures that the static member will be instantiated if
// the class itself is instantiated
using static_register_type =
std::integral_constant<decltype(&static_register), &static_register>;
};
template <class T>
int verify_program<T>::static_register = auto_register_verify_program<T>(); // NOLINT
struct test_literals : verify_program<test_literals>
{
migraphx::program create_program() const
{
......@@ -171,7 +191,7 @@ struct test_literals
}
};
struct test_add
struct test_add : verify_program<test_add>
{
migraphx::program create_program() const
{
......@@ -184,7 +204,7 @@ struct test_add
}
};
struct test_add_half
struct test_add_half : verify_program<test_add_half>
{
migraphx::program create_program() const
{
......@@ -197,7 +217,7 @@ struct test_add_half
}
};
struct test_mul
struct test_mul : verify_program<test_mul>
{
migraphx::program create_program() const
{
......@@ -210,7 +230,7 @@ struct test_mul
}
};
struct test_exp
struct test_exp : verify_program<test_exp>
{
migraphx::program create_program() const
{
......@@ -223,7 +243,7 @@ struct test_exp
}
};
struct test_log
struct test_log : verify_program<test_log>
{
migraphx::program create_program() const
{
......@@ -236,7 +256,7 @@ struct test_log
}
};
struct test_sin
struct test_sin : verify_program<test_sin>
{
migraphx::program create_program() const
{
......@@ -248,7 +268,7 @@ struct test_sin
}
};
struct test_cos
struct test_cos : verify_program<test_cos>
{
migraphx::program create_program() const
{
......@@ -260,7 +280,7 @@ struct test_cos
}
};
struct test_tan
struct test_tan : verify_program<test_tan>
{
migraphx::program create_program() const
{
......@@ -272,7 +292,7 @@ struct test_tan
}
};
struct test_sinh
struct test_sinh : verify_program<test_sinh>
{
migraphx::program create_program() const
{
......@@ -284,7 +304,7 @@ struct test_sinh
}
};
struct test_cosh
struct test_cosh : verify_program<test_cosh>
{
migraphx::program create_program() const
{
......@@ -296,7 +316,7 @@ struct test_cosh
}
};
struct test_tanh
struct test_tanh : verify_program<test_tanh>
{
migraphx::program create_program() const
{
......@@ -307,7 +327,7 @@ struct test_tanh
}
};
struct test_asin
struct test_asin : verify_program<test_asin>
{
migraphx::program create_program() const
{
......@@ -319,7 +339,7 @@ struct test_asin
}
};
struct test_acos
struct test_acos : verify_program<test_acos>
{
migraphx::program create_program() const
{
......@@ -331,7 +351,7 @@ struct test_acos
}
};
struct test_atan
struct test_atan : verify_program<test_atan>
{
migraphx::program create_program() const
{
......@@ -343,7 +363,7 @@ struct test_atan
}
};
struct test_scale
struct test_scale : verify_program<test_scale>
{
migraphx::program create_program() const
{
......@@ -357,7 +377,7 @@ struct test_scale
}
};
struct test_slice
struct test_slice : verify_program<test_slice>
{
migraphx::program create_program() const
{
......@@ -372,7 +392,7 @@ struct test_slice
}
};
struct test_triadd
struct test_triadd : verify_program<test_triadd>
{
migraphx::program create_program() const
{
......@@ -387,7 +407,7 @@ struct test_triadd
}
};
struct test_triadd2
struct test_triadd2 : verify_program<test_triadd2>
{
migraphx::program create_program() const
{
......@@ -404,7 +424,7 @@ struct test_triadd2
}
};
struct test_add_broadcast
struct test_add_broadcast : verify_program<test_add_broadcast>
{
migraphx::program create_program() const
{
......@@ -418,7 +438,7 @@ struct test_add_broadcast
}
};
struct test_add_broadcast2
struct test_add_broadcast2 : verify_program<test_add_broadcast2>
{
migraphx::program create_program() const
{
......@@ -432,7 +452,7 @@ struct test_add_broadcast2
}
};
struct test_add_broadcast3
struct test_add_broadcast3 : verify_program<test_add_broadcast3>
{
migraphx::program create_program() const
{
......@@ -446,7 +466,7 @@ struct test_add_broadcast3
}
};
struct test_add_broadcast4
struct test_add_broadcast4 : verify_program<test_add_broadcast4>
{
migraphx::program create_program() const
{
......@@ -460,7 +480,7 @@ struct test_add_broadcast4
}
};
struct test_add_broadcast5
struct test_add_broadcast5 : verify_program<test_add_broadcast5>
{
migraphx::program create_program() const
{
......@@ -474,7 +494,7 @@ struct test_add_broadcast5
}
};
struct test_triadd_broadcast
struct test_triadd_broadcast : verify_program<test_triadd_broadcast>
{
migraphx::program create_program() const
{
......@@ -490,7 +510,7 @@ struct test_triadd_broadcast
}
};
struct test_sub
struct test_sub : verify_program<test_sub>
{
migraphx::program create_program() const
{
......@@ -505,7 +525,7 @@ struct test_sub
}
};
struct test_sub2
struct test_sub2 : verify_program<test_sub2>
{
migraphx::program create_program() const
{
......@@ -522,7 +542,7 @@ struct test_sub2
}
};
struct test_softmax
struct test_softmax : verify_program<test_softmax>
{
migraphx::program create_program() const
{
......@@ -533,7 +553,7 @@ struct test_softmax
}
};
struct test_softmax2
struct test_softmax2 : verify_program<test_softmax2>
{
migraphx::program create_program() const
{
......@@ -545,7 +565,7 @@ struct test_softmax2
}
};
struct test_conv
struct test_conv : verify_program<test_conv>
{
migraphx::program create_program() const
{
......@@ -559,7 +579,7 @@ struct test_conv
}
};
struct test_conv2
struct test_conv2 : verify_program<test_conv2>
{
migraphx::program create_program() const
{
......@@ -573,7 +593,7 @@ struct test_conv2
}
};
struct test_group_conv
struct test_group_conv : verify_program<test_group_conv>
{
migraphx::program create_program() const
{
......@@ -589,7 +609,7 @@ struct test_group_conv
}
};
struct test_conv_relu
struct test_conv_relu : verify_program<test_conv_relu>
{
migraphx::program create_program() const
{
......@@ -604,7 +624,7 @@ struct test_conv_relu
}
};
struct test_conv_relu_half
struct test_conv_relu_half : verify_program<test_conv_relu_half>
{
migraphx::program create_program() const
{
......@@ -619,7 +639,7 @@ struct test_conv_relu_half
}
};
struct test_add_relu
struct test_add_relu : verify_program<test_add_relu>
{
migraphx::program create_program() const
{
......@@ -632,7 +652,7 @@ struct test_add_relu
}
};
struct test_sigmoid
struct test_sigmoid : verify_program<test_sigmoid>
{
migraphx::program create_program() const
{
......@@ -643,7 +663,7 @@ struct test_sigmoid
}
};
struct test_abs
struct test_abs : verify_program<test_abs>
{
migraphx::program create_program() const
{
......@@ -654,7 +674,7 @@ struct test_abs
}
};
struct test_leaky_relu
struct test_leaky_relu : verify_program<test_leaky_relu>
{
migraphx::program create_program() const
{
......@@ -665,7 +685,7 @@ struct test_leaky_relu
}
};
struct test_elu
struct test_elu : verify_program<test_elu>
{
migraphx::program create_program() const
{
......@@ -676,7 +696,7 @@ struct test_elu
}
};
struct test_relu_lrn
struct test_relu_lrn : verify_program<test_relu_lrn>
{
migraphx::program create_program() const
{
......@@ -688,7 +708,7 @@ struct test_relu_lrn
}
};
struct test_conv_pooling
struct test_conv_pooling : verify_program<test_conv_pooling>
{
migraphx::program create_program() const
{
......@@ -704,7 +724,7 @@ struct test_conv_pooling
}
};
struct test_global_avg_pooling
struct test_global_avg_pooling : verify_program<test_global_avg_pooling>
{
migraphx::program create_program() const
{
......@@ -719,7 +739,7 @@ struct test_global_avg_pooling
}
};
struct test_global_max_pooling
struct test_global_max_pooling : verify_program<test_global_max_pooling>
{
migraphx::program create_program() const
{
......@@ -734,7 +754,7 @@ struct test_global_max_pooling
}
};
struct test_gemm
struct test_gemm : verify_program<test_gemm>
{
migraphx::program create_program() const
{
......@@ -746,7 +766,7 @@ struct test_gemm
}
};
struct test_gemm_ex
struct test_gemm_ex : verify_program<test_gemm_ex>
{
migraphx::program create_program() const
{
......@@ -758,7 +778,7 @@ struct test_gemm_ex
}
};
struct test_gemm_half
struct test_gemm_half : verify_program<test_gemm_half>
{
migraphx::program create_program() const
{
......@@ -770,7 +790,7 @@ struct test_gemm_half
}
};
struct test_gemm_ld
struct test_gemm_ld //: verify_program<test_gemm_ld>
{
migraphx::program create_program() const
{
......@@ -784,7 +804,7 @@ struct test_gemm_ld
}
};
struct test_gemm_transposeb
struct test_gemm_transposeb : verify_program<test_gemm_transposeb>
{
migraphx::program create_program() const
{
......@@ -797,7 +817,7 @@ struct test_gemm_transposeb
}
};
struct test_gemm_transposeb_ex
struct test_gemm_transposeb_ex : verify_program<test_gemm_transposeb_ex>
{
migraphx::program create_program() const
{
......@@ -810,7 +830,7 @@ struct test_gemm_transposeb_ex
}
};
struct test_gemm_transposea
struct test_gemm_transposea : verify_program<test_gemm_transposea>
{
migraphx::program create_program() const
{
......@@ -823,7 +843,7 @@ struct test_gemm_transposea
}
};
struct test_gemm_transposea_ex
struct test_gemm_transposea_ex : verify_program<test_gemm_transposea_ex>
{
migraphx::program create_program() const
{
......@@ -836,7 +856,7 @@ struct test_gemm_transposea_ex
}
};
struct test_gemm_transposeab
struct test_gemm_transposeab : verify_program<test_gemm_transposeab>
{
migraphx::program create_program() const
{
......@@ -882,7 +902,7 @@ struct gemm_mutli_dim_2_3
}
};
struct test_contiguous
struct test_contiguous : verify_program<test_contiguous>
{
migraphx::program create_program() const
{
......@@ -895,7 +915,7 @@ struct test_contiguous
}
};
struct test_eliminate_contiguous
struct test_eliminate_contiguous : verify_program<test_eliminate_contiguous>
{
migraphx::program create_program() const
{
......@@ -911,7 +931,7 @@ struct test_eliminate_contiguous
}
};
struct test_transpose
struct test_transpose : verify_program<test_transpose>
{
migraphx::program create_program() const
{
......@@ -925,7 +945,7 @@ struct test_transpose
}
};
struct test_batchnorm_inference_2
struct test_batchnorm_inference_2 : verify_program<test_batchnorm_inference_2>
{
const size_t width = 14;
const size_t height = 14;
......@@ -948,7 +968,7 @@ struct test_batchnorm_inference_2
}
};
struct test_batchnorm_inference
struct test_batchnorm_inference : verify_program<test_batchnorm_inference>
{
const size_t width = 3;
const size_t height = 3;
......@@ -971,7 +991,7 @@ struct test_batchnorm_inference
}
};
struct test_conv_bn
struct test_conv_bn : verify_program<test_conv_bn>
{
migraphx::program create_program() const
{
......@@ -992,7 +1012,7 @@ struct test_conv_bn
}
};
struct test_conv_bn_relu_pooling
struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
{
migraphx::program create_program() const
{
......@@ -1016,7 +1036,7 @@ struct test_conv_bn_relu_pooling
}
};
struct test_concat
struct test_concat : verify_program<test_concat>
{
migraphx::program create_program() const
{
......@@ -1033,7 +1053,7 @@ struct test_concat
}
};
struct test_concat2
struct test_concat2 : verify_program<test_concat2>
{
migraphx::program create_program() const
{
......@@ -1050,7 +1070,7 @@ struct test_concat2
}
};
struct test_concat_relu
struct test_concat_relu : verify_program<test_concat_relu>
{
migraphx::program create_program() const
{
......@@ -1071,7 +1091,7 @@ struct test_concat_relu
}
};
struct test_pad
struct test_pad : verify_program<test_pad>
{
migraphx::program create_program() const
{
......@@ -1090,7 +1110,7 @@ struct test_pad
}
};
struct test_pooling_autopad
struct test_pooling_autopad : verify_program<test_pooling_autopad>
{
migraphx::program create_program() const
{
......@@ -1106,7 +1126,7 @@ struct test_pooling_autopad
}
};
struct test_gather
struct test_gather : verify_program<test_gather>
{
migraphx::program create_program() const
{
......@@ -1122,7 +1142,7 @@ struct test_gather
}
};
struct test_gather_neg_axis
struct test_gather_neg_axis : verify_program<test_gather_neg_axis>
{
migraphx::program create_program() const
{
......@@ -1138,7 +1158,7 @@ struct test_gather_neg_axis
}
};
struct test_gather_scalar_output
struct test_gather_scalar_output : verify_program<test_gather_scalar_output>
{
migraphx::program create_program() const
{
......@@ -1154,7 +1174,7 @@ struct test_gather_scalar_output
}
};
struct test_gather_scalar_index
struct test_gather_scalar_index : verify_program<test_gather_scalar_index>
{
migraphx::program create_program() const
{
......@@ -1170,7 +1190,7 @@ struct test_gather_scalar_index
}
};
struct test_gather_1d_index
struct test_gather_1d_index : verify_program<test_gather_1d_index>
{
migraphx::program create_program() const
{
......@@ -1232,7 +1252,7 @@ void manual_test_concat_relu()
std::cout << result << std::endl;
}
struct test_conv_bn_relu_pooling2
struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2>
{
static migraphx::instruction_ref
add_bn(migraphx::program& p, migraphx::instruction_ref x, std::size_t channels)
......@@ -1269,7 +1289,7 @@ struct test_conv_bn_relu_pooling2
}
};
struct test_rnn_forward
struct test_rnn_forward : verify_program<test_rnn_forward>
{
migraphx::program create_program() const
{
......@@ -1311,7 +1331,7 @@ struct test_rnn_forward
}
};
struct test_rnn_forward10
struct test_rnn_forward10 : verify_program<test_rnn_forward10>
{
migraphx::program create_program() const
{
......@@ -1353,7 +1373,7 @@ struct test_rnn_forward10
}
};
struct test_rnn_reverse
struct test_rnn_reverse : verify_program<test_rnn_reverse>
{
migraphx::program create_program() const
{
......@@ -1393,7 +1413,7 @@ struct test_rnn_reverse
}
};
struct test_rnn_reverse2
struct test_rnn_reverse2 : verify_program<test_rnn_reverse2>
{
migraphx::program create_program() const
{
......@@ -1433,7 +1453,7 @@ struct test_rnn_reverse2
}
};
struct test_rnn_3args
struct test_rnn_3args : verify_program<test_rnn_3args>
{
migraphx::program create_program() const
{
......@@ -1465,7 +1485,7 @@ struct test_rnn_3args
}
};
struct test_rnn_4args
struct test_rnn_4args : verify_program<test_rnn_4args>
{
migraphx::program create_program() const
{
......@@ -1500,7 +1520,7 @@ struct test_rnn_4args
}
};
struct test_rnn_5args
struct test_rnn_5args : verify_program<test_rnn_5args>
{
migraphx::program create_program() const
{
......@@ -1539,7 +1559,7 @@ struct test_rnn_5args
}
};
struct test_rnn_bidirectional
struct test_rnn_bidirectional : verify_program<test_rnn_bidirectional>
{
migraphx::program create_program() const
{
......@@ -1581,7 +1601,7 @@ struct test_rnn_bidirectional
}
};
struct test_rnn_bidirectional10
struct test_rnn_bidirectional10 : verify_program<test_rnn_bidirectional10>
{
migraphx::program create_program() const
{
......@@ -1622,7 +1642,7 @@ struct test_rnn_bidirectional10
}
};
struct test_rnn_bi_3args
struct test_rnn_bi_3args : verify_program<test_rnn_bi_3args>
{
migraphx::program create_program() const
{
......@@ -1657,7 +1677,7 @@ struct test_rnn_bi_3args
}
};
struct test_gru_forward_last
struct test_gru_forward_last : verify_program<test_gru_forward_last>
{
migraphx::program create_program() const
{
......@@ -1701,7 +1721,7 @@ struct test_gru_forward_last
}
};
struct test_gru_forward_hs
struct test_gru_forward_hs : verify_program<test_gru_forward_hs>
{
migraphx::program create_program() const
{
......@@ -1743,7 +1763,7 @@ struct test_gru_forward_hs
}
};
struct test_gru_forward_3args_und
struct test_gru_forward_3args_und : verify_program<test_gru_forward_3args_und>
{
migraphx::program create_program() const
{
......@@ -1779,7 +1799,7 @@ struct test_gru_forward_3args_und
}
};
struct test_gru_forward_3args
struct test_gru_forward_3args : verify_program<test_gru_forward_3args>
{
migraphx::program create_program() const
{
......@@ -1811,7 +1831,7 @@ struct test_gru_forward_3args
}
};
struct test_gru_forward_seq1
struct test_gru_forward_seq1 : verify_program<test_gru_forward_seq1>
{
migraphx::program create_program() const
{
......@@ -1843,7 +1863,7 @@ struct test_gru_forward_seq1
}
};
struct test_gru_forward_default_actv
struct test_gru_forward_default_actv : verify_program<test_gru_forward_default_actv>
{
migraphx::program create_program() const
{
......@@ -1873,7 +1893,7 @@ struct test_gru_forward_default_actv
}
};
struct test_gru_forward_default_actv1
struct test_gru_forward_default_actv1 : verify_program<test_gru_forward_default_actv1>
{
migraphx::program create_program() const
{
......@@ -1914,7 +1934,7 @@ struct test_gru_forward_default_actv1
}
};
struct test_gru_reverse_last
struct test_gru_reverse_last : verify_program<test_gru_reverse_last>
{
migraphx::program create_program() const
{
......@@ -1958,7 +1978,7 @@ struct test_gru_reverse_last
}
};
struct test_gru_reverse_3args
struct test_gru_reverse_3args : verify_program<test_gru_reverse_3args>
{
migraphx::program create_program() const
{
......@@ -1990,7 +2010,7 @@ struct test_gru_reverse_3args
}
};
struct test_gru_bidirct_last
struct test_gru_bidirct_last : verify_program<test_gru_bidirct_last>
{
migraphx::program create_program() const
{
......@@ -2034,7 +2054,7 @@ struct test_gru_bidirct_last
}
};
struct test_gru_bidirct_hs
struct test_gru_bidirct_hs : verify_program<test_gru_bidirct_hs>
{
migraphx::program create_program() const
{
......@@ -2076,7 +2096,7 @@ struct test_gru_bidirct_hs
}
};
struct test_gru_bidirct_3args_und
struct test_gru_bidirct_3args_und : verify_program<test_gru_bidirct_3args_und>
{
migraphx::program create_program() const
{
......@@ -2112,7 +2132,7 @@ struct test_gru_bidirct_3args_und
}
};
struct test_gru_bidirct_3args
struct test_gru_bidirct_3args : verify_program<test_gru_bidirct_3args>
{
migraphx::program create_program() const
{
......@@ -2144,7 +2164,7 @@ struct test_gru_bidirct_3args
}
};
struct test_gru_bidirct_seq1
struct test_gru_bidirct_seq1 : verify_program<test_gru_bidirct_seq1>
{
migraphx::program create_program() const
{
......@@ -2176,7 +2196,7 @@ struct test_gru_bidirct_seq1
}
};
struct test_gru_bidirct_default_actv
struct test_gru_bidirct_default_actv : verify_program<test_gru_bidirct_default_actv>
{
migraphx::program create_program() const
{
......@@ -2206,7 +2226,7 @@ struct test_gru_bidirct_default_actv
}
};
struct test_gru_bidirct_default_actv1
struct test_gru_bidirct_default_actv1 : verify_program<test_gru_bidirct_default_actv1>
{
migraphx::program create_program() const
{
......@@ -2248,7 +2268,7 @@ struct test_gru_bidirct_default_actv1
}
};
struct test_lstm_forward_last
struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
{
migraphx::program create_program() const
{
......@@ -2298,7 +2318,7 @@ struct test_lstm_forward_last
}
};
struct test_lstm_forward_hs
struct test_lstm_forward_hs : verify_program<test_lstm_forward_hs>
{
migraphx::program create_program() const
{
......@@ -2348,7 +2368,7 @@ struct test_lstm_forward_hs
}
};
struct test_lstm_forward_3args_und
struct test_lstm_forward_3args_und : verify_program<test_lstm_forward_3args_und>
{
migraphx::program create_program() const
{
......@@ -2388,7 +2408,7 @@ struct test_lstm_forward_3args_und
}
};
struct test_lstm_forward_3args
struct test_lstm_forward_3args : verify_program<test_lstm_forward_3args>
{
migraphx::program create_program() const
{
......@@ -2422,7 +2442,7 @@ struct test_lstm_forward_3args
}
};
struct test_lstm_forward_seq1
struct test_lstm_forward_seq1 : verify_program<test_lstm_forward_seq1>
{
migraphx::program create_program() const
{
......@@ -2456,7 +2476,7 @@ struct test_lstm_forward_seq1
}
};
struct test_lstm_forward_default_actv
struct test_lstm_forward_default_actv : verify_program<test_lstm_forward_default_actv>
{
migraphx::program create_program() const
{
......@@ -2486,7 +2506,7 @@ struct test_lstm_forward_default_actv
}
};
struct test_lstm_forward_default_actv1
struct test_lstm_forward_default_actv1 : verify_program<test_lstm_forward_default_actv1>
{
migraphx::program create_program() const
{
......@@ -2527,7 +2547,7 @@ struct test_lstm_forward_default_actv1
}
};
struct test_lstm_reverse_last
struct test_lstm_reverse_last : verify_program<test_lstm_reverse_last>
{
migraphx::program create_program() const
{
......@@ -2578,7 +2598,7 @@ struct test_lstm_reverse_last
}
};
struct test_lstm_reverse_3args
struct test_lstm_reverse_3args : verify_program<test_lstm_reverse_3args>
{
migraphx::program create_program() const
{
......@@ -2612,7 +2632,7 @@ struct test_lstm_reverse_3args
}
};
struct test_lstm_reverse_3args_cell_output
struct test_lstm_reverse_3args_cell_output : verify_program<test_lstm_reverse_3args_cell_output>
{
migraphx::program create_program() const
{
......@@ -2647,7 +2667,7 @@ struct test_lstm_reverse_3args_cell_output
}
};
struct test_lstm_bidirct_last
struct test_lstm_bidirct_last : verify_program<test_lstm_bidirct_last>
{
migraphx::program create_program() const
{
......@@ -2698,7 +2718,7 @@ struct test_lstm_bidirct_last
}
};
struct test_lstm_bidirct_hs
struct test_lstm_bidirct_hs : verify_program<test_lstm_bidirct_hs>
{
migraphx::program create_program() const
{
......@@ -2740,7 +2760,7 @@ struct test_lstm_bidirct_hs
}
};
struct test_lstm_bidirct_3args_und
struct test_lstm_bidirct_3args_und : verify_program<test_lstm_bidirct_3args_und>
{
migraphx::program create_program() const
{
......@@ -2779,7 +2799,7 @@ struct test_lstm_bidirct_3args_und
}
};
struct test_lstm_bidirct_3args
struct test_lstm_bidirct_3args : verify_program<test_lstm_bidirct_3args>
{
migraphx::program create_program() const
{
......@@ -2811,7 +2831,7 @@ struct test_lstm_bidirct_3args
}
};
struct test_lstm_bidirct_seq1
struct test_lstm_bidirct_seq1 : verify_program<test_lstm_bidirct_seq1>
{
migraphx::program create_program() const
{
......@@ -2843,7 +2863,7 @@ struct test_lstm_bidirct_seq1
}
};
struct test_lstm_bidirct_default_actv
struct test_lstm_bidirct_default_actv : verify_program<test_lstm_bidirct_default_actv>
{
migraphx::program create_program() const
{
......@@ -2873,7 +2893,7 @@ struct test_lstm_bidirct_default_actv
}
};
struct test_lstm_bidirct_default_actv1
struct test_lstm_bidirct_default_actv1 : verify_program<test_lstm_bidirct_default_actv1>
{
migraphx::program create_program() const
{
......@@ -2915,7 +2935,7 @@ struct test_lstm_bidirct_default_actv1
}
};
struct test_lstm_bidirct_default_actv2
struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_default_actv2>
{
migraphx::program create_program() const
{
......@@ -2958,7 +2978,7 @@ struct test_lstm_bidirct_default_actv2
};
template <int Axis>
struct test_logsoftmax
struct test_logsoftmax : verify_program<test_logsoftmax<Axis>>
{
migraphx::program create_program() const
{
......@@ -2971,8 +2991,14 @@ struct test_logsoftmax
}
};
template struct test_logsoftmax<0>;
template struct test_logsoftmax<1>;
template struct test_logsoftmax<2>;
template struct test_logsoftmax<3>;
template struct test_logsoftmax<4>;
template <int Axis>
struct test_logsoftmax_1
struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
{
migraphx::program create_program() const
{
......@@ -2985,128 +3011,7 @@ struct test_logsoftmax_1
}
};
int main()
{
verify_program<test_relu_lrn>();
verify_program<test_pooling_autopad>();
verify_program<test_abs>();
verify_program<test_concat>();
verify_program<test_concat2>();
verify_program<test_concat_relu>();
verify_program<test_pad>();
verify_program<test_add>();
verify_program<test_add_half>();
verify_program<test_mul>();
verify_program<test_exp>();
verify_program<test_log>();
verify_program<test_sin>();
verify_program<test_cos>();
verify_program<test_tan>();
verify_program<test_sinh>();
verify_program<test_cosh>();
verify_program<test_tanh>();
verify_program<test_asin>();
verify_program<test_acos>();
verify_program<test_atan>();
verify_program<test_scale>();
verify_program<test_triadd>();
verify_program<test_triadd2>();
verify_program<test_add_broadcast>();
verify_program<test_add_broadcast2>();
verify_program<test_add_broadcast3>();
verify_program<test_add_broadcast4>();
verify_program<test_add_broadcast5>();
verify_program<test_triadd_broadcast>();
verify_program<test_sub>();
verify_program<test_sub2>();
verify_program<test_softmax>();
verify_program<test_softmax2>();
verify_program<test_conv>();
verify_program<test_conv2>();
verify_program<test_group_conv>();
verify_program<test_conv_relu>();
verify_program<test_conv_relu_half>();
verify_program<test_add_relu>();
verify_program<test_leaky_relu>();
verify_program<test_sigmoid>();
verify_program<test_elu>();
verify_program<test_conv_pooling>();
verify_program<test_global_avg_pooling>();
verify_program<test_global_max_pooling>();
verify_program<test_gemm>();
verify_program<test_gemm_ex>();
verify_program<test_gemm_half>();
// verify_program<test_gemm_ld>();
verify_program<test_gemm_transposeb>();
verify_program<test_gemm_transposeb_ex>();
verify_program<test_gemm_transposea>();
verify_program<test_gemm_transposea_ex>();
verify_program<test_gemm_transposeab>();
verify_program<gemm_mutli_dim_2>();
verify_program<gemm_mutli_dim_2_3>();
verify_program<test_contiguous>();
verify_program<test_eliminate_contiguous>();
verify_program<test_transpose>();
verify_program<test_batchnorm_inference>();
verify_program<test_batchnorm_inference_2>();
verify_program<test_conv_bn>();
verify_program<test_conv_bn_relu_pooling>();
verify_program<test_conv_bn_relu_pooling2>();
verify_program<test_slice>();
verify_program<test_gather>();
verify_program<test_gather_neg_axis>();
verify_program<test_gather_scalar_output>();
verify_program<test_gather_scalar_index>();
verify_program<test_gather_1d_index>();
verify_program<test_rnn_forward>();
verify_program<test_rnn_forward10>();
verify_program<test_rnn_reverse>();
verify_program<test_rnn_reverse2>();
verify_program<test_rnn_3args>();
verify_program<test_rnn_4args>();
verify_program<test_rnn_5args>();
verify_program<test_rnn_bidirectional>();
verify_program<test_rnn_bidirectional10>();
verify_program<test_rnn_bi_3args>();
verify_program<test_gru_forward_last>();
verify_program<test_gru_forward_hs>();
verify_program<test_gru_forward_3args_und>();
verify_program<test_gru_forward_3args>();
verify_program<test_gru_forward_seq1>();
verify_program<test_gru_forward_default_actv>();
verify_program<test_gru_forward_default_actv1>();
verify_program<test_gru_reverse_last>();
verify_program<test_gru_reverse_3args>();
verify_program<test_gru_bidirct_last>();
verify_program<test_gru_bidirct_hs>();
verify_program<test_gru_bidirct_3args_und>();
verify_program<test_gru_bidirct_3args>();
verify_program<test_gru_bidirct_seq1>();
verify_program<test_gru_bidirct_default_actv>();
verify_program<test_gru_bidirct_default_actv1>();
verify_program<test_lstm_forward_last>();
verify_program<test_lstm_forward_hs>();
verify_program<test_lstm_forward_3args_und>();
verify_program<test_lstm_forward_3args>();
verify_program<test_lstm_forward_seq1>();
verify_program<test_lstm_forward_default_actv>();
verify_program<test_lstm_forward_default_actv1>();
verify_program<test_lstm_reverse_last>();
verify_program<test_lstm_reverse_3args>();
verify_program<test_lstm_reverse_3args_cell_output>();
verify_program<test_lstm_bidirct_last>();
verify_program<test_lstm_bidirct_hs>();
verify_program<test_lstm_bidirct_3args_und>();
verify_program<test_lstm_bidirct_3args>();
verify_program<test_lstm_bidirct_seq1>();
verify_program<test_lstm_bidirct_default_actv>();
verify_program<test_lstm_bidirct_default_actv1>();
verify_program<test_lstm_bidirct_default_actv2>();
verify_program<test_logsoftmax<0>>();
verify_program<test_logsoftmax<1>>();
verify_program<test_logsoftmax<2>>();
verify_program<test_logsoftmax<3>>();
verify_program<test_logsoftmax<4>>();
verify_program<test_logsoftmax_1<0>>();
verify_program<test_logsoftmax_1<1>>();
}
template struct test_logsoftmax_1<0>;
template struct test_logsoftmax_1<1>;
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -192,10 +192,10 @@ inline void add_test_case(std::string name, std::function<void()> f)
get_test_cases().emplace_back(std::move(name), std::move(f));
}
struct auto_register
struct auto_register_test_case
{
template <class F>
auto_register(const char* name, F f) noexcept
auto_register_test_case(const char* name, F f) noexcept
{
add_test_case(name, f);
}
......@@ -259,8 +259,8 @@ inline void run(int argc, const char* argv[])
// NOLINTNEXTLINE
#define TEST_CASE_REGISTER(...) \
static test::auto_register TEST_CAT(register_test_case_, __LINE__) = \
test::auto_register(#__VA_ARGS__, &__VA_ARGS__);
static test::auto_register_test_case TEST_CAT(register_test_case_, __LINE__) = \
test::auto_register_test_case(#__VA_ARGS__, &__VA_ARGS__);
// NOLINTNEXTLINE
#define TEST_CASE(...) \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment