Commit ab9a6dea authored by Paul's avatar Paul
Browse files

Use test driver to run verify programs

parent 84e7335e
......@@ -16,7 +16,7 @@
#include <future>
#include <thread>
#include "test.hpp"
#include <test.hpp>
#ifdef __clang__
#pragma clang diagnostic push
......@@ -127,7 +127,7 @@ migraphx::argument run_gpu(migraphx::program& p)
}
template <class V>
void verify_program()
void run_verify_program()
{
auto_print::set_terminate_handler(migraphx::get_type_name<V>());
// std::cout << migraphx::get_type_name<V>() << std::endl;
......@@ -149,7 +149,28 @@ void verify_program()
std::set_terminate(nullptr);
}
struct test_literals
template<class T>
int auto_register_verify_program()
{
test::add_test_case(migraphx::get_type_name<T>(), [] { run_verify_program<T>(); });
return 0;
}
template<class T>
struct verify_program
{
static int static_register;
// This typedef ensures that the static member will be instantiated if
// the class itself is instantiated
using static_register_type = std::integral_constant<decltype(&static_register), &static_register>;
};
template<class T>
int verify_program<T>::static_register = auto_register_verify_program<T>();
struct test_literals : verify_program<test_literals>
{
migraphx::program create_program() const
{
......@@ -164,7 +185,7 @@ struct test_literals
}
};
struct test_add
struct test_add : verify_program<test_add>
{
migraphx::program create_program() const
{
......@@ -177,7 +198,7 @@ struct test_add
}
};
struct test_add_half
struct test_add_half : verify_program<test_add_half>
{
migraphx::program create_program() const
{
......@@ -190,7 +211,7 @@ struct test_add_half
}
};
struct test_mul
struct test_mul : verify_program<test_mul>
{
migraphx::program create_program() const
{
......@@ -203,7 +224,7 @@ struct test_mul
}
};
struct test_sin
struct test_sin : verify_program<test_sin>
{
migraphx::program create_program() const
{
......@@ -215,7 +236,7 @@ struct test_sin
}
};
struct test_scale
struct test_scale : verify_program<test_scale>
{
migraphx::program create_program() const
{
......@@ -229,7 +250,7 @@ struct test_scale
}
};
struct test_slice
struct test_slice : verify_program<test_slice>
{
migraphx::program create_program() const
{
......@@ -244,7 +265,7 @@ struct test_slice
}
};
struct test_triadd
struct test_triadd : verify_program<test_triadd>
{
migraphx::program create_program() const
{
......@@ -259,7 +280,7 @@ struct test_triadd
}
};
struct test_triadd2
struct test_triadd2 : verify_program<test_triadd2>
{
migraphx::program create_program() const
{
......@@ -276,7 +297,7 @@ struct test_triadd2
}
};
struct test_add_broadcast
struct test_add_broadcast : verify_program<test_add_broadcast>
{
migraphx::program create_program() const
{
......@@ -290,7 +311,7 @@ struct test_add_broadcast
}
};
struct test_add_broadcast2
struct test_add_broadcast2 : verify_program<test_add_broadcast2>
{
migraphx::program create_program() const
{
......@@ -304,7 +325,7 @@ struct test_add_broadcast2
}
};
struct test_add_broadcast3
struct test_add_broadcast3 : verify_program<test_add_broadcast3>
{
migraphx::program create_program() const
{
......@@ -318,7 +339,7 @@ struct test_add_broadcast3
}
};
struct test_add_broadcast4
struct test_add_broadcast4 : verify_program<test_add_broadcast4>
{
migraphx::program create_program() const
{
......@@ -332,7 +353,7 @@ struct test_add_broadcast4
}
};
struct test_add_broadcast5
struct test_add_broadcast5 : verify_program<test_add_broadcast5>
{
migraphx::program create_program() const
{
......@@ -346,7 +367,7 @@ struct test_add_broadcast5
}
};
struct test_triadd_broadcast
struct test_triadd_broadcast : verify_program<test_triadd_broadcast>
{
migraphx::program create_program() const
{
......@@ -362,7 +383,7 @@ struct test_triadd_broadcast
}
};
struct test_softmax
struct test_softmax : verify_program<test_softmax>
{
migraphx::program create_program() const
{
......@@ -373,7 +394,7 @@ struct test_softmax
}
};
struct test_softmax2
struct test_softmax2 : verify_program<test_softmax2>
{
migraphx::program create_program() const
{
......@@ -385,7 +406,7 @@ struct test_softmax2
}
};
struct test_conv
struct test_conv : verify_program<test_conv>
{
migraphx::program create_program() const
{
......@@ -399,7 +420,7 @@ struct test_conv
}
};
struct test_conv2
struct test_conv2 : verify_program<test_conv2>
{
migraphx::program create_program() const
{
......@@ -413,7 +434,7 @@ struct test_conv2
}
};
struct test_conv_relu
struct test_conv_relu : verify_program<test_conv_relu>
{
migraphx::program create_program() const
{
......@@ -428,7 +449,7 @@ struct test_conv_relu
}
};
struct test_conv_relu_half
struct test_conv_relu_half : verify_program<test_conv_relu_half>
{
migraphx::program create_program() const
{
......@@ -443,7 +464,7 @@ struct test_conv_relu_half
}
};
struct test_add_relu
struct test_add_relu : verify_program<test_add_relu>
{
migraphx::program create_program() const
{
......@@ -456,7 +477,7 @@ struct test_add_relu
}
};
struct test_leaky_relu
struct test_leaky_relu : verify_program<test_leaky_relu>
{
migraphx::program create_program() const
{
......@@ -467,7 +488,7 @@ struct test_leaky_relu
}
};
struct test_conv_pooling
struct test_conv_pooling : verify_program<test_conv_pooling>
{
migraphx::program create_program() const
{
......@@ -483,7 +504,7 @@ struct test_conv_pooling
}
};
struct test_global_avg_pooling
struct test_global_avg_pooling : verify_program<test_global_avg_pooling>
{
migraphx::program create_program() const
{
......@@ -498,7 +519,7 @@ struct test_global_avg_pooling
}
};
struct test_global_max_pooling
struct test_global_max_pooling : verify_program<test_global_max_pooling>
{
migraphx::program create_program() const
{
......@@ -513,7 +534,7 @@ struct test_global_max_pooling
}
};
struct test_gemm
struct test_gemm : verify_program<test_gemm>
{
migraphx::program create_program() const
{
......@@ -525,7 +546,7 @@ struct test_gemm
}
};
struct test_gemm_half
struct test_gemm_half : verify_program<test_gemm_half>
{
migraphx::program create_program() const
{
......@@ -537,7 +558,7 @@ struct test_gemm_half
}
};
struct test_gemm_ld
struct test_gemm_ld // : verify_program<test_gemm_ld>
{
migraphx::program create_program() const
{
......@@ -551,7 +572,7 @@ struct test_gemm_ld
}
};
struct test_gemm_transposeb
struct test_gemm_transposeb : verify_program<test_gemm_transposeb>
{
migraphx::program create_program() const
{
......@@ -564,7 +585,7 @@ struct test_gemm_transposeb
}
};
struct test_gemm_transposea
struct test_gemm_transposea : verify_program<test_gemm_transposea>
{
migraphx::program create_program() const
{
......@@ -577,7 +598,7 @@ struct test_gemm_transposea
}
};
struct test_gemm_transposeab
struct test_gemm_transposeab : verify_program<test_gemm_transposeab>
{
migraphx::program create_program() const
{
......@@ -591,7 +612,7 @@ struct test_gemm_transposeab
}
};
struct test_contiguous
struct test_contiguous : verify_program<test_contiguous>
{
migraphx::program create_program() const
{
......@@ -604,7 +625,7 @@ struct test_contiguous
}
};
struct test_transpose
struct test_transpose : verify_program<test_transpose>
{
migraphx::program create_program() const
{
......@@ -618,7 +639,7 @@ struct test_transpose
}
};
struct test_batchnorm_inference_2
struct test_batchnorm_inference_2 : verify_program<test_batchnorm_inference_2>
{
const size_t width = 14;
const size_t height = 14;
......@@ -641,7 +662,7 @@ struct test_batchnorm_inference_2
}
};
struct test_batchnorm_inference
struct test_batchnorm_inference : verify_program<test_batchnorm_inference>
{
const size_t width = 3;
const size_t height = 3;
......@@ -664,7 +685,7 @@ struct test_batchnorm_inference
}
};
struct test_conv_bn
struct test_conv_bn : verify_program<test_conv_bn>
{
migraphx::program create_program() const
{
......@@ -685,7 +706,7 @@ struct test_conv_bn
}
};
struct test_conv_bn_relu_pooling
struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
{
migraphx::program create_program() const
{
......@@ -709,7 +730,7 @@ struct test_conv_bn_relu_pooling
}
};
struct test_concat
struct test_concat : verify_program<test_concat>
{
migraphx::program create_program() const
{
......@@ -726,7 +747,7 @@ struct test_concat
}
};
struct test_concat2
struct test_concat2 : verify_program<test_concat2>
{
migraphx::program create_program() const
{
......@@ -743,7 +764,7 @@ struct test_concat2
}
};
struct test_concat_relu
struct test_concat_relu : verify_program<test_concat_relu>
{
migraphx::program create_program() const
{
......@@ -810,7 +831,7 @@ void manual_test_concat_relu()
std::cout << result << std::endl;
}
struct test_conv_bn_relu_pooling2
struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2>
{
static migraphx::instruction_ref
add_bn(migraphx::program& p, migraphx::instruction_ref x, std::size_t channels)
......@@ -847,47 +868,4 @@ struct test_conv_bn_relu_pooling2
}
};
int main()
{
verify_program<test_concat>();
verify_program<test_concat2>();
verify_program<test_concat_relu>();
verify_program<test_add>();
verify_program<test_add_half>();
verify_program<test_mul>();
verify_program<test_sin>();
verify_program<test_scale>();
verify_program<test_triadd>();
verify_program<test_triadd2>();
verify_program<test_add_broadcast>();
verify_program<test_add_broadcast2>();
verify_program<test_add_broadcast3>();
verify_program<test_add_broadcast4>();
verify_program<test_add_broadcast5>();
verify_program<test_triadd_broadcast>();
verify_program<test_softmax>();
verify_program<test_softmax2>();
verify_program<test_conv>();
verify_program<test_conv2>();
verify_program<test_conv_relu>();
verify_program<test_conv_relu_half>();
verify_program<test_add_relu>();
verify_program<test_leaky_relu>();
verify_program<test_conv_pooling>();
verify_program<test_global_avg_pooling>();
verify_program<test_global_max_pooling>();
verify_program<test_gemm>();
verify_program<test_gemm_half>();
// verify_program<test_gemm_ld>();
verify_program<test_gemm_transposeb>();
verify_program<test_gemm_transposea>();
verify_program<test_gemm_transposeab>();
verify_program<test_contiguous>();
verify_program<test_transpose>();
verify_program<test_batchnorm_inference>();
verify_program<test_batchnorm_inference_2>();
verify_program<test_conv_bn>();
verify_program<test_conv_bn_relu_pooling>();
verify_program<test_conv_bn_relu_pooling2>();
verify_program<test_slice>();
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -192,10 +192,10 @@ inline void add_test_case(std::string name, std::function<void()> f)
get_test_cases().emplace_back(name, f);
}
struct auto_register
struct auto_register_test_case
{
template <class F>
auto_register(const char* name, F f) noexcept
auto_register_test_case(const char* name, F f) noexcept
{
add_test_case(name, f);
}
......@@ -252,8 +252,8 @@ inline void run(int argc, const char* argv[])
// NOLINTNEXTLINE
#define TEST_CASE_REGISTER(...) \
static test::auto_register TEST_CAT(register_test_case_, __LINE__) = \
test::auto_register(#__VA_ARGS__, &__VA_ARGS__);
static test::auto_register_test_case TEST_CAT(register_test_case_, __LINE__) = \
test::auto_register_test_case(#__VA_ARGS__, &__VA_ARGS__);
// NOLINTNEXTLINE
#define TEST_CASE(...) \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment