Commit ab9a6dea authored by Paul's avatar Paul
Browse files

Use test driver to run verify programs

parent 84e7335e
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <future> #include <future>
#include <thread> #include <thread>
#include "test.hpp" #include <test.hpp>
#ifdef __clang__ #ifdef __clang__
#pragma clang diagnostic push #pragma clang diagnostic push
...@@ -127,7 +127,7 @@ migraphx::argument run_gpu(migraphx::program& p) ...@@ -127,7 +127,7 @@ migraphx::argument run_gpu(migraphx::program& p)
} }
template <class V> template <class V>
void verify_program() void run_verify_program()
{ {
auto_print::set_terminate_handler(migraphx::get_type_name<V>()); auto_print::set_terminate_handler(migraphx::get_type_name<V>());
// std::cout << migraphx::get_type_name<V>() << std::endl; // std::cout << migraphx::get_type_name<V>() << std::endl;
...@@ -149,7 +149,28 @@ void verify_program() ...@@ -149,7 +149,28 @@ void verify_program()
std::set_terminate(nullptr); std::set_terminate(nullptr);
} }
struct test_literals template<class T>
int auto_register_verify_program()
{
test::add_test_case(migraphx::get_type_name<T>(), [] { run_verify_program<T>(); });
return 0;
}
template<class T>
struct verify_program
{
static int static_register;
// This typedef ensures that the static member will be instantiated if
// the class itself is instantiated
using static_register_type = std::integral_constant<decltype(&static_register), &static_register>;
};
template<class T>
int verify_program<T>::static_register = auto_register_verify_program<T>();
struct test_literals : verify_program<test_literals>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -164,7 +185,7 @@ struct test_literals ...@@ -164,7 +185,7 @@ struct test_literals
} }
}; };
struct test_add struct test_add : verify_program<test_add>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -177,7 +198,7 @@ struct test_add ...@@ -177,7 +198,7 @@ struct test_add
} }
}; };
struct test_add_half struct test_add_half : verify_program<test_add_half>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -190,7 +211,7 @@ struct test_add_half ...@@ -190,7 +211,7 @@ struct test_add_half
} }
}; };
struct test_mul struct test_mul : verify_program<test_mul>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -203,7 +224,7 @@ struct test_mul ...@@ -203,7 +224,7 @@ struct test_mul
} }
}; };
struct test_sin struct test_sin : verify_program<test_sin>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -215,7 +236,7 @@ struct test_sin ...@@ -215,7 +236,7 @@ struct test_sin
} }
}; };
struct test_scale struct test_scale : verify_program<test_scale>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -229,7 +250,7 @@ struct test_scale ...@@ -229,7 +250,7 @@ struct test_scale
} }
}; };
struct test_slice struct test_slice : verify_program<test_slice>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -244,7 +265,7 @@ struct test_slice ...@@ -244,7 +265,7 @@ struct test_slice
} }
}; };
struct test_triadd struct test_triadd : verify_program<test_triadd>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -259,7 +280,7 @@ struct test_triadd ...@@ -259,7 +280,7 @@ struct test_triadd
} }
}; };
struct test_triadd2 struct test_triadd2 : verify_program<test_triadd2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -276,7 +297,7 @@ struct test_triadd2 ...@@ -276,7 +297,7 @@ struct test_triadd2
} }
}; };
struct test_add_broadcast struct test_add_broadcast : verify_program<test_add_broadcast>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -290,7 +311,7 @@ struct test_add_broadcast ...@@ -290,7 +311,7 @@ struct test_add_broadcast
} }
}; };
struct test_add_broadcast2 struct test_add_broadcast2 : verify_program<test_add_broadcast2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -304,7 +325,7 @@ struct test_add_broadcast2 ...@@ -304,7 +325,7 @@ struct test_add_broadcast2
} }
}; };
struct test_add_broadcast3 struct test_add_broadcast3 : verify_program<test_add_broadcast3>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -318,7 +339,7 @@ struct test_add_broadcast3 ...@@ -318,7 +339,7 @@ struct test_add_broadcast3
} }
}; };
struct test_add_broadcast4 struct test_add_broadcast4 : verify_program<test_add_broadcast4>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -332,7 +353,7 @@ struct test_add_broadcast4 ...@@ -332,7 +353,7 @@ struct test_add_broadcast4
} }
}; };
struct test_add_broadcast5 struct test_add_broadcast5 : verify_program<test_add_broadcast5>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -346,7 +367,7 @@ struct test_add_broadcast5 ...@@ -346,7 +367,7 @@ struct test_add_broadcast5
} }
}; };
struct test_triadd_broadcast struct test_triadd_broadcast : verify_program<test_triadd_broadcast>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -362,7 +383,7 @@ struct test_triadd_broadcast ...@@ -362,7 +383,7 @@ struct test_triadd_broadcast
} }
}; };
struct test_softmax struct test_softmax : verify_program<test_softmax>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -373,7 +394,7 @@ struct test_softmax ...@@ -373,7 +394,7 @@ struct test_softmax
} }
}; };
struct test_softmax2 struct test_softmax2 : verify_program<test_softmax2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -385,7 +406,7 @@ struct test_softmax2 ...@@ -385,7 +406,7 @@ struct test_softmax2
} }
}; };
struct test_conv struct test_conv : verify_program<test_conv>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -399,7 +420,7 @@ struct test_conv ...@@ -399,7 +420,7 @@ struct test_conv
} }
}; };
struct test_conv2 struct test_conv2 : verify_program<test_conv2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -413,7 +434,7 @@ struct test_conv2 ...@@ -413,7 +434,7 @@ struct test_conv2
} }
}; };
struct test_conv_relu struct test_conv_relu : verify_program<test_conv_relu>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -428,7 +449,7 @@ struct test_conv_relu ...@@ -428,7 +449,7 @@ struct test_conv_relu
} }
}; };
struct test_conv_relu_half struct test_conv_relu_half : verify_program<test_conv_relu_half>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -443,7 +464,7 @@ struct test_conv_relu_half ...@@ -443,7 +464,7 @@ struct test_conv_relu_half
} }
}; };
struct test_add_relu struct test_add_relu : verify_program<test_add_relu>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -456,7 +477,7 @@ struct test_add_relu ...@@ -456,7 +477,7 @@ struct test_add_relu
} }
}; };
struct test_leaky_relu struct test_leaky_relu : verify_program<test_leaky_relu>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -467,7 +488,7 @@ struct test_leaky_relu ...@@ -467,7 +488,7 @@ struct test_leaky_relu
} }
}; };
struct test_conv_pooling struct test_conv_pooling : verify_program<test_conv_pooling>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -483,7 +504,7 @@ struct test_conv_pooling ...@@ -483,7 +504,7 @@ struct test_conv_pooling
} }
}; };
struct test_global_avg_pooling struct test_global_avg_pooling : verify_program<test_global_avg_pooling>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -498,7 +519,7 @@ struct test_global_avg_pooling ...@@ -498,7 +519,7 @@ struct test_global_avg_pooling
} }
}; };
struct test_global_max_pooling struct test_global_max_pooling : verify_program<test_global_max_pooling>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -513,7 +534,7 @@ struct test_global_max_pooling ...@@ -513,7 +534,7 @@ struct test_global_max_pooling
} }
}; };
struct test_gemm struct test_gemm : verify_program<test_gemm>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -525,7 +546,7 @@ struct test_gemm ...@@ -525,7 +546,7 @@ struct test_gemm
} }
}; };
struct test_gemm_half struct test_gemm_half : verify_program<test_gemm_half>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -537,7 +558,7 @@ struct test_gemm_half ...@@ -537,7 +558,7 @@ struct test_gemm_half
} }
}; };
struct test_gemm_ld struct test_gemm_ld // : verify_program<test_gemm_ld>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -551,7 +572,7 @@ struct test_gemm_ld ...@@ -551,7 +572,7 @@ struct test_gemm_ld
} }
}; };
struct test_gemm_transposeb struct test_gemm_transposeb : verify_program<test_gemm_transposeb>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -564,7 +585,7 @@ struct test_gemm_transposeb ...@@ -564,7 +585,7 @@ struct test_gemm_transposeb
} }
}; };
struct test_gemm_transposea struct test_gemm_transposea : verify_program<test_gemm_transposea>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -577,7 +598,7 @@ struct test_gemm_transposea ...@@ -577,7 +598,7 @@ struct test_gemm_transposea
} }
}; };
struct test_gemm_transposeab struct test_gemm_transposeab : verify_program<test_gemm_transposeab>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -591,7 +612,7 @@ struct test_gemm_transposeab ...@@ -591,7 +612,7 @@ struct test_gemm_transposeab
} }
}; };
struct test_contiguous struct test_contiguous : verify_program<test_contiguous>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -604,7 +625,7 @@ struct test_contiguous ...@@ -604,7 +625,7 @@ struct test_contiguous
} }
}; };
struct test_transpose struct test_transpose : verify_program<test_transpose>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -618,7 +639,7 @@ struct test_transpose ...@@ -618,7 +639,7 @@ struct test_transpose
} }
}; };
struct test_batchnorm_inference_2 struct test_batchnorm_inference_2 : verify_program<test_batchnorm_inference_2>
{ {
const size_t width = 14; const size_t width = 14;
const size_t height = 14; const size_t height = 14;
...@@ -641,7 +662,7 @@ struct test_batchnorm_inference_2 ...@@ -641,7 +662,7 @@ struct test_batchnorm_inference_2
} }
}; };
struct test_batchnorm_inference struct test_batchnorm_inference : verify_program<test_batchnorm_inference>
{ {
const size_t width = 3; const size_t width = 3;
const size_t height = 3; const size_t height = 3;
...@@ -664,7 +685,7 @@ struct test_batchnorm_inference ...@@ -664,7 +685,7 @@ struct test_batchnorm_inference
} }
}; };
struct test_conv_bn struct test_conv_bn : verify_program<test_conv_bn>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -685,7 +706,7 @@ struct test_conv_bn ...@@ -685,7 +706,7 @@ struct test_conv_bn
} }
}; };
struct test_conv_bn_relu_pooling struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -709,7 +730,7 @@ struct test_conv_bn_relu_pooling ...@@ -709,7 +730,7 @@ struct test_conv_bn_relu_pooling
} }
}; };
struct test_concat struct test_concat : verify_program<test_concat>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -726,7 +747,7 @@ struct test_concat ...@@ -726,7 +747,7 @@ struct test_concat
} }
}; };
struct test_concat2 struct test_concat2 : verify_program<test_concat2>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -743,7 +764,7 @@ struct test_concat2 ...@@ -743,7 +764,7 @@ struct test_concat2
} }
}; };
struct test_concat_relu struct test_concat_relu : verify_program<test_concat_relu>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -810,7 +831,7 @@ void manual_test_concat_relu() ...@@ -810,7 +831,7 @@ void manual_test_concat_relu()
std::cout << result << std::endl; std::cout << result << std::endl;
} }
struct test_conv_bn_relu_pooling2 struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2>
{ {
static migraphx::instruction_ref static migraphx::instruction_ref
add_bn(migraphx::program& p, migraphx::instruction_ref x, std::size_t channels) add_bn(migraphx::program& p, migraphx::instruction_ref x, std::size_t channels)
...@@ -847,47 +868,4 @@ struct test_conv_bn_relu_pooling2 ...@@ -847,47 +868,4 @@ struct test_conv_bn_relu_pooling2
} }
}; };
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
verify_program<test_concat>();
verify_program<test_concat2>();
verify_program<test_concat_relu>();
verify_program<test_add>();
verify_program<test_add_half>();
verify_program<test_mul>();
verify_program<test_sin>();
verify_program<test_scale>();
verify_program<test_triadd>();
verify_program<test_triadd2>();
verify_program<test_add_broadcast>();
verify_program<test_add_broadcast2>();
verify_program<test_add_broadcast3>();
verify_program<test_add_broadcast4>();
verify_program<test_add_broadcast5>();
verify_program<test_triadd_broadcast>();
verify_program<test_softmax>();
verify_program<test_softmax2>();
verify_program<test_conv>();
verify_program<test_conv2>();
verify_program<test_conv_relu>();
verify_program<test_conv_relu_half>();
verify_program<test_add_relu>();
verify_program<test_leaky_relu>();
verify_program<test_conv_pooling>();
verify_program<test_global_avg_pooling>();
verify_program<test_global_max_pooling>();
verify_program<test_gemm>();
verify_program<test_gemm_half>();
// verify_program<test_gemm_ld>();
verify_program<test_gemm_transposeb>();
verify_program<test_gemm_transposea>();
verify_program<test_gemm_transposeab>();
verify_program<test_contiguous>();
verify_program<test_transpose>();
verify_program<test_batchnorm_inference>();
verify_program<test_batchnorm_inference_2>();
verify_program<test_conv_bn>();
verify_program<test_conv_bn_relu_pooling>();
verify_program<test_conv_bn_relu_pooling2>();
verify_program<test_slice>();
}
...@@ -192,10 +192,10 @@ inline void add_test_case(std::string name, std::function<void()> f) ...@@ -192,10 +192,10 @@ inline void add_test_case(std::string name, std::function<void()> f)
get_test_cases().emplace_back(name, f); get_test_cases().emplace_back(name, f);
} }
struct auto_register struct auto_register_test_case
{ {
template <class F> template <class F>
auto_register(const char* name, F f) noexcept auto_register_test_case(const char* name, F f) noexcept
{ {
add_test_case(name, f); add_test_case(name, f);
} }
...@@ -252,8 +252,8 @@ inline void run(int argc, const char* argv[]) ...@@ -252,8 +252,8 @@ inline void run(int argc, const char* argv[])
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define TEST_CASE_REGISTER(...) \ #define TEST_CASE_REGISTER(...) \
static test::auto_register TEST_CAT(register_test_case_, __LINE__) = \ static test::auto_register_test_case TEST_CAT(register_test_case_, __LINE__) = \
test::auto_register(#__VA_ARGS__, &__VA_ARGS__); test::auto_register_test_case(#__VA_ARGS__, &__VA_ARGS__);
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define TEST_CASE(...) \ #define TEST_CASE(...) \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment